From a2fb25912407ca1895e070908ac3b865624c13f5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 2 Mar 2022 08:57:25 +0100 Subject: [PATCH] Params update: From Box to Enum --- Cargo.toml | 1 + src/ctbn.rs | 14 ++-- src/node.rs | 4 +- src/params.rs | 211 ++++++++++++++++++++++++++++++++------------------ src/tools.rs | 2 +- 5 files changed, 145 insertions(+), 87 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c1e3a53..3b1bb3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ ndarray = "*" thiserror = "*" rand = "*" bimap = "*" +enum_dispatch = "*" [dev-dependencies] approx = "*" diff --git a/src/ctbn.rs b/src/ctbn.rs index 8997b29..b57180d 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, BTreeSet}; use ndarray::prelude::*; use crate::node; -use crate::params::StateType; +use crate::params::{StateType, Params, ParamsTrait}; use crate::network; @@ -29,16 +29,16 @@ use crate::network; /// domain.insert(String::from("B")); /// /// //Create the parameters for a discrete node using the domain -/// let params = params::DiscreteStatesContinousTimeParams::init(domain); +/// let param = params::DiscreteStatesContinousTimeParams::init(domain); /// /// //Create the node using the parameters -/// let X1 = node::Node::init(Box::from(params),String::from("X1")); +/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1")); /// /// let mut domain = BTreeSet::new(); /// domain.insert(String::from("A")); /// domain.insert(String::from("B")); -/// let params = params::DiscreteStatesContinousTimeParams::init(domain); -/// let X2 = node::Node::init(Box::from(params), String::from("X2")); +/// let param = params::DiscreteStatesContinousTimeParams::init(domain); +/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2")); /// /// //Initialize a ctbn /// let mut net = CtbnNetwork::init(); @@ -156,8 +156,8 @@ mod tests { let mut domain = BTreeSet::new(); domain.insert(String::from("A")); domain.insert(String::from("B")); - let params = params::DiscreteStatesContinousTimeParams::init(domain); - let n = node::Node::init(Box::from(params), name); + let param = params::DiscreteStatesContinousTimeParams::init(domain) ; + let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name); return n; } diff --git a/src/node.rs b/src/node.rs index 61987f1..7ed21ba 100644 --- a/src/node.rs +++ b/src/node.rs @@ -2,12 +2,12 @@ use crate::params::*; pub struct Node { - pub params: Box, + pub params: Params, pub label: String } impl Node { - pub fn init(params: Box, label: String) -> Node { + pub fn init(params: Params, label: String) -> Node { Node{ params: params, label:label diff --git a/src/params.rs b/src/params.rs index 8a32e43..d77ef2a 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,28 +1,28 @@ use ndarray::prelude::*; -use std::collections::{HashMap, BTreeSet}; use rand::Rng; +use std::collections::{BTreeSet, HashMap}; use thiserror::Error; +use enum_dispatch::enum_dispatch; - -/// Error types for trait Params +/// Error types for trait Params #[derive(Error, Debug)] pub enum ParamsError { #[error("Unsupported method")] UnsupportedMethod(String), #[error("Paramiters not initialized")] - ParametersNotInitialized(String) + ParametersNotInitialized(String), } /// Allowed type of states #[derive(Clone)] pub enum StateType { - Discrete(u32) + Discrete(u32), } /// Parameters /// The Params trait is the core element for building different types of nodes. The goal is to /// define the set of method required to describes a generic node. -pub trait Params { +pub trait ParamsTrait { fn reset_params(&mut self); /// Randomly generate a possible state of the node disregarding the state of the node and it's @@ -33,10 +33,9 @@ pub trait Params { /// and its parent set. fn get_random_residence_time(&self, state: usize, u: usize) -> Result; - /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. - fn get_random_state(&self, state: usize, u:usize) -> Result; + fn get_random_state(&self, state: usize, u: usize) -> Result; /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. fn get_reserved_space_as_parent(&self) -> usize; @@ -45,13 +44,69 @@ pub trait Params { fn state_to_index(&self, state: &StateType) -> usize; } +/// The Params enum is the core element for building different types of nodes. The goal is to +/// define all the supported type of parameters. +pub enum Params { + DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), +} + +impl ParamsTrait for Params { + fn reset_params(&mut self) { + match self { + Params::DiscreteStatesContinousTime(p) => p.reset_params() + } + } -/// Parameters for a discrete node in continous time. It contains. This represents the parameters -/// of a classical discrete node for ctbn and it's composed by the following elements: + fn get_random_state_uniform(&self) -> StateType{ + match self { + Params::DiscreteStatesContinousTime(p) => p.get_random_state_uniform() + } + } + + + /// Randomly generate a residence time for the given node taking into account the node state + /// and its parent set. + fn get_random_residence_time(&self, state: usize, u: usize) -> Result { + match self { + Params::DiscreteStatesContinousTime(p) => p.get_random_residence_time(state, u) + } + } + + + /// Randomly generate a possible state for the given node taking into account the node state + /// and its parent set. + fn get_random_state(&self, state: usize, u: usize) -> Result { + match self { + Params::DiscreteStatesContinousTime(p) => p.get_random_state(state, u) + } + } + + + /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. + fn get_reserved_space_as_parent(&self) -> usize { + match self { + Params::DiscreteStatesContinousTime(p) => p.get_reserved_space_as_parent() + } + } + + + /// Index used by discrete node to represents their states as usize. + fn state_to_index(&self, state: &StateType) -> usize { + match self { + Params::DiscreteStatesContinousTime(p) => p.state_to_index(state) + } + } + + +} + +/// DiscreteStatesContinousTime. +/// This represents the parameters of a classical discrete node for ctbn and it's composed by the +/// following elements: /// - **domain**: an ordered and exhaustive set of possible states /// - **cim**: Conditional Intensity Matrix /// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter -/// l earning task and are composed by: +/// learning task and are composed by: /// - **transitions**: number of transitions from one state to another given a specific /// realization of the parent set /// - **residence_time**: permanence time in each possible states given a specific @@ -60,21 +115,21 @@ pub struct DiscreteStatesContinousTimeParams { domain: BTreeSet, cim: Option>, transitions: Option>, - residence_time: Option> + residence_time: Option>, } -impl DiscreteStatesContinousTimeParams { +impl DiscreteStatesContinousTimeParams { pub fn init(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams { domain: domain, cim: Option::None, transitions: Option::None, - residence_time: Option::None + residence_time: Option::None, } } } -impl Params for DiscreteStatesContinousTimeParams { +impl ParamsTrait for DiscreteStatesContinousTimeParams { fn reset_params(&mut self) { self.cim = Option::None; self.transitions = Option::None; @@ -86,60 +141,65 @@ impl Params for DiscreteStatesContinousTimeParams { StateType::Discrete(rng.gen_range(0..(self.domain.len() as u32))) } - fn get_random_residence_time(&self, state: usize, u:usize) -> Result { + fn get_random_residence_time(&self, state: usize, u: usize) -> Result { // Generate a random residence time given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates - match &self.cim { - Option::Some(cim) => { - let mut rng = rand::thread_rng(); - let lambda = cim[[u, state, state]] * -1.0; - let x:f64 = rng.gen_range(0.0..1.0); - Ok(-x.ln()/lambda) - }, - Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) - } + match &self.cim { + Option::Some(cim) => { + let mut rng = rand::thread_rng(); + let lambda = cim[[u, state, state]] * -1.0; + let x: f64 = rng.gen_range(0.0..1.0); + Ok(-x.ln() / lambda) + } + Option::None => Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + } } - - fn get_random_state(&self, state: usize, u:usize) -> Result{ + fn get_random_state(&self, state: usize, u: usize) -> Result { // Generate a random transition given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution - match &self.cim { - Option::Some(cim) => { - let mut rng = rand::thread_rng(); - let lambda = cim[[u, state, state]] * -1.0; - let x: f64 = rng.gen_range(0.0..1.0); - - let next_state = cim.slice(s![u,state,..]).map(|x| x / lambda).iter().fold((0, 0.0), |mut acc, ele| { - if &acc.1 + ele < x && ele > &0.0{ - acc.1 += x; - acc.0 += 1; - } - acc}); - - let next_state = if next_state.0 < state { - next_state.0 - } else { - next_state.0 + 1 - }; - - Ok(StateType::Discrete(next_state as u32)) - - }, - Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) - } + match &self.cim { + Option::Some(cim) => { + let mut rng = rand::thread_rng(); + let lambda = cim[[u, state, state]] * -1.0; + let x: f64 = rng.gen_range(0.0..1.0); + + let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold( + (0, 0.0), + |mut acc, ele| { + if &acc.1 + ele < x && ele > &0.0 { + acc.1 += x; + acc.0 += 1; + } + acc + }, + ); + + let next_state = if next_state.0 < state { + next_state.0 + } else { + next_state.0 + 1 + }; + + Ok(StateType::Discrete(next_state as u32)) + } + Option::None => Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + } } - fn get_reserved_space_as_parent(&self) -> usize { self.domain.len() } fn state_to_index(&self, state: &StateType) -> usize { match state { - StateType::Discrete(val) => val.clone() as usize + StateType::Discrete(val) => val.clone() as usize, } } } @@ -147,8 +207,7 @@ impl Params for DiscreteStatesContinousTimeParams { #[cfg(test)] mod tests { use super::*; - use ndarray::prelude::*; - + //use ndarray::prelude::*; fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { let mut domain = BTreeSet::new(); @@ -157,56 +216,54 @@ mod tests { domain.insert(String::from("C")); let mut params = DiscreteStatesContinousTimeParams::init(domain); - let cim = array![ - [ - [-3.0, 2.0, 1.0], - [1.0, -5.0, 4.0], - [3.2, 1.7, -4.0] - ]]; + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; params.cim = Some(cim); params } - #[test] fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state_uniform() { - val - } else {panic!()}); + states.mapv_inplace(|_| { + if let StateType::Discrete(val) = param.get_random_state_uniform() { + val + } else { + panic!() + } + }); let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; - assert_relative_eq!(1.0/3.0, zero_freq, epsilon=0.01); + assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01); } - #[test] fn test_random_generation_state() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { - val - } else {panic!()}); + states.mapv_inplace(|_| { + if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { + val + } else { + panic!() + } + }); let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; - assert_relative_eq!(4.0/5.0, two_freq, epsilon=0.01); - assert_relative_eq!(1.0/5.0, zero_freq, epsilon=0.01); + assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01); + assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01); } - #[test] fn test_random_generation_residence_time() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap() ); - - assert_relative_eq!(1.0/5.0, states.mean().unwrap(), epsilon=0.01); + states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); + assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } - } diff --git a/src/tools.rs b/src/tools.rs index 921d558..349f88e 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -2,7 +2,7 @@ use ndarray::prelude::*; use crate::network; use crate::node; use crate::params; -use crate::params::Params; +use crate::params::ParamsTrait; pub struct Trajectory { time: Array1,