Merge branch '66-feature-add-a-first-wave-of-docstrings-to-core-parts-of-the-library' into 'dev'

Add a first wave of docstrings to core parts of the library
pull/70/head
Meliurwen 2 years ago
commit ce139afdb6
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 8
      README.md
  2. 24
      reCTBN/src/ctbn.rs
  3. 1
      reCTBN/src/lib.rs
  4. 90
      reCTBN/src/network.rs
  5. 2
      reCTBN/src/parameter_learning.rs
  6. 55
      reCTBN/src/params.rs
  7. 2
      reCTBN/src/sampling.rs
  8. 2
      reCTBN/src/structure_learning.rs
  9. 2
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  10. 48
      reCTBN/src/structure_learning/hypothesis_test.rs
  11. 2
      reCTBN/src/structure_learning/score_based_algorithm.rs
  12. 2
      reCTBN/src/structure_learning/score_function.rs
  13. 2
      reCTBN/src/tools.rs

@ -56,3 +56,11 @@ To check the **formatting**:
```sh
cargo fmt --all -- --check
```
## Documentation
To generate the **documentation**:
```sh
cargo rustdoc --package reCTBN --open -- --default-theme=ayu
```

@ -1,3 +1,5 @@
//! Continuous Time Bayesian Network
use std::collections::BTreeSet;
use ndarray::prelude::*;
@ -5,16 +7,18 @@ use ndarray::prelude::*;
use crate::network;
use crate::params::{Params, ParamsTrait, StateType};
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
///composed by the following elements:
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
///- **nodes**: a vector containing all the nodes and their parameters.
///The index of a node inside the vector is also used as index for the adj_matrix.
/// It represents both the structure and the parameters of a CTBN.
///
/// # Arguments
///
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
/// * `nodes` - A vector containing all the nodes and their parameters.
///
///# Examples
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
///
///```
/// # Example
///
/// ```rust
/// use std::collections::BTreeSet;
/// use reCTBN::network::Network;
/// use reCTBN::params;
@ -66,12 +70,14 @@ impl CtbnNetwork {
}
impl network::Network for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
@ -79,6 +85,7 @@ impl network::Network for CtbnNetwork {
Ok(self.nodes.len() - 1)
}
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
@ -94,6 +101,7 @@ impl network::Network for CtbnNetwork {
0..self.nodes.len()
}
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
@ -138,6 +146,7 @@ impl network::Network for CtbnNetwork {
.0
}
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
@ -149,6 +158,7 @@ impl network::Network for CtbnNetwork {
.collect()
}
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()

@ -1,3 +1,4 @@
#![doc = include_str!("../../README.md")]
#![allow(non_snake_case)]
#[cfg(test)]
extern crate approx;

