1
0
Fork 0

Add refactors in Structure Estimator class

parallel_struct_est
philpMartin 4 years ago
parent f373a9fffc
commit a020b9c004
  1. 13
      main_package/classes/network_graph.py
  2. 4
      main_package/classes/structure.py
  3. 12
      main_package/classes/structure_estimator.py
  4. 2
      main_package/tests/test_structure_estimator.py

@ -3,7 +3,6 @@ import networkx as nx
import numpy as np
class NetworkGraph():
"""
Rappresenta il grafo che contiene i nodi e gli archi presenti nell'oggetto Structure graph_struct.
@ -196,16 +195,4 @@ class NetworkGraph():
def transition_filtering(self):
return self._transition_filtering
"""def remove_node(self, node_id):
node_indx = self.get_node_indx(node_id)
self.graph_struct.remove_node(node_id)
self.graph.remove_node(node_id)
del self._fancy_indexing[node_indx]
del self._time_filtering[node_indx]
del self._nodes_labels[node_indx]
del self._transition_scalar_indexing_structure[node_indx]
del self._transition_filtering[node_indx]
del self._time_scalar_indexing_structure[node_indx]
del self.aggregated_info_about_nodes_parents[node_indx]
del self._nodes_indexes[node_indx]"""

@ -66,7 +66,3 @@ class Structure:
self.variables_frame.equals(other.variables_frame)
return NotImplemented
"""def remove_node(self, node_id):
self.variables_frame = self.variables_frame[self.variables_frame.Name != node_id]
self.structure_frame = self.structure_frame[(self.structure_frame.From != node_id) &
(self.structure_frame.To != node_id)]"""

@ -15,7 +15,6 @@ class StructureEstimator:
def __init__(self, sample_path, exp_test_alfa, chi_test_alfa):
self.sample_path = sample_path
#self.complete_graph_frame = self.build_complete_graph_frame(self.sample_path.structure.list_of_nodes_labels())
self.nodes = np.array(self.sample_path.structure.list_of_nodes_labels())
self.nodes_vals = self.sample_path.structure.nodes_vals_arr
self.nodes_indxs = self.sample_path.structure.nodes_indexes_arr
@ -34,7 +33,7 @@ class StructureEstimator:
complete_graph.add_nodes_from(node_ids)
complete_graph.add_edges_from(itertools.permutations(node_ids, 2))
return complete_graph
#TODO Tutti i valori che riguardano il test child possono essere settati una volta sola
def complete_test(self, test_parent, test_child, parent_set, child_states_numb, tot_vars_count):
p_set = parent_set[:]
complete_info = parent_set[:]
@ -50,7 +49,6 @@ class StructureEstimator:
indxs1 = self.nodes_indxs[bool_mask1]
vals1 = self.nodes_vals[bool_mask1]
eds1 = list(itertools.product(parent_set,test_child))
#TODO il numero di variabili puo essere passato dall'esterno
s1 = st.Structure(l1, indxs1, vals1, eds1, tot_vars_count)
g1 = ng.NetworkGraph(s1)
g1.init_graph()
@ -116,10 +114,10 @@ class StructureEstimator:
C1 = cim1.cim
C2 = cim2.cim
F_stats = C2.diagonal() / C1.diagonal()
#child_states_numb = self.sample_path.structure.get_states_number(tested_child)
exp_alfa = self.exp_test_sign
for val in range(0, child_states_numb):
if F_stats[val] < f_dist.ppf(self.exp_test_sign / 2, r1s[val], r2s[val]) or \
F_stats[val] > f_dist.ppf(1 - self.exp_test_sign / 2, r1s[val], r2s[val]):
if F_stats[val] < f_dist.ppf(exp_alfa / 2, r1s[val], r2s[val]) or \
F_stats[val] > f_dist.ppf(1 - exp_alfa / 2, r1s[val], r2s[val]):
#print("CONDITIONALLY DEPENDENT EXP")
return False
#M1_no_diag = self.remove_diagonal_elements(cim1.state_transition_matrix)
@ -151,7 +149,7 @@ class StructureEstimator:
return True
def one_iteration_of_CTPC_algorithm(self, var_id, tot_vars_count):
print("TESTING VAR", var_id)
#print("TESTING VAR", var_id)
u = list(self.complete_graph.predecessors(var_id))
#tests_parents_numb = len(u)
#complete_frame = self.complete_graph_frame

@ -30,7 +30,7 @@ class TestStructureEstimator(unittest.TestCase):
lp_wrapper()
lp.print_stats()
#se1.ctpc_algorithm()
#print(se1.complete_graph.edges)
print(se1.complete_graph.edges)
def aux_test_complete_test(self, estimator, test_par, test_child, p_set):
estimator.complete_test(test_par, test_child, p_set)