diff --git a/PyCTBN/PyCTBN/structure_graph/network_generator.py b/PyCTBN/PyCTBN/structure_graph/network_generator.py index 823bb13..57599ef 100644 --- a/PyCTBN/PyCTBN/structure_graph/network_generator.py +++ b/PyCTBN/PyCTBN/structure_graph/network_generator.py @@ -21,6 +21,9 @@ class NetworkGenerator(object): self._graph.add_edges(s.edges) def generate_cims(self, min_val, max_val): + if self._graph is None: + return + self._cims = {} for i, l in enumerate(self._labels): diff --git a/test_networkgenerator.py b/test_networkgenerator.py new file mode 100644 index 0000000..d4f4882 --- /dev/null +++ b/test_networkgenerator.py @@ -0,0 +1,38 @@ +import unittest +import random +import numpy as np + +from PyCTBN.PyCTBN.structure_graph.network_generator import NetworkGenerator + +class TestNetworkGenerator(unittest.TestCase): + def test_generate_graph(self): + labels = ["U", "V", "W", "X", "Y", "Z"] + card = 3 + vals = [card for l in labels] + ng = NetworkGenerator(labels, vals) + ng.generate_graph() + self.assertEqual(len(labels), len(ng.graph.nodes)) + self.assertEqual(len([edge for edge in ng.graph.edges if edge[0] == edge[1]]), 0) + + def test_generate_cims(self): + labels = ["U", "V", "W", "X", "Y", "Z"] + card = 3 + vals = [card for l in labels] + cim_min = random.uniform(0.5, 5) + cim_max = random.uniform(0.5, 5) + cim_min + ng = NetworkGenerator(labels, vals) + ng.generate_graph() + ng.generate_cims(cim_min, cim_max) + self.assertEqual(len(ng.cims), len(labels)) + self.assertListEqual(list(ng.cims.keys()), labels) + for key in ng.cims: + p_card = ng.graph.get_ordered_by_indx_set_of_parents(key)[2] + comb = ng.graph.build_p_comb_structure_for_a_node(p_card) + self.assertEqual(len(ng.cims[key].actual_cims), len(comb)) + for cim in ng.cims[key].actual_cims: + self.assertEqual(sum(c > 0 for c in cim.cim.diagonal()), 0) + for i, row in enumerate(cim.cim): + self.assertEqual(round(sum(row) - row[i], 8), round(-1 * row[i], 8)) + self.assertEqual(sum(c < 0 for c in np.delete(cim.cim[i], i)), 0) + +unittest.main() \ No newline at end of file