From dc8013d6352b2498e7112a7ad11c2a111cd30834 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 23 Mar 2022 16:28:17 +0100 Subject: [PATCH] Added method validate_params to ParamsTrait. --- src/params.rs | 54 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/params.rs b/src/params.rs index c5a9acf..3efdc6f 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,16 +1,18 @@ +use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; -use enum_dispatch::enum_dispatch; /// Error types for trait Params -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum ParamsError { #[error("Unsupported method")] UnsupportedMethod(String), #[error("Paramiters not initialized")] ParametersNotInitialized(String), + #[error("Invalid cim for parameter")] + InvalidCIM(String), } /// Allowed type of states @@ -43,6 +45,9 @@ pub trait ParamsTrait { /// Index used by discrete node to represents their states as usize. fn state_to_index(&self, state: &StateType) -> usize; + + /// Validate parameters against domain + fn validate_params(&self) -> Result<(), ParamsError>; } /// The Params enum is the core element for building different types of nodes. The goal is to @@ -52,7 +57,6 @@ pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), } - /// DiscreteStatesContinousTime. /// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// following elements: @@ -157,5 +161,47 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(val) => val.clone() as usize, } } -} + fn validate_params(&self) -> Result<(), ParamsError> { + let domain_size = self.domain.len(); + + // Check if the cim is initialized + if let None = self.cim { + return Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))); + } + let cim = self.cim.as_ref().unwrap(); + // Check if the inner dimensions of the cim are equal to the cardinality of the variable + if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { + return Err(ParamsError::InvalidCIM(format!( + "Incompatible shape {:?} with domain {:?}", + cim.shape(), + domain_size + ))); + } + + // Check if the diagonal of each cim is non-positive + if cim + .axis_iter(Axis(0)) + .any(|x| x.diag().iter().any(|x| x >= &0.0)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))); + } + + // Check if each row sum up to 0 + let zeros = Array::zeros(domain_size); + if cim + .axis_iter(Axis(0)) + .any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0", + ))); + } + + return Ok(()); + } +}