diff --git a/README.md b/README.md index f928955..6a60dff 100644 --- a/README.md +++ b/README.md @@ -56,3 +56,11 @@ To check the **formatting**: ```sh cargo fmt --all -- --check ``` + +## Documentation + +To generate the **documentation**: + +```sh +cargo rustdoc --package reCTBN --open -- --default-theme=ayu +``` diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/ctbn.rs index e2f5dd7..2b01d14 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/ctbn.rs @@ -1,3 +1,5 @@ +//! Continuous Time Bayesian Network + use std::collections::BTreeSet; use ndarray::prelude::*; @@ -5,16 +7,18 @@ use ndarray::prelude::*; use crate::network; use crate::params::{Params, ParamsTrait, StateType}; -///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is -///composed by the following elements: -///- **adj_metrix**: a 2d ndarray representing the adjacency matrix -///- **nodes**: a vector containing all the nodes and their parameters. -///The index of a node inside the vector is also used as index for the adj_matrix. +/// It represents both the structure and the parameters of a CTBN. +/// +/// # Arguments +/// +/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix +/// * `nodes` - A vector containing all the nodes and their parameters. +/// +/// The index of a node inside the vector is also used as index for the `adj_matrix`. /// -///# Examples +/// # Example /// -///``` -/// +/// ```rust /// use std::collections::BTreeSet; /// use reCTBN::network::Network; /// use reCTBN::params; @@ -66,12 +70,14 @@ impl CtbnNetwork { } impl network::Network for CtbnNetwork { + /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( (self.nodes.len(), self.nodes.len()).f(), )); } + /// Add a new node. fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; @@ -79,6 +85,7 @@ impl network::Network for CtbnNetwork { Ok(self.nodes.len() - 1) } + /// Connect two nodes with a new edge. fn add_edge(&mut self, parent: usize, child: usize) { if let None = self.adj_matrix { self.initialize_adj_matrix(); @@ -94,6 +101,7 @@ impl network::Network for CtbnNetwork { 0..self.nodes.len() } + /// Get the number of nodes of the network. fn get_number_of_nodes(&self) -> usize { self.nodes.len() } @@ -138,6 +146,7 @@ impl network::Network for CtbnNetwork { .0 } + /// Get all the parents of the given node. fn get_parent_set(&self, node: usize) -> BTreeSet { self.adj_matrix .as_ref() @@ -149,6 +158,7 @@ impl network::Network for CtbnNetwork { .collect() } + /// Get all the children of the given node. fn get_children_set(&self, node: usize) -> BTreeSet { self.adj_matrix .as_ref() diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 280bd21..db33ae4 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -1,3 +1,4 @@ +#![doc = include_str!("../../README.md")] #![allow(non_snake_case)] #[cfg(test)] extern crate approx; diff --git a/reCTBN/src/network.rs b/reCTBN/src/network.rs index cbae339..fbdd2e6 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/network.rs @@ -1,3 +1,5 @@ +//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs + use std::collections::BTreeSet; use thiserror::Error; @@ -11,33 +13,103 @@ pub enum NetworkError { NodeInsertionError(String), } -///Network -///The Network trait define the required methods for a structure used as pgm (such as ctbn). +/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such +/// as a CTBN). pub trait Network { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; + /// Add an **directed edge** between a two nodes of the network. + /// + /// # Arguments + /// + /// * `parent` - parent node. + /// * `child` - child node. fn add_edge(&mut self, parent: usize, child: usize); - ///Get all the indices of the nodes contained inside the network + /// Get all the indices of the nodes contained inside the network. fn get_node_indices(&self) -> std::ops::Range; + + /// Get the numbers of nodes contained in the network. fn get_number_of_nodes(&self) -> usize; + + /// Get the **node param**. + /// + /// # Arguments + /// + /// * `node_idx` - node index value. + /// + /// # Return + /// + /// * The selected **node param**. fn get_node(&self, node_idx: usize) -> ¶ms::Params; + + /// Get the **node param**. + /// + /// # Arguments + /// + /// * `node_idx` - node index value. + /// + /// # Return + /// + /// * The selected **node mutable param**. fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; - ///Compute the index that must be used to access the parameters of a node given a specific - ///configuration of the network. Usually, the only values really used in *current_state* are - ///the ones in the parent set of the *node*. + /// Compute the index that must be used to access the parameters of a `node`, given a specific + /// configuration of the network. + /// + /// Usually, the only values really used in `current_state` are the ones in the parent set of + /// the `node`. + /// + /// # Arguments + /// + /// * `node` - selected node. + /// * `current_state` - current configuration of the network. + /// + /// # Return + /// + /// * Index of the `node` relative to the network. fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; - ///Compute the index that must be used to access the parameters of a node given a specific - ///configuration of the network and a generic parent_set. Usually, the only values really used - ///in *current_state* are the ones in the parent set of the *node*. + /// Compute the index that must be used to access the parameters of a `node`, given a specific + /// configuration of the network and a generic `parent_set`. + /// + /// Usually, the only values really used in `current_state` are the ones in the parent set of + /// the `node`. + /// + /// # Arguments + /// + /// * `current_state` - current configuration of the network. + /// * `parent_set` - parent set of the selected `node`. + /// + /// # Return + /// + /// * Index of the `node` relative to the network. fn get_param_index_from_custom_parent_set( &self, current_state: &Vec, parent_set: &BTreeSet, ) -> usize; + + /// Get the **parent set** of a given **node**. + /// + /// # Arguments + /// + /// * `node` - node index value. + /// + /// # Return + /// + /// * The **parent set** of the selected node. fn get_parent_set(&self, node: usize) -> BTreeSet; + + /// Get the **children set** of a given **node**. + /// + /// # Arguments + /// + /// * `node` - node index value. + /// + /// # Return + /// + /// * The **children set** of the selected node. fn get_children_set(&self, node: usize) -> BTreeSet; } diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index bdb5d4a..61d4dca 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,3 +1,5 @@ +//! Module containing methods used to learn the parameters. + use std::collections::BTreeSet; use ndarray::prelude::*; diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index f994b99..070c997 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -1,3 +1,5 @@ +//! Module containing methods to define different types of nodes. + use std::collections::BTreeSet; use enum_dispatch::enum_dispatch; @@ -23,9 +25,8 @@ pub enum StateType { Discrete(usize), } -/// Parameters -/// The Params trait is the core element for building different types of nodes. The goal is to -/// define the set of method required to describes a generic node. +/// This is a core element for building different types of nodes; the goal is to define the set of +/// methods required to describes a generic node. #[enum_dispatch(Params)] pub trait ParamsTrait { fn reset_params(&mut self); @@ -65,25 +66,27 @@ pub trait ParamsTrait { fn get_label(&self) -> &String; } -/// The Params enum is the core element for building different types of nodes. The goal is to -/// define all the supported type of Parameters +/// Is a core element for building different types of nodes; the goal is to define all the +/// supported type of Parameters #[derive(Clone)] #[enum_dispatch] pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), } -/// DiscreteStatesContinousTime. /// This represents the parameters of a classical discrete node for ctbn and it's composed by the -/// following elements: -/// - **domain**: an ordered and exhaustive set of possible states -/// - **cim**: Conditional Intensity Matrix -/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter -/// learning task and are composed by: -/// - **transitions**: number of transitions from one state to another given a specific -/// realization of the parent set -/// - **residence_time**: permanence time in each possible states given a specific -/// realization of the parent set +/// following elements. +/// +/// # Arguments +/// +/// * `label` - node's variable name. +/// * `domain` - an ordered and exhaustive set of possible states. +/// * `cim` - Conditional Intensity Matrix. +/// * `transitions` - number of transitions from one state to another given a specific realization +/// of the parent set; is a sufficient statistics are mainly used during the parameter learning +/// task. +/// * `residence_time` - residence time in each possible state, given a specific realization of the +/// parent set; is a sufficient statistics are mainly used during the parameter learning task. #[derive(Clone)] pub struct DiscreteStatesContinousTimeParams { label: String, @@ -104,15 +107,17 @@ impl DiscreteStatesContinousTimeParams { } } - ///Getter function for CIM + /// Getter function for CIM pub fn get_cim(&self) -> &Option> { &self.cim } - ///Setter function for CIM.\\ - ///This function check if the cim is valid using the validate_params method. - ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) - ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError + /// Setter function for CIM. + /// + /// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method: + /// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`. + /// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns + /// `ParamsError`. pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError> { self.cim = Some(cim); match self.validate_params() { @@ -124,27 +129,27 @@ impl DiscreteStatesContinousTimeParams { } } - ///Unchecked version of the setter function for CIM. + /// Unchecked version of the setter function for CIM. pub fn set_cim_unchecked(&mut self, cim: Array3) { self.cim = Some(cim); } - ///Getter function for transitions + /// Getter function for transitions. pub fn get_transitions(&self) -> &Option> { &self.transitions } - ///Setter function for transitions + /// Setter function for transitions. pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); } - ///Getter function for residence_time + /// Getter function for residence_time. pub fn get_residence_time(&self) -> &Option> { &self.residence_time } - ///Setter function for residence_time + /// Setter function for residence_time. pub fn set_residence_time(&mut self, residence_time: Array2) { self.residence_time = Some(residence_time); } diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 0660939..d435634 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,3 +1,5 @@ +//! Module containing methods for the sampling. + use crate::{ network::Network, params::{self, ParamsTrait}, diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index 8b90cdf..57fed1e 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -1,3 +1,5 @@ +//! Learn the structure of the network. + pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index b3fc3e1..670c8ed 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,3 +1,5 @@ +//! Module containing constraint based algorithms like CTPC and Hiton. + //pub struct CTPC { // //} diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4f2ce18..6474155 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -1,3 +1,5 @@ +//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc... + use std::collections::BTreeSet; use ndarray::{Array3, Axis}; @@ -20,6 +22,17 @@ pub trait HypothesisTest { P: parameter_learning::ParameterLearning; } +/// Does the chi-squared test (χ2 test). +/// +/// Used to determine if a difference between two sets of data is due to chance, or if it is due to +/// a relationship (dependence) between the variables. +/// +/// # Arguments +/// +/// * `alpha` - is the significance level, the probability to reject a true null hypothesis; +/// in other words is the risk of concluding that an association between the variables exists +/// when there is no actual association. + pub struct ChiSquare { alpha: f64, } @@ -30,8 +43,21 @@ impl ChiSquare { pub fn new(alpha: f64) -> ChiSquare { ChiSquare { alpha } } - // Restituisce true quando le matrici sono molto simili, quindi indipendenti - // false quando sono diverse, quindi dipendenti + + /// Compare two matrices extracted from two 3rd-orer tensors. + /// + /// # Arguments + /// + /// * `i` - Position of the matrix of `M1` to compare with `M2`. + /// * `M1` - 3rd-order tensor 1. + /// * `j` - Position of the matrix of `M2` to compare with `M1`. + /// * `M2` - 3rd-order tensor 2. + /// + /// # Returns + /// + /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**. + /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**. + pub fn compare_matrices( &self, i: usize, @@ -71,7 +97,7 @@ impl ChiSquare { let n = K.len(); K.into_shape((n, 1)).unwrap() }; - println!("K: {:?}", K); + //println!("K: {:?}", K); let L = 1.0 / &K; // ===== 2 // \ (K . M - L . M) @@ -82,18 +108,18 @@ impl ChiSquare { // x'ϵVal /X \ // \ i/ let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); - println!("M1: {:?}", M1); - println!("M2: {:?}", M2); - println!("L*M1: {:?}", (L * &M1)); - println!("K*M2: {:?}", (K * &M2)); - println!("X_2: {:?}", X_2); + //println!("M1: {:?}", M1); + //println!("M2: {:?}", M2); + //println!("L*M1: {:?}", (L * &M1)); + //println!("K*M2: {:?}", (K * &M2)); + //println!("X_2: {:?}", X_2); X_2.diag_mut().fill(0.0); let X_2 = X_2.sum_axis(Axis(1)); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); - println!("CHI^2: {:?}", n); - println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); + //println!("CHI^2: {:?}", n); + //println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); - println!("test: {:?}", ret); + //println!("test: {:?}", ret); ret } } diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index cc8541a..9e329eb 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -1,3 +1,5 @@ +//! Module containing score based algorithms like Hill Climbing and Tabu Search. + use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index b3b1597..cb6ad7b 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -1,3 +1,5 @@ +//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc... + use std::collections::BTreeSet; use ndarray::prelude::*; diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 70bbf76..aa48883 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,3 +1,5 @@ +//! Contains commonly used methods used across the crate. + use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler};