Compare commits

...

5 Commits

Author SHA1 Message Date
AlessandroBregoli 87da2a7c9e Added doctest for parameter_learning 2 years ago
AlessandroBregoli 21ce0ffcb0 Added doc to: params, process, parameter_learning, ctmp, ctbn 2 years ago
Alessandro Bregoli d66173b961 Added doctest for ctmp 2 years ago
Alessandro Bregoli f176dd4fae Added logging to ctmp 2 years ago
AlessandroBregoli bfec2c7c60 Added log to params 2 years ago
  1. 1
      reCTBN/Cargo.toml
  2. 195
      reCTBN/src/parameter_learning.rs
  3. 69
      reCTBN/src/params.rs
  4. 14
      reCTBN/src/process.rs
  5. 8
      reCTBN/src/process/ctbn.rs
  6. 77
      reCTBN/src/process/ctmp.rs

@ -15,6 +15,7 @@ statrs = "~0.16"
rand_chacha = "~0.3"
itertools = "~0.10"
rayon = "~1.6"
log = "~0.4"
[dev-dependencies]
approx = { package = "approx", version = "~0.5" }

@ -7,7 +7,17 @@ use ndarray::prelude::*;
use crate::params::*;
use crate::{process, tools::Dataset};
/// It defines the required methods for learn the `Parameters` from data.
pub trait ParameterLearning: Sync {
/// Fit the parameter of the `node` over a `dataset` given a `parent_set`
///
/// # Arguments
///
/// * `net`: a `NetworkProcess` instance
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
/// * `node`: the node index for which we want to compute the sufficient statistics
/// * `parent_set`: an `Option` containing the parent set used for computing the parameters of
/// `node`. If `None`, the parent set defined in `net` will be used.
fn fit<T: process::NetworkProcess>(
&self,
net: &T,
@ -17,6 +27,19 @@ pub trait ParameterLearning: Sync {
) -> Params;
}
/// Compute the sufficient statistics of a parameters computed from a dataset
///
/// # Arguments
///
/// * `net`: a `NetworkProcess` instance
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
/// * `node`: the node index for which we want to compute the sufficient statistics
/// * `parent_set`: the set of nodes (identified by indices) we want to use as parents of `node`
///
/// # Return
///
/// * A tuple containing the number of transitions (`Array3<usize>`) and the residence time
/// (`Array2<f64>`).
pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T,
dataset: &Dataset,
@ -70,6 +93,89 @@ pub fn sufficient_statistics<T: process::NetworkProcess>(
return (M, T);
}
/// Maximum Likelihood Estimation method for learning the parameters given a dataset.
///
/// # Example
/// ```rust
///
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
/// use ndarray::arr3;
/// use reCTBN::tools::*;
/// use reCTBN::parameter_learning::*;
/// use approx::AbsDiffEq;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Add the CIMs for each node
/// match &mut net.get_node_mut(X1) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
/// }
/// }
/// match &mut net.get_node_mut(X2) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(
/// Ok(()),
/// param.set_cim(arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]))
/// );
/// }
/// }
///
/// //Generate a synthetic dataset from net
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
///
/// //Initialize the `struct MLE`
/// let pl = MLE{};
///
/// // Fit the parameters for X2
/// let p = match pl.fit(&net, &data, X2, None) {
/// params::Params::DiscreteStatesContinousTime(p) => p,
/// };
///
/// // Check the shape of the CIM
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
///
/// // Check if the learned parameters are close enough to the real ones.
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
/// &arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]),
/// 0.1
/// ));
/// ```
pub struct MLE {}
impl ParameterLearning for MLE {
@ -114,6 +220,95 @@ impl ParameterLearning for MLE {
}
}
/// Bayesian Approach for learning the parameters given a dataset.
///
/// # Arguments
///
/// `alpha`: hyperparameter for the priori over the number of transitions.
/// `tau`: hyperparameter for the priori over the residence time.
///
/// # Example
///
/// ```rust
///
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
/// use ndarray::arr3;
/// use reCTBN::tools::*;
/// use reCTBN::parameter_learning::*;
/// use approx::AbsDiffEq;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Add the CIMs for each node
/// match &mut net.get_node_mut(X1) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
/// }
/// }
/// match &mut net.get_node_mut(X2) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(
/// Ok(()),
/// param.set_cim(arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]))
/// );
/// }
/// }
///
/// //Generate a synthetic dataset from net
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
///
/// //Initialize the `struct BayesianApproach`
/// let pl = BayesianApproach{alpha: 1, tau: 1.0};
///
/// // Fit the parameters for X2
/// let p = match pl.fit(&net, &data, X2, None) {
/// params::Params::DiscreteStatesContinousTime(p) => p,
/// };
///
/// // Check the shape of the CIM
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
///
/// // Check if the learned parameters are close enough to the real ones.
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
/// &arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]),
/// 0.1
/// ));
/// ```
pub struct BayesianApproach {
pub alpha: usize,
pub tau: f64,

@ -3,6 +3,7 @@
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch;
use log::{debug, error, trace, warn};
use ndarray::prelude::*;
use rand::Rng;
use rand_chacha::ChaCha8Rng;
@ -29,6 +30,7 @@ pub enum StateType {
/// methods required to describes a generic node.
#[enum_dispatch(Params)]
pub trait ParamsTrait {
///Reset the parameters
fn reset_params(&mut self);
/// Randomly generate a possible state of the node disregarding the state of the node and it's
@ -98,6 +100,7 @@ pub struct DiscreteStatesContinousTimeParams {
impl DiscreteStatesContinousTimeParams {
pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
debug!("Creation of node {}", label);
DiscreteStatesContinousTimeParams {
label,
domain,
@ -109,6 +112,7 @@ impl DiscreteStatesContinousTimeParams {
/// Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> {
debug!("Getting cim from node {}", self.label);
&self.cim
}
@ -119,10 +123,12 @@ impl DiscreteStatesContinousTimeParams {
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
/// `ParamsError`.
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
debug!("Setting cim for node {}", self.label);
self.cim = Some(cim);
match self.validate_params() {
Ok(()) => Ok(()),
Err(e) => {
warn!("Validation cim faild for node {}", self.label);
self.cim = None;
Err(e)
}
@ -131,39 +137,54 @@ impl DiscreteStatesContinousTimeParams {
/// Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
debug!("Setting cim (unchecked) for node {}", self.label);
self.cim = Some(cim);
}
/// Getter function for transitions.
pub fn get_transitions(&self) -> &Option<Array3<usize>> {
debug!("Get transitions from node {}", self.label);
&self.transitions
}
/// Setter function for transitions.
pub fn set_transitions(&mut self, transitions: Array3<usize>) {
debug!("Set transitions for node {}", self.label);
self.transitions = Some(transitions);
}
/// Getter function for residence_time.
pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
debug!("Get residence time from node {}", self.label);
&self.residence_time
}
/// Setter function for residence_time.
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
debug!("Set residence time for node {}", self.label);
self.residence_time = Some(residence_time);
}
}
impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn reset_params(&mut self) {
debug!(
"Setting cim, transitions and residence_time to None for node {}",
self.label
);
self.cim = Option::None;
self.transitions = Option::None;
self.residence_time = Option::None;
}
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType {
StateType::Discrete(rng.gen_range(0..(self.domain.len())))
let state = StateType::Discrete(rng.gen_range(0..(self.domain.len())));
trace!(
"Generate random state uniform. Node: {} - State: {:?}",
self.get_label(),
&state
);
return state;
}
fn get_random_residence_time(
@ -179,11 +200,20 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
Option::Some(cim) => {
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda)
let ret = -x.ln() / lambda;
trace!(
"Generate random residence time. Node: {} - Time: {}",
self.get_label(),
ret
);
Ok(ret)
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
Option::None => {
warn!("Cim not initialized for node {}", self.get_label());
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
)))
}
}
}
@ -220,11 +250,21 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
next_state.0 + 1
};
Ok(StateType::Discrete(next_state))
let next_state = StateType::Discrete(next_state);
trace!(
"Generate random state. Node: {} - State: {:?}",
self.get_label(),
next_state
);
Ok(next_state)
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
Option::None => {
warn!("Cim not initialized for node {}", self.get_label());
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
)))
}
}
}
@ -243,6 +283,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
// Check if the cim is initialized
if let None = self.cim {
warn!("Cim not initialized for node {}", self.get_label());
return Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)));
@ -250,11 +291,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
let cim = self.cim.as_ref().unwrap();
// Check if the inner dimensions of the cim are equal to the cardinality of the variable
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size {
return Err(ParamsError::InvalidCIM(format!(
let message = format!(
"Incompatible shape {:?} with domain {:?}",
cim.shape(),
domain_size
)));
);
warn!("{}", message);
return Err(ParamsError::InvalidCIM(message));
}
// Check if the diagonal of each cim is non-positive
@ -262,6 +305,10 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
.axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0))
{
warn!(
"The diagonal of each cim for node {} must be non-positive",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
)));
@ -273,6 +320,10 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{
warn!(
"The sum of each row of the cim for node {} must be 0",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0",
)));

