Added doctest for parameter_learning

72-feature-add-logging-and-documentation
AlessandroBregoli 2 years ago
parent 21ce0ffcb0
commit 87da2a7c9e
  1. 163
      reCTBN/src/parameter_learning.rs
  2. 2
      reCTBN/src/process/ctmp.rs

@ -95,6 +95,87 @@ pub fn sufficient_statistics<T: process::NetworkProcess>(
/// Maximum Likelihood Estimation method for learning the parameters given a dataset. /// Maximum Likelihood Estimation method for learning the parameters given a dataset.
///
/// # Example
/// ```rust
///
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
/// use ndarray::arr3;
/// use reCTBN::tools::*;
/// use reCTBN::parameter_learning::*;
/// use approx::AbsDiffEq;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Add the CIMs for each node
/// match &mut net.get_node_mut(X1) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
/// }
/// }
/// match &mut net.get_node_mut(X2) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(
/// Ok(()),
/// param.set_cim(arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]))
/// );
/// }
/// }
///
/// //Generate a synthetic dataset from net
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
///
/// //Initialize the `struct MLE`
/// let pl = MLE{};
///
/// // Fit the parameters for X2
/// let p = match pl.fit(&net, &data, X2, None) {
/// params::Params::DiscreteStatesContinousTime(p) => p,
/// };
///
/// // Check the shape of the CIM
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
///
/// // Check if the learned parameters are close enough to the real ones.
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
/// &arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]),
/// 0.1
/// ));
/// ```
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
@ -146,6 +227,88 @@ impl ParameterLearning for MLE {
/// ///
/// `alpha`: hyperparameter for the priori over the number of transitions. /// `alpha`: hyperparameter for the priori over the number of transitions.
/// `tau`: hyperparameter for the priori over the residence time. /// `tau`: hyperparameter for the priori over the residence time.
///
/// # Example
///
/// ```rust
///
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
/// use ndarray::arr3;
/// use reCTBN::tools::*;
/// use reCTBN::parameter_learning::*;
/// use approx::AbsDiffEq;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Add the CIMs for each node
/// match &mut net.get_node_mut(X1) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
/// }
/// }
/// match &mut net.get_node_mut(X2) {
/// params::Params::DiscreteStatesContinousTime(param) => {
/// assert_eq!(
/// Ok(()),
/// param.set_cim(arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]))
/// );
/// }
/// }
///
/// //Generate a synthetic dataset from net
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
///
/// //Initialize the `struct BayesianApproach`
/// let pl = BayesianApproach{alpha: 1, tau: 1.0};
///
/// // Fit the parameters for X2
/// let p = match pl.fit(&net, &data, X2, None) {
/// params::Params::DiscreteStatesContinousTime(p) => p,
/// };
///
/// // Check the shape of the CIM
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
///
/// // Check if the learned parameters are close enough to the real ones.
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
/// &arr3(&[
/// [[-0.01, 0.01], [5.0, -5.0]],
/// [[-5.0, 5.0], [0.01, -0.01]]
/// ]),
/// 0.1
/// ));
/// ```
pub struct BayesianApproach { pub struct BayesianApproach {
pub alpha: usize, pub alpha: usize,
pub tau: f64, pub tau: f64,

@ -107,7 +107,7 @@ impl NetworkProcess for CtmpProcess {
} }
fn add_edge(&mut self, _parent: usize, _child: usize) { fn add_edge(&mut self, _parent: usize, _child: usize) {
warn!("A CTMP cannot have endges"); warn!("A CTMP cannot have edges");
unimplemented!("CtmpProcess has only one node") unimplemented!("CtmpProcess has only one node")
} }

Loading…
Cancel
Save