diff --git a/main_package/classes/estimators/structure_constraint_based_estimator.py b/main_package/classes/estimators/structure_constraint_based_estimator.py index fba11e9..174ebf9 100644 --- a/main_package/classes/estimators/structure_constraint_based_estimator.py +++ b/main_package/classes/estimators/structure_constraint_based_estimator.py @@ -19,6 +19,8 @@ import structure_graph.sample_path as sp import structure_graph.structure as st import optimizers.constraint_based_optimizer as optimizer +import concurrent.futures + from utility.decorators import timing,timing_write import multiprocessing @@ -243,11 +245,11 @@ class StructureConstraintBasedEstimator(se.StructureEstimator): 'Estimate the best parents for each node' #with multiprocessing.Pool(processes=cpu_count) as pool: - with get_context("spawn").Pool(processes=cpu_count) as pool: - - list_edges_partial = pool.starmap(ctpc_algo, zip( + #with get_context("spawn").Pool(processes=cpu_count) as pool: + with concurrent.futures.ProcessPoolExecutor(max_workers=cpu_count) as executor: + list_edges_partial = executor.map(ctpc_algo, self.nodes, - total_vars_numb_array)) + total_vars_numb_array) #list_edges_partial = [ctpc_algo(n,total_vars_numb) for n in self.nodes] return set(itertools.chain.from_iterable(list_edges_partial)) diff --git a/main_package/tests/estimators/test_structure_constraint_based_estimator_server.py b/main_package/tests/estimators/test_structure_constraint_based_estimator_server.py new file mode 100644 index 0000000..1f37831 --- /dev/null +++ b/main_package/tests/estimators/test_structure_constraint_based_estimator_server.py @@ -0,0 +1,76 @@ +import sys +sys.path.append("../../classes/") +import glob +import math +import os +import unittest + +import networkx as nx +import numpy as np +import psutil +from line_profiler import LineProfiler + +import utility.cache as ch +import structure_graph.sample_path as sp +import estimators.structure_constraint_based_estimator as se +import utility.json_importer as ji + +from multiprocessing import set_start_method + +import copy + + +class TestStructureConstraintBasedEstimator(unittest.TestCase): + @classmethod + def setUpClass(cls): + pass + + def test_structure(self): + #cls.read_files = glob.glob(os.path.join('../../data', "*.json")) + self.importer = ji.JsonImporter("/home/alessandro/Documents/ctbn_cba/data/networks_and_trajectories_ternary_data_15.json", 'samples', 'dyn.str', 'variables', 'Time', 'Name') + self.s1 = sp.SamplePath(self.importer) + self.s1.build_trajectories() + self.s1.build_structure() + + true_edges = copy.deepcopy(self.s1.structure.edges) + true_edges = set(map(tuple, true_edges)) + + + se1 = se.StructureConstraintBasedEstimator(self.s1,0.1,0.1) + edges = se1.estimate_structure( + max_parents = None, + iterations_number = 100, + patience = 35, + tabu_length = 15, + tabu_rules_duration = 15, + optimizer = 'tabu', + disable_multiprocessing=False + ) + + + self.importer = ji.JsonImporter("/home/alessandro/Documents/ctbn_cba/data/networks_and_trajectories_ternary_data_15.json", 'samples', 'dyn.str', 'variables', 'Time', 'Name') + self.s1 = sp.SamplePath(self.importer) + self.s1.build_trajectories() + self.s1.build_structure() + + true_edges = copy.deepcopy(self.s1.structure.edges) + true_edges = set(map(tuple, true_edges)) + + + se1 = se.StructureConstraintBasedEstimator(self.s1,0.1,0.1) + edges = se1.estimate_structure( + max_parents = None, + iterations_number = 100, + patience = 35, + tabu_length = 15, + tabu_rules_duration = 15, + optimizer = 'tabu', + disable_multiprocessing=True + ) + + + + self.assertEqual(edges, true_edges) + +if __name__ == '__main__': + unittest.main()