From 3987f1316505e95f5a396ad29ab11c02642013ad Mon Sep 17 00:00:00 2001 From: philpMartin Date: Fri, 24 Jul 2020 17:51:24 +0200 Subject: [PATCH] Add tests ParameterEstimator class --- main_package/classes/json_importer.py | 10 +++ main_package/classes/network_graph.py | 2 +- main_package/classes/parameters_estimator.py | 8 +-- .../tests/test_parameters_estimator.py | 68 +++++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 main_package/tests/test_parameters_estimator.py diff --git a/main_package/classes/json_importer.py b/main_package/classes/json_importer.py index efcbec4..3595a2b 100644 --- a/main_package/classes/json_importer.py +++ b/main_package/classes/json_importer.py @@ -150,6 +150,16 @@ class JsonImporter(AbstractImporter): for indx in range(len(self.df_samples_list)): # Le singole traj non servono piĆ¹ self.df_samples_list[indx] = self.df_samples_list[indx].iloc[0:0] + def import_sampled_cims(self, raw_data, indx, cims_key): + cims_for_all_vars = {} + for var in raw_data[indx][cims_key]: + sampled_cims_list = [] + cims_for_all_vars[var] = sampled_cims_list + for p_comb in raw_data[indx][cims_key][var]: + cims_for_all_vars[var].append(pd.DataFrame(raw_data[indx][cims_key][var][p_comb]).to_numpy()) + return cims_for_all_vars + + @property def concatenated_samples(self): return self._concatenated_samples diff --git a/main_package/classes/network_graph.py b/main_package/classes/network_graph.py index ae4a04d..c383359 100644 --- a/main_package/classes/network_graph.py +++ b/main_package/classes/network_graph.py @@ -159,7 +159,7 @@ class NetworkGraph(): #if p_indxs.size == 0: #self._time_filtering.append(np.append(p_indxs, np.array([node_indx], dtype=np.int))) #else: - self._time_filtering.append(np.append(np.array([node_indx], dtype=np.int), p_indxs)) + self._time_filtering.append(np.append(np.array([node_indx], dtype=np.int), p_indxs).astype(np.int)) def build_transition_columns_filtering_structure(self): parents_indexes_list = self._fancy_indexing diff --git a/main_package/classes/parameters_estimator.py b/main_package/classes/parameters_estimator.py index 0162c82..ab998aa 100644 --- a/main_package/classes/parameters_estimator.py +++ b/main_package/classes/parameters_estimator.py @@ -16,9 +16,9 @@ class ParametersEstimator: self.net_graph = net_graph self.sets_of_cims_struct = None - def init_amalgamated_cims_struct(self): - self.sets_of_cims_struct = acims.SetsOfCimsContainer(self.net_graph.get_states_number_of_all_nodes_sorted(), - self.net_graph.get_nodes(), + def init_sets_cims_container(self): + self.sets_of_cims_struct = acims.SetsOfCimsContainer(self.net_graph.get_nodes(), + self.net_graph.get_states_number_of_all_nodes_sorted(), self.net_graph.get_ordered_by_indx_parents_values_for_all_nodes()) @@ -26,7 +26,7 @@ class ParametersEstimator: #print(self.net_graph.get_nodes()) #print(self.amalgamated_cims_struct.sets_of_cims) #enumerate(zip(self.net_graph.get_nodes(), self.amalgamated_cims_struct.sets_of_cims)) - for indx, aggr in enumerate(zip(self.net_graph.get_nodes(), self.amalgamated_cims_struct.sets_of_cims)): + for indx, aggr in enumerate(zip(self.net_graph.get_nodes(), self.sets_of_cims_struct.sets_of_cims)): #print(self.net_graph.time_filtering[indx]) #print(self.net_graph.time_scalar_indexing_strucure[indx]) self.compute_state_res_time_for_node(self.net_graph.get_node_indx(aggr[0]), self.sample_path.trajectories.times, diff --git a/main_package/tests/test_parameters_estimator.py b/main_package/tests/test_parameters_estimator.py new file mode 100644 index 0000000..10a7256 --- /dev/null +++ b/main_package/tests/test_parameters_estimator.py @@ -0,0 +1,68 @@ +import unittest +import numpy as np + +import network_graph as ng +import sample_path as sp +import sets_of_cims_container as scc +import parameters_estimator as pe +import json_importer as ji + + +class TestParametersEstimatior(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.s1 = sp.SamplePath('../data', 'samples', 'dyn.str', 'variables', 'Time', 'Name') + cls.s1.build_trajectories() + cls.s1.build_structure() + cls.g1 = ng.NetworkGraph(cls.s1.structure) + cls.g1.init_graph() + + def test_init(self): + self.aux_test_init(self.s1, self.g1) + + def test_init_sets_of_cims_container(self): + self.aux_test_init_sets_cims_container(self.s1, self.g1) + + def aux_test_init(self, sample_p, graph): + pe1 = pe.ParametersEstimator(sample_p, graph) + self.assertEqual(sample_p, pe1.sample_path) + self.assertEqual(graph, pe1.net_graph) + self.assertIsNone(pe1.sets_of_cims_struct) + + def aux_test_init_sets_cims_container(self, sample_p, graph): + pe1 = pe.ParametersEstimator(sample_p, graph) + pe1.init_sets_cims_container() + self.assertIsInstance(pe1.sets_of_cims_struct, scc.SetsOfCimsContainer) + + def test_compute_parameters(self): + self.aux_test_compute_parameters(self.s1, self.g1) + + def aux_test_compute_parameters(self, sample_p, graph): + pe1 = pe.ParametersEstimator(sample_p, graph) + pe1.init_sets_cims_container() + pe1.compute_parameters() + samples_cims = self.aux_import_sampled_cims('dyn.cims') + for indx, sc in enumerate(samples_cims.values()): + self.equality_of_cims_of_node(sc, pe1.sets_of_cims_struct.get_set_of_cims(indx).get_cims()) + + def equality_of_cims_of_node(self, sampled_cims, estimated_cims): + self.assertEqual(len(sampled_cims), len(estimated_cims)) + for c1, c2 in zip(sampled_cims, estimated_cims): + self.cim_equality_test(c1, c2.cim) + + def cim_equality_test(self, cim1, cim2): + for r1, r2 in zip(cim1, cim2): + self.assertTrue(np.all(np.isclose(r1, r2, 1e-01, 1e-01) == True)) + + def aux_import_sampled_cims(self, cims_label): + i1 = ji.JsonImporter('../data', '', '', '', '', '') + raw_data = i1.read_json_file() + return i1.import_sampled_cims(raw_data, 0, cims_label) + + + + + +if __name__ == '__main__': + unittest.main()