A new, blazing-fast learning engine for Continuous Time Bayesian Networks. Written in pure Rust. 🦀
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
reCTBN/src/tools.rs

94 lines
3.3 KiB

3 years ago
use ndarray::prelude::*;
use crate::network;
use crate::node;
use crate::params;
use crate::params::Params;
3 years ago
pub struct Trajectory {
time: Array1<f64>,
events: Array2<u32>
}
pub struct Dataset {
trajectories: Vec<Trajectory>
}
pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
3 years ago
let mut dataset = Dataset{
trajectories: Vec::new()
};
let node_idx: Vec<_> = net.get_node_indices().collect();
3 years ago
for _ in 0..n_trajectories {
let mut t = 0.0;
3 years ago
let mut time: Vec<f64> = Vec::new();
3 years ago
let mut events: Vec<Array1<u32>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
net.get_node(*x).get_random_state_uniform()
}).collect();
let mut next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect();
events.push(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
}).collect());
time.push(t.clone());
while t < t_end {
3 years ago
for (idx, val) in next_transitions.iter_mut().enumerate(){
if let None = val {
*val = Some(net.get_node(idx)
.get_random_residence_time(net.get_node(idx).state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state)).unwrap() + t);
}
3 years ago
};
let next_node_transition = next_transitions
.iter()
.enumerate()
.min_by(|x, y|
x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap().0;
if next_transitions[next_node_transition].unwrap() > t_end {
break
}
t = next_transitions[next_node_transition].unwrap().clone();
time.push(t.clone());
current_state[next_node_transition] = net.get_node(next_node_transition)
.get_random_state(
net.get_node(next_node_transition).
state_to_index(
&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state))
.unwrap();
3 years ago
3 years ago
events.push(Array::from_vec(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
3 years ago
}).collect()));
next_transitions[next_node_transition] = None;
for child in net.get_children_set(next_node_transition){
next_transitions[child] = None
}
}
events.push(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
}).collect());
time.push(t.clone());
3 years ago
dataset.trajectories.push(Trajectory {
3 years ago
time: Array::from_vec(time),
events: Array2::from_shape_vec((events.len(), current_state.len()), events.iter().flatten().cloned().collect()).unwrap()
});
3 years ago
}
dataset
}