|
|
|
@ -3,16 +3,49 @@ use crate::node; |
|
|
|
|
use crate::params; |
|
|
|
|
use crate::params::ParamsTrait; |
|
|
|
|
use ndarray::prelude::*; |
|
|
|
|
use rand_chacha::ChaCha8Rng; |
|
|
|
|
use rand_chacha::rand_core::SeedableRng; |
|
|
|
|
use rand_chacha::ChaCha8Rng; |
|
|
|
|
|
|
|
|
|
pub struct Trajectory { |
|
|
|
|
pub time: Array1<f64>, |
|
|
|
|
pub events: Array2<usize>, |
|
|
|
|
time: Array1<f64>, |
|
|
|
|
events: Array2<usize>, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl Trajectory { |
|
|
|
|
pub fn init(time: Array1<f64>, events: Array2<usize>) -> Trajectory { |
|
|
|
|
if time.shape()[0] != events.shape()[0] { |
|
|
|
|
panic!("time.shape[0] must be equal to events.shape[0]"); |
|
|
|
|
} |
|
|
|
|
Trajectory { time, events } |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn get_time(&self) -> &Array1<f64> { |
|
|
|
|
&self.time |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn get_events(&self) -> &Array2<usize> { |
|
|
|
|
&self.events |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct Dataset { |
|
|
|
|
pub trajectories: Vec<Trajectory>, |
|
|
|
|
trajectories: Vec<Trajectory>, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl Dataset { |
|
|
|
|
pub fn init(trajectories: Vec<Trajectory>) -> Dataset { |
|
|
|
|
if trajectories |
|
|
|
|
.iter() |
|
|
|
|
.any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1]) |
|
|
|
|
{ |
|
|
|
|
panic!("All the trajectories mus represents the same number of variables"); |
|
|
|
|
} |
|
|
|
|
Dataset { trajectories } |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn get_trajectories(&self) -> &Vec<Trajectory> { |
|
|
|
|
&self.trajectories |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn trajectory_generator<T: network::Network>( |
|
|
|
@ -21,10 +54,8 @@ pub fn trajectory_generator<T: network::Network>( |
|
|
|
|
t_end: f64, |
|
|
|
|
seed: Option<u64>, |
|
|
|
|
) -> Dataset { |
|
|
|
|
let mut dataset = Dataset { |
|
|
|
|
trajectories: Vec::new(), |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
let mut trajectories: Vec<Trajectory> = Vec::new(); |
|
|
|
|
let seed = seed.unwrap_or_else(rand::random); |
|
|
|
|
|
|
|
|
|
let mut rng = ChaCha8Rng::seed_from_u64(seed); |
|
|
|
@ -115,14 +146,14 @@ pub fn trajectory_generator<T: network::Network>( |
|
|
|
|
); |
|
|
|
|
time.push(t_end.clone()); |
|
|
|
|
|
|
|
|
|
dataset.trajectories.push(Trajectory { |
|
|
|
|
time: Array::from_vec(time), |
|
|
|
|
events: Array2::from_shape_vec( |
|
|
|
|
trajectories.push(Trajectory::init( |
|
|
|
|
Array::from_vec(time), |
|
|
|
|
Array2::from_shape_vec( |
|
|
|
|
(events.len(), current_state.len()), |
|
|
|
|
events.iter().flatten().cloned().collect(), |
|
|
|
|
) |
|
|
|
|
.unwrap(), |
|
|
|
|
}); |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
dataset |
|
|
|
|
Dataset::init(trajectories) |
|
|
|
|
} |
|
|
|
|