Params update: From Box to Enum

pull/19/head
AlessandroBregoli 3 years ago
parent 212f4aef4b
commit a2fb259124
  1. 1
      Cargo.toml
  2. 14
      src/ctbn.rs
  3. 4
      src/node.rs
  4. 157
      src/params.rs
  5. 2
      src/tools.rs

@ -11,6 +11,7 @@ ndarray = "*"
thiserror = "*" thiserror = "*"
rand = "*" rand = "*"
bimap = "*" bimap = "*"
enum_dispatch = "*"
[dev-dependencies] [dev-dependencies]
approx = "*" approx = "*"

@ -1,7 +1,7 @@
use std::collections::{HashMap, BTreeSet}; use std::collections::{HashMap, BTreeSet};
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::node; use crate::node;
use crate::params::StateType; use crate::params::{StateType, Params, ParamsTrait};
use crate::network; use crate::network;
@ -29,16 +29,16 @@ use crate::network;
/// domain.insert(String::from("B")); /// domain.insert(String::from("B"));
/// ///
/// //Create the parameters for a discrete node using the domain /// //Create the parameters for a discrete node using the domain
/// let params = params::DiscreteStatesContinousTimeParams::init(domain); /// let param = params::DiscreteStatesContinousTimeParams::init(domain);
/// ///
/// //Create the node using the parameters /// //Create the node using the parameters
/// let X1 = node::Node::init(Box::from(params),String::from("X1")); /// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
/// ///
/// let mut domain = BTreeSet::new(); /// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A")); /// domain.insert(String::from("A"));
/// domain.insert(String::from("B")); /// domain.insert(String::from("B"));
/// let params = params::DiscreteStatesContinousTimeParams::init(domain); /// let param = params::DiscreteStatesContinousTimeParams::init(domain);
/// let X2 = node::Node::init(Box::from(params), String::from("X2")); /// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
/// ///
/// //Initialize a ctbn /// //Initialize a ctbn
/// let mut net = CtbnNetwork::init(); /// let mut net = CtbnNetwork::init();
@ -156,8 +156,8 @@ mod tests {
let mut domain = BTreeSet::new(); let mut domain = BTreeSet::new();
domain.insert(String::from("A")); domain.insert(String::from("A"));
domain.insert(String::from("B")); domain.insert(String::from("B"));
let params = params::DiscreteStatesContinousTimeParams::init(domain); let param = params::DiscreteStatesContinousTimeParams::init(domain) ;
let n = node::Node::init(Box::from(params), name); let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n; return n;
} }

