Added doc to: params, process, parameter_learning, ctmp, ctbn

72-feature-add-logging-and-documentation
AlessandroBregoli 2 years ago
parent d66173b961
commit 21ce0ffcb0
  1. 32
      reCTBN/src/parameter_learning.rs
  2. 12
      reCTBN/src/params.rs
  3. 14
      reCTBN/src/process.rs
  4. 8
      reCTBN/src/process/ctbn.rs
  5. 5
      reCTBN/src/process/ctmp.rs

@ -7,7 +7,17 @@ use ndarray::prelude::*;
use crate::params::*;
use crate::{process, tools::Dataset};
/// It defines the required methods for learn the `Parameters` from data.
pub trait ParameterLearning: Sync {
/// Fit the parameter of the `node` over a `dataset` given a `parent_set`
///
/// # Arguments
///
/// * `net`: a `NetworkProcess` instance
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
/// * `node`: the node index for which we want to compute the sufficient statistics
/// * `parent_set`: an `Option` containing the parent set used for computing the parameters of
/// `node`. If `None`, the parent set defined in `net` will be used.
fn fit<T: process::NetworkProcess>(
&self,
net: &T,
@ -17,6 +27,19 @@ pub trait ParameterLearning: Sync {
) -> Params;
}
/// Compute the sufficient statistics of a parameters computed from a dataset
///
/// # Arguments
///
/// * `net`: a `NetworkProcess` instance
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
/// * `node`: the node index for which we want to compute the sufficient statistics
/// * `parent_set`: the set of nodes (identified by indices) we want to use as parents of `node`
///
/// # Return
///
/// * A tuple containing the number of transitions (`Array3<usize>`) and the residence time
/// (`Array2<f64>`).
pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T,
dataset: &Dataset,
@ -70,6 +93,8 @@ pub fn sufficient_statistics<T: process::NetworkProcess>(
return (M, T);
}
/// Maximum Likelihood Estimation method for learning the parameters given a dataset.
pub struct MLE {}
impl ParameterLearning for MLE {
@ -114,6 +139,13 @@ impl ParameterLearning for MLE {
}
}
/// Bayesian Approach for learning the parameters given a dataset.
///
/// # Arguments
///
/// `alpha`: hyperparameter for the priori over the number of transitions.
/// `tau`: hyperparameter for the priori over the residence time.
pub struct BayesianApproach {
pub alpha: usize,
pub tau: f64,

@ -3,7 +3,7 @@
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch;
use log::{debug, error, info, trace, warn};
use log::{debug, error, trace, warn};
use ndarray::prelude::*;
use rand::Rng;
use rand_chacha::ChaCha8Rng;
@ -305,7 +305,10 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
.axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0))
{
warn!("The diagonal of each cim for node {} must be non-positive", self.get_label());
warn!(
"The diagonal of each cim for node {} must be non-positive",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
)));
@ -317,7 +320,10 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{
warn!("The sum of each row of the cim for node {} must be 0", self.get_label());
warn!(
"The sum of each row of the cim for node {} must be 0",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0",
)));

@ -23,8 +23,20 @@ pub type NetworkProcessState = Vec<params::StateType>;
/// as a CTBN).
pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self);
/// Add a **node** to the network
///
/// # Arguments
///
/// * `n` - instantiation of the `enum params::Params` describing a node
///
/// # Return
///
/// * A `Result` containing the `node_idx` automatically assigned if everything is fine,
/// or a `NetworkError` if something went wrong.
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
/// Add a **directed edge** between a two nodes of the network.
///
/// # Arguments
///

@ -2,6 +2,7 @@
use std::collections::BTreeSet;
use log::info;
use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
@ -77,6 +78,7 @@ impl CtbnNetwork {
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess {
info!("Network Amalgamation Started");
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
@ -141,14 +143,12 @@ impl CtbnNetwork {
}
impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
@ -156,7 +156,6 @@ impl process::NetworkProcess for CtbnNetwork {
Ok(self.nodes.len() - 1)
}
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
@ -172,7 +171,6 @@ impl process::NetworkProcess for CtbnNetwork {
0..self.nodes.len()
}
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
@ -217,7 +215,6 @@ impl process::NetworkProcess for CtbnNetwork {
.0
}
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
@ -229,7 +226,6 @@ impl process::NetworkProcess for CtbnNetwork {
.collect()
}
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()

@ -6,7 +6,7 @@ use crate::{
};
use super::{NetworkProcess, NetworkProcessState};
use log::{debug, error, info, trace, warn};
use log::warn;
/// This structure represents a Continuous Time Markov process
///
@ -101,7 +101,8 @@ impl NetworkProcess for CtmpProcess {
warn!("A CTMP do not support more than one node");
Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(),
))}
))
}
}
}

Loading…
Cancel
Save