1
0
Fork 0

Removed trajectories copies in ParametersEstimator class

parallel_struct_est
Filippo Martini 4 years ago
parent a589797c8c
commit 81eb5ac0f9
  1. 20
      PyCTBN/PyCTBN/parameters_estimator.py
  2. 14
      PyCTBN/PyCTBN/structure_estimator.py

@ -17,11 +17,11 @@ class ParametersEstimator:
:_single_set_of_cims: the set of cims object that will hold the cims of the node
"""
def __init__(self, times, trajectories, net_graph: NetworkGraph):
def __init__(self, net_graph: NetworkGraph):
"""Constructor Method
"""
self._times = times
self._trajectories = trajectories
#self._times = times
#self._trajectories = trajectories
self._net_graph = net_graph
self._single_set_of_cims = None
@ -35,7 +35,7 @@ class ParametersEstimator:
node_states_number = self._net_graph.get_states_number(node_id)
self._single_set_of_cims = SetOfCims(node_id, p_vals, node_states_number, self._net_graph.p_combs)
def compute_parameters_for_node(self, node_id: str) -> SetOfCims:
def compute_parameters_for_node(self, node_id: str, times, trajectories) -> SetOfCims:
"""Compute the CIMS of the node identified by the label ``node_id``.
:param node_id: the node label
@ -46,20 +46,21 @@ class ParametersEstimator:
node_indx = self._net_graph.get_node_indx(node_id)
state_res_times = self._single_set_of_cims._state_residence_times
transition_matrices = self._single_set_of_cims._transition_matrices
self.compute_state_res_time_for_node(node_indx, self._times,
self._trajectories,
self.compute_state_res_time_for_node(node_indx, times,
trajectories,
self._net_graph.time_filtering,
self._net_graph.time_scalar_indexing_strucure,
state_res_times)
self.compute_state_transitions_for_a_node(node_indx,
self._trajectories,
trajectories,
self._net_graph.transition_filtering,
self._net_graph.transition_scalar_indexing_structure,
transition_matrices)
self._single_set_of_cims.build_cims(state_res_times, transition_matrices)
return self._single_set_of_cims
def compute_state_res_time_for_node(self, node_indx: int, times: np.ndarray, trajectory: np.ndarray,
@staticmethod
def compute_state_res_time_for_node(node_indx: int, times: np.ndarray, trajectory: np.ndarray,
cols_filter: np.ndarray, scalar_indexes_struct: np.ndarray,
T: np.ndarray) -> None:
"""Compute the state residence times for a node and fill the matrix ``T`` with the results
@ -82,7 +83,8 @@ class ParametersEstimator:
times,
minlength=scalar_indexes_struct[-1]).reshape(-1, T.shape[1])
def compute_state_transitions_for_a_node(self, node_indx: int, trajectory: np.ndarray, cols_filter: np.ndarray,
@staticmethod
def compute_state_transitions_for_a_node(node_indx: int, trajectory: np.ndarray, cols_filter: np.ndarray,
scalar_indexing: np.ndarray, M: np.ndarray):
"""Compute the state residence times for a node and fill the matrices ``M`` with the results.

@ -124,9 +124,9 @@ class StructureEstimator:
s1 = Structure(l1, indxs1, vals1, eds1, tot_vars_count)
g1 = NetworkGraph(s1)
g1.fast_init(test_child)
p1 = ParametersEstimator(self._times, self._trajectories, g1)
p1 = ParametersEstimator(g1)
p1.fast_init(test_child)
sofc1 = p1.compute_parameters_for_node(test_child)
sofc1 = p1.compute_parameters_for_node(test_child, self._times, self._trajectories)
cache.put(set(p_set), sofc1)
sofc2 = None
p_set.insert(0, test_parent)
@ -142,9 +142,9 @@ class StructureEstimator:
s2 = Structure(l2, indxs2, vals2, eds2, tot_vars_count)
g2 = NetworkGraph(s2)
g2.fast_init(test_child)
p2 = ParametersEstimator(self._times, self._trajectories, g2)
p2 = ParametersEstimator(g2)
p2.fast_init(test_child)
sofc2 = p2.compute_parameters_for_node(test_child)
sofc2 = p2.compute_parameters_for_node(test_child, self._times, self._trajectories)
cache.put(set(p_set), sofc2)
for cim1, p_comb in zip(sofc1.actual_cims, sofc1.p_combs):
cond_cims = sofc2.filter_cims_with_mask(cims_filter, p_comb)
@ -204,7 +204,7 @@ class StructureEstimator:
u = list(self._complete_graph.predecessors(var_id))
#child_states_numb = self._sample_path.structure.get_states_number(var_id)
child_states_numb = self._nodes_vals[np.where(self._nodes == var_id)][0]
print("Child States Numb", child_states_numb)
#print("Child States Numb", child_states_numb)
b = 0
while b < len(u):
parent_indx = 0
@ -221,7 +221,7 @@ class StructureEstimator:
if not removed:
parent_indx += 1
b += 1
print("Parent set for node ", var_id, " : ", u)
#print("Parent set for node ", var_id, " : ", u)
cache.clear()
return u
@ -251,7 +251,7 @@ class StructureEstimator:
total_vars_numb_list = [self._tot_vars_number] * len(self._nodes)
cpu_count = multiprocessing.cpu_count()
print("CPU COUNT", cpu_count)
with multiprocessing.Pool(processes=cpu_count) as pool:
with multiprocessing.Pool(processes=1) as pool:
parent_sets = pool.starmap(ctpc_algo, zip(self._nodes, self._caches, total_vars_numb_list))
#parent_sets = [ctpc_algo(n, c, total_vars_numb) for n, c in tqdm(zip(self._nodes, self._caches))]
print(parent_sets)