|
|
@ -8,6 +8,8 @@ import numpy as np |
|
|
|
from networkx.readwrite import json_graph |
|
|
|
from networkx.readwrite import json_graph |
|
|
|
from scipy.stats import chi2 as chi2_dist |
|
|
|
from scipy.stats import chi2 as chi2_dist |
|
|
|
from scipy.stats import f as f_dist |
|
|
|
from scipy.stats import f as f_dist |
|
|
|
|
|
|
|
import multiprocessing |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .cache import Cache |
|
|
|
from .cache import Cache |
|
|
|
from .conditional_intensity_matrix import ConditionalIntensityMatrix |
|
|
|
from .conditional_intensity_matrix import ConditionalIntensityMatrix |
|
|
@ -43,7 +45,8 @@ class StructureEstimator: |
|
|
|
self._complete_graph = self.build_complete_graph(self._sample_path.structure.nodes_labels) |
|
|
|
self._complete_graph = self.build_complete_graph(self._sample_path.structure.nodes_labels) |
|
|
|
self._exp_test_sign = exp_test_alfa |
|
|
|
self._exp_test_sign = exp_test_alfa |
|
|
|
self._chi_test_alfa = chi_test_alfa |
|
|
|
self._chi_test_alfa = chi_test_alfa |
|
|
|
self._cache = Cache() |
|
|
|
self._caches = [Cache() for _ in range(len(self._nodes))] |
|
|
|
|
|
|
|
self._result_graph = None |
|
|
|
|
|
|
|
|
|
|
|
def build_complete_graph(self, node_ids: typing.List) -> nx.DiGraph: |
|
|
|
def build_complete_graph(self, node_ids: typing.List) -> nx.DiGraph: |
|
|
|
"""Builds a complete directed graph (no self loops) given the nodes labels in the list ``node_ids``: |
|
|
|
"""Builds a complete directed graph (no self loops) given the nodes labels in the list ``node_ids``: |
|
|
@ -58,8 +61,21 @@ class StructureEstimator: |
|
|
|
complete_graph.add_edges_from(itertools.permutations(node_ids, 2)) |
|
|
|
complete_graph.add_edges_from(itertools.permutations(node_ids, 2)) |
|
|
|
return complete_graph |
|
|
|
return complete_graph |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_edges_for_node(self, child_id: str, parents_ids: typing.List): |
|
|
|
|
|
|
|
edges = [(parent, child_id) for parent in parents_ids] |
|
|
|
|
|
|
|
return edges |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_result_graph(self, nodes_ids: typing.List, parent_sets: typing.List): |
|
|
|
|
|
|
|
edges = [] |
|
|
|
|
|
|
|
for node_id, parent_set in zip(nodes_ids, parent_sets): |
|
|
|
|
|
|
|
edges += self.build_edges_for_node(node_id, parent_set) |
|
|
|
|
|
|
|
result_graph = nx.DiGraph() |
|
|
|
|
|
|
|
result_graph.add_nodes_from(nodes_ids) |
|
|
|
|
|
|
|
result_graph.add_edges_from(edges) |
|
|
|
|
|
|
|
return result_graph |
|
|
|
|
|
|
|
|
|
|
|
def complete_test(self, test_parent: str, test_child: str, parent_set: typing.List, child_states_numb: int, |
|
|
|
def complete_test(self, test_parent: str, test_child: str, parent_set: typing.List, child_states_numb: int, |
|
|
|
tot_vars_count: int) -> bool: |
|
|
|
tot_vars_count: int, cache: Cache) -> bool: |
|
|
|
"""Performs a complete independence test on the directed graphs G1 = {test_child U parent_set} |
|
|
|
"""Performs a complete independence test on the directed graphs G1 = {test_child U parent_set} |
|
|
|
G2 = {G1 U test_parent} (added as an additional parent of the test_child). |
|
|
|
G2 = {G1 U test_parent} (added as an additional parent of the test_child). |
|
|
|
Generates all the necessary structures and datas to perform the tests. |
|
|
|
Generates all the necessary structures and datas to perform the tests. |
|
|
@ -85,7 +101,7 @@ class StructureEstimator: |
|
|
|
parents = np.append(parents, test_parent) |
|
|
|
parents = np.append(parents, test_parent) |
|
|
|
sorted_parents = self._nodes[np.isin(self._nodes, parents)] |
|
|
|
sorted_parents = self._nodes[np.isin(self._nodes, parents)] |
|
|
|
cims_filter = sorted_parents != test_parent |
|
|
|
cims_filter = sorted_parents != test_parent |
|
|
|
sofc1 = self._cache.find(set(p_set)) |
|
|
|
sofc1 = cache.find(set(p_set)) |
|
|
|
|
|
|
|
|
|
|
|
if not sofc1: |
|
|
|
if not sofc1: |
|
|
|
bool_mask1 = np.isin(self._nodes, complete_info) |
|
|
|
bool_mask1 = np.isin(self._nodes, complete_info) |
|
|
@ -99,11 +115,11 @@ class StructureEstimator: |
|
|
|
p1 = ParametersEstimator(self._sample_path.trajectories, g1) |
|
|
|
p1 = ParametersEstimator(self._sample_path.trajectories, g1) |
|
|
|
p1.fast_init(test_child) |
|
|
|
p1.fast_init(test_child) |
|
|
|
sofc1 = p1.compute_parameters_for_node(test_child) |
|
|
|
sofc1 = p1.compute_parameters_for_node(test_child) |
|
|
|
self._cache.put(set(p_set), sofc1) |
|
|
|
cache.put(set(p_set), sofc1) |
|
|
|
sofc2 = None |
|
|
|
sofc2 = None |
|
|
|
p_set.insert(0, test_parent) |
|
|
|
p_set.insert(0, test_parent) |
|
|
|
if p_set: |
|
|
|
if p_set: |
|
|
|
sofc2 = self._cache.find(set(p_set)) |
|
|
|
sofc2 = cache.find(set(p_set)) |
|
|
|
if not sofc2: |
|
|
|
if not sofc2: |
|
|
|
complete_info.append(test_parent) |
|
|
|
complete_info.append(test_parent) |
|
|
|
bool_mask2 = np.isin(self._nodes, complete_info) |
|
|
|
bool_mask2 = np.isin(self._nodes, complete_info) |
|
|
@ -117,7 +133,7 @@ class StructureEstimator: |
|
|
|
p2 = ParametersEstimator(self._sample_path.trajectories, g2) |
|
|
|
p2 = ParametersEstimator(self._sample_path.trajectories, g2) |
|
|
|
p2.fast_init(test_child) |
|
|
|
p2.fast_init(test_child) |
|
|
|
sofc2 = p2.compute_parameters_for_node(test_child) |
|
|
|
sofc2 = p2.compute_parameters_for_node(test_child) |
|
|
|
self._cache.put(set(p_set), sofc2) |
|
|
|
cache.put(set(p_set), sofc2) |
|
|
|
for cim1, p_comb in zip(sofc1.actual_cims, sofc1.p_combs): |
|
|
|
for cim1, p_comb in zip(sofc1.actual_cims, sofc1.p_combs): |
|
|
|
cond_cims = sofc2.filter_cims_with_mask(cims_filter, p_comb) |
|
|
|
cond_cims = sofc2.filter_cims_with_mask(cims_filter, p_comb) |
|
|
|
for cim2 in cond_cims: |
|
|
|
for cim2 in cond_cims: |
|
|
@ -165,7 +181,7 @@ class StructureEstimator: |
|
|
|
return False |
|
|
|
return False |
|
|
|
return True |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def one_iteration_of_CTPC_algorithm(self, var_id: str, tot_vars_count: int) -> None: |
|
|
|
def one_iteration_of_CTPC_algorithm(self, var_id: str, cache: Cache, tot_vars_count: int) -> typing.List: |
|
|
|
"""Performs an iteration of the CTPC algorithm using the node ``var_id`` as ``test_child``. |
|
|
|
"""Performs an iteration of the CTPC algorithm using the node ``var_id`` as ``test_child``. |
|
|
|
|
|
|
|
|
|
|
|
:param var_id: the node label of the test child |
|
|
|
:param var_id: the node label of the test child |
|
|
@ -183,15 +199,17 @@ class StructureEstimator: |
|
|
|
S = self.generate_possible_sub_sets_of_size(u, b, u[parent_indx]) |
|
|
|
S = self.generate_possible_sub_sets_of_size(u, b, u[parent_indx]) |
|
|
|
test_parent = u[parent_indx] |
|
|
|
test_parent = u[parent_indx] |
|
|
|
for parents_set in S: |
|
|
|
for parents_set in S: |
|
|
|
if self.complete_test(test_parent, var_id, parents_set, child_states_numb, tot_vars_count): |
|
|
|
if self.complete_test(test_parent, var_id, parents_set, child_states_numb, tot_vars_count, cache): |
|
|
|
self._complete_graph.remove_edge(test_parent, var_id) |
|
|
|
#self._complete_graph.remove_edge(test_parent, var_id) |
|
|
|
u.remove(test_parent) |
|
|
|
u.remove(test_parent) |
|
|
|
removed = True |
|
|
|
removed = True |
|
|
|
break |
|
|
|
break |
|
|
|
if not removed: |
|
|
|
if not removed: |
|
|
|
parent_indx += 1 |
|
|
|
parent_indx += 1 |
|
|
|
b += 1 |
|
|
|
b += 1 |
|
|
|
self._cache.clear() |
|
|
|
print("Parent set for node ", var_id, " : ", u) |
|
|
|
|
|
|
|
cache.clear() |
|
|
|
|
|
|
|
return u |
|
|
|
|
|
|
|
|
|
|
|
def generate_possible_sub_sets_of_size(self, u: typing.List, size: int, parent_label: str) -> \ |
|
|
|
def generate_possible_sub_sets_of_size(self, u: typing.List, size: int, parent_label: str) -> \ |
|
|
|
typing.Iterator: |
|
|
|
typing.Iterator: |
|
|
@ -216,13 +234,20 @@ class StructureEstimator: |
|
|
|
""" |
|
|
|
""" |
|
|
|
ctpc_algo = self.one_iteration_of_CTPC_algorithm |
|
|
|
ctpc_algo = self.one_iteration_of_CTPC_algorithm |
|
|
|
total_vars_numb = self._sample_path.total_variables_count |
|
|
|
total_vars_numb = self._sample_path.total_variables_count |
|
|
|
[ctpc_algo(n, total_vars_numb) for n in tqdm(self._nodes)] |
|
|
|
total_vars_numb_list = [total_vars_numb for _ in range(total_vars_numb)] |
|
|
|
|
|
|
|
cpu_count = multiprocessing.cpu_count() |
|
|
|
|
|
|
|
print("CPU COUNT", cpu_count) |
|
|
|
|
|
|
|
with multiprocessing.Pool(processes=cpu_count) as pool: |
|
|
|
|
|
|
|
parent_sets = pool.starmap(ctpc_algo, zip(self._nodes, self._caches, total_vars_numb_list)) |
|
|
|
|
|
|
|
#parent_sets = [ctpc_algo(n, c, total_vars_numb) for n, c in tqdm(zip(self._nodes, self._caches))] |
|
|
|
|
|
|
|
print(parent_sets) |
|
|
|
|
|
|
|
self._result_graph = self.build_result_graph(self._nodes, parent_sets) |
|
|
|
|
|
|
|
|
|
|
|
def save_results(self) -> None: |
|
|
|
def save_results(self) -> None: |
|
|
|
"""Save the estimated Structure to a .json file in the path where the data are loaded from. |
|
|
|
"""Save the estimated Structure to a .json file in the path where the data are loaded from. |
|
|
|
The file is named as the input dataset but the `results_` word is appended to the results file. |
|
|
|
The file is named as the input dataset but the `results_` word is appended to the results file. |
|
|
|
""" |
|
|
|
""" |
|
|
|
res = json_graph.node_link_data(self._complete_graph) |
|
|
|
res = json_graph.node_link_data(self._result_graph_graph) |
|
|
|
name = self._sample_path._importer.file_path.rsplit('/', 1)[-1] + str(self._sample_path._importer.dataset_id()) |
|
|
|
name = self._sample_path._importer.file_path.rsplit('/', 1)[-1] + str(self._sample_path._importer.dataset_id()) |
|
|
|
name = 'results_' + name |
|
|
|
name = 'results_' + name |
|
|
|
with open(name, 'w') as f: |
|
|
|
with open(name, 'w') as f: |
|
|
@ -234,6 +259,6 @@ class StructureEstimator: |
|
|
|
:return: The adjacency matrix of the graph ``_complete_graph`` |
|
|
|
:return: The adjacency matrix of the graph ``_complete_graph`` |
|
|
|
:rtype: numpy.ndArray |
|
|
|
:rtype: numpy.ndArray |
|
|
|
""" |
|
|
|
""" |
|
|
|
return nx.adj_matrix(self._complete_graph).toarray().astype(bool) |
|
|
|
return nx.adj_matrix(self._result_graph).toarray().astype(bool) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|