|
|
@ -49,7 +49,7 @@ class StructureEstimator: |
|
|
|
self._complete_graph = self.build_complete_graph(sample_path.structure.nodes_labels) |
|
|
|
self._complete_graph = self.build_complete_graph(sample_path.structure.nodes_labels) |
|
|
|
self._exp_test_sign = exp_test_alfa |
|
|
|
self._exp_test_sign = exp_test_alfa |
|
|
|
self._chi_test_alfa = chi_test_alfa |
|
|
|
self._chi_test_alfa = chi_test_alfa |
|
|
|
self._caches = [Cache() for _ in range(len(self._nodes))] |
|
|
|
#self._caches = [Cache() for _ in range(len(self._nodes))] |
|
|
|
self._result_graph = None |
|
|
|
self._result_graph = None |
|
|
|
|
|
|
|
|
|
|
|
def build_complete_graph(self, node_ids: typing.List) -> nx.DiGraph: |
|
|
|
def build_complete_graph(self, node_ids: typing.List) -> nx.DiGraph: |
|
|
@ -251,6 +251,7 @@ class StructureEstimator: |
|
|
|
"""Compute the CTPC algorithm over the entire net. |
|
|
|
"""Compute the CTPC algorithm over the entire net. |
|
|
|
""" |
|
|
|
""" |
|
|
|
ctpc_algo = StructureEstimator.one_iteration_of_CTPC_algorithm |
|
|
|
ctpc_algo = StructureEstimator.one_iteration_of_CTPC_algorithm |
|
|
|
|
|
|
|
#SHM INIT |
|
|
|
shm_times = multiprocessing.shared_memory.SharedMemory(name='sh_times', create=True, |
|
|
|
shm_times = multiprocessing.shared_memory.SharedMemory(name='sh_times', create=True, |
|
|
|
size=self._times.nbytes) |
|
|
|
size=self._times.nbytes) |
|
|
|
shm_trajectories = multiprocessing.shared_memory.SharedMemory(name='sh_traj', create=True, |
|
|
|
shm_trajectories = multiprocessing.shared_memory.SharedMemory(name='sh_traj', create=True, |
|
|
@ -262,21 +263,26 @@ class StructureEstimator: |
|
|
|
trajectories_arr = np.ndarray(self._trajectories.shape, |
|
|
|
trajectories_arr = np.ndarray(self._trajectories.shape, |
|
|
|
self._trajectories.dtype, shm_trajectories.buf) |
|
|
|
self._trajectories.dtype, shm_trajectories.buf) |
|
|
|
trajectories_arr[:] = self._trajectories[:] |
|
|
|
trajectories_arr[:] = self._trajectories[:] |
|
|
|
total_vars_numb_list = [self._tot_vars_number] * len(self._nodes) |
|
|
|
|
|
|
|
|
|
|
|
nodes = np.copy(self._nodes) |
|
|
|
|
|
|
|
nodes_vals = np.copy(self._nodes_vals) |
|
|
|
|
|
|
|
nodes_numb = len(self._nodes) |
|
|
|
|
|
|
|
caches = [Cache() for _ in range(nodes_numb)] |
|
|
|
|
|
|
|
total_vars_numb_list = [self._tot_vars_number] * nodes_numb |
|
|
|
parents_list = [list(self._complete_graph.predecessors(var_id)) for var_id in self._nodes] |
|
|
|
parents_list = [list(self._complete_graph.predecessors(var_id)) for var_id in self._nodes] |
|
|
|
nodes_array_list = [self._nodes] * len(self._nodes) |
|
|
|
nodes_array_list = [self._nodes] * len(self._nodes) |
|
|
|
nodes_indxs_array_list = [self._nodes_indxs] * len(self._nodes) |
|
|
|
nodes_indxs_array_list = [self._nodes_indxs] * nodes_numb |
|
|
|
nodes_vals_array_list = [self._nodes_vals] * len(self._nodes) |
|
|
|
nodes_vals_array_list = [self._nodes_vals] * nodes_numb |
|
|
|
tests_alfa_dims_list = [(self._exp_test_sign, self._chi_test_alfa)] * len(self._nodes) |
|
|
|
tests_alfa_dims_list = [(self._exp_test_sign, self._chi_test_alfa)] * nodes_numb |
|
|
|
datas_dims_list = [[self._times.shape, self._trajectories.shape]] * len(self._nodes) |
|
|
|
datas_dims_list = [[self._times.shape, self._trajectories.shape]] * nodes_numb |
|
|
|
if multi_processing: |
|
|
|
if multi_processing: |
|
|
|
cpu_count = multiprocessing.cpu_count() |
|
|
|
cpu_count = multiprocessing.cpu_count() |
|
|
|
else: |
|
|
|
else: |
|
|
|
cpu_count = 1 |
|
|
|
cpu_count = 1 |
|
|
|
print("CPU COUNT", cpu_count) |
|
|
|
print("CPU COUNT", cpu_count) |
|
|
|
with multiprocessing.Pool(processes=cpu_count) as pool: |
|
|
|
with multiprocessing.Pool(processes=cpu_count) as pool: |
|
|
|
parent_sets = pool.starmap(ctpc_algo, zip(self._nodes, self._nodes_vals, parents_list, |
|
|
|
parent_sets = pool.starmap(ctpc_algo, zip(nodes, nodes_vals, parents_list, |
|
|
|
self._caches, total_vars_numb_list, |
|
|
|
caches, total_vars_numb_list, |
|
|
|
nodes_array_list, nodes_indxs_array_list, nodes_vals_array_list, |
|
|
|
nodes_array_list, nodes_indxs_array_list, nodes_vals_array_list, |
|
|
|
tests_alfa_dims_list, datas_dims_list)) |
|
|
|
tests_alfa_dims_list, datas_dims_list)) |
|
|
|
|
|
|
|
|
|
|
|