diff --git a/main_package/classes/utility/sample_importer.py b/main_package/classes/utility/sample_importer.py index 7c5fa49..3c049ef 100644 --- a/main_package/classes/utility/sample_importer.py +++ b/main_package/classes/utility/sample_importer.py @@ -43,7 +43,7 @@ class SampleImporter(ai.AbstractImporter): def import_data(self, header_column = None): - if header_column is None: + if header_column is not None: self._sorter = header_column else: self._sorter = self.build_sorter(self._df_samples_list[0]) diff --git a/main_package/tests/optimizers/test_tabu_search.py b/main_package/tests/optimizers/test_tabu_search.py index d900421..cbde628 100644 --- a/main_package/tests/optimizers/test_tabu_search.py +++ b/main_package/tests/optimizers/test_tabu_search.py @@ -46,7 +46,7 @@ class TestTabuSearch(unittest.TestCase): prior_net_structure=prior_net_structure ) - cls.importer.import_data(0) + cls.importer.import_data() cls.s1 = sp.SamplePath(cls.importer) #cls.traj = cls.s1.concatenated_samples diff --git a/main_package/tests/utility/test_sample_importer.py b/main_package/tests/utility/test_sample_importer.py new file mode 100644 index 0000000..00cb847 --- /dev/null +++ b/main_package/tests/utility/test_sample_importer.py @@ -0,0 +1,81 @@ +import sys +sys.path.append("../../classes/") +import unittest +import os +import glob +import numpy as np +import pandas as pd +import utility.sample_importer as si +import structure_graph.sample_path as sp + +import json + + + +class TestSampleImporter(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + with open("../../data/networks_and_trajectories_binary_data_01_3.json") as f: + raw_data = json.load(f) + + trajectory_list_raw= raw_data[0]["samples"] + + cls.trajectory_list = [pd.DataFrame(sample) for sample in trajectory_list_raw] + + cls.variables= pd.DataFrame(raw_data[0]["variables"]) + cls.prior_net_structure = pd.DataFrame(raw_data[0]["dyn.str"]) + + + def test_init(self): + sample_importer = si.SampleImporter( + trajectory_list=self.trajectory_list, + variables=self.variables, + prior_net_structure=self.prior_net_structure + ) + + sample_importer.import_data() + + s1 = sp.SamplePath(sample_importer) + s1.build_trajectories() + s1.build_structure() + s1.clear_memory() + + self.assertEqual(len(s1._importer._df_samples_list), 300) + self.assertIsInstance(s1._importer._df_samples_list,list) + self.assertIsInstance(s1._importer._df_samples_list[0],pd.DataFrame) + self.assertEqual(len(s1._importer._df_variables), 3) + self.assertIsInstance(s1._importer._df_variables,pd.DataFrame) + self.assertEqual(len(s1._importer._df_structure), 2) + self.assertIsInstance(s1._importer._df_structure,pd.DataFrame) + + def test_order(self): + sample_importer = si.SampleImporter( + trajectory_list=self.trajectory_list, + variables=self.variables, + prior_net_structure=self.prior_net_structure + ) + + sample_importer.import_data() + + s1 = sp.SamplePath(sample_importer) + s1.build_trajectories() + s1.build_structure() + s1.clear_memory() + + for count,var in enumerate(s1._importer._df_samples_list[0].columns[1:]): + self.assertEqual(s1._importer._sorter[count],var) + + + + def ordered(self, obj): + if isinstance(obj, dict): + return sorted((k, self.ordered(v)) for k, v in obj.items()) + if isinstance(obj, list): + return sorted(self.ordered(x) for x in obj) + else: + return obj + + +if __name__ == '__main__': + unittest.main()