From 331c2006e9633c1b1e6630eded47ae9515f0c01e Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 23 Mar 2022 16:29:02 +0100 Subject: [PATCH] Tested validate_params implemetation for DiscreteStateContinousTimeParams --- tests/params.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/tests/params.rs b/tests/params.rs index ed601b2..6901293 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,5 +1,5 @@ -use rustyCTBN::params::*; use ndarray::prelude::*; +use rustyCTBN::params::*; use std::collections::BTreeSet; mod utils; @@ -7,11 +7,10 @@ mod utils; #[macro_use] extern crate approx; - fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { let mut params = utils::generate_discrete_time_continous_param(3); - let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; params.cim = Some(cim); params @@ -62,3 +61,63 @@ fn test_random_generation_residence_time() { assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } + +#[test] +fn test_validate_params_valid_cim() { + let param = create_ternary_discrete_time_continous_param(); + + assert_eq!(Ok(()), param.validate_params()); +} + +#[test] +fn test_validate_params_cim_not_initialized() { + let param = utils::generate_discrete_time_continous_param(3); + assert_eq!( + Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_wrong_shape() { + let mut param = utils::generate_discrete_time_continous_param(4); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "Incompatible shape [1, 3, 3] with domain 4" + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_positive_diag() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_row_not_sum_to_zero() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0" + ))), + param.validate_params() + ); +}