Refactor of sampling

pull/75/head
AlessandroBregoli 2 years ago
parent fd3b1ecfea
commit 1878f687d6
  1. 13
      reCTBN/src/sampling.rs
  2. 12
      reCTBN/src/tools.rs

@ -7,10 +7,17 @@ use crate::{
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
pub trait Sampler: Iterator { pub struct Sample {
pub t: f64,
pub state: Vec<params::StateType>
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self); fn reset(&mut self);
} }
pub struct ForwardSampler<'a, T> pub struct ForwardSampler<'a, T>
where where
T: NetworkProcess, T: NetworkProcess,
@ -43,7 +50,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
} }
impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>); type Item = Sample;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone(); let ret_time = self.current_time.clone();
@ -96,7 +103,7 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None; self.next_transitions[child] = None;
} }
Some((ret_time, ret_state)) Some(Sample{t: ret_time, state: ret_state})
} }
} }

@ -72,15 +72,15 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut events: Vec<Vec<params::StateType>> = Vec::new(); let mut events: Vec<Vec<params::StateType>> = Vec::new();
//Current Time and Current State //Current Time and Current State
let (mut t, mut current_state) = sampler.next().unwrap(); let mut sample = sampler.next().unwrap();
//Generate new samples until ending time is reached. //Generate new samples until ending time is reached.
while t < t_end { while sample.t < t_end {
time.push(t); time.push(sample.t);
events.push(current_state); events.push(sample.state);
(t, current_state) = sampler.next().unwrap(); sample = sampler.next().unwrap();
} }
current_state = events.last().unwrap().clone(); let current_state = events.last().unwrap().clone();
events.push(current_state); events.push(current_state);
//Add t_end as last time. //Add t_end as last time.

Loading…
Cancel
Save