From bfec2c7c60c05861a0b0ce4fb1910d4b2669dbb4 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 16 Feb 2023 16:32:51 +0100 Subject: [PATCH] Added log to params --- reCTBN/Cargo.toml | 1 + reCTBN/src/params.rs | 67 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index 4749b23..8cd97c4 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -15,6 +15,7 @@ statrs = "~0.16" rand_chacha = "~0.3" itertools = "~0.10" rayon = "~1.6" +log = "~0.4" [dev-dependencies] approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index ccbb750..119e13a 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -3,6 +3,7 @@ use std::collections::BTreeSet; use enum_dispatch::enum_dispatch; +use log::{debug, error, info, trace, warn}; use ndarray::prelude::*; use rand::Rng; use rand_chacha::ChaCha8Rng; @@ -29,6 +30,7 @@ pub enum StateType { /// methods required to describes a generic node. #[enum_dispatch(Params)] pub trait ParamsTrait { + ///Reset the parameters fn reset_params(&mut self); /// Randomly generate a possible state of the node disregarding the state of the node and it's @@ -98,6 +100,7 @@ pub struct DiscreteStatesContinousTimeParams { impl DiscreteStatesContinousTimeParams { pub fn new(label: String, domain: BTreeSet) -> DiscreteStatesContinousTimeParams { + debug!("Creation of node {}", label); DiscreteStatesContinousTimeParams { label, domain, @@ -109,6 +112,7 @@ impl DiscreteStatesContinousTimeParams { /// Getter function for CIM pub fn get_cim(&self) -> &Option> { + debug!("Getting cim from node {}", self.label); &self.cim } @@ -119,10 +123,12 @@ impl DiscreteStatesContinousTimeParams { /// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns /// `ParamsError`. pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError> { + debug!("Setting cim for node {}", self.label); self.cim = Some(cim); match self.validate_params() { Ok(()) => Ok(()), Err(e) => { + warn!("Validation cim faild for node {}", self.label); self.cim = None; Err(e) } @@ -131,39 +137,54 @@ impl DiscreteStatesContinousTimeParams { /// Unchecked version of the setter function for CIM. pub fn set_cim_unchecked(&mut self, cim: Array3) { + debug!("Setting cim (unchecked) for node {}", self.label); self.cim = Some(cim); } /// Getter function for transitions. pub fn get_transitions(&self) -> &Option> { + debug!("Get transitions from node {}", self.label); &self.transitions } /// Setter function for transitions. pub fn set_transitions(&mut self, transitions: Array3) { + debug!("Set transitions for node {}", self.label); self.transitions = Some(transitions); } /// Getter function for residence_time. pub fn get_residence_time(&self) -> &Option> { + debug!("Get residence time from node {}", self.label); &self.residence_time } /// Setter function for residence_time. pub fn set_residence_time(&mut self, residence_time: Array2) { + debug!("Set residence time for node {}", self.label); self.residence_time = Some(residence_time); } } impl ParamsTrait for DiscreteStatesContinousTimeParams { fn reset_params(&mut self) { + debug!( + "Setting cim, transitions and residence_time to None for node {}", + self.label + ); self.cim = Option::None; self.transitions = Option::None; self.residence_time = Option::None; } fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { - StateType::Discrete(rng.gen_range(0..(self.domain.len()))) + let state = StateType::Discrete(rng.gen_range(0..(self.domain.len()))); + trace!( + "Generate random state uniform. Node: {} - State: {:?}", + self.get_label(), + &state + ); + return state; } fn get_random_residence_time( @@ -179,11 +200,20 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { Option::Some(cim) => { let lambda = cim[[u, state, state]] * -1.0; let x: f64 = rng.gen_range(0.0..=1.0); - Ok(-x.ln() / lambda) + let ret = -x.ln() / lambda; + trace!( + "Generate random residence time. Node: {} - Time: {}", + self.get_label(), + ret + ); + Ok(ret) + } + Option::None => { + warn!("Cim not initialized for node {}", self.get_label()); + Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))) } - Option::None => Err(ParamsError::ParametersNotInitialized(String::from( - "CIM not initialized", - ))), } } @@ -220,11 +250,21 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { next_state.0 + 1 }; - Ok(StateType::Discrete(next_state)) + let next_state = StateType::Discrete(next_state); + trace!( + "Generate random state. Node: {} - State: {:?}", + self.get_label(), + next_state + ); + + Ok(next_state) } - Option::None => Err(ParamsError::ParametersNotInitialized(String::from( + Option::None => { + warn!("Cim not initialized for node {}", self.get_label()); + Err(ParamsError::ParametersNotInitialized(String::from( "CIM not initialized", - ))), + ))) + } } } @@ -243,6 +283,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { // Check if the cim is initialized if let None = self.cim { + warn!("Cim not initialized for node {}", self.get_label()); return Err(ParamsError::ParametersNotInitialized(String::from( "CIM not initialized", ))); @@ -250,18 +291,21 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { let cim = self.cim.as_ref().unwrap(); // Check if the inner dimensions of the cim are equal to the cardinality of the variable if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { - return Err(ParamsError::InvalidCIM(format!( + let message = format!( "Incompatible shape {:?} with domain {:?}", cim.shape(), domain_size - ))); + ); + warn!("{}", message); + return Err(ParamsError::InvalidCIM(message)); } // Check if the diagonal of each cim is non-positive if cim .axis_iter(Axis(0)) .any(|x| x.diag().iter().any(|x| x >= &0.0)) - { + { + warn!("The diagonal of each cim for node {} must be non-positive", self.get_label()); return Err(ParamsError::InvalidCIM(String::from( "The diagonal of each cim must be non-positive", ))); @@ -273,6 +317,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { .iter() .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) { + warn!("The sum of each row of the cim for node {} must be 0", self.get_label()); return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", )));