1
0
Fork 0

Add plot_graph method

better_develop
Filippo Martini 4 years ago
parent b88288f8db
commit 5688e6103b
  1. 5
      PyCTBN/classes/sample_path.py
  2. 47
      PyCTBN/classes/structure_estimator.py
  3. 10
      PyCTBN/tests/test_structure_estimator.py

@ -66,6 +66,11 @@ class SamplePath(object):
def total_variables_count(self) -> int:
return self._total_variables_count
@property
def has_prior_net_structure(self) -> bool:
return bool(self._structure.edges)

@ -3,9 +3,11 @@ import itertools
import json
import typing
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from networkx.readwrite import json_graph
#import pygraphviz
from scipy.stats import chi2 as chi2_dist
from scipy.stats import f as f_dist
from tqdm import tqdm
@ -229,11 +231,54 @@ class StructureEstimator(object):
json.dump(res, f)
def adjacency_matrix(self) -> np.ndarray:
"""Converts the estimated structrure ``_complete_graph`` to a boolean adjacency matrix representation.
"""Converts the estimated structure ``_complete_graph`` to a boolean adjacency matrix representation.
:return: The adjacency matrix of the graph ``_complete_graph``
:rtype: numpy.ndArray
"""
return nx.adj_matrix(self._complete_graph).toarray().astype(bool)
def spurious_edges(self) -> typing.List:
"""Return the spurious edges present in the estimated structure, if a prior net structure is present in
``_sample_path.structure``.
:return: A list containing the spurious edges
:rtype: List
"""
if not self._sample_path.has_prior_net_structure:
raise RuntimeError("Can not compute spurious edges with no prior net structure!")
real_graph = nx.DiGraph()
real_graph.add_nodes_from(self._sample_path.structure.nodes_labels)
real_graph.add_edges_from(self._sample_path.structure.edges)
return nx.difference(real_graph, self._complete_graph).edges
def plot_estimated_structure_graph(self) -> None:
"""Plot the estimated structure in a graphical model style.
Spurious edges are colored in red.
"""
graph_to_draw = nx.DiGraph()
spurious_edges = self.spurious_edges()
non_spurious_edges = list(set(self._complete_graph.edges) - set(spurious_edges))
edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges]
graph_to_draw.add_edges_from(spurious_edges)
graph_to_draw.add_edges_from(non_spurious_edges)
pos = nx.spring_layout(graph_to_draw)
options = {
"node_size": 2500,
'linewidths':2,
"with_labels":True,
"font_size":13,
'connectionstyle': 'arc3, rad = 0.1',
"arrowsize": 15,
"arrowstyle": 'fancy',
"width": 1,
"edge_color":edges_colors,
}
nx.draw(graph_to_draw, pos, **options)
plt.show()

@ -20,7 +20,7 @@ class TestStructureEstimator(unittest.TestCase):
def setUpClass(cls):
cls.read_files = glob.glob(os.path.join('./data', "*.json"))
cls.importer = JsonImporter(cls.read_files[0], 'samples', 'dyn.str', 'variables', 'Time', 'Name')
cls.importer.import_data(0)
cls.importer.import_data(3)
cls.s1 = SamplePath(cls.importer)
cls.s1.build_trajectories()
cls.s1.build_structure()
@ -72,12 +72,8 @@ class TestStructureEstimator(unittest.TestCase):
print("Execution Time: ", exec_time)
for ed in self.s1.structure.edges:
self.assertIn(tuple(ed), se1._complete_graph.edges)
tuples_edges = [tuple(rec) for rec in self.s1.structure.edges]
spurious_edges = []
for ed in se1._complete_graph.edges:
if not(ed in tuples_edges):
spurious_edges.append(ed)
print("Spurious Edges:",spurious_edges)
print("Spurious Edges:", se1.spurious_edges())
se1.plot_estimated_structure_graph()
def test_save_results(self):
se1 = StructureEstimator(self.s1, 0.1, 0.1)