Params update: From Box to Enum

pull/19/head
AlessandroBregoli 3 years ago
parent 212f4aef4b
commit a2fb259124
  1. 1
      Cargo.toml
  2. 14
      src/ctbn.rs
  3. 4
      src/node.rs
  4. 133
      src/params.rs
  5. 2
      src/tools.rs

@ -11,6 +11,7 @@ ndarray = "*"
thiserror = "*"
rand = "*"
bimap = "*"
enum_dispatch = "*"
[dev-dependencies]
approx = "*"

@ -1,7 +1,7 @@
use std::collections::{HashMap, BTreeSet};
use ndarray::prelude::*;
use crate::node;
use crate::params::StateType;
use crate::params::{StateType, Params, ParamsTrait};
use crate::network;
@ -29,16 +29,16 @@ use crate::network;
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let params = params::DiscreteStatesContinousTimeParams::init(domain);
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
///
/// //Create the node using the parameters
/// let X1 = node::Node::init(Box::from(params),String::from("X1"));
/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let params = params::DiscreteStatesContinousTimeParams::init(domain);
/// let X2 = node::Node::init(Box::from(params), String::from("X2"));
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::init();
@ -156,8 +156,8 @@ mod tests {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let params = params::DiscreteStatesContinousTimeParams::init(domain);
let n = node::Node::init(Box::from(params), name);
let param = params::DiscreteStatesContinousTimeParams::init(domain) ;
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}

@ -2,12 +2,12 @@ use crate::params::*;
pub struct Node {
pub params: Box<dyn Params>,
pub params: Params,
pub label: String
}
impl Node {
pub fn init(params: Box<dyn Params>, label: String) -> Node {
pub fn init(params: Params, label: String) -> Node {
Node{
params: params,
label:label

@ -1,8 +1,8 @@
use ndarray::prelude::*;
use std::collections::{HashMap, BTreeSet};
use rand::Rng;
use std::collections::{BTreeSet, HashMap};
use thiserror::Error;
use enum_dispatch::enum_dispatch;
/// Error types for trait Params
#[derive(Error, Debug)]
@ -10,19 +10,19 @@ pub enum ParamsError {
#[error("Unsupported method")]
UnsupportedMethod(String),
#[error("Paramiters not initialized")]
ParametersNotInitialized(String)
ParametersNotInitialized(String),
}
/// Allowed type of states
#[derive(Clone)]
pub enum StateType {
Discrete(u32)
Discrete(u32),
}
/// Parameters
/// The Params trait is the core element for building different types of nodes. The goal is to
/// define the set of method required to describes a generic node.
pub trait Params {
pub trait ParamsTrait {
fn reset_params(&mut self);
/// Randomly generate a possible state of the node disregarding the state of the node and it's
@ -33,7 +33,6 @@ pub trait Params {
/// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set.
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError>;
@ -45,9 +44,65 @@ pub trait Params {
fn state_to_index(&self, state: &StateType) -> usize;
}
/// The Params enum is the core element for building different types of nodes. The goal is to
/// define all the supported type of parameters.
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
impl ParamsTrait for Params {
fn reset_params(&mut self) {
match self {
Params::DiscreteStatesContinousTime(p) => p.reset_params()
}
}
fn get_random_state_uniform(&self) -> StateType{
match self {
Params::DiscreteStatesContinousTime(p) => p.get_random_state_uniform()
}
}
/// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> {
match self {
Params::DiscreteStatesContinousTime(p) => p.get_random_residence_time(state, u)
}
}
/// Parameters for a discrete node in continous time. It contains. This represents the parameters
/// of a classical discrete node for ctbn and it's composed by the following elements:
/// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set.
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> {
match self {
Params::DiscreteStatesContinousTime(p) => p.get_random_state(state, u)
}
}
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize {
match self {
Params::DiscreteStatesContinousTime(p) => p.get_reserved_space_as_parent()
}
}
/// Index used by discrete node to represents their states as usize.
fn state_to_index(&self, state: &StateType) -> usize {
match self {
Params::DiscreteStatesContinousTime(p) => p.state_to_index(state)
}
}
}
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements:
/// - **domain**: an ordered and exhaustive set of possible states
/// - **cim**: Conditional Intensity Matrix
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
@ -60,7 +115,7 @@ pub struct DiscreteStatesContinousTimeParams {
domain: BTreeSet<String>,
cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>,
residence_time: Option<Array2<f64>>
residence_time: Option<Array2<f64>>,
}
impl DiscreteStatesContinousTimeParams {
@ -69,12 +124,12 @@ impl DiscreteStatesContinousTimeParams {
domain: domain,
cim: Option::None,
transitions: Option::None,
residence_time: Option::None
residence_time: Option::None,
}
}
}
impl Params for DiscreteStatesContinousTimeParams {
impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn reset_params(&mut self) {
self.cim = Option::None;
self.transitions = Option::None;
@ -96,11 +151,12 @@ impl Params for DiscreteStatesContinousTimeParams {
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..1.0);
Ok(-x.ln() / lambda)
},
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized")))
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
}
}
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set.
@ -112,12 +168,16 @@ impl Params for DiscreteStatesContinousTimeParams {
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..1.0);
let next_state = cim.slice(s![u,state,..]).map(|x| x / lambda).iter().fold((0, 0.0), |mut acc, ele| {
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
(0, 0.0),
|mut acc, ele| {
if &acc.1 + ele < x && ele > &0.0 {
acc.1 += x;
acc.0 += 1;
}
acc});
acc
},
);
let next_state = if next_state.0 < state {
next_state.0
@ -126,12 +186,12 @@ impl Params for DiscreteStatesContinousTimeParams {
};
Ok(StateType::Discrete(next_state as u32))
},
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized")))
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
}
}
fn get_reserved_space_as_parent(&self) -> usize {
self.domain.len()
@ -139,7 +199,7 @@ impl Params for DiscreteStatesContinousTimeParams {
fn state_to_index(&self, state: &StateType) -> usize {
match state {
StateType::Discrete(val) => val.clone() as usize
StateType::Discrete(val) => val.clone() as usize,
}
}
}
@ -147,8 +207,7 @@ impl Params for DiscreteStatesContinousTimeParams {
#[cfg(test)]
mod tests {
use super::*;
use ndarray::prelude::*;
//use ndarray::prelude::*;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut domain = BTreeSet::new();
@ -157,39 +216,40 @@ mod tests {
domain.insert(String::from("C"));
let mut params = DiscreteStatesContinousTimeParams::init(domain);
let cim = array![
[
[-3.0, 2.0, 1.0],
[1.0, -5.0, 4.0],
[3.2, 1.7, -4.0]
]];
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]];
params.cim = Some(cim);
params
}
#[test]
fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000);
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state_uniform() {
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
val
} else {panic!()});
} else {
panic!()
}
});
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000);
states.mapv_inplace(|_| if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
val
} else {panic!()});
} else {
panic!()
}
});
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
@ -197,7 +257,6 @@ mod tests {
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param();
@ -206,7 +265,5 @@ mod tests {
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap());
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
}
}

@ -2,7 +2,7 @@ use ndarray::prelude::*;
use crate::network;
use crate::node;
use crate::params;
use crate::params::Params;
use crate::params::ParamsTrait;
pub struct Trajectory {
time: Array1<f64>,

Loading…
Cancel
Save