Refactoring tests

pull/19/head
AlessandroBregoli 3 years ago
parent dc53e5167e
commit 4adfbfa4e4
  1. 109
      src/ctbn.rs
  2. 109
      src/parameter_learning.rs
  3. 63
      src/params.rs
  4. 48
      src/tools.rs
  5. 99
      tests/ctbn.rs
  6. 95
      tests/parameter_learning.rs
  7. 64
      tests/params.rs
  8. 43
      tests/tools.rs
  9. 16
      tests/utils.rs

@ -162,112 +162,3 @@ impl network::Network for CtbnNetwork {
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::Network;
use crate::node;
use crate::params;
use std::collections::BTreeSet;
fn define_binary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let param = params::DiscreteStatesContinousTimeParams::init(domain) ;
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}
#[test]
fn define_simpe_ctbn() {
let _ = CtbnNetwork::init();
assert!(true);
}
#[test]
fn add_node_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
assert_eq!(String::from("n1"), net.get_node(n1).label);
}
#[test]
fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
}
#[test]
fn children_and_parents() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
let ps = net.get_parent_set(n2);
assert_eq!(&n1, ps.iter().next().unwrap());
}
#[test]
fn compute_index_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
let n3 = net.add_node(define_binary_node(String::from("n3"))).unwrap();
net.add_edge(n1, n2);
net.add_edge(n3, n2);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(3, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(0),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(2, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(0)]);
assert_eq!(1, idx);
}
#[test]
fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
let n3 = net.add_node(define_binary_node(String::from("n3"))).unwrap();
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1]));
assert_eq!(0, idx);
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1,2]));
assert_eq!(2, idx);
}
}

@ -86,112 +86,3 @@ pub fn MLE(
return (CIM, M, T);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ctbn::*;
use crate::network::Network;
use crate::node;
use crate::params;
use ndarray::arr3;
use std::collections::BTreeSet;
use tools::*;
fn define_binary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let param = params::DiscreteStatesContinousTimeParams::init(domain);
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}
fn define_ternary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
domain.insert(String::from("C"));
let param = params::DiscreteStatesContinousTimeParams::init(domain);
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}
#[test]
fn learn_binary_cim_MLE() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(define_binary_node(String::from("n1")))
.unwrap();
let n2 = net
.add_node(define_binary_node(String::from("n2")))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 10, 100.0);
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2);
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2);
}
#[test]
fn learn_ternary_cim_MLE() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(define_ternary_node(String::from("n1")))
.unwrap();
let n2 = net
.add_node(define_ternary_node(String::from("n2")))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 100, 200.0);
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2);
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2);
}
}

@ -157,66 +157,3 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
}
}
#[cfg(test)]
mod tests {
use super::*;
//use ndarray::prelude::*;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
domain.insert(String::from("C"));
let mut params = DiscreteStatesContinousTimeParams::init(domain);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]];
params.cim = Some(cim);
params
}
#[test]
fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
val
} else {
panic!()
}
});
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
val
} else {
panic!()
}
});
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01);
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap());
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
}
}

@ -93,51 +93,3 @@ pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::Network;
use crate::ctbn::*;
use crate::node;
use crate::params;
use std::collections::BTreeSet;
use ndarray::arr3;
fn define_binary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let param = params::DiscreteStatesContinousTimeParams::init(domain) ;
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}
#[test]
fn run_sampling() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[
[[-1.0,1.0],[4.0,-4.0]],
[[-6.0,6.0],[2.0,-2.0]]]));
}
}
let data = trajectory_generator(Box::new(&net), 4, 1.0);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);
}
}

@ -0,0 +1,99 @@
mod utils;
use utils::generate_discrete_time_continous_node;
use rustyCTBN::network::Network;
use rustyCTBN::node;
use rustyCTBN::params;
use std::collections::BTreeSet;
use rustyCTBN::ctbn::*;
#[test]
fn define_simpe_ctbn() {
let _ = CtbnNetwork::init();
assert!(true);
}
#[test]
fn add_node_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
assert_eq!(String::from("n1"), net.get_node(n1).label);
}
#[test]
fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
}
#[test]
fn children_and_parents() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
let ps = net.get_parent_set(n2);
assert_eq!(&n1, ps.iter().next().unwrap());
}
#[test]
fn compute_index_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();
net.add_edge(n1, n2);
net.add_edge(n3, n2);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(3, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(0),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(2, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(0)]);
assert_eq!(1, idx);
}
#[test]
fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::init();
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1]));
assert_eq!(0, idx);
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1,2]));
assert_eq!(2, idx);
}

@ -0,0 +1,95 @@
mod utils;
use utils::*;
use rustyCTBN::parameter_learning::*;
use rustyCTBN::ctbn::*;
use rustyCTBN::network::Network;
use rustyCTBN::node;
use rustyCTBN::params;
use rustyCTBN::tools::*;
use ndarray::arr3;
use std::collections::BTreeSet;
#[macro_use]
extern crate approx;
#[test]
fn learn_binary_cim_MLE() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),2))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 100, 100.0);
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2);
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2);
}
#[test]
fn learn_ternary_cim_MLE() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 100, 200.0);
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2);
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2);
}

@ -0,0 +1,64 @@
use rustyCTBN::params::*;
use ndarray::prelude::*;
use std::collections::BTreeSet;
mod utils;
#[macro_use]
extern crate approx;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut params = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]];
params.cim = Some(cim);
params
}
#[test]
fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
val
} else {
panic!()
}
});
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
val
} else {
panic!()
}
});
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01);
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap());
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
}

@ -0,0 +1,43 @@
use rustyCTBN::tools::*;
use rustyCTBN::network::Network;
use rustyCTBN::ctbn::*;
use rustyCTBN::node;
use rustyCTBN::params;
use std::collections::BTreeSet;
use ndarray::arr3;
#[macro_use]
extern crate approx;
mod utils;
#[test]
fn run_sampling() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[
[[-1.0,1.0],[4.0,-4.0]],
[[-6.0,6.0],[2.0,-2.0]]]));
}
}
let data = trajectory_generator(Box::new(&net), 4, 1.0);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);
}

@ -0,0 +1,16 @@
use rustyCTBN::params;
use rustyCTBN::node;
use std::collections::BTreeSet;
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node {
node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name)
}
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
let mut domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::init(domain)
}
Loading…
Cancel
Save