|
|
|
@ -41,4 +41,17 @@ class TestTrajectoryGenerator(unittest.TestCase): |
|
|
|
|
traj = Trajectory(self.j1.build_list_of_samples_array(sigma), len(self.j1.sorter) + 1) |
|
|
|
|
self.assertEqual(len(traj.times), n_tr + 1) |
|
|
|
|
|
|
|
|
|
def test_multi_trajectory(self): |
|
|
|
|
tg = TrajectoryGenerator(self.j1) |
|
|
|
|
max_trs = [random.randint(5, 100) for i in range(10)] |
|
|
|
|
trajectories = tg.multi_trajectory(max_trs = max_trs) |
|
|
|
|
self.assertEqual(len(trajectories), len(max_trs)) |
|
|
|
|
for i, trajectory in enumerate(trajectories): |
|
|
|
|
self.assertEqual(len(trajectory), max_trs[i]) |
|
|
|
|
t_ends = [random.randint(100, 500) for i in range(10)] |
|
|
|
|
trajectories = tg.multi_trajectory(t_ends = t_ends) |
|
|
|
|
self.assertEqual(len(trajectories), len(t_ends)) |
|
|
|
|
for i, trajectory in enumerate(trajectories): |
|
|
|
|
self.assertEqual(len(trajectory), t_ends[i]) |
|
|
|
|
|
|
|
|
|
unittest.main() |