|
|
|
@ -3,6 +3,7 @@ use std::collections::{HashMap, BTreeSet}; |
|
|
|
|
use rand::Rng; |
|
|
|
|
use thiserror::Error; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Error, Debug)] |
|
|
|
|
pub enum ParamsError { |
|
|
|
|
#[error("Unsupported method")] |
|
|
|
@ -73,21 +74,22 @@ impl Params for DiscreteStatesContinousTimeParams { |
|
|
|
|
Option::Some(cim) => {
|
|
|
|
|
let mut rng = rand::thread_rng(); |
|
|
|
|
let lambda = cim[[u, state, state]] * -1.0; |
|
|
|
|
let x = rng.gen_range(0.0..1.0); |
|
|
|
|
let x: f64 = rng.gen_range(0.0..1.0); |
|
|
|
|
|
|
|
|
|
let state = (cim.slice(s![u,state,..])).iter().scan((0, 0.0), |acc, &x| { |
|
|
|
|
if x > 0.0 && acc.1 < x { |
|
|
|
|
acc.0 += 1; |
|
|
|
|
let next_state = cim.slice(s![u,state,..]).map(|x| x / lambda).iter().fold((0, 0.0), |mut acc, ele| {
|
|
|
|
|
if &acc.1 + ele < x && ele > &0.0{ |
|
|
|
|
acc.1 += x; |
|
|
|
|
return Some(*acc); |
|
|
|
|
} else if acc.1 < x { |
|
|
|
|
acc.0 += 1; |
|
|
|
|
return Some(*acc); |
|
|
|
|
}
|
|
|
|
|
None |
|
|
|
|
acc}); |
|
|
|
|
|
|
|
|
|
}).last(); |
|
|
|
|
Ok(StateType::Discrete(state.unwrap().0)) |
|
|
|
|
let next_state = if next_state.0 < state { |
|
|
|
|
next_state.0 |
|
|
|
|
} else { |
|
|
|
|
next_state.0 + 1 |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
Ok(StateType::Discrete(next_state as u32)) |
|
|
|
|
|
|
|
|
|
}, |
|
|
|
|
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) |
|
|
|
@ -105,3 +107,70 @@ impl Params for DiscreteStatesContinousTimeParams { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[cfg(test)] |
|
|
|
|
mod tests { |
|
|
|
|
use super::*; |
|
|
|
|
use ndarray::prelude::*; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { |
|
|
|
|
let mut domain = BTreeSet::new(); |
|
|
|
|
domain.insert(String::from("A")); |
|
|
|
|
domain.insert(String::from("B")); |
|
|
|
|
domain.insert(String::from("C")); |
|
|
|
|
let mut params = DiscreteStatesContinousTimeParams::init(domain); |
|
|
|
|
|
|
|
|
|
let cim = array![ |
|
|
|
|
[ |
|
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
|
[1.0, -5.0, 4.0], |
|
|
|
|
[3.2, 1.7, -4.0] |
|
|
|
|
]]; |
|
|
|
|
|
|
|
|
|
params.cim = Some(cim); |
|
|
|
|
params |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn test_uniform_generation() { |
|
|
|
|
let param = create_ternary_discrete_time_continous_param(); |
|
|
|
|
let mut states = Array1::<u32>::zeros(10000); |
|
|
|
|
|
|
|
|
|
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state_uniform() { |
|
|
|
|
val |
|
|
|
|
} else {panic!()}); |
|
|
|
|
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
|
|
|
|
|
|
|
|
|
assert_relative_eq!(1.0/3.0, zero_freq, epsilon=0.01); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn test_random_generation_state() { |
|
|
|
|
let param = create_ternary_discrete_time_continous_param(); |
|
|
|
|
let mut states = Array1::<u32>::zeros(10000); |
|
|
|
|
|
|
|
|
|
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { |
|
|
|
|
val |
|
|
|
|
} else {panic!()}); |
|
|
|
|
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; |
|
|
|
|
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
|
|
|
|
|
|
|
|
|
assert_relative_eq!(4.0/5.0, two_freq, epsilon=0.01); |
|
|
|
|
assert_relative_eq!(1.0/5.0, zero_freq, epsilon=0.01); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn test_random_generation_residence_time() { |
|
|
|
|
let param = create_ternary_discrete_time_continous_param(); |
|
|
|
|
let mut states = Array1::<f64>::zeros(10000); |
|
|
|
|
|
|
|
|
|
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap() ); |
|
|
|
|
|
|
|
|
|
assert_relative_eq!(1.0/5.0, states.mean().unwrap(), epsilon=0.01); |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|