Merge branch '66-feature-add-a-first-wave-of-docstrings-to-core-parts-of-the-library' into 'dev'

Add a first wave of docstrings to core parts of the library
pull/70/head
Meliurwen 2 years ago
commit ce139afdb6
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 8
      README.md
  2. 24
      reCTBN/src/ctbn.rs
  3. 1
      reCTBN/src/lib.rs
  4. 90
      reCTBN/src/network.rs
  5. 2
      reCTBN/src/parameter_learning.rs
  6. 51
      reCTBN/src/params.rs
  7. 2
      reCTBN/src/sampling.rs
  8. 2
      reCTBN/src/structure_learning.rs
  9. 2
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  10. 48
      reCTBN/src/structure_learning/hypothesis_test.rs
  11. 2
      reCTBN/src/structure_learning/score_based_algorithm.rs
  12. 2
      reCTBN/src/structure_learning/score_function.rs
  13. 2
      reCTBN/src/tools.rs

@ -56,3 +56,11 @@ To check the **formatting**:
```sh ```sh
cargo fmt --all -- --check cargo fmt --all -- --check
``` ```
## Documentation
To generate the **documentation**:
```sh
cargo rustdoc --package reCTBN --open -- --default-theme=ayu
```

@ -1,3 +1,5 @@
//! Continuous Time Bayesian Network
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
@ -5,16 +7,18 @@ use ndarray::prelude::*;
use crate::network; use crate::network;
use crate::params::{Params, ParamsTrait, StateType}; use crate::params::{Params, ParamsTrait, StateType};
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is /// It represents both the structure and the parameters of a CTBN.
///composed by the following elements:
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
///- **nodes**: a vector containing all the nodes and their parameters.
///The index of a node inside the vector is also used as index for the adj_matrix.
/// ///
///# Examples /// # Arguments
/// ///
///``` /// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
/// * `nodes` - A vector containing all the nodes and their parameters.
///
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
///
/// # Example
/// ///
/// ```rust
/// use std::collections::BTreeSet; /// use std::collections::BTreeSet;
/// use reCTBN::network::Network; /// use reCTBN::network::Network;
/// use reCTBN::params; /// use reCTBN::params;
@ -66,12 +70,14 @@ impl CtbnNetwork {
} }
impl network::Network for CtbnNetwork { impl network::Network for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) { fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros( self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(), (self.nodes.len(), self.nodes.len()).f(),
)); ));
} }
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.reset_params(); n.reset_params();
self.adj_matrix = Option::None; self.adj_matrix = Option::None;
@ -79,6 +85,7 @@ impl network::Network for CtbnNetwork {
Ok(self.nodes.len() - 1) Ok(self.nodes.len() - 1)
} }
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) { fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix { if let None = self.adj_matrix {
self.initialize_adj_matrix(); self.initialize_adj_matrix();
@ -94,6 +101,7 @@ impl network::Network for CtbnNetwork {
0..self.nodes.len() 0..self.nodes.len()
} }
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize { fn get_number_of_nodes(&self) -> usize {
self.nodes.len() self.nodes.len()
} }
@ -138,6 +146,7 @@ impl network::Network for CtbnNetwork {
.0 .0
} }
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix self.adj_matrix
.as_ref() .as_ref()
@ -149,6 +158,7 @@ impl network::Network for CtbnNetwork {
.collect() .collect()
} }
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix self.adj_matrix
.as_ref() .as_ref()

@ -1,3 +1,4 @@
#![doc = include_str!("../../README.md")]
#![allow(non_snake_case)] #![allow(non_snake_case)]
#[cfg(test)] #[cfg(test)]
extern crate approx; extern crate approx;

