diff --git a/src/params.rs b/src/params.rs index c5a9acf..019e281 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,16 +1,18 @@ +use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; -use enum_dispatch::enum_dispatch; /// Error types for trait Params -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum ParamsError { #[error("Unsupported method")] UnsupportedMethod(String), #[error("Paramiters not initialized")] ParametersNotInitialized(String), + #[error("Invalid cim for parameter")] + InvalidCIM(String), } /// Allowed type of states @@ -43,6 +45,9 @@ pub trait ParamsTrait { /// Index used by discrete node to represents their states as usize. fn state_to_index(&self, state: &StateType) -> usize; + + /// Validate parameters against domain + fn validate_params(&self) -> Result<(), ParamsError>; } /// The Params enum is the core element for building different types of nodes. The goal is to @@ -52,7 +57,6 @@ pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), } - /// DiscreteStatesContinousTime. /// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// following elements: @@ -65,21 +69,65 @@ pub enum Params { /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set pub struct DiscreteStatesContinousTimeParams { - pub domain: BTreeSet, - pub cim: Option>, - pub transitions: Option>, - pub residence_time: Option>, + domain: BTreeSet, + cim: Option>, + transitions: Option>, + residence_time: Option>, } impl DiscreteStatesContinousTimeParams { pub fn init(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams { - domain: domain, + domain, cim: Option::None, transitions: Option::None, residence_time: Option::None, } } + + ///Getter function for CIM + pub fn get_cim(&self) -> &Option> { + &self.cim + } + + ///Setter function for CIM.\\ + ///This function check if the cim is valid using the validate_params method. + ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) + ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError + pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError>{ + self.cim = Some(cim); + match self.validate_params() { + Ok(()) => Ok(()), + Err(e) => { + self.cim = None; + Err(e) + } + } + } + + + ///Getter function for transitions + pub fn get_transitions(&self) -> &Option> { + &self.transitions + } + + + ///Setter function for transitions + pub fn set_transitions(&mut self, transitions: Array3) { + self.transitions = Some(transitions); + } + + ///Getter function for residence_time + pub fn get_residence_time(&self) -> &Option> { + &self.residence_time + } + + + ///Setter function for residence_time + pub fn set_residence_time(&mut self, residence_time: Array2) { + self.residence_time = Some(residence_time); + } + } impl ParamsTrait for DiscreteStatesContinousTimeParams { @@ -157,5 +205,45 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(val) => val.clone() as usize, } } -} + fn validate_params(&self) -> Result<(), ParamsError> { + let domain_size = self.domain.len(); + + // Check if the cim is initialized + if let None = self.cim { + return Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))); + } + let cim = self.cim.as_ref().unwrap(); + // Check if the inner dimensions of the cim are equal to the cardinality of the variable + if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { + return Err(ParamsError::InvalidCIM(format!( + "Incompatible shape {:?} with domain {:?}", + cim.shape(), + domain_size + ))); + } + + // Check if the diagonal of each cim is non-positive + if cim + .axis_iter(Axis(0)) + .any(|x| x.diag().iter().any(|x| x >= &0.0)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))); + } + + // Check if each row sum up to 0 + if cim.sum_axis(Axis(2)).iter() + .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) + { + return Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0", + ))); + } + + return Ok(()); + } +} diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index a5cca51..d6b8fd2 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -27,16 +27,16 @@ fn learn_binary_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ])); + ]))); } } @@ -77,19 +77,19 @@ fn learn_ternary_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } @@ -132,19 +132,19 @@ fn learn_ternary_cim_no_parents (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } @@ -192,28 +192,28 @@ fn learn_mixed_discrete_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } match &mut net.get_node_mut(n3).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], + [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], @@ -223,12 +223,12 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ])); + ]))); } } - let data = trajectory_generator(&net, 300, 200.0); + let data = trajectory_generator(&net, 300, 300.0); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/params.rs b/tests/params.rs index ed601b2..cbc7636 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,5 +1,5 @@ -use rustyCTBN::params::*; use ndarray::prelude::*; +use rustyCTBN::params::*; use std::collections::BTreeSet; mod utils; @@ -7,13 +7,12 @@ mod utils; #[macro_use] extern crate approx; - fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { let mut params = utils::generate_discrete_time_continous_param(3); - let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; - params.cim = Some(cim); + params.set_cim(cim); params } @@ -62,3 +61,68 @@ fn test_random_generation_residence_time() { assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } + +#[test] +fn test_validate_params_valid_cim() { + let param = create_ternary_discrete_time_continous_param(); + + assert_eq!(Ok(()), param.validate_params()); +} + +#[test] +fn test_validate_params_valid_cim_with_huge_values() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]]; + let result = param.set_cim(cim); + assert_eq!(Ok(()), result); +} + +#[test] +fn test_validate_params_cim_not_initialized() { + let param = utils::generate_discrete_time_continous_param(3); + assert_eq!( + Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + param.validate_params() + ); +} + +#[test] +fn test_validate_params_wrong_shape() { + let mut param = utils::generate_discrete_time_continous_param(4); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + let result = param.set_cim(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "Incompatible shape [1, 3, 3] with domain 4" + ))), + result + ); +} + +#[test] +fn test_validate_params_positive_diag() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + let result = param.set_cim(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))), + result + ); +} + +#[test] +fn test_validate_params_row_not_sum_to_zero() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; + let result = param.set_cim(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0" + ))), + result + ); +} diff --git a/tests/tools.rs b/tests/tools.rs index 802e2fe..257c957 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -23,14 +23,14 @@ fn run_sampling() { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); + param.set_cim(arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[ + param.set_cim(arr3(&[ [[-1.0,1.0],[4.0,-4.0]], [[-6.0,6.0],[2.0,-2.0]]])); } diff --git a/tests/utils.rs b/tests/utils.rs index a973926..290fa2e 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -8,7 +8,7 @@ pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) - pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ - let mut domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); + let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); params::DiscreteStatesContinousTimeParams::init(domain) }