10 feature importance sampling for ctbn #59

Merged
AlessandroBregoli merged 4 commits from 10-feature-importance-sampling-for-ctbn into dev 2 years ago
  1. 1
      src/lib.rs
  2. 111
      src/sampling.rs
  3. 121
      src/tools.rs
  4. 2
      tests/parameter_learning.rs

@ -6,5 +6,6 @@ pub mod ctbn;
pub mod network; pub mod network;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;

@ -0,0 +1,111 @@
use crate::{
network::Network,
params::{self, ParamsTrait},
};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
pub trait Sampler: Iterator {
fn reset(&mut self);
}
pub struct ForwardSampler<'a, T>
where
T: Network,
{
net: &'a T,
rng: ChaCha8Rng,
current_time: f64,
current_state: Vec<params::StateType>,
next_transitions: Vec<Option<f64>>,
}
impl<'a, T: Network> ForwardSampler<'a, T> {
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed),
//Otherwise create a new random generator using the method `from_entropy`
None => SeedableRng::from_entropy(),
};
let mut fs = ForwardSampler {
net: net,
rng: rng,
current_time: 0.0,
current_state: vec![],
next_transitions: vec![],
};
fs.reset();
return fs;
}
}
impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>);
fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone();
let ret_state = self.current_state.clone();
for (idx, val) in self.next_transitions.iter_mut().enumerate() {
if let None = val {
*val = Some(
self.net
.get_node(idx)
.get_random_residence_time(
self.net
.get_node(idx)
.state_to_index(&self.current_state[idx]),
self.net.get_param_index_network(idx, &self.current_state),
&mut self.rng,
)
.unwrap()
+ self.current_time,
);
}
}
let next_node_transition = self
.next_transitions
.iter()
.enumerate()
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap()
.0;
self.current_time = self.next_transitions[next_node_transition].unwrap().clone();
self.current_state[next_node_transition] = self
.net
.get_node(next_node_transition)
.get_random_state(
self.net
.get_node(next_node_transition)
.state_to_index(&self.current_state[next_node_transition]),
self.net
.get_param_index_network(next_node_transition, &self.current_state),
&mut self.rng,
)
.unwrap();
self.next_transitions[next_node_transition] = None;
for child in self.net.get_children_set(next_node_transition) {
self.next_transitions[child] = None;
}
Some((ret_time, ret_state))
}
}
impl<'a, T: Network> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) {
self.current_time = 0.0;
self.current_state = self
.net
.get_node_indices()
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng))
.collect();
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect();
}
}

@ -1,8 +1,6 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait; use crate::sampling::{ForwardSampler, Sampler};
use crate::{network, params}; use crate::{network, params};
pub struct Trajectory { pub struct Trajectory {
@ -61,114 +59,28 @@ pub fn trajectory_generator<T: network::Network>(
let mut trajectories: Vec<Trajectory> = Vec::new(); let mut trajectories: Vec<Trajectory> = Vec::new();
//Random Generator object //Random Generator object
let mut rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed),
//Otherwise create a new random generator using the method `from_entropy`
None => SeedableRng::from_entropy(),
};
let mut sampler = ForwardSampler::new(net, seed);
//Each iteration generate one trajectory //Each iteration generate one trajectory
for _ in 0..n_trajectories { for _ in 0..n_trajectories {
//Current time of the sampling process
let mut t = 0.0;
//History of all the moments in which something changed //History of all the moments in which something changed
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut current_state: Vec<params::StateType> = net let mut events: Vec<Vec<params::StateType>> = Vec::new();
.get_node_indices()
.map(|x| net.get_node(x).get_random_state_uniform(&mut rng))
.collect();
//History of all the configurations of the process variables.
let mut events: Vec<Array1<usize>> = Vec::new();
//Vector containing to time to the next transition for each variable.
let mut next_transitions: Vec<Option<f64>> =
net.get_node_indices().map(|_| Option::None).collect();
//Add the starting time for the trajectory. //Current Time and Current State
time.push(t.clone()); let (mut t, mut current_state) = sampler.next().unwrap();
//Add the starting configuration of the trajectory.
events.push(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
//Generate new samples until ending time is reached. //Generate new samples until ending time is reached.
while t < t_end { while t < t_end {
//Generate the next transition time for each uninitialized variable. time.push(t);
for (idx, val) in next_transitions.iter_mut().enumerate() { events.push(current_state);
if let None = val { (t, current_state) = sampler.next().unwrap();
*val = Some(
net.get_node(idx)
.get_random_residence_time(
net.get_node(idx).state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state),
&mut rng,
)
.unwrap()
+ t,
);
}
}
//Get the variable with the smallest transition time.
let next_node_transition = next_transitions
.iter()
.enumerate()
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap()
.0;
//Check if the next transition take place after the ending time.
if next_transitions[next_node_transition].unwrap() > t_end {
break;
}
//Get the time in which the next transition occurs.
t = next_transitions[next_node_transition].unwrap().clone();
//Add the transition time to next
time.push(t.clone());
//Compute the new state of the transitioning variable.
current_state[next_node_transition] = net
.get_node(next_node_transition)
.get_random_state(
net.get_node(next_node_transition)
.state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state),
&mut rng,
)
.unwrap();
//Add the new state to events
events.push(Array::from_vec(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
));
//Reset the next transition time for the transitioning node.
next_transitions[next_node_transition] = None;
//Reset the next transition time for each child of the transitioning node.
for child in net.get_children_set(next_node_transition) {
next_transitions[child] = None
}
} }
//Add current_state as last state. current_state = events.last().unwrap().clone();
events.push( events.push(current_state);
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
//Add t_end as last time. //Add t_end as last time.
time.push(t_end.clone()); time.push(t_end.clone());
@ -176,11 +88,18 @@ pub fn trajectory_generator<T: network::Network>(
trajectories.push(Trajectory::new( trajectories.push(Trajectory::new(
Array::from_vec(time), Array::from_vec(time),
Array2::from_shape_vec( Array2::from_shape_vec(
(events.len(), current_state.len()), (events.len(), events.last().unwrap().len()),
events.iter().flatten().cloned().collect(), events
.iter()
.flatten()
.map(|x| match x {
params::StateType::Discrete(x) => x.clone(),
})
.collect(),
) )
.unwrap(), .unwrap(),
)); ));
sampler.reset();
} }
//Return a dataset object with the sampled trajectories. //Return a dataset object with the sampled trajectories.
Dataset::new(trajectories) Dataset::new(trajectories)

@ -427,7 +427,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
[0.8, 0.6, 0.2, -1.6] [0.8, 0.6, 0.2, -1.6]
], ],
]), ]),
0.1 0.2
)); ));
} }

Loading…
Cancel
Save