From 23e3ced19631bcb70539cb87011bd276a26cd775 Mon Sep 17 00:00:00 2001 From: Filippo Martini Date: Tue, 22 Dec 2020 19:46:55 +0100 Subject: [PATCH] Modify Graph Layout in Plot --- PyCTBN/classes/structure_estimator.py | 24 +++++++++++++++++------- PyCTBN/tests/test_structure_estimator.py | 4 ++-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/PyCTBN/classes/structure_estimator.py b/PyCTBN/classes/structure_estimator.py index 4b4db27..48b137c 100644 --- a/PyCTBN/classes/structure_estimator.py +++ b/PyCTBN/classes/structure_estimator.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import networkx as nx import numpy as np from networkx.readwrite import json_graph -#import pygraphviz +import os from scipy.stats import chi2 as chi2_dist from scipy.stats import f as f_dist from tqdm import tqdm @@ -252,7 +252,7 @@ class StructureEstimator(object): real_graph.add_edges_from(self._sample_path.structure.edges) return nx.difference(real_graph, self._complete_graph).edges - def plot_estimated_structure_graph(self) -> None: + def save_plot_estimated_structure_graph(self) -> None: """Plot the estimated structure in a graphical model style. Spurious edges are colored in red. """ @@ -262,21 +262,31 @@ class StructureEstimator(object): edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges] graph_to_draw.add_edges_from(spurious_edges) graph_to_draw.add_edges_from(non_spurious_edges) - pos = nx.spring_layout(graph_to_draw) + pos = nx.spring_layout(graph_to_draw, k=5, scale=3) options = { - "node_size": 2500, + "node_size": 2000, + "node_color": "white", + "edgecolors": "black", 'linewidths':2, "with_labels":True, "font_size":13, - 'connectionstyle': 'arc3, rad = 0.1', + 'connectionstyle': 'arc3, rad = 0.', "arrowsize": 15, - "arrowstyle": 'fancy', + "arrowstyle": '<|-', "width": 1, "edge_color":edges_colors, } nx.draw(graph_to_draw, pos, **options) - plt.show() + ax = plt.gca() + ax.margins(0.20) + plt.axis("off") + name = self._sample_path._importer.file_path.rsplit('/', 1)[-1] + name = name.split('.', 1)[0] + name += '_' + str(self._sample_path._importer.dataset_id()) + name += '.png' + plt.savefig(name) + print("Estimated Structure Plot Saved At: ", os.path.abspath(name)) diff --git a/PyCTBN/tests/test_structure_estimator.py b/PyCTBN/tests/test_structure_estimator.py index fd3b9ea..0283ce0 100644 --- a/PyCTBN/tests/test_structure_estimator.py +++ b/PyCTBN/tests/test_structure_estimator.py @@ -20,7 +20,7 @@ class TestStructureEstimator(unittest.TestCase): def setUpClass(cls): cls.read_files = glob.glob(os.path.join('./data', "*.json")) cls.importer = JsonImporter(cls.read_files[0], 'samples', 'dyn.str', 'variables', 'Time', 'Name') - cls.importer.import_data(3) + cls.importer.import_data(1) cls.s1 = SamplePath(cls.importer) cls.s1.build_trajectories() cls.s1.build_structure() @@ -73,7 +73,7 @@ class TestStructureEstimator(unittest.TestCase): for ed in self.s1.structure.edges: self.assertIn(tuple(ed), se1._complete_graph.edges) print("Spurious Edges:", se1.spurious_edges()) - se1.plot_estimated_structure_graph() + se1.save_plot_estimated_structure_graph() def test_save_results(self): se1 = StructureEstimator(self.s1, 0.1, 0.1)