|
|
|
@ -1,16 +1,18 @@ |
|
|
|
|
use enum_dispatch::enum_dispatch; |
|
|
|
|
use ndarray::prelude::*; |
|
|
|
|
use rand::Rng; |
|
|
|
|
use std::collections::{BTreeSet, HashMap}; |
|
|
|
|
use thiserror::Error; |
|
|
|
|
use enum_dispatch::enum_dispatch; |
|
|
|
|
|
|
|
|
|
/// Error types for trait Params
|
|
|
|
|
#[derive(Error, Debug)] |
|
|
|
|
#[derive(Error, Debug, PartialEq)] |
|
|
|
|
pub enum ParamsError { |
|
|
|
|
#[error("Unsupported method")] |
|
|
|
|
UnsupportedMethod(String), |
|
|
|
|
#[error("Paramiters not initialized")] |
|
|
|
|
ParametersNotInitialized(String), |
|
|
|
|
#[error("Invalid cim for parameter")] |
|
|
|
|
InvalidCIM(String), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Allowed type of states
|
|
|
|
@ -43,6 +45,9 @@ pub trait ParamsTrait { |
|
|
|
|
|
|
|
|
|
/// Index used by discrete node to represents their states as usize.
|
|
|
|
|
fn state_to_index(&self, state: &StateType) -> usize; |
|
|
|
|
|
|
|
|
|
/// Validate parameters against domain
|
|
|
|
|
fn validate_params(&self) -> Result<(), ParamsError>; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// The Params enum is the core element for building different types of nodes. The goal is to
|
|
|
|
@ -52,7 +57,6 @@ pub enum Params { |
|
|
|
|
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// DiscreteStatesContinousTime.
|
|
|
|
|
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
|
|
|
|
|
/// following elements:
|
|
|
|
@ -157,5 +161,47 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { |
|
|
|
|
StateType::Discrete(val) => val.clone() as usize, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn validate_params(&self) -> Result<(), ParamsError> { |
|
|
|
|
let domain_size = self.domain.len(); |
|
|
|
|
|
|
|
|
|
// Check if the cim is initialized
|
|
|
|
|
if let None = self.cim { |
|
|
|
|
return Err(ParamsError::ParametersNotInitialized(String::from( |
|
|
|
|
"CIM not initialized", |
|
|
|
|
))); |
|
|
|
|
} |
|
|
|
|
let cim = self.cim.as_ref().unwrap(); |
|
|
|
|
// Check if the inner dimensions of the cim are equal to the cardinality of the variable
|
|
|
|
|
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { |
|
|
|
|
return Err(ParamsError::InvalidCIM(format!( |
|
|
|
|
"Incompatible shape {:?} with domain {:?}", |
|
|
|
|
cim.shape(), |
|
|
|
|
domain_size |
|
|
|
|
))); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Check if the diagonal of each cim is non-positive
|
|
|
|
|
if cim |
|
|
|
|
.axis_iter(Axis(0)) |
|
|
|
|
.any(|x| x.diag().iter().any(|x| x >= &0.0)) |
|
|
|
|
{ |
|
|
|
|
return Err(ParamsError::InvalidCIM(String::from( |
|
|
|
|
"The diagonal of each cim must be non-positive", |
|
|
|
|
))); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Check if each row sum up to 0
|
|
|
|
|
let zeros = Array::zeros(domain_size); |
|
|
|
|
if cim |
|
|
|
|
.axis_iter(Axis(0)) |
|
|
|
|
.any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE)) |
|
|
|
|
{ |
|
|
|
|
return Err(ParamsError::InvalidCIM(String::from( |
|
|
|
|
"The sum of each row must be 0", |
|
|
|
|
))); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return Ok(()); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|