diff --git a/PyCTBN/classes/sample_path.py b/PyCTBN/classes/sample_path.py index 2f7fe9b..93d1b72 100644 --- a/PyCTBN/classes/sample_path.py +++ b/PyCTBN/classes/sample_path.py @@ -66,6 +66,11 @@ class SamplePath(object): def total_variables_count(self) -> int: return self._total_variables_count + @property + def has_prior_net_structure(self) -> bool: + return bool(self._structure.edges) + + diff --git a/PyCTBN/classes/structure_estimator.py b/PyCTBN/classes/structure_estimator.py index 0ade97b..4b4db27 100644 --- a/PyCTBN/classes/structure_estimator.py +++ b/PyCTBN/classes/structure_estimator.py @@ -3,9 +3,11 @@ import itertools import json import typing +import matplotlib.pyplot as plt import networkx as nx import numpy as np from networkx.readwrite import json_graph +#import pygraphviz from scipy.stats import chi2 as chi2_dist from scipy.stats import f as f_dist from tqdm import tqdm @@ -229,11 +231,54 @@ class StructureEstimator(object): json.dump(res, f) def adjacency_matrix(self) -> np.ndarray: - """Converts the estimated structrure ``_complete_graph`` to a boolean adjacency matrix representation. + """Converts the estimated structure ``_complete_graph`` to a boolean adjacency matrix representation. :return: The adjacency matrix of the graph ``_complete_graph`` :rtype: numpy.ndArray """ return nx.adj_matrix(self._complete_graph).toarray().astype(bool) + def spurious_edges(self) -> typing.List: + """Return the spurious edges present in the estimated structure, if a prior net structure is present in + ``_sample_path.structure``. + + :return: A list containing the spurious edges + :rtype: List + """ + if not self._sample_path.has_prior_net_structure: + raise RuntimeError("Can not compute spurious edges with no prior net structure!") + real_graph = nx.DiGraph() + real_graph.add_nodes_from(self._sample_path.structure.nodes_labels) + real_graph.add_edges_from(self._sample_path.structure.edges) + return nx.difference(real_graph, self._complete_graph).edges + + def plot_estimated_structure_graph(self) -> None: + """Plot the estimated structure in a graphical model style. + Spurious edges are colored in red. + """ + graph_to_draw = nx.DiGraph() + spurious_edges = self.spurious_edges() + non_spurious_edges = list(set(self._complete_graph.edges) - set(spurious_edges)) + edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges] + graph_to_draw.add_edges_from(spurious_edges) + graph_to_draw.add_edges_from(non_spurious_edges) + pos = nx.spring_layout(graph_to_draw) + options = { + "node_size": 2500, + 'linewidths':2, + "with_labels":True, + "font_size":13, + 'connectionstyle': 'arc3, rad = 0.1', + "arrowsize": 15, + "arrowstyle": 'fancy', + "width": 1, + "edge_color":edges_colors, + } + + nx.draw(graph_to_draw, pos, **options) + plt.show() + + + + diff --git a/PyCTBN/tests/test_structure_estimator.py b/PyCTBN/tests/test_structure_estimator.py index 5e98fa0..fd3b9ea 100644 --- a/PyCTBN/tests/test_structure_estimator.py +++ b/PyCTBN/tests/test_structure_estimator.py @@ -20,7 +20,7 @@ class TestStructureEstimator(unittest.TestCase): def setUpClass(cls): cls.read_files = glob.glob(os.path.join('./data', "*.json")) cls.importer = JsonImporter(cls.read_files[0], 'samples', 'dyn.str', 'variables', 'Time', 'Name') - cls.importer.import_data(0) + cls.importer.import_data(3) cls.s1 = SamplePath(cls.importer) cls.s1.build_trajectories() cls.s1.build_structure() @@ -72,12 +72,8 @@ class TestStructureEstimator(unittest.TestCase): print("Execution Time: ", exec_time) for ed in self.s1.structure.edges: self.assertIn(tuple(ed), se1._complete_graph.edges) - tuples_edges = [tuple(rec) for rec in self.s1.structure.edges] - spurious_edges = [] - for ed in se1._complete_graph.edges: - if not(ed in tuples_edges): - spurious_edges.append(ed) - print("Spurious Edges:",spurious_edges) + print("Spurious Edges:", se1.spurious_edges()) + se1.plot_estimated_structure_graph() def test_save_results(self): se1 = StructureEstimator(self.s1, 0.1, 0.1)