diff --git a/src/tools.rs b/src/tools.rs index 048a77b..1b0d8ba 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -23,12 +23,14 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 for _ in 0..n_trajectories { let mut t = 0.0; let mut time: Vec = Vec::new(); - let mut events: Vec> = Vec::new(); + let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx.iter().map(|x| { net.get_node(*x).get_random_state_uniform() }).collect(); let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); - events.push(current_state.clone()); + events.push(current_state.iter().map(|x| match x { + params::StateType::Discrete(state) => state.clone() + }).collect()); time.push(t.clone()); while t < t_end { next_transitions.iter_mut().enumerate().map(|(idx, val)| { @@ -60,17 +62,27 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 net.get_param_index_network(next_node_transition, ¤t_state)) .unwrap(); - events.push(current_state.clone()); + + events.push(current_state.iter().map(|x| match x { + params::StateType::Discrete(state) => state.clone() + }).collect()); next_transitions[next_node_transition] = None; for child in net.get_children_set(next_node_transition){ next_transitions[child] = None } - - } + events.push(current_state.iter().map(|x| match x { + params::StateType::Discrete(state) => state.clone() + }).collect()); + time.push(t.clone()); + + dataset.trajectories.push(Trajectory { + time: array![time], + events: array![events] + }); }