@ -1,3 +1,5 @@
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
use std::collections::BTreeSet; use std::collections::BTreeSet;
use thiserror::Error; use thiserror::Error;
@ -11,33 +13,103 @@ pub enum NetworkError {
NodeInsertionError(String), NodeInsertionError(String),
} }
///Network /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
///The Network trait define the required methods for a structure used as pgm (such as ctbn). /// as a CTBN).
pub trait Network { pub trait Network {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
///
/// # Arguments
///
/// * `parent` - parent node.
/// * `child` - child node.
fn add_edge(&mut self, parent: usize, child: usize); fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network /// Get all the indices of the nodes contained inside the network.
fn get_node_indices(&self) -> std::ops::Range<usize>; fn get_node_indices(&self) -> std::ops::Range<usize>;
/// Get the numbers of nodes contained in the network.
fn get_number_of_nodes(&self) -> usize; fn get_number_of_nodes(&self) -> usize;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node param**.
fn get_node(&self, node_idx: usize) -> &params::Params; fn get_node(&self, node_idx: usize) -> &params::Params;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node mutable param**.
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
///Compute the index that must be used to access the parameters of a node given a specific /// Compute the index that must be used to access the parameters of a `node`, given a specific
///configuration of the network. Usually, the only values really used in *current_state* are /// configuration of the network.
///the ones in the parent set of the *node*. ///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `node` - selected node.
/// * `current_state` - current configuration of the network.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize; -> usize;
///Compute the index that must be used to access the parameters of a node given a specific /// Compute the index that must be used to access the parameters of a `node`, given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used /// configuration of the network and a generic `parent_set`.
///in *current_state* are the ones in the parent set of the *node*. ///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `current_state` - current configuration of the network.
/// * `parent_set` - parent set of the selected `node`.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_from_custom_parent_set( fn get_param_index_from_custom_parent_set(
&self, &self,
current_state: &Vec<params::StateType>, current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>, parent_set: &BTreeSet<usize>,
) -> usize; ) -> usize;
/// Get the **parent set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **parent set** of the selected node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
/// Get the **children set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **children set** of the selected node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
} }

@ -1,3 +1,5 @@
//! Module containing methods used to learn the parameters.
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;

@ -1,3 +1,5 @@
//! Module containing methods to define different types of nodes.
use std::collections::BTreeSet; use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch; use enum_dispatch::enum_dispatch;
@ -23,9 +25,8 @@ pub enum StateType {
Discrete(usize), Discrete(usize),
} }
/// Parameters /// This is a core element for building different types of nodes; the goal is to define the set of
/// The Params trait is the core element for building different types of nodes. The goal is to /// methods required to describes a generic node.
/// define the set of method required to describes a generic node.
#[enum_dispatch(Params)] #[enum_dispatch(Params)]
pub trait ParamsTrait { pub trait ParamsTrait {
fn reset_params(&mut self); fn reset_params(&mut self);
@ -65,25 +66,27 @@ pub trait ParamsTrait {
fn get_label(&self) -> &String; fn get_label(&self) -> &String;
} }
/// The Params enum is the core element for building different types of nodes. The goal is to /// Is a core element for building different types of nodes; the goal is to define all the
/// define all the supported type of Parameters /// supported type of Parameters
#[derive(Clone)] #[derive(Clone)]
#[enum_dispatch] #[enum_dispatch]
pub enum Params { pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
} }
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements: /// following elements.
/// - **domain**: an ordered and exhaustive set of possible states ///
/// - **cim**: Conditional Intensity Matrix /// # Arguments
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter ///
/// learning task and are composed by: /// * `label` - node's variable name.
/// - **transitions**: number of transitions from one state to another given a specific /// * `domain` - an ordered and exhaustive set of possible states.
/// realization of the parent set /// * `cim` - Conditional Intensity Matrix.
/// - **residence_time**: permanence time in each possible states given a specific /// * `transitions` - number of transitions from one state to another given a specific realization
/// realization of the parent set /// of the parent set; is a sufficient statistics are mainly used during the parameter learning
/// task.
/// * `residence_time` - residence time in each possible state, given a specific realization of the
/// parent set; is a sufficient statistics are mainly used during the parameter learning task.
#[derive(Clone)] #[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams { pub struct DiscreteStatesContinousTimeParams {
label: String, label: String,
@ -109,10 +112,12 @@ impl DiscreteStatesContinousTimeParams {
&self.cim &self.cim
} }
///Setter function for CIM.\\ /// Setter function for CIM.
///This function check if the cim is valid using the validate_params method. ///
///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) /// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method:
///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError /// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`.
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
/// `ParamsError`.
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> { pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
self.cim = Some(cim); self.cim = Some(cim);
match self.validate_params() { match self.validate_params() {
@ -129,22 +134,22 @@ impl DiscreteStatesContinousTimeParams {
self.cim = Some(cim); self.cim = Some(cim);
} }
///Getter function for transitions /// Getter function for transitions.
pub fn get_transitions(&self) -> &Option<Array3<usize>> { pub fn get_transitions(&self) -> &Option<Array3<usize>> {
&self.transitions &self.transitions
} }
///Setter function for transitions /// Setter function for transitions.
pub fn set_transitions(&mut self, transitions: Array3<usize>) { pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions); self.transitions = Some(transitions);
} }
///Getter function for residence_time /// Getter function for residence_time.
pub fn get_residence_time(&self) -> &Option<Array2<f64>> { pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
&self.residence_time &self.residence_time
} }
///Setter function for residence_time /// Setter function for residence_time.
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) { pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time); self.residence_time = Some(residence_time);
} }

