Added doc to: params, process, parameter_learning, ctmp, ctbn

72-feature-add-logging-and-documentation
AlessandroBregoli 2 years ago
parent d66173b961
commit 21ce0ffcb0
  1. 32
      reCTBN/src/parameter_learning.rs
  2. 16
      reCTBN/src/params.rs
  3. 14
      reCTBN/src/process.rs
  4. 8
      reCTBN/src/process/ctbn.rs
  5. 7
      reCTBN/src/process/ctmp.rs

@ -7,7 +7,17 @@ use ndarray::prelude::*;
use crate::params::*; use crate::params::*;
use crate::{process, tools::Dataset}; use crate::{process, tools::Dataset};
/// It defines the required methods for learn the `Parameters` from data.
pub trait ParameterLearning: Sync { pub trait ParameterLearning: Sync {
/// Fit the parameter of the `node` over a `dataset` given a `parent_set`
///
/// # Arguments
///
/// * `net`: a `NetworkProcess` instance
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
/// * `node`: the node index for which we want to compute the sufficient statistics
/// * `parent_set`: an `Option` containing the parent set used for computing the parameters of
/// `node`. If `None`, the parent set defined in `net` will be used.
fn fit<T: process::NetworkProcess>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
@ -17,6 +27,19 @@ pub trait ParameterLearning: Sync {
) -> Params; ) -> Params;
} }
/// Compute the sufficient statistics of a parameters computed from a dataset
///
/// # Arguments
///
/// * `net`: a `NetworkProcess` instance
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
/// * `node`: the node index for which we want to compute the sufficient statistics
/// * `parent_set`: the set of nodes (identified by indices) we want to use as parents of `node`
///
/// # Return
///
/// * A tuple containing the number of transitions (`Array3<usize>`) and the residence time
/// (`Array2<f64>`).
pub fn sufficient_statistics<T: process::NetworkProcess>( pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T, net: &T,
dataset: &Dataset, dataset: &Dataset,
@ -70,6 +93,8 @@ pub fn sufficient_statistics<T: process::NetworkProcess>(
return (M, T); return (M, T);
} }
/// Maximum Likelihood Estimation method for learning the parameters given a dataset.
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
@ -114,6 +139,13 @@ impl ParameterLearning for MLE {
} }
} }
/// Bayesian Approach for learning the parameters given a dataset.
///
/// # Arguments
///
/// `alpha`: hyperparameter for the priori over the number of transitions.
/// `tau`: hyperparameter for the priori over the residence time.
pub struct BayesianApproach { pub struct BayesianApproach {
pub alpha: usize, pub alpha: usize,
pub tau: f64, pub tau: f64,

@ -3,7 +3,7 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch; use enum_dispatch::enum_dispatch;
use log::{debug, error, info, trace, warn}; use log::{debug, error, trace, warn};
use ndarray::prelude::*; use ndarray::prelude::*;
use rand::Rng; use rand::Rng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
@ -262,8 +262,8 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
Option::None => { Option::None => {
warn!("Cim not initialized for node {}", self.get_label()); warn!("Cim not initialized for node {}", self.get_label());
Err(ParamsError::ParametersNotInitialized(String::from( Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized", "CIM not initialized",
))) )))
} }
} }
} }
@ -305,7 +305,10 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
.axis_iter(Axis(0)) .axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0)) .any(|x| x.diag().iter().any(|x| x >= &0.0))
{ {
warn!("The diagonal of each cim for node {} must be non-positive", self.get_label()); warn!(
"The diagonal of each cim for node {} must be non-positive",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive", "The diagonal of each cim must be non-positive",
))); )));
@ -317,7 +320,10 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
.iter() .iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{ {
warn!("The sum of each row of the cim for node {} must be 0", self.get_label()); warn!(
"The sum of each row of the cim for node {} must be 0",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",
))); )));

@ -23,8 +23,20 @@ pub type NetworkProcessState = Vec<params::StateType>;
/// as a CTBN). /// as a CTBN).
pub trait NetworkProcess: Sync { pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
/// Add a **node** to the network
///
/// # Arguments
///
/// * `n` - instantiation of the `enum params::Params` describing a node
///
/// # Return
///
/// * A `Result` containing the `node_idx` automatically assigned if everything is fine,
/// or a `NetworkError` if something went wrong.
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
/// Add a **directed edge** between a two nodes of the network.
/// ///
/// # Arguments /// # Arguments
/// ///

@ -2,6 +2,7 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use log::info;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
@ -77,6 +78,7 @@ impl CtbnNetwork {
/// ///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess { pub fn amalgamation(&self) -> CtmpProcess {
info!("Network Amalgamation Started");
let variables_domain = let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
@ -141,14 +143,12 @@ impl CtbnNetwork {
} }
impl process::NetworkProcess for CtbnNetwork { impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) { fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros( self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(), (self.nodes.len(), self.nodes.len()).f(),
)); ));
} }
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> { fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params(); n.reset_params();
self.adj_matrix = Option::None; self.adj_matrix = Option::None;
@ -156,7 +156,6 @@ impl process::NetworkProcess for CtbnNetwork {
Ok(self.nodes.len() - 1) Ok(self.nodes.len() - 1)
} }
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) { fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix { if let None = self.adj_matrix {
self.initialize_adj_matrix(); self.initialize_adj_matrix();
@ -172,7 +171,6 @@ impl process::NetworkProcess for CtbnNetwork {
0..self.nodes.len() 0..self.nodes.len()
} }
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize { fn get_number_of_nodes(&self) -> usize {
self.nodes.len() self.nodes.len()
} }
@ -217,7 +215,6 @@ impl process::NetworkProcess for CtbnNetwork {
.0 .0
} }
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix self.adj_matrix
.as_ref() .as_ref()
@ -229,7 +226,6 @@ impl process::NetworkProcess for CtbnNetwork {
.collect() .collect()
} }
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix self.adj_matrix
.as_ref() .as_ref()

@ -6,7 +6,7 @@ use crate::{
}; };
use super::{NetworkProcess, NetworkProcessState}; use super::{NetworkProcess, NetworkProcessState};
use log::{debug, error, info, trace, warn}; use log::warn;
/// This structure represents a Continuous Time Markov process /// This structure represents a Continuous Time Markov process
/// ///
@ -100,8 +100,9 @@ impl NetworkProcess for CtmpProcess {
Some(_) => { Some(_) => {
warn!("A CTMP do not support more than one node"); warn!("A CTMP do not support more than one node");
Err(process::NetworkError::NodeInsertionError( Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(), "CtmpProcess has only one node".to_string(),
))} ))
}
} }
} }

Loading…
Cancel
Save