diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 0662994..3bc0c6f 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,10 +7,17 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -pub trait Sampler: Iterator { +pub struct Sample { + pub t: f64, + pub state: Vec +} + +pub trait Sampler: Iterator { fn reset(&mut self); } + + pub struct ForwardSampler<'a, T> where T: NetworkProcess, @@ -43,7 +50,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { } impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { - type Item = (f64, Vec); + type Item = Sample; fn next(&mut self) -> Option { let ret_time = self.current_time.clone(); @@ -96,7 +103,7 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) + Some(Sample{t: ret_time, state: ret_state}) } } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 2e727e8..e749d69 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -72,15 +72,15 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); //Current Time and Current State - let (mut t, mut current_state) = sampler.next().unwrap(); + let mut sample = sampler.next().unwrap(); //Generate new samples until ending time is reached. - while t < t_end { - time.push(t); - events.push(current_state); - (t, current_state) = sampler.next().unwrap(); + while sample.t < t_end { + time.push(sample.t); + events.push(sample.state); + sample = sampler.next().unwrap(); } - current_state = events.last().unwrap().clone(); + let current_state = events.last().unwrap().clone(); events.push(current_state); //Add t_end as last time.