Added log to params

72-feature-add-logging-and-documentation
AlessandroBregoli 2 years ago
parent e638a627bb
commit bfec2c7c60
  1. 1
      reCTBN/Cargo.toml
  2. 67
      reCTBN/src/params.rs

@ -15,6 +15,7 @@ statrs = "~0.16"
rand_chacha = "~0.3" rand_chacha = "~0.3"
itertools = "~0.10" itertools = "~0.10"
rayon = "~1.6" rayon = "~1.6"
log = "~0.4"
[dev-dependencies] [dev-dependencies]
approx = { package = "approx", version = "~0.5" } approx = { package = "approx", version = "~0.5" }

@ -3,6 +3,7 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch; use enum_dispatch::enum_dispatch;
use log::{debug, error, info, trace, warn};
use ndarray::prelude::*; use ndarray::prelude::*;
use rand::Rng; use rand::Rng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
@ -29,6 +30,7 @@ pub enum StateType {
/// methods required to describes a generic node. /// methods required to describes a generic node.
#[enum_dispatch(Params)] #[enum_dispatch(Params)]
pub trait ParamsTrait { pub trait ParamsTrait {
///Reset the parameters
fn reset_params(&mut self); fn reset_params(&mut self);
/// Randomly generate a possible state of the node disregarding the state of the node and it's /// Randomly generate a possible state of the node disregarding the state of the node and it's
@ -98,6 +100,7 @@ pub struct DiscreteStatesContinousTimeParams {
impl DiscreteStatesContinousTimeParams { impl DiscreteStatesContinousTimeParams {
pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams { pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
debug!("Creation of node {}", label);
DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams {
label, label,
domain, domain,
@ -109,6 +112,7 @@ impl DiscreteStatesContinousTimeParams {
/// Getter function for CIM /// Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> { pub fn get_cim(&self) -> &Option<Array3<f64>> {
debug!("Getting cim from node {}", self.label);
&self.cim &self.cim
} }
@ -119,10 +123,12 @@ impl DiscreteStatesContinousTimeParams {
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns /// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
/// `ParamsError`. /// `ParamsError`.
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> { pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
debug!("Setting cim for node {}", self.label);
self.cim = Some(cim); self.cim = Some(cim);
match self.validate_params() { match self.validate_params() {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(e) => { Err(e) => {
warn!("Validation cim faild for node {}", self.label);
self.cim = None; self.cim = None;
Err(e) Err(e)
} }
@ -131,39 +137,54 @@ impl DiscreteStatesContinousTimeParams {
/// Unchecked version of the setter function for CIM. /// Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) { pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
debug!("Setting cim (unchecked) for node {}", self.label);
self.cim = Some(cim); self.cim = Some(cim);
} }
/// Getter function for transitions. /// Getter function for transitions.
pub fn get_transitions(&self) -> &Option<Array3<usize>> { pub fn get_transitions(&self) -> &Option<Array3<usize>> {
debug!("Get transitions from node {}", self.label);
&self.transitions &self.transitions
} }
/// Setter function for transitions. /// Setter function for transitions.
pub fn set_transitions(&mut self, transitions: Array3<usize>) { pub fn set_transitions(&mut self, transitions: Array3<usize>) {
debug!("Set transitions for node {}", self.label);
self.transitions = Some(transitions); self.transitions = Some(transitions);
} }
/// Getter function for residence_time. /// Getter function for residence_time.
pub fn get_residence_time(&self) -> &Option<Array2<f64>> { pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
debug!("Get residence time from node {}", self.label);
&self.residence_time &self.residence_time
} }
/// Setter function for residence_time. /// Setter function for residence_time.
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) { pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
debug!("Set residence time for node {}", self.label);
self.residence_time = Some(residence_time); self.residence_time = Some(residence_time);
} }
} }
impl ParamsTrait for DiscreteStatesContinousTimeParams { impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn reset_params(&mut self) { fn reset_params(&mut self) {
debug!(
"Setting cim, transitions and residence_time to None for node {}",
self.label
);
self.cim = Option::None; self.cim = Option::None;
self.transitions = Option::None; self.transitions = Option::None;
self.residence_time = Option::None; self.residence_time = Option::None;
} }
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType {
StateType::Discrete(rng.gen_range(0..(self.domain.len()))) let state = StateType::Discrete(rng.gen_range(0..(self.domain.len())));
trace!(
"Generate random state uniform. Node: {} - State: {:?}",
self.get_label(),
&state
);
return state;
} }
fn get_random_residence_time( fn get_random_residence_time(
@ -179,11 +200,20 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
Option::Some(cim) => { Option::Some(cim) => {
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..=1.0); let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda) let ret = -x.ln() / lambda;
trace!(
"Generate random residence time. Node: {} - Time: {}",
self.get_label(),
ret
);
Ok(ret)
}
Option::None => {
warn!("Cim not initialized for node {}", self.get_label());
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)))
} }
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
} }
} }
@ -220,11 +250,21 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
next_state.0 + 1 next_state.0 + 1
}; };
Ok(StateType::Discrete(next_state)) let next_state = StateType::Discrete(next_state);
trace!(
"Generate random state. Node: {} - State: {:?}",
self.get_label(),
next_state
);
Ok(next_state)
} }
Option::None => Err(ParamsError::ParametersNotInitialized(String::from( Option::None => {
warn!("Cim not initialized for node {}", self.get_label());
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized", "CIM not initialized",
))), )))
}
} }
} }
@ -243,6 +283,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
// Check if the cim is initialized // Check if the cim is initialized
if let None = self.cim { if let None = self.cim {
warn!("Cim not initialized for node {}", self.get_label());
return Err(ParamsError::ParametersNotInitialized(String::from( return Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized", "CIM not initialized",
))); )));
@ -250,18 +291,21 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
let cim = self.cim.as_ref().unwrap(); let cim = self.cim.as_ref().unwrap();
// Check if the inner dimensions of the cim are equal to the cardinality of the variable // Check if the inner dimensions of the cim are equal to the cardinality of the variable
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size {
return Err(ParamsError::InvalidCIM(format!( let message = format!(
"Incompatible shape {:?} with domain {:?}", "Incompatible shape {:?} with domain {:?}",
cim.shape(), cim.shape(),
domain_size domain_size
))); );
warn!("{}", message);
return Err(ParamsError::InvalidCIM(message));
} }
// Check if the diagonal of each cim is non-positive // Check if the diagonal of each cim is non-positive
if cim if cim
.axis_iter(Axis(0)) .axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0)) .any(|x| x.diag().iter().any(|x| x >= &0.0))
{ {
warn!("The diagonal of each cim for node {} must be non-positive", self.get_label());
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive", "The diagonal of each cim must be non-positive",
))); )));
@ -273,6 +317,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
.iter() .iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{ {
warn!("The sum of each row of the cim for node {} must be 0", self.get_label());
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",
))); )));

Loading…
Cancel
Save