diff --git a/PyCTBN/PyCTBN/utility/json_importer.py b/PyCTBN/PyCTBN/utility/json_importer.py index edff212..2cadff4 100644 --- a/PyCTBN/PyCTBN/utility/json_importer.py +++ b/PyCTBN/PyCTBN/utility/json_importer.py @@ -49,10 +49,10 @@ class JsonImporter(AbstractImporter): super(JsonImporter, self).__init__(file_path) self._raw_data = self.read_json_file() - def import_data(self, indx: int) -> None: + def import_data(self, indx: int = 0) -> None: """Implements the abstract method of :class:`AbstractImporter`. - :param indx: the index of the outer JsonArray to extract the data from + :param indx: the index of the outer JsonArray to extract the data from, default to 0 :type indx: int """ self._array_indx = indx @@ -101,7 +101,12 @@ class JsonImporter(AbstractImporter): """ with open(self._file_path) as f: data = json.load(f) - return data + + if (isinstance(data,list)): + return data + else: + return [data] + def one_level_normalizing(self, raw_data: typing.List, indx: int, key: str) -> pd.DataFrame: """Extracts the one-level nested data in the list ``raw_data`` at the index ``indx`` at the key ``key``. diff --git a/PyCTBN/tests/utility/test_json_importer.py b/PyCTBN/tests/utility/test_json_importer.py index 877805b..a4bb4c0 100644 --- a/PyCTBN/tests/utility/test_json_importer.py +++ b/PyCTBN/tests/utility/test_json_importer.py @@ -39,7 +39,7 @@ class TestJsonImporter(unittest.TestCase): path = os.getcwd() path = path + '/data.json' j1 = JsonImporter(path, '', '', '', '', '') - self.assertTrue(self.ordered(data_set) == self.ordered(j1._raw_data)) + self.assertTrue(self.ordered([data_set]) == self.ordered(j1._raw_data)) os.remove('data.json') def test_read_json_file_not_found(self): @@ -155,7 +155,7 @@ class TestJsonImporter(unittest.TestCase): self.assertEqual(j1.file_path, "./PyCTBN/test_data/networks_and_trajectories_binary_data_01_3.json") def test_import_data(self): - j1 = JsonImporter("./PyCTBN/test_data/networks_and_trajectories_binary_data_01_3.json", 'samples', 'dyn.str', 'variables', 'Time', 'Name') + j1 = JsonImporter("./PyCTBN/test_data/networks_and_trajectories_binary_data_02_10_1.json", 'samples', 'dyn.str', 'variables', 'Time', 'Name') j1.import_data(0) self.assertEqual(list(j1.variables[j1._variables_key]), list(j1.concatenated_samples.columns.values[1:len(j1.variables[j1._variables_key]) + 1]))