@ -1,3 +1,5 @@
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
use std::collections::BTreeSet;
use thiserror::Error;
@ -11,33 +13,103 @@ pub enum NetworkError {
NodeInsertionError(String),
}
///Network
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN).
pub trait Network {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
///
/// # Arguments
///
/// * `parent` - parent node.
/// * `child` - child node.
fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network
/// Get all the indices of the nodes contained inside the network.
fn get_node_indices(&self) -> std::ops::Range<usize>;
/// Get the numbers of nodes contained in the network.
fn get_number_of_nodes(&self) -> usize;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node param**.
fn get_node(&self, node_idx: usize) -> &params::Params;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node mutable param**.
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*.
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `node` - selected node.
/// * `current_state` - current configuration of the network.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*.
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `current_state` - current configuration of the network.
/// * `parent_set` - parent set of the selected `node`.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
/// Get the **parent set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **parent set** of the selected node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
/// Get the **children set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **children set** of the selected node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -1,3 +1,5 @@
//! Module containing methods used to learn the parameters.
use std::collections::BTreeSet;
use ndarray::prelude::*;

@ -1,3 +1,5 @@
//! Module containing methods to define different types of nodes.
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch;
@ -23,9 +25,8 @@ pub enum StateType {
Discrete(usize),
}
/// Parameters
/// The Params trait is the core element for building different types of nodes. The goal is to
/// define the set of method required to describes a generic node.
/// This is a core element for building different types of nodes; the goal is to define the set of
/// methods required to describes a generic node.
#[enum_dispatch(Params)]
pub trait ParamsTrait {
fn reset_params(&mut self);
@ -65,25 +66,27 @@ pub trait ParamsTrait {
fn get_label(&self) -> &String;
}
/// The Params enum is the core element for building different types of nodes. The goal is to
/// define all the supported type of Parameters
/// Is a core element for building different types of nodes; the goal is to define all the
/// supported type of Parameters
#[derive(Clone)]
#[enum_dispatch]
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements:
/// - **domain**: an ordered and exhaustive set of possible states
/// - **cim**: Conditional Intensity Matrix
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
/// learning task and are composed by:
/// - **transitions**: number of transitions from one state to another given a specific
/// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
/// following elements.
///
/// # Arguments
///
/// * `label` - node's variable name.
/// * `domain` - an ordered and exhaustive set of possible states.
/// * `cim` - Conditional Intensity Matrix.
/// * `transitions` - number of transitions from one state to another given a specific realization
/// of the parent set; is a sufficient statistics are mainly used during the parameter learning
/// task.
/// * `residence_time` - residence time in each possible state, given a specific realization of the
/// parent set; is a sufficient statistics are mainly used during the parameter learning task.
#[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams {
label: String,
@ -104,15 +107,17 @@ impl DiscreteStatesContinousTimeParams {
}
}
///Getter function for CIM
/// Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> {
&self.cim
}
///Setter function for CIM.\\
///This function check if the cim is valid using the validate_params method.
///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(())
///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError
/// Setter function for CIM.
///
/// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method:
/// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`.
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
/// `ParamsError`.
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
self.cim = Some(cim);
match self.validate_params() {
@ -124,27 +129,27 @@ impl DiscreteStatesContinousTimeParams {
}
}
///Unchecked version of the setter function for CIM.
/// Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim);
}
///Getter function for transitions
/// Getter function for transitions.
pub fn get_transitions(&self) -> &Option<Array3<usize>> {
&self.transitions
}
///Setter function for transitions
/// Setter function for transitions.
pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions);
}
///Getter function for residence_time
/// Getter function for residence_time.
pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
&self.residence_time
}
///Setter function for residence_time
/// Setter function for residence_time.
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time);
}

@ -1,3 +1,5 @@
//! Module containing methods for the sampling.
use crate::{
network::Network,
params::{self, ParamsTrait},

@ -1,3 +1,5 @@
//! Learn the structure of the network.
pub mod constraint_based_algorithm;
pub mod hypothesis_test;
pub mod score_based_algorithm;

@ -1,3 +1,5 @@
//! Module containing constraint based algorithms like CTPC and Hiton.
//pub struct CTPC {
//
//}

@ -1,3 +1,5 @@
//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc...
use std::collections::BTreeSet;
use ndarray::{Array3, Axis};
@ -20,6 +22,17 @@ pub trait HypothesisTest {
P: parameter_learning::ParameterLearning;
}
/// Does the chi-squared test (χ2 test).
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct ChiSquare {
alpha: f64,
}
@ -30,8 +43,21 @@ impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha }
}
// Restituisce true quando le matrici sono molto simili, quindi indipendenti
// false quando sono diverse, quindi dipendenti
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices(
&self,
i: usize,
@ -71,7 +97,7 @@ impl ChiSquare {
let n = K.len();
K.into_shape((n, 1)).unwrap()
};
println!("K: {:?}", K);
//println!("K: {:?}", K);
let L = 1.0 / &K;
// ===== 2
// \ (K . M - L . M)
@ -82,18 +108,18 @@ impl ChiSquare {
// x'ϵVal /X \
// \ i/
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1);
println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1));
println!("K*M2: {:?}", (K * &M2));
println!("X_2: {:?}", X_2);
//println!("M1: {:?}", M1);
//println!("M2: {:?}", M2);
//println!("L*M1: {:?}", (L * &M1));
//println!("K*M2: {:?}", (K * &M2));
//println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
//println!("CHI^2: {:?}", n);
//println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
println!("test: {:?}", ret);
//println!("test: {:?}", ret);
ret
}
}

@ -1,3 +1,5 @@
//! Module containing score based algorithms like Hill Climbing and Tabu Search.
use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction;

@ -1,3 +1,5 @@
//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
use std::collections::BTreeSet;
use ndarray::prelude::*;

@ -1,3 +1,5 @@
//! Contains commonly used methods used across the crate.
use ndarray::prelude::*;
use crate::sampling::{ForwardSampler, Sampler};

Loading…
Cancel
Save