From 21ce0ffcb0d2f45785f6d94c57b4590e47aa7896 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 22 Feb 2023 11:18:40 +0100 Subject: [PATCH] Added doc to: params, process, parameter_learning, ctmp, ctbn --- reCTBN/src/parameter_learning.rs | 32 ++++++++++++++++++++++++++++++++ reCTBN/src/params.rs | 18 ++++++++++++------ reCTBN/src/process.rs | 14 +++++++++++++- reCTBN/src/process/ctbn.rs | 8 ++------ reCTBN/src/process/ctmp.rs | 7 ++++--- 5 files changed, 63 insertions(+), 16 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 3c34d06..bc06952 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -7,7 +7,17 @@ use ndarray::prelude::*; use crate::params::*; use crate::{process, tools::Dataset}; +/// It defines the required methods for learn the `Parameters` from data. pub trait ParameterLearning: Sync { + /// Fit the parameter of the `node` over a `dataset` given a `parent_set` + /// + /// # Arguments + /// + /// * `net`: a `NetworkProcess` instance + /// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics + /// * `node`: the node index for which we want to compute the sufficient statistics + /// * `parent_set`: an `Option` containing the parent set used for computing the parameters of + /// `node`. If `None`, the parent set defined in `net` will be used. fn fit( &self, net: &T, @@ -17,6 +27,19 @@ pub trait ParameterLearning: Sync { ) -> Params; } +/// Compute the sufficient statistics of a parameters computed from a dataset +/// +/// # Arguments +/// +/// * `net`: a `NetworkProcess` instance +/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics +/// * `node`: the node index for which we want to compute the sufficient statistics +/// * `parent_set`: the set of nodes (identified by indices) we want to use as parents of `node` +/// +/// # Return +/// +/// * A tuple containing the number of transitions (`Array3`) and the residence time +/// (`Array2`). pub fn sufficient_statistics( net: &T, dataset: &Dataset, @@ -70,6 +93,8 @@ pub fn sufficient_statistics( return (M, T); } + +/// Maximum Likelihood Estimation method for learning the parameters given a dataset. pub struct MLE {} impl ParameterLearning for MLE { @@ -114,6 +139,13 @@ impl ParameterLearning for MLE { } } + +/// Bayesian Approach for learning the parameters given a dataset. +/// +/// # Arguments +/// +/// `alpha`: hyperparameter for the priori over the number of transitions. +/// `tau`: hyperparameter for the priori over the residence time. pub struct BayesianApproach { pub alpha: usize, pub tau: f64, diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 119e13a..71395af 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -3,7 +3,7 @@ use std::collections::BTreeSet; use enum_dispatch::enum_dispatch; -use log::{debug, error, info, trace, warn}; +use log::{debug, error, trace, warn}; use ndarray::prelude::*; use rand::Rng; use rand_chacha::ChaCha8Rng; @@ -262,8 +262,8 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { Option::None => { warn!("Cim not initialized for node {}", self.get_label()); Err(ParamsError::ParametersNotInitialized(String::from( - "CIM not initialized", - ))) + "CIM not initialized", + ))) } } } @@ -304,8 +304,11 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { if cim .axis_iter(Axis(0)) .any(|x| x.diag().iter().any(|x| x >= &0.0)) - { - warn!("The diagonal of each cim for node {} must be non-positive", self.get_label()); + { + warn!( + "The diagonal of each cim for node {} must be non-positive", + self.get_label() + ); return Err(ParamsError::InvalidCIM(String::from( "The diagonal of each cim must be non-positive", ))); @@ -317,7 +320,10 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { .iter() .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) { - warn!("The sum of each row of the cim for node {} must be 0", self.get_label()); + warn!( + "The sum of each row of the cim for node {} must be 0", + self.get_label() + ); return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", ))); diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index 45c5e0a..ecc1391 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -23,8 +23,20 @@ pub type NetworkProcessState = Vec; /// as a CTBN). pub trait NetworkProcess: Sync { fn initialize_adj_matrix(&mut self); + + /// Add a **node** to the network + /// + /// # Arguments + /// + /// * `n` - instantiation of the `enum params::Params` describing a node + /// + /// # Return + /// + /// * A `Result` containing the `node_idx` automatically assigned if everything is fine, + /// or a `NetworkError` if something went wrong. fn add_node(&mut self, n: params::Params) -> Result; - /// Add an **directed edge** between a two nodes of the network. + + /// Add a **directed edge** between a two nodes of the network. /// /// # Arguments /// diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 162345e..bc4b28d 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -2,6 +2,7 @@ use std::collections::BTreeSet; +use log::info; use ndarray::prelude::*; use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; @@ -77,6 +78,7 @@ impl CtbnNetwork { /// /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { + info!("Network Amalgamation Started"); let variables_domain = Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); @@ -141,14 +143,12 @@ impl CtbnNetwork { } impl process::NetworkProcess for CtbnNetwork { - /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( (self.nodes.len(), self.nodes.len()).f(), )); } - /// Add a new node. fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; @@ -156,7 +156,6 @@ impl process::NetworkProcess for CtbnNetwork { Ok(self.nodes.len() - 1) } - /// Connect two nodes with a new edge. fn add_edge(&mut self, parent: usize, child: usize) { if let None = self.adj_matrix { self.initialize_adj_matrix(); @@ -172,7 +171,6 @@ impl process::NetworkProcess for CtbnNetwork { 0..self.nodes.len() } - /// Get the number of nodes of the network. fn get_number_of_nodes(&self) -> usize { self.nodes.len() } @@ -217,7 +215,6 @@ impl process::NetworkProcess for CtbnNetwork { .0 } - /// Get all the parents of the given node. fn get_parent_set(&self, node: usize) -> BTreeSet { self.adj_matrix .as_ref() @@ -229,7 +226,6 @@ impl process::NetworkProcess for CtbnNetwork { .collect() } - /// Get all the children of the given node. fn get_children_set(&self, node: usize) -> BTreeSet { self.adj_matrix .as_ref() diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 9fdca1b..592c757 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -6,7 +6,7 @@ use crate::{ }; use super::{NetworkProcess, NetworkProcessState}; -use log::{debug, error, info, trace, warn}; +use log::warn; /// This structure represents a Continuous Time Markov process /// @@ -100,8 +100,9 @@ impl NetworkProcess for CtmpProcess { Some(_) => { warn!("A CTMP do not support more than one node"); Err(process::NetworkError::NodeInsertionError( - "CtmpProcess has only one node".to_string(), - ))} + "CtmpProcess has only one node".to_string(), + )) + } } }