diff --git a/PyCTBN/PyCTBN/structure_estimator.py b/PyCTBN/PyCTBN/structure_estimator.py index 0e8af09..82dde17 100644 --- a/PyCTBN/PyCTBN/structure_estimator.py +++ b/PyCTBN/PyCTBN/structure_estimator.py @@ -49,7 +49,7 @@ class StructureEstimator: self._complete_graph = self.build_complete_graph(sample_path.structure.nodes_labels) self._exp_test_sign = exp_test_alfa self._chi_test_alfa = chi_test_alfa - self._caches = [Cache() for _ in range(len(self._nodes))] + #self._caches = [Cache() for _ in range(len(self._nodes))] self._result_graph = None def build_complete_graph(self, node_ids: typing.List) -> nx.DiGraph: @@ -251,6 +251,7 @@ class StructureEstimator: """Compute the CTPC algorithm over the entire net. """ ctpc_algo = StructureEstimator.one_iteration_of_CTPC_algorithm + #SHM INIT shm_times = multiprocessing.shared_memory.SharedMemory(name='sh_times', create=True, size=self._times.nbytes) shm_trajectories = multiprocessing.shared_memory.SharedMemory(name='sh_traj', create=True, @@ -262,21 +263,26 @@ class StructureEstimator: trajectories_arr = np.ndarray(self._trajectories.shape, self._trajectories.dtype, shm_trajectories.buf) trajectories_arr[:] = self._trajectories[:] - total_vars_numb_list = [self._tot_vars_number] * len(self._nodes) + + nodes = np.copy(self._nodes) + nodes_vals = np.copy(self._nodes_vals) + nodes_numb = len(self._nodes) + caches = [Cache() for _ in range(nodes_numb)] + total_vars_numb_list = [self._tot_vars_number] * nodes_numb parents_list = [list(self._complete_graph.predecessors(var_id)) for var_id in self._nodes] nodes_array_list = [self._nodes] * len(self._nodes) - nodes_indxs_array_list = [self._nodes_indxs] * len(self._nodes) - nodes_vals_array_list = [self._nodes_vals] * len(self._nodes) - tests_alfa_dims_list = [(self._exp_test_sign, self._chi_test_alfa)] * len(self._nodes) - datas_dims_list = [[self._times.shape, self._trajectories.shape]] * len(self._nodes) + nodes_indxs_array_list = [self._nodes_indxs] * nodes_numb + nodes_vals_array_list = [self._nodes_vals] * nodes_numb + tests_alfa_dims_list = [(self._exp_test_sign, self._chi_test_alfa)] * nodes_numb + datas_dims_list = [[self._times.shape, self._trajectories.shape]] * nodes_numb if multi_processing: cpu_count = multiprocessing.cpu_count() else: cpu_count = 1 print("CPU COUNT", cpu_count) with multiprocessing.Pool(processes=cpu_count) as pool: - parent_sets = pool.starmap(ctpc_algo, zip(self._nodes, self._nodes_vals, parents_list, - self._caches, total_vars_numb_list, + parent_sets = pool.starmap(ctpc_algo, zip(nodes, nodes_vals, parents_list, + caches, total_vars_numb_list, nodes_array_list, nodes_indxs_array_list, nodes_vals_array_list, tests_alfa_dims_list, datas_dims_list))