1
0
Fork 0

Add plot_graph method

better_develop
Filippo Martini 4 years ago
parent b88288f8db
commit 5688e6103b
  1. 5
      PyCTBN/classes/sample_path.py
  2. 47
      PyCTBN/classes/structure_estimator.py
  3. 10
      PyCTBN/tests/test_structure_estimator.py

@ -66,6 +66,11 @@ class SamplePath(object):
def total_variables_count(self) -> int: def total_variables_count(self) -> int:
return self._total_variables_count return self._total_variables_count
@property
def has_prior_net_structure(self) -> bool:
return bool(self._structure.edges)

@ -3,9 +3,11 @@ import itertools
import json import json
import typing import typing
import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
#import pygraphviz
from scipy.stats import chi2 as chi2_dist from scipy.stats import chi2 as chi2_dist
from scipy.stats import f as f_dist from scipy.stats import f as f_dist
from tqdm import tqdm from tqdm import tqdm
@ -229,11 +231,54 @@ class StructureEstimator(object):
json.dump(res, f) json.dump(res, f)
def adjacency_matrix(self) -> np.ndarray: def adjacency_matrix(self) -> np.ndarray:
"""Converts the estimated structrure ``_complete_graph`` to a boolean adjacency matrix representation. """Converts the estimated structure ``_complete_graph`` to a boolean adjacency matrix representation.
:return: The adjacency matrix of the graph ``_complete_graph`` :return: The adjacency matrix of the graph ``_complete_graph``
:rtype: numpy.ndArray :rtype: numpy.ndArray
""" """
return nx.adj_matrix(self._complete_graph).toarray().astype(bool) return nx.adj_matrix(self._complete_graph).toarray().astype(bool)
def spurious_edges(self) -> typing.List:
"""Return the spurious edges present in the estimated structure, if a prior net structure is present in
``_sample_path.structure``.
:return: A list containing the spurious edges
:rtype: List
"""
if not self._sample_path.has_prior_net_structure:
raise RuntimeError("Can not compute spurious edges with no prior net structure!")
real_graph = nx.DiGraph()
real_graph.add_nodes_from(self._sample_path.structure.nodes_labels)
real_graph.add_edges_from(self._sample_path.structure.edges)
return nx.difference(real_graph, self._complete_graph).edges
def plot_estimated_structure_graph(self) -> None:
"""Plot the estimated structure in a graphical model style.
Spurious edges are colored in red.
"""
graph_to_draw = nx.DiGraph()
spurious_edges = self.spurious_edges()
non_spurious_edges = list(set(self._complete_graph.edges) - set(spurious_edges))
edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges]
graph_to_draw.add_edges_from(spurious_edges)
graph_to_draw.add_edges_from(non_spurious_edges)
pos = nx.spring_layout(graph_to_draw)
options = {
"node_size": 2500,
'linewidths':2,
"with_labels":True,
"font_size":13,
'connectionstyle': 'arc3, rad = 0.1',
"arrowsize": 15,
"arrowstyle": 'fancy',
"width": 1,
"edge_color":edges_colors,
}
nx.draw(graph_to_draw, pos, **options)
plt.show()

@ -20,7 +20,7 @@ class TestStructureEstimator(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
cls.read_files = glob.glob(os.path.join('./data', "*.json")) cls.read_files = glob.glob(os.path.join('./data', "*.json"))
cls.importer = JsonImporter(cls.read_files[0], 'samples', 'dyn.str', 'variables', 'Time', 'Name') cls.importer = JsonImporter(cls.read_files[0], 'samples', 'dyn.str', 'variables', 'Time', 'Name')
cls.importer.import_data(0) cls.importer.import_data(3)
cls.s1 = SamplePath(cls.importer) cls.s1 = SamplePath(cls.importer)
cls.s1.build_trajectories() cls.s1.build_trajectories()
cls.s1.build_structure() cls.s1.build_structure()
@ -72,12 +72,8 @@ class TestStructureEstimator(unittest.TestCase):
print("Execution Time: ", exec_time) print("Execution Time: ", exec_time)
for ed in self.s1.structure.edges: for ed in self.s1.structure.edges:
self.assertIn(tuple(ed), se1._complete_graph.edges) self.assertIn(tuple(ed), se1._complete_graph.edges)
tuples_edges = [tuple(rec) for rec in self.s1.structure.edges] print("Spurious Edges:", se1.spurious_edges())
spurious_edges = [] se1.plot_estimated_structure_graph()
for ed in se1._complete_graph.edges:
if not(ed in tuples_edges):
spurious_edges.append(ed)
print("Spurious Edges:",spurious_edges)
def test_save_results(self): def test_save_results(self):
se1 = StructureEstimator(self.s1, 0.1, 0.1) se1 = StructureEstimator(self.s1, 0.1, 0.1)