|
|
@ -1,5 +1,5 @@ |
|
|
|
use rustyCTBN::params::*; |
|
|
|
|
|
|
|
use ndarray::prelude::*; |
|
|
|
use ndarray::prelude::*; |
|
|
|
|
|
|
|
use rustyCTBN::params::*; |
|
|
|
use std::collections::BTreeSet; |
|
|
|
use std::collections::BTreeSet; |
|
|
|
|
|
|
|
|
|
|
|
mod utils; |
|
|
|
mod utils; |
|
|
@ -7,11 +7,10 @@ mod utils; |
|
|
|
#[macro_use] |
|
|
|
#[macro_use] |
|
|
|
extern crate approx; |
|
|
|
extern crate approx; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { |
|
|
|
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { |
|
|
|
let mut params = utils::generate_discrete_time_continous_param(3); |
|
|
|
let mut params = utils::generate_discrete_time_continous_param(3); |
|
|
|
|
|
|
|
|
|
|
|
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; |
|
|
|
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; |
|
|
|
|
|
|
|
|
|
|
|
params.cim = Some(cim); |
|
|
|
params.cim = Some(cim); |
|
|
|
params |
|
|
|
params |
|
|
@ -62,3 +61,63 @@ fn test_random_generation_residence_time() { |
|
|
|
|
|
|
|
|
|
|
|
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); |
|
|
|
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
|
|
|
fn test_validate_params_valid_cim() { |
|
|
|
|
|
|
|
let param = create_ternary_discrete_time_continous_param(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert_eq!(Ok(()), param.validate_params()); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
|
|
|
fn test_validate_params_cim_not_initialized() { |
|
|
|
|
|
|
|
let param = utils::generate_discrete_time_continous_param(3); |
|
|
|
|
|
|
|
assert_eq!( |
|
|
|
|
|
|
|
Err(ParamsError::ParametersNotInitialized(String::from( |
|
|
|
|
|
|
|
"CIM not initialized", |
|
|
|
|
|
|
|
))), |
|
|
|
|
|
|
|
param.validate_params() |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
|
|
|
fn test_validate_params_wrong_shape() { |
|
|
|
|
|
|
|
let mut param = utils::generate_discrete_time_continous_param(4); |
|
|
|
|
|
|
|
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; |
|
|
|
|
|
|
|
param.cim = Some(cim); |
|
|
|
|
|
|
|
assert_eq!( |
|
|
|
|
|
|
|
Err(ParamsError::InvalidCIM(String::from( |
|
|
|
|
|
|
|
"Incompatible shape [1, 3, 3] with domain 4" |
|
|
|
|
|
|
|
))), |
|
|
|
|
|
|
|
param.validate_params() |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
|
|
|
fn test_validate_params_positive_diag() { |
|
|
|
|
|
|
|
let mut param = utils::generate_discrete_time_continous_param(3); |
|
|
|
|
|
|
|
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; |
|
|
|
|
|
|
|
param.cim = Some(cim); |
|
|
|
|
|
|
|
assert_eq!( |
|
|
|
|
|
|
|
Err(ParamsError::InvalidCIM(String::from( |
|
|
|
|
|
|
|
"The diagonal of each cim must be non-positive", |
|
|
|
|
|
|
|
))), |
|
|
|
|
|
|
|
param.validate_params() |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
|
|
|
fn test_validate_params_row_not_sum_to_zero() { |
|
|
|
|
|
|
|
let mut param = utils::generate_discrete_time_continous_param(3); |
|
|
|
|
|
|
|
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; |
|
|
|
|
|
|
|
param.cim = Some(cim); |
|
|
|
|
|
|
|
assert_eq!( |
|
|
|
|
|
|
|
Err(ParamsError::InvalidCIM(String::from( |
|
|
|
|
|
|
|
"The sum of each row must be 0" |
|
|
|
|
|
|
|
))), |
|
|
|
|
|
|
|
param.validate_params() |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
} |
|
|
|