ctbn unit tests

main
Alessandro Bregoli 3 years ago
parent 323c03e4e1
commit 9b98ecae23
  1. 92
      src/ctbn.rs
  2. 18
      src/lib.rs
  3. 7
      src/node.rs
  4. 11
      src/tools.rs

@ -12,6 +12,15 @@ pub struct CtbnNetwork {
nodes: Vec<node::Node>
}
impl CtbnNetwork {
pub fn init() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new()
}
}
}
impl network::Network for CtbnNetwork {
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f()));
@ -85,3 +94,86 @@ impl network::Network for CtbnNetwork {
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::Network;
use crate::node;
use crate::params;
use std::collections::BTreeSet;
fn define_binary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let params = params::DiscreteStatesContinousTimeParams::init(domain);
let n = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),name);
return n;
}
#[test]
fn define_simpe_ctbn() {
let _ = CtbnNetwork::init();
assert!(true);
}
#[test]
fn add_node_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
assert_eq!(String::from("n1"), net.get_node(n1).label);
}
#[test]
fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(n2, cs[0]);
}
#[test]
fn children_and_parents() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(n2, cs[0]);
let ps = net.get_parent_set(n2);
assert_eq!(n1, ps[0]);
}
#[test]
fn compute_index_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
let n3 = net.add_node(define_binary_node(String::from("n3"))).unwrap();
net.add_edge(n1, n2);
net.add_edge(n3, n2);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(3, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(0),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(2, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(0)]);
assert_eq!(1, idx);
}
}

@ -1,14 +1,6 @@
mod node;
mod params;
mod network;
mod ctbn;
mod tools;
pub mod node;
pub mod params;
pub mod network;
pub mod ctbn;
pub mod tools;
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}

@ -11,6 +11,13 @@ pub struct Node {
}
impl Node {
pub fn init(params: NodeType, label: String) -> Node {
Node{
params: params,
label:label
}
}
pub fn reset_params(&mut self) {
match &mut self.params {
NodeType::DiscreteStatesContinousTime(params) => {params.reset_params();}

@ -23,7 +23,7 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
for _ in 0..n_trajectories {
let mut t = 0.0;
let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Vec<u32>> = Vec::new();
let mut events: Vec<Array1<u32>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
net.get_node(*x).get_random_state_uniform()
}).collect();
@ -63,9 +63,9 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
.unwrap();
events.push(current_state.iter().map(|x| match x {
events.push(Array::from_vec(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
}).collect());
}).collect()));
next_transitions[next_node_transition] = None;
for child in net.get_children_set(next_node_transition){
@ -79,9 +79,10 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
}).collect());
time.push(t.clone());
dataset.trajectories.push(Trajectory {
time: array![time],
events: array![events]
time: Array::from_vec(time),
events: Array2::from_shape_vec((events.len(), current_state.len()), events.iter().flatten().cloned().collect()).unwrap()
});

Loading…
Cancel
Save