@ -23,8 +23,20 @@ pub type NetworkProcessState = Vec<params::StateType>;
/// as a CTBN).
pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self);
/// Add a **node** to the network
///
/// # Arguments
///
/// * `n` - instantiation of the `enum params::Params` describing a node
///
/// # Return
///
/// * A `Result` containing the `node_idx` automatically assigned if everything is fine,
/// or a `NetworkError` if something went wrong.
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
/// Add a **directed edge** between a two nodes of the network.
///
/// # Arguments
///

@ -2,6 +2,7 @@
use std::collections::BTreeSet;
use log::info;
use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
@ -77,6 +78,7 @@ impl CtbnNetwork {
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess {
info!("Network Amalgamation Started");
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
@ -141,14 +143,12 @@ impl CtbnNetwork {
}
impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
@ -156,7 +156,6 @@ impl process::NetworkProcess for CtbnNetwork {
Ok(self.nodes.len() - 1)
}
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
@ -172,7 +171,6 @@ impl process::NetworkProcess for CtbnNetwork {
0..self.nodes.len()
}
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
@ -217,7 +215,6 @@ impl process::NetworkProcess for CtbnNetwork {
.0
}
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
@ -229,7 +226,6 @@ impl process::NetworkProcess for CtbnNetwork {
.collect()
}
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()

@ -6,6 +6,75 @@ use crate::{
};
use super::{NetworkProcess, NetworkProcessState};
use log::warn;
/// This structure represents a Continuous Time Markov process
///
/// * Arguments
///
/// * `param` - An Option containing the parameters of the process
///
///```rust
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
/// use ndarray::arr3;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
/// match &mut net.get_node_mut(X1) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
/// }
/// }
///
/// match &mut net.get_node_mut(X2) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(
/// Ok(()),
/// param.set_cim(arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]))
/// );
/// }
/// }
/// //Amalgamate the ctbn into a CtmpProcess
/// let ctmp = net.amalgamation();
///
/// //Extract the amalgamated params from the ctmp
///let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
///let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
///
/// //The shape of the params for an amalgamated ctmp can be computed as a Cartesian product of the
/// //domains variables of the ctbn
/// assert_eq!(p_ctmp.shape()[1], 4);
///```
pub struct CtmpProcess {
param: Option<Params>,
@ -28,13 +97,17 @@ impl NetworkProcess for CtmpProcess {
self.param = Some(n);
Ok(0)
}
Some(_) => Err(process::NetworkError::NodeInsertionError(
Some(_) => {
warn!("A CTMP do not support more than one node");
Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(),
)),
))
}
}
}
fn add_edge(&mut self, _parent: usize, _child: usize) {
warn!("A CTMP cannot have edges");
unimplemented!("CtmpProcess has only one node")
}

Loading…
Cancel
Save