1
0
Fork 0

Modify Graph Layout in Plot

better_develop
Filippo Martini 4 years ago
parent 5688e6103b
commit 23e3ced196
  1. 24
      PyCTBN/classes/structure_estimator.py
  2. 4
      PyCTBN/tests/test_structure_estimator.py

@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from networkx.readwrite import json_graph
#import pygraphviz
import os
from scipy.stats import chi2 as chi2_dist
from scipy.stats import f as f_dist
from tqdm import tqdm
@ -252,7 +252,7 @@ class StructureEstimator(object):
real_graph.add_edges_from(self._sample_path.structure.edges)
return nx.difference(real_graph, self._complete_graph).edges
def plot_estimated_structure_graph(self) -> None:
def save_plot_estimated_structure_graph(self) -> None:
"""Plot the estimated structure in a graphical model style.
Spurious edges are colored in red.
"""
@ -262,21 +262,31 @@ class StructureEstimator(object):
edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges]
graph_to_draw.add_edges_from(spurious_edges)
graph_to_draw.add_edges_from(non_spurious_edges)
pos = nx.spring_layout(graph_to_draw)
pos = nx.spring_layout(graph_to_draw, k=5, scale=3)
options = {
"node_size": 2500,
"node_size": 2000,
"node_color": "white",
"edgecolors": "black",
'linewidths':2,
"with_labels":True,
"font_size":13,
'connectionstyle': 'arc3, rad = 0.1',
'connectionstyle': 'arc3, rad = 0.',
"arrowsize": 15,
"arrowstyle": 'fancy',
"arrowstyle": '<|-',
"width": 1,
"edge_color":edges_colors,
}
nx.draw(graph_to_draw, pos, **options)
plt.show()
ax = plt.gca()
ax.margins(0.20)
plt.axis("off")
name = self._sample_path._importer.file_path.rsplit('/', 1)[-1]
name = name.split('.', 1)[0]
name += '_' + str(self._sample_path._importer.dataset_id())
name += '.png'
plt.savefig(name)
print("Estimated Structure Plot Saved At: ", os.path.abspath(name))

@ -20,7 +20,7 @@ class TestStructureEstimator(unittest.TestCase):
def setUpClass(cls):
cls.read_files = glob.glob(os.path.join('./data', "*.json"))
cls.importer = JsonImporter(cls.read_files[0], 'samples', 'dyn.str', 'variables', 'Time', 'Name')
cls.importer.import_data(3)
cls.importer.import_data(1)
cls.s1 = SamplePath(cls.importer)
cls.s1.build_trajectories()
cls.s1.build_structure()
@ -73,7 +73,7 @@ class TestStructureEstimator(unittest.TestCase):
for ed in self.s1.structure.edges:
self.assertIn(tuple(ed), se1._complete_graph.edges)
print("Spurious Edges:", se1.spurious_edges())
se1.plot_estimated_structure_graph()
se1.save_plot_estimated_structure_graph()
def test_save_results(self):
se1 = StructureEstimator(self.s1, 0.1, 0.1)