From bb2fe52a39b1f556526665399a05d5f1acadca8c Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Fri, 5 Aug 2022 17:08:24 +0200 Subject: [PATCH 1/4] Implemented ForwardSampler --- src/lib.rs | 1 + src/sampling.rs | 102 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 src/sampling.rs diff --git a/src/lib.rs b/src/lib.rs index 8c57af2..d40776a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,3 +8,4 @@ pub mod parameter_learning; pub mod params; pub mod structure_learning; pub mod tools; +pub mod sampling; diff --git a/src/sampling.rs b/src/sampling.rs new file mode 100644 index 0000000..a0bfaaf --- /dev/null +++ b/src/sampling.rs @@ -0,0 +1,102 @@ +use crate::{ + network::{self, Network}, + params::{self, ParamsTrait}, +}; +use rand::SeedableRng; +use rand_chacha::ChaCha8Rng; + +trait Sampler: Iterator { + fn reset(&mut self); +} + +pub struct ForwardSampler<'a, T> +where + T: Network, +{ + net: &'a T, + rng: ChaCha8Rng, + current_time: f64, + current_state: Vec, + next_transitions: Vec>, +} + +impl<'a, T: Network> ForwardSampler<'a, T> { + pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { + let mut rng: ChaCha8Rng = match seed { + //If a seed is present use it to initialize the random generator. + Some(seed) => SeedableRng::seed_from_u64(seed), + //Otherwise create a new random generator using the method `from_entropy` + None => SeedableRng::from_entropy(), + }; + let mut fs = ForwardSampler { + net: net, + rng: rng, + current_time: 0.0, + current_state: vec![], + next_transitions: vec![], + }; + fs.reset(); + return fs; + } +} + +impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { + type Item = (f64, Vec); + + fn next(&mut self) -> Option { + let ret_time = self.current_time.clone(); + let ret_state = self.current_state.clone(); + + for (idx, val) in self.next_transitions.iter_mut().enumerate() { + if let None = val { + *val = Some( + self.net + .get_node(idx) + .get_random_residence_time( + self.net + .get_node(idx) + .state_to_index(&self.current_state[idx]), + self.net.get_param_index_network(idx, &self.current_state), + &mut self.rng, + ) + .unwrap() + + self.current_time, + ); + } + } + + let next_node_transition = self.next_transitions + .iter() + .enumerate() + .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) + .unwrap() + .0; + + self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); + + self.current_state[next_node_transition] = self.net + .get_node(next_node_transition) + .get_random_state( + self.net.get_node(next_node_transition) + .state_to_index(&self.current_state[next_node_transition]), + self.net.get_param_index_network(next_node_transition, &self.current_state), + &mut self.rng, + ) + .unwrap(); + + + Some((ret_time, ret_state)) + } +} + +impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { + fn reset(&mut self) { + self.current_time = 0.0; + self.current_state = self + .net + .get_node_indices() + .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) + .collect(); + self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); + } +} From 8163bfb2b0317b1f0fcd5b0a1b63588a8d7fd7c7 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 21 Sep 2022 10:39:01 +0200 Subject: [PATCH 2/4] Update trajectory generator --- src/sampling.rs | 8 ++- src/tools.rs | 118 ++++++------------------------------ tests/parameter_learning.rs | 2 +- 3 files changed, 28 insertions(+), 100 deletions(-) diff --git a/src/sampling.rs b/src/sampling.rs index a0bfaaf..9bbf569 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -5,7 +5,7 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -trait Sampler: Iterator { +pub trait Sampler: Iterator { fn reset(&mut self); } @@ -84,6 +84,12 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { ) .unwrap(); + self.next_transitions[next_node_transition] = None; + + for child in self.net.get_children_set(next_node_transition) { + self.next_transitions[child] = None; + } + Some((ret_time, ret_state)) } diff --git a/src/tools.rs b/src/tools.rs index 448b26f..0dfea9e 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -3,6 +3,7 @@ use rand_chacha::rand_core::SeedableRng; use rand_chacha::ChaCha8Rng; use crate::params::ParamsTrait; +use crate::sampling::{Sampler, ForwardSampler}; use crate::{network, params}; pub struct Trajectory { @@ -61,114 +62,28 @@ pub fn trajectory_generator( let mut trajectories: Vec = Vec::new(); //Random Generator object - let mut rng: ChaCha8Rng = match seed { - //If a seed is present use it to initialize the random generator. - Some(seed) => SeedableRng::seed_from_u64(seed), - //Otherwise create a new random generator using the method `from_entropy` - None => SeedableRng::from_entropy(), - }; + let mut sampler = ForwardSampler::new(net, seed); //Each iteration generate one trajectory for _ in 0..n_trajectories { - //Current time of the sampling process - let mut t = 0.0; //History of all the moments in which something changed let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut current_state: Vec = net - .get_node_indices() - .map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) - .collect(); - //History of all the configurations of the process variables. - let mut events: Vec> = Vec::new(); - //Vector containing to time to the next transition for each variable. - let mut next_transitions: Vec> = - net.get_node_indices().map(|_| Option::None).collect(); + let mut events: Vec> = Vec::new(); - //Add the starting time for the trajectory. - time.push(t.clone()); - //Add the starting configuration of the trajectory. - events.push( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - ); + //Current Time and Current State + let (mut t, mut current_state) = sampler.next().unwrap(); //Generate new samples until ending time is reached. while t < t_end { - //Generate the next transition time for each uninitialized variable. - for (idx, val) in next_transitions.iter_mut().enumerate() { - if let None = val { - *val = Some( - net.get_node(idx) - .get_random_residence_time( - net.get_node(idx).state_to_index(¤t_state[idx]), - net.get_param_index_network(idx, ¤t_state), - &mut rng, - ) - .unwrap() - + t, - ); - } - } - - //Get the variable with the smallest transition time. - let next_node_transition = next_transitions - .iter() - .enumerate() - .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) - .unwrap() - .0; - //Check if the next transition take place after the ending time. - if next_transitions[next_node_transition].unwrap() > t_end { - break; - } - //Get the time in which the next transition occurs. - t = next_transitions[next_node_transition].unwrap().clone(); - //Add the transition time to next - time.push(t.clone()); - - //Compute the new state of the transitioning variable. - current_state[next_node_transition] = net - .get_node(next_node_transition) - .get_random_state( - net.get_node(next_node_transition) - .state_to_index(¤t_state[next_node_transition]), - net.get_param_index_network(next_node_transition, ¤t_state), - &mut rng, - ) - .unwrap(); - - //Add the new state to events - events.push(Array::from_vec( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - )); - //Reset the next transition time for the transitioning node. - next_transitions[next_node_transition] = None; - - //Reset the next transition time for each child of the transitioning node. - for child in net.get_children_set(next_node_transition) { - next_transitions[child] = None - } + time.push(t); + events.push(current_state); + (t, current_state) = sampler.next().unwrap(); } - //Add current_state as last state. - events.push( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - ); + current_state = events.last().unwrap().clone(); + events.push(current_state); + //Add t_end as last time. time.push(t_end.clone()); @@ -176,11 +91,18 @@ pub fn trajectory_generator( trajectories.push(Trajectory::new( Array::from_vec(time), Array2::from_shape_vec( - (events.len(), current_state.len()), - events.iter().flatten().cloned().collect(), + (events.len(), events.last().unwrap().len()), + events + .iter() + .flatten() + .map(|x| match x { + params::StateType::Discrete(x) => x.clone(), + }) + .collect(), ) .unwrap(), )); + sampler.reset(); } //Return a dataset object with the sampled trajectories. Dataset::new(trajectories) diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 5de02d7..1e19ce7 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -427,7 +427,7 @@ fn learn_mixed_discrete_cim(pl: T) { [0.8, 0.6, 0.2, -1.6] ], ]), - 0.1 + 0.2 )); } From a07dc214ce72f25538eb0b86ab16b44f0422cae1 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 21 Sep 2022 10:50:12 +0200 Subject: [PATCH 3/4] Clippy errors solved --- src/sampling.rs | 4 ++-- src/tools.rs | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/sampling.rs b/src/sampling.rs index 9bbf569..19b2409 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -1,5 +1,5 @@ use crate::{ - network::{self, Network}, + network::{Network}, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -22,7 +22,7 @@ where impl<'a, T: Network> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { - let mut rng: ChaCha8Rng = match seed { + let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. Some(seed) => SeedableRng::seed_from_u64(seed), //Otherwise create a new random generator using the method `from_entropy` diff --git a/src/tools.rs b/src/tools.rs index 0dfea9e..671a55a 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,8 +1,5 @@ use ndarray::prelude::*; -use rand_chacha::rand_core::SeedableRng; -use rand_chacha::ChaCha8Rng; -use crate::params::ParamsTrait; use crate::sampling::{Sampler, ForwardSampler}; use crate::{network, params}; From 622cd305d0b2ac2d376d74318f0c3aa6a5c1e487 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 21 Sep 2022 11:24:26 +0200 Subject: [PATCH 4/4] Formatting --- src/lib.rs | 2 +- src/sampling.rs | 21 ++++++++++++--------- src/tools.rs | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d40776a..280bd21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,6 @@ pub mod ctbn; pub mod network; pub mod parameter_learning; pub mod params; +pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod sampling; diff --git a/src/sampling.rs b/src/sampling.rs index 19b2409..0660939 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -1,5 +1,5 @@ use crate::{ - network::{Network}, + network::Network, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -65,32 +65,35 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { } } - let next_node_transition = self.next_transitions + let next_node_transition = self + .next_transitions .iter() .enumerate() .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) .unwrap() .0; - + self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); - - self.current_state[next_node_transition] = self.net + + self.current_state[next_node_transition] = self + .net .get_node(next_node_transition) .get_random_state( - self.net.get_node(next_node_transition) + self.net + .get_node(next_node_transition) .state_to_index(&self.current_state[next_node_transition]), - self.net.get_param_index_network(next_node_transition, &self.current_state), + self.net + .get_param_index_network(next_node_transition, &self.current_state), &mut self.rng, ) .unwrap(); self.next_transitions[next_node_transition] = None; - + for child in self.net.get_children_set(next_node_transition) { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) } } diff --git a/src/tools.rs b/src/tools.rs index 671a55a..70bbf76 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,6 +1,6 @@ use ndarray::prelude::*; -use crate::sampling::{Sampler, ForwardSampler}; +use crate::sampling::{ForwardSampler, Sampler}; use crate::{network, params}; pub struct Trajectory {