Refactor of sampling

pull/75/head
AlessandroBregoli 2 years ago
parent fd3b1ecfea
commit 1878f687d6
  1. 13
      reCTBN/src/sampling.rs
  2. 12
      reCTBN/src/tools.rs

@ -7,10 +7,17 @@ use crate::{
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
pub trait Sampler: Iterator {
pub struct Sample {
pub t: f64,
pub state: Vec<params::StateType>
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self);
}
pub struct ForwardSampler<'a, T>
where
T: NetworkProcess,
@ -43,7 +50,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
}
impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>);
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone();
@ -96,7 +103,7 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None;
}
Some((ret_time, ret_state))
Some(Sample{t: ret_time, state: ret_state})
}
}

@ -72,15 +72,15 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut events: Vec<Vec<params::StateType>> = Vec::new();
//Current Time and Current State
let (mut t, mut current_state) = sampler.next().unwrap();
let mut sample = sampler.next().unwrap();
//Generate new samples until ending time is reached.
while t < t_end {
time.push(t);
events.push(current_state);
(t, current_state) = sampler.next().unwrap();
while sample.t < t_end {
time.push(sample.t);
events.push(sample.state);
sample = sampler.next().unwrap();
}
current_state = events.last().unwrap().clone();
let current_state = events.last().unwrap().clone();
events.push(current_state);
//Add t_end as last time.

Loading…
Cancel
Save