1
0
Fork 0

Removed trajectories copies in ParametersEstimator class

parallel_struct_est
Filippo Martini 4 years ago
parent a589797c8c
commit 81eb5ac0f9
  1. 20
      PyCTBN/PyCTBN/parameters_estimator.py
  2. 14
      PyCTBN/PyCTBN/structure_estimator.py

@ -17,11 +17,11 @@ class ParametersEstimator:
:_single_set_of_cims: the set of cims object that will hold the cims of the node :_single_set_of_cims: the set of cims object that will hold the cims of the node
""" """
def __init__(self, times, trajectories, net_graph: NetworkGraph): def __init__(self, net_graph: NetworkGraph):
"""Constructor Method """Constructor Method
""" """
self._times = times #self._times = times
self._trajectories = trajectories #self._trajectories = trajectories
self._net_graph = net_graph self._net_graph = net_graph
self._single_set_of_cims = None self._single_set_of_cims = None
@ -35,7 +35,7 @@ class ParametersEstimator:
node_states_number = self._net_graph.get_states_number(node_id) node_states_number = self._net_graph.get_states_number(node_id)
self._single_set_of_cims = SetOfCims(node_id, p_vals, node_states_number, self._net_graph.p_combs) self._single_set_of_cims = SetOfCims(node_id, p_vals, node_states_number, self._net_graph.p_combs)
def compute_parameters_for_node(self, node_id: str) -> SetOfCims: def compute_parameters_for_node(self, node_id: str, times, trajectories) -> SetOfCims:
"""Compute the CIMS of the node identified by the label ``node_id``. """Compute the CIMS of the node identified by the label ``node_id``.
:param node_id: the node label :param node_id: the node label
@ -46,20 +46,21 @@ class ParametersEstimator:
node_indx = self._net_graph.get_node_indx(node_id) node_indx = self._net_graph.get_node_indx(node_id)
state_res_times = self._single_set_of_cims._state_residence_times state_res_times = self._single_set_of_cims._state_residence_times
transition_matrices = self._single_set_of_cims._transition_matrices transition_matrices = self._single_set_of_cims._transition_matrices
self.compute_state_res_time_for_node(node_indx, self._times, self.compute_state_res_time_for_node(node_indx, times,
self._trajectories, trajectories,
self._net_graph.time_filtering, self._net_graph.time_filtering,
self._net_graph.time_scalar_indexing_strucure, self._net_graph.time_scalar_indexing_strucure,
state_res_times) state_res_times)
self.compute_state_transitions_for_a_node(node_indx, self.compute_state_transitions_for_a_node(node_indx,
self._trajectories, trajectories,
self._net_graph.transition_filtering, self._net_graph.transition_filtering,
self._net_graph.transition_scalar_indexing_structure, self._net_graph.transition_scalar_indexing_structure,
transition_matrices) transition_matrices)
self._single_set_of_cims.build_cims(state_res_times, transition_matrices) self._single_set_of_cims.build_cims(state_res_times, transition_matrices)
return self._single_set_of_cims return self._single_set_of_cims
def compute_state_res_time_for_node(self, node_indx: int, times: np.ndarray, trajectory: np.ndarray, @staticmethod
def compute_state_res_time_for_node(node_indx: int, times: np.ndarray, trajectory: np.ndarray,
cols_filter: np.ndarray, scalar_indexes_struct: np.ndarray, cols_filter: np.ndarray, scalar_indexes_struct: np.ndarray,
T: np.ndarray) -> None: T: np.ndarray) -> None:
"""Compute the state residence times for a node and fill the matrix ``T`` with the results """Compute the state residence times for a node and fill the matrix ``T`` with the results
@ -82,7 +83,8 @@ class ParametersEstimator:
times, times,
minlength=scalar_indexes_struct[-1]).reshape(-1, T.shape[1]) minlength=scalar_indexes_struct[-1]).reshape(-1, T.shape[1])
def compute_state_transitions_for_a_node(self, node_indx: int, trajectory: np.ndarray, cols_filter: np.ndarray, @staticmethod
def compute_state_transitions_for_a_node(node_indx: int, trajectory: np.ndarray, cols_filter: np.ndarray,
scalar_indexing: np.ndarray, M: np.ndarray): scalar_indexing: np.ndarray, M: np.ndarray):
"""Compute the state residence times for a node and fill the matrices ``M`` with the results. """Compute the state residence times for a node and fill the matrices ``M`` with the results.

