parameters unit tests

main
Alessandro Bregoli 3 years ago
parent 9b98ecae23
commit c30c496b6b
  1. 3
      Cargo.toml
  2. 4
      src/lib.rs
  3. 91
      src/params.rs

@ -11,3 +11,6 @@ ndarray = "*"
thiserror = "*" thiserror = "*"
rand = "*" rand = "*"
bimap = "*" bimap = "*"
[dev-dependencies]
approx = "*"

@ -1,3 +1,7 @@
#[cfg(test)]
#[macro_use]
extern crate approx;
pub mod node; pub mod node;
pub mod params; pub mod params;
pub mod network; pub mod network;

@ -3,6 +3,7 @@ use std::collections::{HashMap, BTreeSet};
use rand::Rng; use rand::Rng;
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ParamsError { pub enum ParamsError {
#[error("Unsupported method")] #[error("Unsupported method")]
@ -73,21 +74,22 @@ impl Params for DiscreteStatesContinousTimeParams {
Option::Some(cim) => { Option::Some(cim) => {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let x = rng.gen_range(0.0..1.0); let x: f64 = rng.gen_range(0.0..1.0);
let state = (cim.slice(s![u,state,..])).iter().scan((0, 0.0), |acc, &x| { let next_state = cim.slice(s![u,state,..]).map(|x| x / lambda).iter().fold((0, 0.0), |mut acc, ele| {
if x > 0.0 && acc.1 < x { if &acc.1 + ele < x && ele > &0.0{
acc.0 += 1;
acc.1 += x; acc.1 += x;
return Some(*acc);
} else if acc.1 < x {
acc.0 += 1; acc.0 += 1;
return Some(*acc); }
} acc});
None
let next_state = if next_state.0 < state {
next_state.0
} else {
next_state.0 + 1
};
}).last(); Ok(StateType::Discrete(next_state as u32))
Ok(StateType::Discrete(state.unwrap().0))
}, },
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized")))
@ -105,3 +107,70 @@ impl Params for DiscreteStatesContinousTimeParams {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use ndarray::prelude::*;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
domain.insert(String::from("C"));
let mut params = DiscreteStatesContinousTimeParams::init(domain);
let cim = array![
[
[-3.0, 2.0, 1.0],
[1.0, -5.0, 4.0],
[3.2, 1.7, -4.0]
]];
params.cim = Some(cim);
params
}
#[test]
fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000);
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state_uniform() {
val
} else {panic!()});
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(1.0/3.0, zero_freq, epsilon=0.01);
}
#[test]
fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000);
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
val
} else {panic!()});
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(4.0/5.0, two_freq, epsilon=0.01);
assert_relative_eq!(1.0/5.0, zero_freq, epsilon=0.01);
}
#[test]
fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap() );
assert_relative_eq!(1.0/5.0, states.mean().unwrap(), epsilon=0.01);
}
}

Loading…
Cancel
Save