From 87da2a7c9e4cacb264f66286c1d6d19f24a6e6fb Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 22 Feb 2023 14:38:21 +0100 Subject: [PATCH] Added doctest for parameter_learning --- reCTBN/src/parameter_learning.rs | 163 +++++++++++++++++++++++++++++++ reCTBN/src/process/ctmp.rs | 2 +- 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index bc06952..77734aa 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -95,6 +95,87 @@ pub fn sufficient_statistics( /// Maximum Likelihood Estimation method for learning the parameters given a dataset. +/// +/// # Example +/// ```rust +/// +/// use std::collections::BTreeSet; +/// use reCTBN::process::NetworkProcess; +/// use reCTBN::params; +/// use reCTBN::process::ctbn::*; +/// use ndarray::arr3; +/// use reCTBN::tools::*; +/// use reCTBN::parameter_learning::*; +/// use approx::AbsDiffEq; +/// +/// //Create the domain for a discrete node +/// let mut domain = BTreeSet::new(); +/// domain.insert(String::from("A")); +/// domain.insert(String::from("B")); +/// +/// //Create the parameters for a discrete node using the domain +/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); +/// +/// //Create the node using the parameters +/// let X1 = params::Params::DiscreteStatesContinousTime(param); +/// +/// let mut domain = BTreeSet::new(); +/// domain.insert(String::from("A")); +/// domain.insert(String::from("B")); +/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); +/// let X2 = params::Params::DiscreteStatesContinousTime(param); +/// +/// //Initialize a ctbn +/// let mut net = CtbnNetwork::new(); +/// +/// //Add nodes +/// let X1 = net.add_node(X1).unwrap(); +/// let X2 = net.add_node(X2).unwrap(); +/// +/// //Add an edge +/// net.add_edge(X1, X2); +/// +/// //Add the CIMs for each node +/// match &mut net.get_node_mut(X1) { +/// params::Params::DiscreteStatesContinousTime(param) => { +/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); +/// } +/// } +/// match &mut net.get_node_mut(X2) { +/// params::Params::DiscreteStatesContinousTime(param) => { +/// assert_eq!( +/// Ok(()), +/// param.set_cim(arr3(&[ +/// [[-0.01, 0.01], [5.0, -5.0]], +/// [[-5.0, 5.0], [0.01, -0.01]] +/// ])) +/// ); +/// } +/// } +/// +/// //Generate a synthetic dataset from net +/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); +/// +/// //Initialize the `struct MLE` +/// let pl = MLE{}; +/// +/// // Fit the parameters for X2 +/// let p = match pl.fit(&net, &data, X2, None) { +/// params::Params::DiscreteStatesContinousTime(p) => p, +/// }; +/// +/// // Check the shape of the CIM +/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); +/// +/// // Check if the learned parameters are close enough to the real ones. +/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( +/// &arr3(&[ +/// [[-0.01, 0.01], [5.0, -5.0]], +/// [[-5.0, 5.0], [0.01, -0.01]] +/// ]), +/// 0.1 +/// )); +/// ``` pub struct MLE {} impl ParameterLearning for MLE { @@ -146,6 +227,88 @@ impl ParameterLearning for MLE { /// /// `alpha`: hyperparameter for the priori over the number of transitions. /// `tau`: hyperparameter for the priori over the residence time. +/// +/// # Example +/// +/// ```rust +/// +/// use std::collections::BTreeSet; +/// use reCTBN::process::NetworkProcess; +/// use reCTBN::params; +/// use reCTBN::process::ctbn::*; +/// use ndarray::arr3; +/// use reCTBN::tools::*; +/// use reCTBN::parameter_learning::*; +/// use approx::AbsDiffEq; +/// +/// //Create the domain for a discrete node +/// let mut domain = BTreeSet::new(); +/// domain.insert(String::from("A")); +/// domain.insert(String::from("B")); +/// +/// //Create the parameters for a discrete node using the domain +/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); +/// +/// //Create the node using the parameters +/// let X1 = params::Params::DiscreteStatesContinousTime(param); +/// +/// let mut domain = BTreeSet::new(); +/// domain.insert(String::from("A")); +/// domain.insert(String::from("B")); +/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); +/// let X2 = params::Params::DiscreteStatesContinousTime(param); +/// +/// //Initialize a ctbn +/// let mut net = CtbnNetwork::new(); +/// +/// //Add nodes +/// let X1 = net.add_node(X1).unwrap(); +/// let X2 = net.add_node(X2).unwrap(); +/// +/// //Add an edge +/// net.add_edge(X1, X2); +/// +/// //Add the CIMs for each node +/// match &mut net.get_node_mut(X1) { +/// params::Params::DiscreteStatesContinousTime(param) => { +/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); +/// } +/// } +/// match &mut net.get_node_mut(X2) { +/// params::Params::DiscreteStatesContinousTime(param) => { +/// assert_eq!( +/// Ok(()), +/// param.set_cim(arr3(&[ +/// [[-0.01, 0.01], [5.0, -5.0]], +/// [[-5.0, 5.0], [0.01, -0.01]] +/// ])) +/// ); +/// } +/// } +/// +/// //Generate a synthetic dataset from net +/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); +/// +/// //Initialize the `struct BayesianApproach` +/// let pl = BayesianApproach{alpha: 1, tau: 1.0}; +/// +/// // Fit the parameters for X2 +/// let p = match pl.fit(&net, &data, X2, None) { +/// params::Params::DiscreteStatesContinousTime(p) => p, +/// }; +/// +/// // Check the shape of the CIM +/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); +/// +/// // Check if the learned parameters are close enough to the real ones. +/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( +/// &arr3(&[ +/// [[-0.01, 0.01], [5.0, -5.0]], +/// [[-5.0, 5.0], [0.01, -0.01]] +/// ]), +/// 0.1 +/// )); +/// ``` pub struct BayesianApproach { pub alpha: usize, pub tau: f64, diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 592c757..6c53899 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -107,7 +107,7 @@ impl NetworkProcess for CtmpProcess { } fn add_edge(&mut self, _parent: usize, _child: usize) { - warn!("A CTMP cannot have endges"); + warn!("A CTMP cannot have edges"); unimplemented!("CtmpProcess has only one node") }