@ -124,9 +124,9 @@ class StructureEstimator:
s1 = Structure(l1, indxs1, vals1, eds1, tot_vars_count) s1 = Structure(l1, indxs1, vals1, eds1, tot_vars_count)
g1 = NetworkGraph(s1) g1 = NetworkGraph(s1)
g1.fast_init(test_child) g1.fast_init(test_child)
p1 = ParametersEstimator(self._times, self._trajectories, g1) p1 = ParametersEstimator(g1)
p1.fast_init(test_child) p1.fast_init(test_child)
sofc1 = p1.compute_parameters_for_node(test_child) sofc1 = p1.compute_parameters_for_node(test_child, self._times, self._trajectories)
cache.put(set(p_set), sofc1) cache.put(set(p_set), sofc1)
sofc2 = None sofc2 = None
p_set.insert(0, test_parent) p_set.insert(0, test_parent)
@ -142,9 +142,9 @@ class StructureEstimator:
s2 = Structure(l2, indxs2, vals2, eds2, tot_vars_count) s2 = Structure(l2, indxs2, vals2, eds2, tot_vars_count)
g2 = NetworkGraph(s2) g2 = NetworkGraph(s2)
g2.fast_init(test_child) g2.fast_init(test_child)
p2 = ParametersEstimator(self._times, self._trajectories, g2) p2 = ParametersEstimator(g2)
p2.fast_init(test_child) p2.fast_init(test_child)
sofc2 = p2.compute_parameters_for_node(test_child) sofc2 = p2.compute_parameters_for_node(test_child, self._times, self._trajectories)
cache.put(set(p_set), sofc2) cache.put(set(p_set), sofc2)
for cim1, p_comb in zip(sofc1.actual_cims, sofc1.p_combs): for cim1, p_comb in zip(sofc1.actual_cims, sofc1.p_combs):
cond_cims = sofc2.filter_cims_with_mask(cims_filter, p_comb) cond_cims = sofc2.filter_cims_with_mask(cims_filter, p_comb)
@ -204,7 +204,7 @@ class StructureEstimator:
u = list(self._complete_graph.predecessors(var_id)) u = list(self._complete_graph.predecessors(var_id))
#child_states_numb = self._sample_path.structure.get_states_number(var_id) #child_states_numb = self._sample_path.structure.get_states_number(var_id)
child_states_numb = self._nodes_vals[np.where(self._nodes == var_id)][0] child_states_numb = self._nodes_vals[np.where(self._nodes == var_id)][0]
print("Child States Numb", child_states_numb) #print("Child States Numb", child_states_numb)
b = 0 b = 0
while b < len(u): while b < len(u):
parent_indx = 0 parent_indx = 0
@ -221,7 +221,7 @@ class StructureEstimator:
if not removed: if not removed:
parent_indx += 1 parent_indx += 1
b += 1 b += 1
print("Parent set for node ", var_id, " : ", u) #print("Parent set for node ", var_id, " : ", u)
cache.clear() cache.clear()
return u return u
@ -251,7 +251,7 @@ class StructureEstimator:
total_vars_numb_list = [self._tot_vars_number] * len(self._nodes) total_vars_numb_list = [self._tot_vars_number] * len(self._nodes)
cpu_count = multiprocessing.cpu_count() cpu_count = multiprocessing.cpu_count()
print("CPU COUNT", cpu_count) print("CPU COUNT", cpu_count)
with multiprocessing.Pool(processes=cpu_count) as pool: with multiprocessing.Pool(processes=1) as pool:
parent_sets = pool.starmap(ctpc_algo, zip(self._nodes, self._caches, total_vars_numb_list)) parent_sets = pool.starmap(ctpc_algo, zip(self._nodes, self._caches, total_vars_numb_list))
#parent_sets = [ctpc_algo(n, c, total_vars_numb) for n, c in tqdm(zip(self._nodes, self._caches))] #parent_sets = [ctpc_algo(n, c, total_vars_numb) for n, c in tqdm(zip(self._nodes, self._caches))]
print(parent_sets) print(parent_sets)