Merge pull request #59 from AlessandroBregoli/10-feature-importance-sampling-for-ctbn
10 feature importance sampling for ctbnpull/61/head
commit
2f5a80dccb
@ -0,0 +1,111 @@ |
||||
use crate::{ |
||||
network::Network, |
||||
params::{self, ParamsTrait}, |
||||
}; |
||||
use rand::SeedableRng; |
||||
use rand_chacha::ChaCha8Rng; |
||||
|
||||
pub trait Sampler: Iterator { |
||||
fn reset(&mut self); |
||||
} |
||||
|
||||
pub struct ForwardSampler<'a, T> |
||||
where |
||||
T: Network, |
||||
{ |
||||
net: &'a T, |
||||
rng: ChaCha8Rng, |
||||
current_time: f64, |
||||
current_state: Vec<params::StateType>, |
||||
next_transitions: Vec<Option<f64>>, |
||||
} |
||||
|
||||
impl<'a, T: Network> ForwardSampler<'a, T> { |
||||
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> { |
||||
let rng: ChaCha8Rng = match seed { |
||||
//If a seed is present use it to initialize the random generator.
|
||||
Some(seed) => SeedableRng::seed_from_u64(seed), |
||||
//Otherwise create a new random generator using the method `from_entropy`
|
||||
None => SeedableRng::from_entropy(), |
||||
}; |
||||
let mut fs = ForwardSampler { |
||||
net: net, |
||||
rng: rng, |
||||
current_time: 0.0, |
||||
current_state: vec![], |
||||
next_transitions: vec![], |
||||
}; |
||||
fs.reset(); |
||||
return fs; |
||||
} |
||||
} |
||||
|
||||
impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { |
||||
type Item = (f64, Vec<params::StateType>); |
||||
|
||||
fn next(&mut self) -> Option<Self::Item> { |
||||
let ret_time = self.current_time.clone(); |
||||
let ret_state = self.current_state.clone(); |
||||
|
||||
for (idx, val) in self.next_transitions.iter_mut().enumerate() { |
||||
if let None = val { |
||||
*val = Some( |
||||
self.net |
||||
.get_node(idx) |
||||
.get_random_residence_time( |
||||
self.net |
||||
.get_node(idx) |
||||
.state_to_index(&self.current_state[idx]), |
||||
self.net.get_param_index_network(idx, &self.current_state), |
||||
&mut self.rng, |
||||
) |
||||
.unwrap() |
||||
+ self.current_time, |
||||
); |
||||
} |
||||
} |
||||
|
||||
let next_node_transition = self |
||||
.next_transitions |
||||
.iter() |
||||
.enumerate() |
||||
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) |
||||
.unwrap() |
||||
.0; |
||||
|
||||
self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); |
||||
|
||||
self.current_state[next_node_transition] = self |
||||
.net |
||||
.get_node(next_node_transition) |
||||
.get_random_state( |
||||
self.net |
||||
.get_node(next_node_transition) |
||||
.state_to_index(&self.current_state[next_node_transition]), |
||||
self.net |
||||
.get_param_index_network(next_node_transition, &self.current_state), |
||||
&mut self.rng, |
||||
) |
||||
.unwrap(); |
||||
|
||||
self.next_transitions[next_node_transition] = None; |
||||
|
||||
for child in self.net.get_children_set(next_node_transition) { |
||||
self.next_transitions[child] = None; |
||||
} |
||||
|
||||
Some((ret_time, ret_state)) |
||||
} |
||||
} |
||||
|
||||
impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { |
||||
fn reset(&mut self) { |
||||
self.current_time = 0.0; |
||||
self.current_state = self |
||||
.net |
||||
.get_node_indices() |
||||
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) |
||||
.collect(); |
||||
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); |
||||
} |
||||
} |
Loading…
Reference in new issue