|
|
|
@ -95,6 +95,87 @@ pub fn sufficient_statistics<T: process::NetworkProcess>( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// Maximum Likelihood Estimation method for learning the parameters given a dataset.
|
|
|
|
|
///
|
|
|
|
|
/// # Example
|
|
|
|
|
/// ```rust
|
|
|
|
|
///
|
|
|
|
|
/// use std::collections::BTreeSet;
|
|
|
|
|
/// use reCTBN::process::NetworkProcess;
|
|
|
|
|
/// use reCTBN::params;
|
|
|
|
|
/// use reCTBN::process::ctbn::*;
|
|
|
|
|
/// use ndarray::arr3;
|
|
|
|
|
/// use reCTBN::tools::*;
|
|
|
|
|
/// use reCTBN::parameter_learning::*;
|
|
|
|
|
/// use approx::AbsDiffEq;
|
|
|
|
|
///
|
|
|
|
|
/// //Create the domain for a discrete node
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
///
|
|
|
|
|
/// //Create the parameters for a discrete node using the domain
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
|
|
|
|
///
|
|
|
|
|
/// //Create the node using the parameters
|
|
|
|
|
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
|
|
|
|
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize a ctbn
|
|
|
|
|
/// let mut net = CtbnNetwork::new();
|
|
|
|
|
///
|
|
|
|
|
/// //Add nodes
|
|
|
|
|
/// let X1 = net.add_node(X1).unwrap();
|
|
|
|
|
/// let X2 = net.add_node(X2).unwrap();
|
|
|
|
|
///
|
|
|
|
|
/// //Add an edge
|
|
|
|
|
/// net.add_edge(X1, X2);
|
|
|
|
|
///
|
|
|
|
|
/// //Add the CIMs for each node
|
|
|
|
|
/// match &mut net.get_node_mut(X1) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
/// match &mut net.get_node_mut(X2) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(
|
|
|
|
|
/// Ok(()),
|
|
|
|
|
/// param.set_cim(arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]))
|
|
|
|
|
/// );
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
///
|
|
|
|
|
/// //Generate a synthetic dataset from net
|
|
|
|
|
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize the `struct MLE`
|
|
|
|
|
/// let pl = MLE{};
|
|
|
|
|
///
|
|
|
|
|
/// // Fit the parameters for X2
|
|
|
|
|
/// let p = match pl.fit(&net, &data, X2, None) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(p) => p,
|
|
|
|
|
/// };
|
|
|
|
|
///
|
|
|
|
|
/// // Check the shape of the CIM
|
|
|
|
|
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
|
|
|
|
|
///
|
|
|
|
|
/// // Check if the learned parameters are close enough to the real ones.
|
|
|
|
|
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
|
|
|
|
|
/// &arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]),
|
|
|
|
|
/// 0.1
|
|
|
|
|
/// ));
|
|
|
|
|
/// ```
|
|
|
|
|
pub struct MLE {} |
|
|
|
|
|
|
|
|
|
impl ParameterLearning for MLE { |
|
|
|
@ -146,6 +227,88 @@ impl ParameterLearning for MLE { |
|
|
|
|
///
|
|
|
|
|
/// `alpha`: hyperparameter for the priori over the number of transitions.
|
|
|
|
|
/// `tau`: hyperparameter for the priori over the residence time.
|
|
|
|
|
///
|
|
|
|
|
/// # Example
|
|
|
|
|
///
|
|
|
|
|
/// ```rust
|
|
|
|
|
///
|
|
|
|
|
/// use std::collections::BTreeSet;
|
|
|
|
|
/// use reCTBN::process::NetworkProcess;
|
|
|
|
|
/// use reCTBN::params;
|
|
|
|
|
/// use reCTBN::process::ctbn::*;
|
|
|
|
|
/// use ndarray::arr3;
|
|
|
|
|
/// use reCTBN::tools::*;
|
|
|
|
|
/// use reCTBN::parameter_learning::*;
|
|
|
|
|
/// use approx::AbsDiffEq;
|
|
|
|
|
///
|
|
|
|
|
/// //Create the domain for a discrete node
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
///
|
|
|
|
|
/// //Create the parameters for a discrete node using the domain
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
|
|
|
|
///
|
|
|
|
|
/// //Create the node using the parameters
|
|
|
|
|
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
|
|
|
|
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize a ctbn
|
|
|
|
|
/// let mut net = CtbnNetwork::new();
|
|
|
|
|
///
|
|
|
|
|
/// //Add nodes
|
|
|
|
|
/// let X1 = net.add_node(X1).unwrap();
|
|
|
|
|
/// let X2 = net.add_node(X2).unwrap();
|
|
|
|
|
///
|
|
|
|
|
/// //Add an edge
|
|
|
|
|
/// net.add_edge(X1, X2);
|
|
|
|
|
///
|
|
|
|
|
/// //Add the CIMs for each node
|
|
|
|
|
/// match &mut net.get_node_mut(X1) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
/// match &mut net.get_node_mut(X2) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// assert_eq!(
|
|
|
|
|
/// Ok(()),
|
|
|
|
|
/// param.set_cim(arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]))
|
|
|
|
|
/// );
|
|
|
|
|
/// }
|
|
|
|
|
/// }
|
|
|
|
|
///
|
|
|
|
|
/// //Generate a synthetic dataset from net
|
|
|
|
|
/// let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize the `struct BayesianApproach`
|
|
|
|
|
/// let pl = BayesianApproach{alpha: 1, tau: 1.0};
|
|
|
|
|
///
|
|
|
|
|
/// // Fit the parameters for X2
|
|
|
|
|
/// let p = match pl.fit(&net, &data, X2, None) {
|
|
|
|
|
/// params::Params::DiscreteStatesContinousTime(p) => p,
|
|
|
|
|
/// };
|
|
|
|
|
///
|
|
|
|
|
/// // Check the shape of the CIM
|
|
|
|
|
/// assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
|
|
|
|
|
///
|
|
|
|
|
/// // Check if the learned parameters are close enough to the real ones.
|
|
|
|
|
/// assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
|
|
|
|
|
/// &arr3(&[
|
|
|
|
|
/// [[-0.01, 0.01], [5.0, -5.0]],
|
|
|
|
|
/// [[-5.0, 5.0], [0.01, -0.01]]
|
|
|
|
|
/// ]),
|
|
|
|
|
/// 0.1
|
|
|
|
|
/// ));
|
|
|
|
|
/// ```
|
|
|
|
|
pub struct BayesianApproach { |
|
|
|
|
pub alpha: usize, |
|
|
|
|
pub tau: f64, |
|
|
|
|