1
0
Fork 0

Refactor in Structure Estimator

parallel_struct_est
philpMartin 4 years ago
parent 1b61c70f39
commit 5c6aa186db
  1. 23
      main_package/classes/structure_estimator.py

@ -30,11 +30,8 @@ class StructureEstimator:
def __init__(self, sample_path: sp.SamplePath, exp_test_alfa: float, chi_test_alfa: float): def __init__(self, sample_path: sp.SamplePath, exp_test_alfa: float, chi_test_alfa: float):
self.sample_path = sample_path self.sample_path = sample_path
self.nodes = np.array(self.sample_path.structure.nodes_labels) self.nodes = np.array(self.sample_path.structure.nodes_labels)
#print("NODES", self.nodes)
self.nodes_vals = self.sample_path.structure.nodes_values self.nodes_vals = self.sample_path.structure.nodes_values
self.nodes_indxs = self.sample_path.structure.nodes_indexes self.nodes_indxs = self.sample_path.structure.nodes_indexes
#self.nodes_indxs = np.array(range(0,4))
#print("INDXS", self.nodes_indxs)
self.complete_graph = self.build_complete_graph(self.sample_path.structure.nodes_labels) self.complete_graph = self.build_complete_graph(self.sample_path.structure.nodes_labels)
self.exp_test_sign = exp_test_alfa self.exp_test_sign = exp_test_alfa
self.chi_test_alfa = chi_test_alfa self.chi_test_alfa = chi_test_alfa
@ -81,12 +78,6 @@ class StructureEstimator:
g1 = ng.NetworkGraph(s1) g1 = ng.NetworkGraph(s1)
#g1.init_graph() #g1.init_graph()
g1.fast_init(test_child) g1.fast_init(test_child)
#print("M Vector", g1.transition_scalar_indexing_structure)
#print("Time Vecotr", g1.time_scalar_indexing_strucure)
#print("Time Filter", g1.time_filtering)
#print("M Filter", g1.transition_filtering)
#print("G1 NODES", g1.get_nodes())
#print("G1 Edges", g1.get_edges())
p1 = pe.ParametersEstimator(self.sample_path, g1) p1 = pe.ParametersEstimator(self.sample_path, g1)
#p1.init_sets_cims_container() #p1.init_sets_cims_container()
p1.fast_init(test_child) p1.fast_init(test_child)
@ -130,12 +121,6 @@ class StructureEstimator:
g2 = ng.NetworkGraph(s2) g2 = ng.NetworkGraph(s2)
#g2.init_graph() #g2.init_graph()
g2.fast_init(test_child) g2.fast_init(test_child)
#print("M Vector", g2.transition_scalar_indexing_structure)
#print("Time Vecotr", g2.time_scalar_indexing_strucure)
#print("Time Filter", g2.time_filtering)
#print("M Filter", g2.transition_filtering)
#print("G2 Nodes", g2.get_nodes())
#print("G2 Edges", g2.get_edges())
p2 = pe.ParametersEstimator(self.sample_path, g2) p2 = pe.ParametersEstimator(self.sample_path, g2)
#p2.init_sets_cims_container() #p2.init_sets_cims_container()
p2.fast_init(test_child) p2.fast_init(test_child)
@ -144,10 +129,6 @@ class StructureEstimator:
#if p_set: #if p_set:
#set_p_set = set(p_set) #set_p_set = set(p_set)
self.cache.put(set(p_set), sofc2) self.cache.put(set(p_set), sofc2)
#start = 0
#end = self.sample_path.structure.get_states_number(test_parent)
#print("SOFC2", sofc2.actual_cims)
#print("Sofc2 pcomb", sofc2.p_combs)
for cim1, p_comb in zip(sofc1.actual_cims, sofc1.p_combs): for cim1, p_comb in zip(sofc1.actual_cims, sofc1.p_combs):
#print("GETTING THIS P COMB", p_comb) #print("GETTING THIS P COMB", p_comb)
#if len(parent_set) > 1: #if len(parent_set) > 1:
@ -260,9 +241,5 @@ class StructureEstimator:
def ctpc_algorithm(self): def ctpc_algorithm(self):
ctpc_algo = self.one_iteration_of_CTPC_algorithm ctpc_algo = self.one_iteration_of_CTPC_algorithm
total_vars_numb = self.sample_path.total_variables_count total_vars_numb = self.sample_path.total_variables_count
#for node_id in self.sample_path.structure.list_of_nodes_labels():
#print("TESTING VAR:", node_id)
#self.one_iteration_of_CTPC_algorithm(node_id)
#print(self.complete_graph_frame)
[ctpc_algo(n, total_vars_numb) for n in self.nodes] [ctpc_algo(n, total_vars_numb) for n in self.nodes]