From dc8013d6352b2498e7112a7ad11c2a111cd30834 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 23 Mar 2022 16:28:17 +0100 Subject: [PATCH 1/6] Added method validate_params to ParamsTrait. --- src/params.rs | 54 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/params.rs b/src/params.rs index c5a9acf..3efdc6f 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,16 +1,18 @@ +use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; -use enum_dispatch::enum_dispatch; /// Error types for trait Params -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum ParamsError { #[error("Unsupported method")] UnsupportedMethod(String), #[error("Paramiters not initialized")] ParametersNotInitialized(String), + #[error("Invalid cim for parameter")] + InvalidCIM(String), } /// Allowed type of states @@ -43,6 +45,9 @@ pub trait ParamsTrait { /// Index used by discrete node to represents their states as usize. fn state_to_index(&self, state: &StateType) -> usize; + + /// Validate parameters against domain + fn validate_params(&self) -> Result<(), ParamsError>; } /// The Params enum is the core element for building different types of nodes. The goal is to @@ -52,7 +57,6 @@ pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), } - /// DiscreteStatesContinousTime. /// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// following elements: @@ -157,5 +161,47 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(val) => val.clone() as usize, } } -} + fn validate_params(&self) -> Result<(), ParamsError> { + let domain_size = self.domain.len(); + + // Check if the cim is initialized + if let None = self.cim { + return Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))); + } + let cim = self.cim.as_ref().unwrap(); + // Check if the inner dimensions of the cim are equal to the cardinality of the variable + if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { + return Err(ParamsError::InvalidCIM(format!( + "Incompatible shape {:?} with domain {:?}", + cim.shape(), + domain_size + ))); + } + + // Check if the diagonal of each cim is non-positive + if cim + .axis_iter(Axis(0)) + .any(|x| x.diag().iter().any(|x| x >= &0.0)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))); + } + + // Check if each row sum up to 0 + let zeros = Array::zeros(domain_size); + if cim + .axis_iter(Axis(0)) + .any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0", + ))); + } + + return Ok(()); + } +} -- 2.36.3 From 331c2006e9633c1b1e6630eded47ae9515f0c01e Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 23 Mar 2022 16:29:02 +0100 Subject: [PATCH 2/6] Tested validate_params implemetation for DiscreteStateContinousTimeParams --- tests/params.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/tests/params.rs b/tests/params.rs index ed601b2..6901293 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,5 +1,5 @@ -use rustyCTBN::params::*; use ndarray::prelude::*; +use rustyCTBN::params::*; use std::collections::BTreeSet; mod utils; @@ -7,11 +7,10 @@ mod utils; #[macro_use] extern crate approx; - fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { let mut params = utils::generate_discrete_time_continous_param(3); - let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; params.cim = Some(cim); params @@ -62,3 +61,63 @@ fn test_random_generation_residence_time() { assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } + +#[test] +fn test_validate_params_valid_cim() { + let param = create_ternary_discrete_time_continous_param(); + + assert_eq!(Ok(()), param.validate_params()); +} + +#[test] +fn test_validate_params_cim_not_initialized() { + let param = utils::generate_discrete_time_continous_param(3); + assert_eq!( + Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_wrong_shape() { + let mut param = utils::generate_discrete_time_continous_param(4); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "Incompatible shape [1, 3, 3] with domain 4" + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_positive_diag() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_row_not_sum_to_zero() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0" + ))), + param.validate_params() + ); +} -- 2.36.3 From c178862664e0aa216f28da010573b31ca617893d Mon Sep 17 00:00:00 2001 From: Alessandro Bregoli Date: Sat, 26 Mar 2022 10:33:30 +0100 Subject: [PATCH 3/6] Enforced correct set of parameters (cim) when inserted manually --- src/params.rs | 50 ++++++++++++++++++++++++++++++------- tests/parameter_learning.rs | 38 ++++++++++++++-------------- tests/params.rs | 14 +++++------ tests/tools.rs | 4 +-- tests/utils.rs | 2 +- 5 files changed, 70 insertions(+), 38 deletions(-) diff --git a/src/params.rs b/src/params.rs index 3efdc6f..154fd12 100644 --- a/src/params.rs +++ b/src/params.rs @@ -69,21 +69,55 @@ pub enum Params { /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set pub struct DiscreteStatesContinousTimeParams { - pub domain: BTreeSet, - pub cim: Option>, - pub transitions: Option>, - pub residence_time: Option>, + domain: BTreeSet, + cim: Option>, + transitions: Option>, + residence_time: Option>, } impl DiscreteStatesContinousTimeParams { pub fn init(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams { - domain: domain, + domain, cim: Option::None, transitions: Option::None, residence_time: Option::None, } } + + pub fn get_cim(&self) -> &Option> { + &self.cim + } + + pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError>{ + self.cim = Some(cim); + match self.validate_params() { + Ok(()) => Ok(()), + Err(e) => { + self.cim = None; + Err(e) + } + } + } + + pub fn get_transitions(&self) -> &Option> { + &self.transitions + } + + + pub fn set_transitions(&mut self, transitions: Array3) { + self.transitions = Some(transitions); + } + + pub fn get_residence_time(&self) -> &Option> { + &self.residence_time + } + + + pub fn set_residence_time(&mut self, residence_time: Array2) { + self.residence_time = Some(residence_time); + } + } impl ParamsTrait for DiscreteStatesContinousTimeParams { @@ -192,10 +226,8 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } // Check if each row sum up to 0 - let zeros = Array::zeros(domain_size); - if cim - .axis_iter(Axis(0)) - .any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE)) + if cim.sum_axis(Axis(2)).iter() + .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index a5cca51..d6b8fd2 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -27,16 +27,16 @@ fn learn_binary_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ])); + ]))); } } @@ -77,19 +77,19 @@ fn learn_ternary_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } @@ -132,19 +132,19 @@ fn learn_ternary_cim_no_parents (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } @@ -192,28 +192,28 @@ fn learn_mixed_discrete_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } match &mut net.get_node_mut(n3).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], + [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], @@ -223,12 +223,12 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ])); + ]))); } } - let data = trajectory_generator(&net, 300, 200.0); + let data = trajectory_generator(&net, 300, 300.0); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/params.rs b/tests/params.rs index 6901293..9bc9b49 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -12,7 +12,7 @@ fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTime let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; - params.cim = Some(cim); + params.set_cim(cim); params } @@ -85,12 +85,12 @@ fn test_validate_params_cim_not_initialized() { fn test_validate_params_wrong_shape() { let mut param = utils::generate_discrete_time_continous_param(4); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; - param.cim = Some(cim); + let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( "Incompatible shape [1, 3, 3] with domain 4" ))), - param.validate_params() + result ); } @@ -99,12 +99,12 @@ fn test_validate_params_wrong_shape() { fn test_validate_params_positive_diag() { let mut param = utils::generate_discrete_time_continous_param(3); let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; - param.cim = Some(cim); + let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( "The diagonal of each cim must be non-positive", ))), - param.validate_params() + result ); } @@ -113,11 +113,11 @@ fn test_validate_params_positive_diag() { fn test_validate_params_row_not_sum_to_zero() { let mut param = utils::generate_discrete_time_continous_param(3); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; - param.cim = Some(cim); + let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0" ))), - param.validate_params() + result ); } diff --git a/tests/tools.rs b/tests/tools.rs index 802e2fe..257c957 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -23,14 +23,14 @@ fn run_sampling() { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); + param.set_cim(arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[ + param.set_cim(arr3(&[ [[-1.0,1.0],[4.0,-4.0]], [[-6.0,6.0],[2.0,-2.0]]])); } diff --git a/tests/utils.rs b/tests/utils.rs index a973926..290fa2e 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -8,7 +8,7 @@ pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) - pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ - let mut domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); + let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); params::DiscreteStatesContinousTimeParams::init(domain) } -- 2.36.3 From 490fe4e010ea6d1cc7841778bc30a65cbd00ac4b Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Mar 2022 15:53:50 +0200 Subject: [PATCH 4/6] Comments --- src/params.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/params.rs b/src/params.rs index 154fd12..5c4347e 100644 --- a/src/params.rs +++ b/src/params.rs @@ -84,11 +84,16 @@ impl DiscreteStatesContinousTimeParams { residence_time: Option::None, } } - + + ///Getter function for CIM pub fn get_cim(&self) -> &Option> { &self.cim } + ///Setter function for CIM + ///This function check if the cim is valid using the validate_params method. + ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) + ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError>{ self.cim = Some(cim); match self.validate_params() { -- 2.36.3 From 16714b48b5a759ba44c14a2522270327624e0fbe Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Mar 2022 15:56:34 +0200 Subject: [PATCH 5/6] Comments --- src/params.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/params.rs b/src/params.rs index 5c4347e..019e281 100644 --- a/src/params.rs +++ b/src/params.rs @@ -90,7 +90,7 @@ impl DiscreteStatesContinousTimeParams { &self.cim } - ///Setter function for CIM + ///Setter function for CIM.\\ ///This function check if the cim is valid using the validate_params method. ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError @@ -105,20 +105,25 @@ impl DiscreteStatesContinousTimeParams { } } + + ///Getter function for transitions pub fn get_transitions(&self) -> &Option> { &self.transitions } + ///Setter function for transitions pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); } + ///Getter function for residence_time pub fn get_residence_time(&self) -> &Option> { &self.residence_time } + ///Setter function for residence_time pub fn set_residence_time(&mut self, residence_time: Array2) { self.residence_time = Some(residence_time); } -- 2.36.3 From 82e3c779a04a3a6cfd8472bc29fc26b558ad75d6 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Mar 2022 16:03:27 +0200 Subject: [PATCH 6/6] Added test: test_validate_params_valid_cim_with_huge_values --- tests/params.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/params.rs b/tests/params.rs index 9bc9b49..cbc7636 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -69,6 +69,14 @@ fn test_validate_params_valid_cim() { assert_eq!(Ok(()), param.validate_params()); } +#[test] +fn test_validate_params_valid_cim_with_huge_values() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]]; + let result = param.set_cim(cim); + assert_eq!(Ok(()), result); +} + #[test] fn test_validate_params_cim_not_initialized() { let param = utils::generate_discrete_time_continous_param(3); @@ -80,7 +88,6 @@ fn test_validate_params_cim_not_initialized() { ); } - #[test] fn test_validate_params_wrong_shape() { let mut param = utils::generate_discrete_time_continous_param(4); @@ -88,13 +95,12 @@ fn test_validate_params_wrong_shape() { let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( - "Incompatible shape [1, 3, 3] with domain 4" - ))), + "Incompatible shape [1, 3, 3] with domain 4" + ))), result ); } - #[test] fn test_validate_params_positive_diag() { let mut param = utils::generate_discrete_time_continous_param(3); @@ -102,13 +108,12 @@ fn test_validate_params_positive_diag() { let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( - "The diagonal of each cim must be non-positive", - ))), + "The diagonal of each cim must be non-positive", + ))), result ); } - #[test] fn test_validate_params_row_not_sum_to_zero() { let mut param = utils::generate_discrete_time_continous_param(3); @@ -117,7 +122,7 @@ fn test_validate_params_row_not_sum_to_zero() { assert_eq!( Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0" - ))), + ))), result ); } -- 2.36.3