1
0
Fork 0

Updated json importer and tests

master
Luca Moretti 4 years ago
parent e160b88167
commit 4728e99fa3
  1. 9
      PyCTBN/PyCTBN/utility/json_importer.py
  2. 4
      PyCTBN/tests/utility/test_json_importer.py

@ -49,10 +49,10 @@ class JsonImporter(AbstractImporter):
super(JsonImporter, self).__init__(file_path) super(JsonImporter, self).__init__(file_path)
self._raw_data = self.read_json_file() self._raw_data = self.read_json_file()
def import_data(self, indx: int) -> None: def import_data(self, indx: int = 0) -> None:
"""Implements the abstract method of :class:`AbstractImporter`. """Implements the abstract method of :class:`AbstractImporter`.
:param indx: the index of the outer JsonArray to extract the data from :param indx: the index of the outer JsonArray to extract the data from, default to 0
:type indx: int :type indx: int
""" """
self._array_indx = indx self._array_indx = indx
@ -101,7 +101,12 @@ class JsonImporter(AbstractImporter):
""" """
with open(self._file_path) as f: with open(self._file_path) as f:
data = json.load(f) data = json.load(f)
if (isinstance(data,list)):
return data return data
else:
return [data]
def one_level_normalizing(self, raw_data: typing.List, indx: int, key: str) -> pd.DataFrame: def one_level_normalizing(self, raw_data: typing.List, indx: int, key: str) -> pd.DataFrame:
"""Extracts the one-level nested data in the list ``raw_data`` at the index ``indx`` at the key ``key``. """Extracts the one-level nested data in the list ``raw_data`` at the index ``indx`` at the key ``key``.

@ -39,7 +39,7 @@ class TestJsonImporter(unittest.TestCase):
path = os.getcwd() path = os.getcwd()
path = path + '/data.json' path = path + '/data.json'
j1 = JsonImporter(path, '', '', '', '', '') j1 = JsonImporter(path, '', '', '', '', '')
self.assertTrue(self.ordered(data_set) == self.ordered(j1._raw_data)) self.assertTrue(self.ordered([data_set]) == self.ordered(j1._raw_data))
os.remove('data.json') os.remove('data.json')
def test_read_json_file_not_found(self): def test_read_json_file_not_found(self):
@ -155,7 +155,7 @@ class TestJsonImporter(unittest.TestCase):
self.assertEqual(j1.file_path, "./PyCTBN/test_data/networks_and_trajectories_binary_data_01_3.json") self.assertEqual(j1.file_path, "./PyCTBN/test_data/networks_and_trajectories_binary_data_01_3.json")
def test_import_data(self): def test_import_data(self):
j1 = JsonImporter("./PyCTBN/test_data/networks_and_trajectories_binary_data_01_3.json", 'samples', 'dyn.str', 'variables', 'Time', 'Name') j1 = JsonImporter("./PyCTBN/test_data/networks_and_trajectories_binary_data_02_10_1.json", 'samples', 'dyn.str', 'variables', 'Time', 'Name')
j1.import_data(0) j1.import_data(0)
self.assertEqual(list(j1.variables[j1._variables_key]), self.assertEqual(list(j1.variables[j1._variables_key]),
list(j1.concatenated_samples.columns.values[1:len(j1.variables[j1._variables_key]) + 1])) list(j1.concatenated_samples.columns.values[1:len(j1.variables[j1._variables_key]) + 1]))