|
|
|
@ -7,7 +7,7 @@ import matplotlib.pyplot as plt |
|
|
|
|
import networkx as nx |
|
|
|
|
import numpy as np |
|
|
|
|
from networkx.readwrite import json_graph |
|
|
|
|
#import pygraphviz |
|
|
|
|
import os |
|
|
|
|
from scipy.stats import chi2 as chi2_dist |
|
|
|
|
from scipy.stats import f as f_dist |
|
|
|
|
from tqdm import tqdm |
|
|
|
@ -252,7 +252,7 @@ class StructureEstimator(object): |
|
|
|
|
real_graph.add_edges_from(self._sample_path.structure.edges) |
|
|
|
|
return nx.difference(real_graph, self._complete_graph).edges |
|
|
|
|
|
|
|
|
|
def plot_estimated_structure_graph(self) -> None: |
|
|
|
|
def save_plot_estimated_structure_graph(self) -> None: |
|
|
|
|
"""Plot the estimated structure in a graphical model style. |
|
|
|
|
Spurious edges are colored in red. |
|
|
|
|
""" |
|
|
|
@ -262,21 +262,31 @@ class StructureEstimator(object): |
|
|
|
|
edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges] |
|
|
|
|
graph_to_draw.add_edges_from(spurious_edges) |
|
|
|
|
graph_to_draw.add_edges_from(non_spurious_edges) |
|
|
|
|
pos = nx.spring_layout(graph_to_draw) |
|
|
|
|
pos = nx.spring_layout(graph_to_draw, k=5, scale=3) |
|
|
|
|
options = { |
|
|
|
|
"node_size": 2500, |
|
|
|
|
"node_size": 2000, |
|
|
|
|
"node_color": "white", |
|
|
|
|
"edgecolors": "black", |
|
|
|
|
'linewidths':2, |
|
|
|
|
"with_labels":True, |
|
|
|
|
"font_size":13, |
|
|
|
|
'connectionstyle': 'arc3, rad = 0.1', |
|
|
|
|
'connectionstyle': 'arc3, rad = 0.', |
|
|
|
|
"arrowsize": 15, |
|
|
|
|
"arrowstyle": 'fancy', |
|
|
|
|
"arrowstyle": '<|-', |
|
|
|
|
"width": 1, |
|
|
|
|
"edge_color":edges_colors, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
nx.draw(graph_to_draw, pos, **options) |
|
|
|
|
plt.show() |
|
|
|
|
ax = plt.gca() |
|
|
|
|
ax.margins(0.20) |
|
|
|
|
plt.axis("off") |
|
|
|
|
name = self._sample_path._importer.file_path.rsplit('/', 1)[-1] |
|
|
|
|
name = name.split('.', 1)[0] |
|
|
|
|
name += '_' + str(self._sample_path._importer.dataset_id()) |
|
|
|
|
name += '.png' |
|
|
|
|
plt.savefig(name) |
|
|
|
|
print("Estimated Structure Plot Saved At: ", os.path.abspath(name)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|