Tested validate_params implemetation for DiscreteStateContinousTimeParams

pull/26/head
AlessandroBregoli 3 years ago
parent dc8013d635
commit 331c2006e9
  1. 65
      tests/params.rs

@ -1,5 +1,5 @@
use rustyCTBN::params::*;
use ndarray::prelude::*;
use rustyCTBN::params::*;
use std::collections::BTreeSet;
mod utils;
@ -7,11 +7,10 @@ mod utils;
#[macro_use]
extern crate approx;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut params = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]];
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
params.cim = Some(cim);
params
@ -62,3 +61,63 @@ fn test_random_generation_residence_time() {
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
}
#[test]
fn test_validate_params_valid_cim() {
let param = create_ternary_discrete_time_continous_param();
assert_eq!(Ok(()), param.validate_params());
}
#[test]
fn test_validate_params_cim_not_initialized() {
let param = utils::generate_discrete_time_continous_param(3);
assert_eq!(
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
param.validate_params()
);
}
#[test]
fn test_validate_params_wrong_shape() {
let mut param = utils::generate_discrete_time_continous_param(4);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
param.cim = Some(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"Incompatible shape [1, 3, 3] with domain 4"
))),
param.validate_params()
);
}
#[test]
fn test_validate_params_positive_diag() {
let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
param.cim = Some(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
))),
param.validate_params()
);
}
#[test]
fn test_validate_params_row_not_sum_to_zero() {
let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]];
param.cim = Some(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0"
))),
param.validate_params()
);
}

Loading…
Cancel
Save