diff --git a/PyCTBN/PyCTBN/structure_estimator.py b/PyCTBN/PyCTBN/structure_estimator.py index 49735e0..0e8af09 100644 --- a/PyCTBN/PyCTBN/structure_estimator.py +++ b/PyCTBN/PyCTBN/structure_estimator.py @@ -247,7 +247,7 @@ class StructureEstimator: list_without_test_parent.remove(parent_label) return map(list, itertools.combinations(list_without_test_parent, size)) - def ctpc_algorithm(self) -> None: + def ctpc_algorithm(self, multi_processing: bool) -> None: """Compute the CTPC algorithm over the entire net. """ ctpc_algo = StructureEstimator.one_iteration_of_CTPC_algorithm @@ -269,10 +269,10 @@ class StructureEstimator: nodes_vals_array_list = [self._nodes_vals] * len(self._nodes) tests_alfa_dims_list = [(self._exp_test_sign, self._chi_test_alfa)] * len(self._nodes) datas_dims_list = [[self._times.shape, self._trajectories.shape]] * len(self._nodes) - #if multi_processing: - cpu_count = multiprocessing.cpu_count() - #else: - #cpu_count = 1 + if multi_processing: + cpu_count = multiprocessing.cpu_count() + else: + cpu_count = 1 print("CPU COUNT", cpu_count) with multiprocessing.Pool(processes=cpu_count) as pool: parent_sets = pool.starmap(ctpc_algo, zip(self._nodes, self._nodes_vals, parents_list, diff --git a/PyCTBN/tests/test_structure_estimator.py b/PyCTBN/tests/test_structure_estimator.py index f1c5265..6e642fb 100644 --- a/PyCTBN/tests/test_structure_estimator.py +++ b/PyCTBN/tests/test_structure_estimator.py @@ -79,14 +79,14 @@ class TestStructureEstimator(unittest.TestCase): def test_time(self): se1 = StructureEstimator(self.s1, 0.1, 0.1) lp = LineProfiler() - MULTI_PROCESSING = False ###### MODIFICARE QUI SINGLE/MULTI PROCESS + MULTI_PROCESSING = True ###### MODIFICARE QUI SINGLE/MULTI PROCESS lp_wrapper = lp(se1.ctpc_algorithm) lp_wrapper(MULTI_PROCESSING) lp.print_stats() #paralell_time = timeit.timeit(se1.ctpc_algorithm, MULTI_PROCESSING, number=1) #print("EXEC TIME:", paralell_time) print(se1._result_graph.edges) - print(self.s1.structure.edges) + #print(self.s1.structure.edges) for ed in self.s1.structure.edges: self.assertIn(tuple(ed), se1._result_graph.edges) tuples_edges = [tuple(rec) for rec in self.s1.structure.edges]