diff --git a/Cargo.toml b/Cargo.toml index b5d70a3..c1e3a53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,6 @@ ndarray = "*" thiserror = "*" rand = "*" bimap = "*" + +[dev-dependencies] +approx = "*" diff --git a/src/lib.rs b/src/lib.rs index ee48d80..b2e9365 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,7 @@ +#[cfg(test)] +#[macro_use] +extern crate approx; + pub mod node; pub mod params; pub mod network; diff --git a/src/params.rs b/src/params.rs index f3793a3..8825e86 100644 --- a/src/params.rs +++ b/src/params.rs @@ -3,6 +3,7 @@ use std::collections::{HashMap, BTreeSet}; use rand::Rng; use thiserror::Error; + #[derive(Error, Debug)] pub enum ParamsError { #[error("Unsupported method")] @@ -73,21 +74,22 @@ impl Params for DiscreteStatesContinousTimeParams { Option::Some(cim) => { let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; - let x = rng.gen_range(0.0..1.0); + let x: f64 = rng.gen_range(0.0..1.0); - let state = (cim.slice(s![u,state,..])).iter().scan((0, 0.0), |acc, &x| { - if x > 0.0 && acc.1 < x { - acc.0 += 1; + let next_state = cim.slice(s![u,state,..]).map(|x| x / lambda).iter().fold((0, 0.0), |mut acc, ele| { + if &acc.1 + ele < x && ele > &0.0{ acc.1 += x; - return Some(*acc); - } else if acc.1 < x { acc.0 += 1; - return Some(*acc); - } - None + } + acc}); + + let next_state = if next_state.0 < state { + next_state.0 + } else { + next_state.0 + 1 + }; - }).last(); - Ok(StateType::Discrete(state.unwrap().0)) + Ok(StateType::Discrete(next_state as u32)) }, Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) @@ -105,3 +107,70 @@ impl Params for DiscreteStatesContinousTimeParams { } } } + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::prelude::*; + + + fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { + let mut domain = BTreeSet::new(); + domain.insert(String::from("A")); + domain.insert(String::from("B")); + domain.insert(String::from("C")); + let mut params = DiscreteStatesContinousTimeParams::init(domain); + + let cim = array![ + [ + [-3.0, 2.0, 1.0], + [1.0, -5.0, 4.0], + [3.2, 1.7, -4.0] + ]]; + + params.cim = Some(cim); + params + } + + #[test] + fn test_uniform_generation() { + let param = create_ternary_discrete_time_continous_param(); + let mut states = Array1::::zeros(10000); + + states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state_uniform() { + val + } else {panic!()}); + let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; + + assert_relative_eq!(1.0/3.0, zero_freq, epsilon=0.01); + } + + + #[test] + fn test_random_generation_state() { + let param = create_ternary_discrete_time_continous_param(); + let mut states = Array1::::zeros(10000); + + states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { + val + } else {panic!()}); + let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; + let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; + + assert_relative_eq!(4.0/5.0, two_freq, epsilon=0.01); + assert_relative_eq!(1.0/5.0, zero_freq, epsilon=0.01); + } + + + #[test] + fn test_random_generation_residence_time() { + let param = create_ternary_discrete_time_continous_param(); + let mut states = Array1::::zeros(10000); + + states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap() ); + + assert_relative_eq!(1.0/5.0, states.mean().unwrap(), epsilon=0.01); + + } + +}