@ -1,3 +1,5 @@
//! Module containing methods for the sampling.
use crate::{ use crate::{
network::Network, network::Network,
params::{self, ParamsTrait}, params::{self, ParamsTrait},

@ -1,3 +1,5 @@
//! Learn the structure of the network.
pub mod constraint_based_algorithm; pub mod constraint_based_algorithm;
pub mod hypothesis_test; pub mod hypothesis_test;
pub mod score_based_algorithm; pub mod score_based_algorithm;

@ -1,3 +1,5 @@
//! Module containing constraint based algorithms like CTPC and Hiton.
//pub struct CTPC { //pub struct CTPC {
// //
//} //}

@ -1,3 +1,5 @@
//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc...
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::{Array3, Axis}; use ndarray::{Array3, Axis};
@ -20,6 +22,17 @@ pub trait HypothesisTest {
P: parameter_learning::ParameterLearning; P: parameter_learning::ParameterLearning;
} }
/// Does the chi-squared test (χ2 test).
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct ChiSquare { pub struct ChiSquare {
alpha: f64, alpha: f64,
} }
@ -30,8 +43,21 @@ impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare { pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha } ChiSquare { alpha }
} }
// Restituisce true quando le matrici sono molto simili, quindi indipendenti
// false quando sono diverse, quindi dipendenti /// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices( pub fn compare_matrices(
&self, &self,
i: usize, i: usize,
@ -71,7 +97,7 @@ impl ChiSquare {
let n = K.len(); let n = K.len();
K.into_shape((n, 1)).unwrap() K.into_shape((n, 1)).unwrap()
}; };
println!("K: {:?}", K); //println!("K: {:?}", K);
let L = 1.0 / &K; let L = 1.0 / &K;
// ===== 2 // ===== 2
// \ (K . M - L . M) // \ (K . M - L . M)
@ -82,18 +108,18 @@ impl ChiSquare {
// x'ϵVal /X \ // x'ϵVal /X \
// \ i/ // \ i/
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1); //println!("M1: {:?}", M1);
println!("M2: {:?}", M2); //println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1)); //println!("L*M1: {:?}", (L * &M1));
println!("K*M2: {:?}", (K * &M2)); //println!("K*M2: {:?}", (K * &M2));
println!("X_2: {:?}", X_2); //println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0); X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1)); let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n); //println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); //println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
println!("test: {:?}", ret); //println!("test: {:?}", ret);
ret ret
} }
} }

@ -1,3 +1,5 @@
//! Module containing score based algorithms like Hill Climbing and Tabu Search.
use std::collections::BTreeSet; use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::score_function::ScoreFunction;

@ -1,3 +1,5 @@
//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;

@ -1,3 +1,5 @@
//! Contains commonly used methods used across the crate.
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};

Loading…
Cancel
Save