|
|
|
@ -3,9 +3,11 @@ import itertools |
|
|
|
|
import json |
|
|
|
|
import typing |
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
import networkx as nx |
|
|
|
|
import numpy as np |
|
|
|
|
from networkx.readwrite import json_graph |
|
|
|
|
#import pygraphviz |
|
|
|
|
from scipy.stats import chi2 as chi2_dist |
|
|
|
|
from scipy.stats import f as f_dist |
|
|
|
|
from tqdm import tqdm |
|
|
|
@ -229,11 +231,54 @@ class StructureEstimator(object): |
|
|
|
|
json.dump(res, f) |
|
|
|
|
|
|
|
|
|
def adjacency_matrix(self) -> np.ndarray: |
|
|
|
|
"""Converts the estimated structrure ``_complete_graph`` to a boolean adjacency matrix representation. |
|
|
|
|
"""Converts the estimated structure ``_complete_graph`` to a boolean adjacency matrix representation. |
|
|
|
|
|
|
|
|
|
:return: The adjacency matrix of the graph ``_complete_graph`` |
|
|
|
|
:rtype: numpy.ndArray |
|
|
|
|
""" |
|
|
|
|
return nx.adj_matrix(self._complete_graph).toarray().astype(bool) |
|
|
|
|
|
|
|
|
|
def spurious_edges(self) -> typing.List: |
|
|
|
|
"""Return the spurious edges present in the estimated structure, if a prior net structure is present in |
|
|
|
|
``_sample_path.structure``. |
|
|
|
|
|
|
|
|
|
:return: A list containing the spurious edges |
|
|
|
|
:rtype: List |
|
|
|
|
""" |
|
|
|
|
if not self._sample_path.has_prior_net_structure: |
|
|
|
|
raise RuntimeError("Can not compute spurious edges with no prior net structure!") |
|
|
|
|
real_graph = nx.DiGraph() |
|
|
|
|
real_graph.add_nodes_from(self._sample_path.structure.nodes_labels) |
|
|
|
|
real_graph.add_edges_from(self._sample_path.structure.edges) |
|
|
|
|
return nx.difference(real_graph, self._complete_graph).edges |
|
|
|
|
|
|
|
|
|
def plot_estimated_structure_graph(self) -> None: |
|
|
|
|
"""Plot the estimated structure in a graphical model style. |
|
|
|
|
Spurious edges are colored in red. |
|
|
|
|
""" |
|
|
|
|
graph_to_draw = nx.DiGraph() |
|
|
|
|
spurious_edges = self.spurious_edges() |
|
|
|
|
non_spurious_edges = list(set(self._complete_graph.edges) - set(spurious_edges)) |
|
|
|
|
edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges] |
|
|
|
|
graph_to_draw.add_edges_from(spurious_edges) |
|
|
|
|
graph_to_draw.add_edges_from(non_spurious_edges) |
|
|
|
|
pos = nx.spring_layout(graph_to_draw) |
|
|
|
|
options = { |
|
|
|
|
"node_size": 2500, |
|
|
|
|
'linewidths':2, |
|
|
|
|
"with_labels":True, |
|
|
|
|
"font_size":13, |
|
|
|
|
'connectionstyle': 'arc3, rad = 0.1', |
|
|
|
|
"arrowsize": 15, |
|
|
|
|
"arrowstyle": 'fancy', |
|
|
|
|
"width": 1, |
|
|
|
|
"edge_color":edges_colors, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
nx.draw(graph_to_draw, pos, **options) |
|
|
|
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|