|
|
|
@ -7,7 +7,17 @@ use ndarray::prelude::*; |
|
|
|
|
use crate::params::*; |
|
|
|
|
use crate::{process, tools::Dataset}; |
|
|
|
|
|
|
|
|
|
/// It defines the required methods for learn the `Parameters` from data.
|
|
|
|
|
pub trait ParameterLearning: Sync { |
|
|
|
|
/// Fit the parameter of the `node` over a `dataset` given a `parent_set`
|
|
|
|
|
///
|
|
|
|
|
/// # Arguments
|
|
|
|
|
///
|
|
|
|
|
/// * `net`: a `NetworkProcess` instance
|
|
|
|
|
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
|
|
|
|
|
/// * `node`: the node index for which we want to compute the sufficient statistics
|
|
|
|
|
/// * `parent_set`: an `Option` containing the parent set used for computing the parameters of
|
|
|
|
|
/// `node`. If `None`, the parent set defined in `net` will be used.
|
|
|
|
|
fn fit<T: process::NetworkProcess>( |
|
|
|
|
&self, |
|
|
|
|
net: &T, |
|
|
|
@ -17,6 +27,19 @@ pub trait ParameterLearning: Sync { |
|
|
|
|
) -> Params; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Compute the sufficient statistics of a parameters computed from a dataset
|
|
|
|
|
///
|
|
|
|
|
/// # Arguments
|
|
|
|
|
///
|
|
|
|
|
/// * `net`: a `NetworkProcess` instance
|
|
|
|
|
/// * `dataset`: a dataset compatible with `net` used for computing the sufficient statistics
|
|
|
|
|
/// * `node`: the node index for which we want to compute the sufficient statistics
|
|
|
|
|
/// * `parent_set`: the set of nodes (identified by indices) we want to use as parents of `node`
|
|
|
|
|
///
|
|
|
|
|
/// # Return
|
|
|
|
|
///
|
|
|
|
|
/// * A tuple containing the number of transitions (`Array3<usize>`) and the residence time
|
|
|
|
|
/// (`Array2<f64>`).
|
|
|
|
|
pub fn sufficient_statistics<T: process::NetworkProcess>( |
|
|
|
|
net: &T, |
|
|
|
|
dataset: &Dataset, |
|
|
|
@ -70,6 +93,89 @@ pub fn sufficient_statistics<T: process::NetworkProcess>( |
|
|
|
|
return (M, T); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// Maximum Likelihood Estimation method for learning the parameters given a dataset.
|
|
|
|
|
///
|
|
|
|
|
/// # Example
|
|
|
|
|
/// ```rust
|
|
|
|
|
///
|
|
|
|
|
/// use std::collections::BTreeSet;
|
|
|
|
|
/// use reCTBN::process::NetworkProcess;
|
|
|
|
|
/// use reCTBN::params;
|
|
|
|
|
/// use reCTBN::process::ctbn::*;
|
|
|
|
|
/// use ndarray::arr3;
|
|
|
|
|
/// use reCTBN::tools::*;
|
|
|
|
|
/// use reCTBN::parameter_learning::*;
|
|
|
|
|
/// use approx::AbsDiffEq;
|
|
|
|
|
///
|
|
|
|
|
/// //Create the domain for a discrete node
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
///
|
|
|
|
|
/// //Create the parameters for a discrete node using the domain
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
|
|
|
|
///
|
|
|
|
|
/// //Create the node using the parameters
|
|
|
|
|
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
|
|
|
|
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize a ctbn
|
|
|
|
|
/// let mut net = CtbnNetwork::new();
|
|
|
|
|
///
|
|
|
|
|
/// //Add nodes
|
|
|
|
|
/// let X1 = net.add_node(X1).unwrap();
|
|
|
|
|
/// let X2 = net.add_node(X2).unwrap();
|
|
|
|
|
///
|
|
|
|
|
/// //Add an edge
|
|
|
|
|
/// net.add_edge(X1, X2);
|
|
|
|
|
///
|
|
|
|
|
/// //Add the CIMs for each node
|
|
|
|
|
/// match &mut net.get_node_mut(X1) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
/// match &mut net.get_node_mut(X2) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(
|
|
|
|
|
/// Ok(()),
|
|
|
|
|
/// param.set_cim(arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]))
|
|
|
|
|
/// );
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
///
|
|
|
|
|
/// //Generate a synthetic dataset from net
|
|
|
|
|
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize the `struct MLE`
|
|
|
|
|
/// let pl = MLE{};
|
|
|
|
|
///
|
|
|
|
|
/// // Fit the parameters for X2
|
|
|
|
|
/// let p = match pl.fit(&net, &data, X2, None) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(p) => p,
|
|
|
|
|
/// };
|
|
|
|
|
///
|
|
|
|
|
/// // Check the shape of the CIM
|
|
|
|
|
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
|
|
|
|
|
///
|
|
|
|
|
/// // Check if the learned parameters are close enough to the real ones.
|
|
|
|
|
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
|
|
|
|
|
/// &arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]),
|
|
|
|
|
/// 0.1
|
|
|
|
|
/// ));
|
|
|
|
|
/// ```
|
|
|
|
|
pub struct MLE {} |
|
|
|
|
|
|
|
|
|
impl ParameterLearning for MLE { |
|
|
|
@ -114,6 +220,95 @@ impl ParameterLearning for MLE { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// Bayesian Approach for learning the parameters given a dataset.
|
|
|
|
|
///
|
|
|
|
|
/// # Arguments
|
|
|
|
|
///
|
|
|
|
|
/// `alpha`: hyperparameter for the priori over the number of transitions.
|
|
|
|
|
/// `tau`: hyperparameter for the priori over the residence time.
|
|
|
|
|
///
|
|
|
|
|
/// # Example
|
|
|
|
|
///
|
|
|
|
|
/// ```rust
|
|
|
|
|
///
|
|
|
|
|
/// use std::collections::BTreeSet;
|
|
|
|
|
/// use reCTBN::process::NetworkProcess;
|
|
|
|
|
/// use reCTBN::params;
|
|
|
|
|
/// use reCTBN::process::ctbn::*;
|
|
|
|
|
/// use ndarray::arr3;
|
|
|
|
|
/// use reCTBN::tools::*;
|
|
|
|
|
/// use reCTBN::parameter_learning::*;
|
|
|
|
|
/// use approx::AbsDiffEq;
|
|
|
|
|
///
|
|
|
|
|
/// //Create the domain for a discrete node
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
///
|
|
|
|
|
/// //Create the parameters for a discrete node using the domain
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
|
|
|
|
///
|
|
|
|
|
/// //Create the node using the parameters
|
|
|
|
|
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
|
|
|
|
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize a ctbn
|
|
|
|
|
/// let mut net = CtbnNetwork::new();
|
|
|
|
|
///
|
|
|
|
|
/// //Add nodes
|
|
|
|
|
/// let X1 = net.add_node(X1).unwrap();
|
|
|
|
|
/// let X2 = net.add_node(X2).unwrap();
|
|
|
|
|
///
|
|
|
|
|
/// //Add an edge
|
|
|
|
|
/// net.add_edge(X1, X2);
|
|
|
|
|
///
|
|
|
|
|
/// //Add the CIMs for each node
|
|
|
|
|
/// match &mut net.get_node_mut(X1) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
/// match &mut net.get_node_mut(X2) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(
|
|
|
|
|
/// Ok(()),
|
|
|
|
|
/// param.set_cim(arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]))
|
|
|
|
|
/// );
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
///
|
|
|
|
|
/// //Generate a synthetic dataset from net
|
|
|
|
|
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize the `struct BayesianApproach`
|
|
|
|
|
/// let pl = BayesianApproach{alpha: 1, tau: 1.0};
|
|
|
|
|
///
|
|
|
|
|
/// // Fit the parameters for X2
|
|
|
|
|
/// let p = match pl.fit(&net, &data, X2, None) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(p) => p,
|
|
|
|
|
/// };
|
|
|
|
|
///
|
|
|
|
|
/// // Check the shape of the CIM
|
|
|
|
|
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
|
|
|
|
|
///
|
|
|
|
|
/// // Check if the learned parameters are close enough to the real ones.
|
|
|
|
|
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
|
|
|
|
|
/// &arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]),
|
|
|
|
|
/// 0.1
|
|
|
|
|
/// ));
|
|
|
|
|
/// ```
|
|
|
|
|
pub struct BayesianApproach { |
|
|
|
|
pub alpha: usize, |
|
|
|
|
pub tau: f64, |
|
|
|
|