Added method validate_params to ParamsTrait.

pull/26/head
AlessandroBregoli 3 years ago
parent 4217654c3a
commit dc8013d635
  1. 52
      src/params.rs

@ -1,16 +1,18 @@
use enum_dispatch::enum_dispatch;
use ndarray::prelude::*;
use rand::Rng;
use std::collections::{BTreeSet, HashMap};
use thiserror::Error;
use enum_dispatch::enum_dispatch;
/// Error types for trait Params
#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum ParamsError {
#[error("Unsupported method")]
UnsupportedMethod(String),
#[error("Paramiters not initialized")]
ParametersNotInitialized(String),
#[error("Invalid cim for parameter")]
InvalidCIM(String),
}
/// Allowed type of states
@ -43,6 +45,9 @@ pub trait ParamsTrait {
/// Index used by discrete node to represents their states as usize.
fn state_to_index(&self, state: &StateType) -> usize;
/// Validate parameters against domain
fn validate_params(&self) -> Result<(), ParamsError>;
}
/// The Params enum is the core element for building different types of nodes. The goal is to
@ -52,7 +57,6 @@ pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements:
@ -157,5 +161,47 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
StateType::Discrete(val) => val.clone() as usize,
}
}
fn validate_params(&self) -> Result<(), ParamsError> {
let domain_size = self.domain.len();
// Check if the cim is initialized
if let None = self.cim {
return Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)));
}
let cim = self.cim.as_ref().unwrap();
// Check if the inner dimensions of the cim are equal to the cardinality of the variable
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size {
return Err(ParamsError::InvalidCIM(format!(
"Incompatible shape {:?} with domain {:?}",
cim.shape(),
domain_size
)));
}
// Check if the diagonal of each cim is non-positive
if cim
.axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0))
{
return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
)));
}
// Check if each row sum up to 0
let zeros = Array::zeros(domain_size);
if cim
.axis_iter(Axis(0))
.any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE))
{
return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0",
)));
}
return Ok(());
}
}

Loading…
Cancel
Save