@ -2,12 +2,12 @@ use crate::params::*;
pub struct Node { pub struct Node {
pub params: Box<dyn Params>, pub params: Params,
pub label: String pub label: String
} }
impl Node { impl Node {
pub fn init(params: Box<dyn Params>, label: String) -> Node { pub fn init(params: Params, label: String) -> Node {
Node{ Node{
params: params, params: params,
label:label label:label

@ -1,8 +1,8 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use std::collections::{HashMap, BTreeSet};
use rand::Rng; use rand::Rng;
use std::collections::{BTreeSet, HashMap};
use thiserror::Error; use thiserror::Error;
use enum_dispatch::enum_dispatch;
/// Error types for trait Params /// Error types for trait Params
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -10,19 +10,19 @@ pub enum ParamsError {
#[error("Unsupported method")] #[error("Unsupported method")]
UnsupportedMethod(String), UnsupportedMethod(String),
#[error("Paramiters not initialized")] #[error("Paramiters not initialized")]
ParametersNotInitialized(String) ParametersNotInitialized(String),
} }
/// Allowed type of states /// Allowed type of states
#[derive(Clone)] #[derive(Clone)]
pub enum StateType { pub enum StateType {
Discrete(u32) Discrete(u32),
} }
/// Parameters /// Parameters
/// The Params trait is the core element for building different types of nodes. The goal is to /// The Params trait is the core element for building different types of nodes. The goal is to
/// define the set of method required to describes a generic node. /// define the set of method required to describes a generic node.
pub trait Params { pub trait ParamsTrait {
fn reset_params(&mut self); fn reset_params(&mut self);
/// Randomly generate a possible state of the node disregarding the state of the node and it's /// Randomly generate a possible state of the node disregarding the state of the node and it's
@ -33,10 +33,9 @@ pub trait Params {
/// and its parent set. /// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>; fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state /// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set. /// and its parent set.
fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>; fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs. /// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize; fn get_reserved_space_as_parent(&self) -> usize;
@ -45,13 +44,69 @@ pub trait Params {
fn state_to_index(&self, state: &StateType) -> usize; fn state_to_index(&self, state: &StateType) -> usize;
} }
/// The Params enum is the core element for building different types of nodes. The goal is to
/// define all the supported type of parameters.
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
impl ParamsTrait for Params {
fn reset_params(&mut self) {
match self {
Params::DiscreteStatesContinousTime(p) => p.reset_params()
}
}
fn get_random_state_uniform(&self) -> StateType{
match self {
Params::DiscreteStatesContinousTime(p) => p.get_random_state_uniform()
}
}
/// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> {
match self {
Params::DiscreteStatesContinousTime(p) => p.get_random_residence_time(state, u)
}
}
/// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set.
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> {
match self {
Params::DiscreteStatesContinousTime(p) => p.get_random_state(state, u)
}
}
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize {
match self {
Params::DiscreteStatesContinousTime(p) => p.get_reserved_space_as_parent()
}
}
/// Index used by discrete node to represents their states as usize.
fn state_to_index(&self, state: &StateType) -> usize {
match self {
Params::DiscreteStatesContinousTime(p) => p.state_to_index(state)
}
}
/// Parameters for a discrete node in continous time. It contains. This represents the parameters }
/// of a classical discrete node for ctbn and it's composed by the following elements:
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements:
/// - **domain**: an ordered and exhaustive set of possible states /// - **domain**: an ordered and exhaustive set of possible states
/// - **cim**: Conditional Intensity Matrix /// - **cim**: Conditional Intensity Matrix
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter /// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
/// l earning task and are composed by: /// learning task and are composed by:
/// - **transitions**: number of transitions from one state to another given a specific /// - **transitions**: number of transitions from one state to another given a specific
/// realization of the parent set /// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific /// - **residence_time**: permanence time in each possible states given a specific
@ -60,7 +115,7 @@ pub struct DiscreteStatesContinousTimeParams {
domain: BTreeSet<String>, domain: BTreeSet<String>,
cim: Option<Array3<f64>>, cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>, transitions: Option<Array3<u64>>,
residence_time: Option<Array2<f64>> residence_time: Option<Array2<f64>>,
} }
impl DiscreteStatesContinousTimeParams { impl DiscreteStatesContinousTimeParams {
@ -69,12 +124,12 @@ impl DiscreteStatesContinousTimeParams {
domain: domain, domain: domain,
cim: Option::None, cim: Option::None,
transitions: Option::None, transitions: Option::None,
residence_time: Option::None residence_time: Option::None,
} }
} }
} }
impl Params for DiscreteStatesContinousTimeParams {
impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn reset_params(&mut self) { fn reset_params(&mut self) {
self.cim = Option::None; self.cim = Option::None;
self.transitions = Option::None; self.transitions = Option::None;
@ -86,7 +141,7 @@ impl Params for DiscreteStatesContinousTimeParams {
StateType::Discrete(rng.gen_range(0..(self.domain.len() as u32))) StateType::Discrete(rng.gen_range(0..(self.domain.len() as u32)))
} }
fn get_random_residence_time(&self, state: usize, u:usize) -> Result<f64, ParamsError> { fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set. // Generate a random residence time given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
@ -94,15 +149,16 @@ impl Params for DiscreteStatesContinousTimeParams {
Option::Some(cim) => { Option::Some(cim) => {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let x:f64 = rng.gen_range(0.0..1.0); let x: f64 = rng.gen_range(0.0..1.0);
Ok(-x.ln()/lambda) Ok(-x.ln() / lambda)
}, }
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
} }
} }
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> {
fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>{
// Generate a random transition given the current state of the node and its parent set. // Generate a random transition given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
@ -112,12 +168,16 @@ impl Params for DiscreteStatesContinousTimeParams {
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..1.0); let x: f64 = rng.gen_range(0.0..1.0);
let next_state = cim.slice(s![u,state,..]).map(|x| x / lambda).iter().fold((0, 0.0), |mut acc, ele| { let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
if &acc.1 + ele < x && ele > &0.0{ (0, 0.0),
|mut acc, ele| {
if &acc.1 + ele < x && ele > &0.0 {
acc.1 += x; acc.1 += x;
acc.0 += 1; acc.0 += 1;
} }
acc}); acc
},
);
let next_state = if next_state.0 < state { let next_state = if next_state.0 < state {
next_state.0 next_state.0
@ -126,12 +186,12 @@ impl Params for DiscreteStatesContinousTimeParams {
}; };
Ok(StateType::Discrete(next_state as u32)) Ok(StateType::Discrete(next_state as u32))
}
}, Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) "CIM not initialized",
))),
} }
} }
fn get_reserved_space_as_parent(&self) -> usize { fn get_reserved_space_as_parent(&self) -> usize {
self.domain.len() self.domain.len()
@ -139,7 +199,7 @@ impl Params for DiscreteStatesContinousTimeParams {
fn state_to_index(&self, state: &StateType) -> usize { fn state_to_index(&self, state: &StateType) -> usize {
match state { match state {
StateType::Discrete(val) => val.clone() as usize StateType::Discrete(val) => val.clone() as usize,
} }
} }
} }
@ -147,8 +207,7 @@ impl Params for DiscreteStatesContinousTimeParams {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use ndarray::prelude::*; //use ndarray::prelude::*;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut domain = BTreeSet::new(); let mut domain = BTreeSet::new();
@ -157,56 +216,54 @@ mod tests {
domain.insert(String::from("C")); domain.insert(String::from("C"));
let mut params = DiscreteStatesContinousTimeParams::init(domain); let mut params = DiscreteStatesContinousTimeParams::init(domain);
let cim = array![ let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]];
[
[-3.0, 2.0, 1.0],
[1.0, -5.0, 4.0],
[3.2, 1.7, -4.0]
]];
params.cim = Some(cim); params.cim = Some(cim);
params params
} }
#[test] #[test]
fn test_uniform_generation() { fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000); let mut states = Array1::<u32>::zeros(10000);
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state_uniform() { states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
val val
} else {panic!()}); } else {
panic!()
}
});
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(1.0/3.0, zero_freq, epsilon=0.01); assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01);
} }
#[test] #[test]
fn test_random_generation_state() { fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000); let mut states = Array1::<u32>::zeros(10000);
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
val val
} else {panic!()}); } else {
panic!()
}
});
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(4.0/5.0, two_freq, epsilon=0.01); assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01);
assert_relative_eq!(1.0/5.0, zero_freq, epsilon=0.01); assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01);
} }
#[test] #[test]
fn test_random_generation_residence_time() { fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000); let mut states = Array1::<f64>::zeros(10000);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap() ); states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap());
assert_relative_eq!(1.0/5.0, states.mean().unwrap(), epsilon=0.01);
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
} }
} }

@ -2,7 +2,7 @@ use ndarray::prelude::*;
use crate::network; use crate::network;
use crate::node; use crate::node;
use crate::params; use crate::params;
use crate::params::Params; use crate::params::ParamsTrait;
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,

Loading…
Cancel
Save