1
0
Fork 0

Modify plot graph method

master
Filippo Martini 4 years ago
parent c825739490
commit bf1c154658
  1. 1
      .gitignore
  2. 62
      PyCTBN/PyCTBN/estimators/structure_estimator.py
  3. 20
      PyCTBN/setup.py
  4. 2
      setup.py

1
.gitignore vendored

@ -5,4 +5,5 @@ __pycache__
**/PyCTBN.egg-info
**/results_data
**/.scannerwork
.idea

@ -142,40 +142,34 @@ class StructureEstimator(object):
return nx.difference(real_graph, self._complete_graph).edges
def save_plot_estimated_structure_graph(self, file_path: str) -> None:
"""Plot the estimated structure in a graphical model style, use .png extension.
Spurious edges are colored in red if a prior structure is present.
:param file_path: path to save the file to
:type: string
"""
graph_to_draw = nx.DiGraph()
spurious_edges = self.spurious_edges()
non_spurious_edges = list(set(self._complete_graph.edges) - set(spurious_edges))
edges_colors = ['red' if edge in spurious_edges else 'black' for edge in self._complete_graph.edges]
graph_to_draw.add_edges_from(spurious_edges)
graph_to_draw.add_edges_from(non_spurious_edges)
pos = nx.spring_layout(graph_to_draw, k=0.5*1/np.sqrt(len(graph_to_draw.nodes())), iterations=50,scale=10)
options = {
"node_size": 2000,
"node_color": "white",
"edgecolors": "black",
'linewidths':2,
"with_labels":True,
"font_size":13,
'connectionstyle': 'arc3, rad = 0.1',
"arrowsize": 15,
"arrowstyle": '<|-',
"width": 1,
"edge_color":edges_colors,
}
nx.draw(graph_to_draw, pos, **options)
ax = plt.gca()
ax.margins(0.20)
plt.axis("off")
plt.savefig(file_path)
plt.clf()
print("Estimated Structure Plot Saved At: ", os.path.abspath(file_path))
"""Plot the estimated structure in a graphical model style, use .png extension.
:param file_path: path to save the file to
:type: string
"""
pos = nx.spring_layout(self._complete_graph, k=0.5*1/np.sqrt(len(self._complete_graph.nodes)),
iterations=50,scale=10)
options = {
"node_size": 2000,
"node_color": "white",
"edgecolors": "black",
'linewidths':2,
"with_labels":True,
"font_size":13,
'connectionstyle': 'arc3, rad = 0.1',
"arrowsize": 15,
#"arrowstyle": '<|-',
"width": 1,
#"edge_color":edges_colors,
}
nx.draw(self._complete_graph, pos, **options)
ax = plt.gca()
ax.margins(0.20)
plt.axis("off")
plt.savefig(file_path)
plt.clf()
print("Estimated Structure Plot Saved At: ", os.path.abspath(file_path))

@ -1,20 +0,0 @@
from setuptools import setup, find_packages
setup(name='PyCTBN',
version='1.0',
url='https://github.com/philipMartini/PyCTBN',
license='MIT',
author=['Alessandro Bregoli', 'Filippo Martini','Luca Moretti'],
author_email=['a.bregoli1@campus.unimib.it', 'f.martini@campus.unimib.it','lucamoretti96@gmail.com'],
description='A Continuous Time Bayesian Networks Library',
packages=find_packages('.', exclude=['tests']),
#packages=['PyCTBN.PyCTBN'],
install_requires=[
'numpy', 'pandas', 'networkx', 'scipy', 'matplotlib', 'tqdm'],
dependency_links=['https://github.com/numpy/numpy', 'https://github.com/pandas-dev/pandas',
'https://github.com/networkx/networkx', 'https://github.com/scipy/scipy',
'https://github.com/tqdm/tqdm'],
#long_description=open('../README.md').read(),
zip_safe=False,
python_requires='>=3.6')

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(name='PyCTBN',
version='1.0',
version='2.1',
url='https://github.com/philipMartini/PyCTBN',
license='MIT',
author=['Alessandro Bregoli', 'Filippo Martini','Luca Moretti'],