Merge pull request #26 from AlessandroBregoli/14-feature

14 feature
pull/30/head
AlessandroBregoli 3 years ago committed by GitHub
commit 2d7e52f8f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 104
      src/params.rs
  2. 36
      tests/parameter_learning.rs
  3. 72
      tests/params.rs
  4. 4
      tests/tools.rs
  5. 2
      tests/utils.rs

@ -1,16 +1,18 @@
use enum_dispatch::enum_dispatch;
use ndarray::prelude::*; use ndarray::prelude::*;
use rand::Rng; use rand::Rng;
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use thiserror::Error; use thiserror::Error;
use enum_dispatch::enum_dispatch;
/// Error types for trait Params /// Error types for trait Params
#[derive(Error, Debug)] #[derive(Error, Debug, PartialEq)]
pub enum ParamsError { pub enum ParamsError {
#[error("Unsupported method")] #[error("Unsupported method")]
UnsupportedMethod(String), UnsupportedMethod(String),
#[error("Paramiters not initialized")] #[error("Paramiters not initialized")]
ParametersNotInitialized(String), ParametersNotInitialized(String),
#[error("Invalid cim for parameter")]
InvalidCIM(String),
} }
/// Allowed type of states /// Allowed type of states
@ -43,6 +45,9 @@ pub trait ParamsTrait {
/// Index used by discrete node to represents their states as usize. /// Index used by discrete node to represents their states as usize.
fn state_to_index(&self, state: &StateType) -> usize; fn state_to_index(&self, state: &StateType) -> usize;
/// Validate parameters against domain
fn validate_params(&self) -> Result<(), ParamsError>;
} }
/// The Params enum is the core element for building different types of nodes. The goal is to /// The Params enum is the core element for building different types of nodes. The goal is to
@ -52,7 +57,6 @@ pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
} }
/// DiscreteStatesContinousTime. /// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements: /// following elements:
@ -65,21 +69,65 @@ pub enum Params {
/// - **residence_time**: permanence time in each possible states given a specific /// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set /// realization of the parent set
pub struct DiscreteStatesContinousTimeParams { pub struct DiscreteStatesContinousTimeParams {
pub domain: BTreeSet<String>, domain: BTreeSet<String>,
pub cim: Option<Array3<f64>>, cim: Option<Array3<f64>>,
pub transitions: Option<Array3<u64>>, transitions: Option<Array3<u64>>,
pub residence_time: Option<Array2<f64>>, residence_time: Option<Array2<f64>>,
} }
impl DiscreteStatesContinousTimeParams { impl DiscreteStatesContinousTimeParams {
pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams { pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams {
domain: domain, domain,
cim: Option::None, cim: Option::None,
transitions: Option::None, transitions: Option::None,
residence_time: Option::None, residence_time: Option::None,
} }
} }
///Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> {
&self.cim
}
///Setter function for CIM.\\
///This function check if the cim is valid using the validate_params method.
///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(())
///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError>{
self.cim = Some(cim);
match self.validate_params() {
Ok(()) => Ok(()),
Err(e) => {
self.cim = None;
Err(e)
}
}
}
///Getter function for transitions
pub fn get_transitions(&self) -> &Option<Array3<u64>> {
&self.transitions
}
///Setter function for transitions
pub fn set_transitions(&mut self, transitions: Array3<u64>) {
self.transitions = Some(transitions);
}
///Getter function for residence_time
pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
&self.residence_time
}
///Setter function for residence_time
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time);
}
} }
impl ParamsTrait for DiscreteStatesContinousTimeParams { impl ParamsTrait for DiscreteStatesContinousTimeParams {
@ -157,5 +205,45 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
StateType::Discrete(val) => val.clone() as usize, StateType::Discrete(val) => val.clone() as usize,
} }
} }
fn validate_params(&self) -> Result<(), ParamsError> {
let domain_size = self.domain.len();
// Check if the cim is initialized
if let None = self.cim {
return Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)));
}
let cim = self.cim.as_ref().unwrap();
// Check if the inner dimensions of the cim are equal to the cardinality of the variable
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size {
return Err(ParamsError::InvalidCIM(format!(
"Incompatible shape {:?} with domain {:?}",
cim.shape(),
domain_size
)));
}
// Check if the diagonal of each cim is non-positive
if cim
.axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0))
{
return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
)));
}
// Check if each row sum up to 0
if cim.sum_axis(Axis(2)).iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0)
{
return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0",
)));
} }
return Ok(());
}
}

@ -27,16 +27,16 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]], [[-6.0, 6.0], [2.0, -2.0]],
])); ])));
} }
} }
@ -77,19 +77,19 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])); [0.4, 0.6, -1.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])); ])));
} }
} }
@ -132,19 +132,19 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])); [0.4, 0.6, -1.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])); ])));
} }
} }
@ -192,26 +192,26 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])); [0.4, 0.6, -1.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])); ])));
} }
} }
match &mut net.get_node_mut(n3).params { match &mut net.get_node_mut(n3).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]],
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]],
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]],
@ -223,12 +223,12 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
])); ])));
} }
} }
let data = trajectory_generator(&net, 300, 200.0); let data = trajectory_generator(&net, 300, 300.0);
let (CIM, M, T) = pl.fit(&net, &data, 2, None); let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]); assert_eq!(CIM.shape(), [9, 4, 4]);

@ -1,5 +1,5 @@
use rustyCTBN::params::*;
use ndarray::prelude::*; use ndarray::prelude::*;
use rustyCTBN::params::*;
use std::collections::BTreeSet; use std::collections::BTreeSet;
mod utils; mod utils;
@ -7,13 +7,12 @@ mod utils;
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut params = utils::generate_discrete_time_continous_param(3); let mut params = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
params.cim = Some(cim); params.set_cim(cim);
params params
} }
@ -62,3 +61,68 @@ fn test_random_generation_residence_time() {
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
} }
#[test]
fn test_validate_params_valid_cim() {
let param = create_ternary_discrete_time_continous_param();
assert_eq!(Ok(()), param.validate_params());
}
#[test]
fn test_validate_params_valid_cim_with_huge_values() {
let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]];
let result = param.set_cim(cim);
assert_eq!(Ok(()), result);
}
#[test]
fn test_validate_params_cim_not_initialized() {
let param = utils::generate_discrete_time_continous_param(3);
assert_eq!(
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
param.validate_params()
);
}
#[test]
fn test_validate_params_wrong_shape() {
let mut param = utils::generate_discrete_time_continous_param(4);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"Incompatible shape [1, 3, 3] with domain 4"
))),
result
);
}
#[test]
fn test_validate_params_positive_diag() {
let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
))),
result
);
}
#[test]
fn test_validate_params_row_not_sum_to_zero() {
let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0"
))),
result
);
}

@ -23,14 +23,14 @@ fn run_sampling() {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); param.set_cim(arr3(&[[[-3.0,3.0],[2.0,-2.0]]]));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[ param.set_cim(arr3(&[
[[-1.0,1.0],[4.0,-4.0]], [[-1.0,1.0],[4.0,-4.0]],
[[-6.0,6.0],[2.0,-2.0]]])); [[-6.0,6.0],[2.0,-2.0]]]));
} }

@ -8,7 +8,7 @@ pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
let mut domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::init(domain) params::DiscreteStatesContinousTimeParams::init(domain)
} }

Loading…
Cancel
Save