|
|
|
@ -79,6 +79,190 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Continuous-Time Peter Clark algorithm.
|
|
|
|
|
///
|
|
|
|
|
/// A method to learn the structure of the network.
|
|
|
|
|
///
|
|
|
|
|
/// # Arguments
|
|
|
|
|
///
|
|
|
|
|
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
|
|
|
|
|
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
|
|
|
|
|
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
|
|
|
|
|
/// # Example
|
|
|
|
|
///
|
|
|
|
|
/// ```rust
|
|
|
|
|
/// # use std::collections::BTreeSet;
|
|
|
|
|
/// # use ndarray::{arr1, arr2, arr3};
|
|
|
|
|
/// # use reCTBN::params;
|
|
|
|
|
/// # use reCTBN::tools::trajectory_generator;
|
|
|
|
|
/// # use reCTBN::process::NetworkProcess;
|
|
|
|
|
/// # use reCTBN::process::ctbn::CtbnNetwork;
|
|
|
|
|
/// use reCTBN::parameter_learning::BayesianApproach;
|
|
|
|
|
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
|
|
|
|
|
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
|
|
|
|
|
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
|
|
|
|
|
/// #
|
|
|
|
|
/// # // Create the domain for a discrete node
|
|
|
|
|
/// # let mut domain = BTreeSet::new();
|
|
|
|
|
/// # domain.insert(String::from("A"));
|
|
|
|
|
/// # domain.insert(String::from("B"));
|
|
|
|
|
/// # domain.insert(String::from("C"));
|
|
|
|
|
/// # // Create the parameters for a discrete node using the domain
|
|
|
|
|
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
|
|
|
|
|
/// # //Create the node n1 using the parameters
|
|
|
|
|
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
/// #
|
|
|
|
|
/// # let mut domain = BTreeSet::new();
|
|
|
|
|
/// # domain.insert(String::from("D"));
|
|
|
|
|
/// # domain.insert(String::from("E"));
|
|
|
|
|
/// # domain.insert(String::from("F"));
|
|
|
|
|
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
|
|
|
|
|
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
/// #
|
|
|
|
|
/// # let mut domain = BTreeSet::new();
|
|
|
|
|
/// # domain.insert(String::from("G"));
|
|
|
|
|
/// # domain.insert(String::from("H"));
|
|
|
|
|
/// # domain.insert(String::from("I"));
|
|
|
|
|
/// # domain.insert(String::from("F"));
|
|
|
|
|
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
|
|
|
|
|
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
|
|
|
|
|
/// #
|
|
|
|
|
/// # // Initialize a ctbn
|
|
|
|
|
/// # let mut net = CtbnNetwork::new();
|
|
|
|
|
/// #
|
|
|
|
|
/// # // Add the nodes and their edges
|
|
|
|
|
/// # let n1 = net.add_node(n1).unwrap();
|
|
|
|
|
/// # let n2 = net.add_node(n2).unwrap();
|
|
|
|
|
/// # let n3 = net.add_node(n3).unwrap();
|
|
|
|
|
/// # net.add_edge(n1, n2);
|
|
|
|
|
/// # net.add_edge(n1, n3);
|
|
|
|
|
/// # net.add_edge(n2, n3);
|
|
|
|
|
/// #
|
|
|
|
|
/// # match &mut net.get_node_mut(n1) {
|
|
|
|
|
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// # assert_eq!(
|
|
|
|
|
/// # Ok(()),
|
|
|
|
|
/// # param.set_cim(arr3(&[
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-3.0, 2.0, 1.0],
|
|
|
|
|
/// # [1.5, -2.0, 0.5],
|
|
|
|
|
/// # [0.4, 0.6, -1.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # ]))
|
|
|
|
|
/// # );
|
|
|
|
|
/// # }
|
|
|
|
|
/// # }
|
|
|
|
|
/// #
|
|
|
|
|
/// # match &mut net.get_node_mut(n2) {
|
|
|
|
|
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// # assert_eq!(
|
|
|
|
|
/// # Ok(()),
|
|
|
|
|
/// # param.set_cim(arr3(&[
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-1.0, 0.5, 0.5],
|
|
|
|
|
/// # [3.0, -4.0, 1.0],
|
|
|
|
|
/// # [0.9, 0.1, -1.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-6.0, 2.0, 4.0],
|
|
|
|
|
/// # [1.5, -2.0, 0.5],
|
|
|
|
|
/// # [3.0, 1.0, -4.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-1.0, 0.1, 0.9],
|
|
|
|
|
/// # [2.0, -2.5, 0.5],
|
|
|
|
|
/// # [0.9, 0.1, -1.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # ]))
|
|
|
|
|
/// # );
|
|
|
|
|
/// # }
|
|
|
|
|
/// # }
|
|
|
|
|
/// #
|
|
|
|
|
/// # match &mut net.get_node_mut(n3) {
|
|
|
|
|
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
|
|
|
|
/// # assert_eq!(
|
|
|
|
|
/// # Ok(()),
|
|
|
|
|
/// # param.set_cim(arr3(&[
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-1.0, 0.5, 0.3, 0.2],
|
|
|
|
|
/// # [0.5, -4.0, 2.5, 1.0],
|
|
|
|
|
/// # [2.5, 0.5, -4.0, 1.0],
|
|
|
|
|
/// # [0.7, 0.2, 0.1, -1.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-6.0, 2.0, 3.0, 1.0],
|
|
|
|
|
/// # [1.5, -3.0, 0.5, 1.0],
|
|
|
|
|
/// # [2.0, 1.3, -5.0, 1.7],
|
|
|
|
|
/// # [2.5, 0.5, 1.0, -4.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-1.3, 0.3, 0.1, 0.9],
|
|
|
|
|
/// # [1.4, -4.0, 0.5, 2.1],
|
|
|
|
|
/// # [1.0, 1.5, -3.0, 0.5],
|
|
|
|
|
/// # [0.4, 0.3, 0.1, -0.8]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-2.0, 1.0, 0.7, 0.3],
|
|
|
|
|
/// # [1.3, -5.9, 2.7, 1.9],
|
|
|
|
|
/// # [2.0, 1.5, -4.0, 0.5],
|
|
|
|
|
/// # [0.2, 0.7, 0.1, -1.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-6.0, 1.0, 2.0, 3.0],
|
|
|
|
|
/// # [0.5, -3.0, 1.0, 1.5],
|
|
|
|
|
/// # [1.4, 2.1, -4.3, 0.8],
|
|
|
|
|
/// # [0.5, 1.0, 2.5, -4.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-1.3, 0.9, 0.3, 0.1],
|
|
|
|
|
/// # [0.1, -1.3, 0.2, 1.0],
|
|
|
|
|
/// # [0.5, 1.0, -3.0, 1.5],
|
|
|
|
|
/// # [0.1, 0.4, 0.3, -0.8]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-2.0, 1.0, 0.6, 0.4],
|
|
|
|
|
/// # [2.6, -7.1, 1.4, 3.1],
|
|
|
|
|
/// # [5.0, 1.0, -8.0, 2.0],
|
|
|
|
|
/// # [1.4, 0.4, 0.2, -2.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-3.0, 1.0, 1.5, 0.5],
|
|
|
|
|
/// # [3.0, -6.0, 1.0, 2.0],
|
|
|
|
|
/// # [0.3, 0.5, -1.9, 1.1],
|
|
|
|
|
/// # [5.0, 1.0, 2.0, -8.0]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # [
|
|
|
|
|
/// # [-2.6, 0.6, 0.2, 1.8],
|
|
|
|
|
/// # [2.0, -6.0, 3.0, 1.0],
|
|
|
|
|
/// # [0.1, 0.5, -1.3, 0.7],
|
|
|
|
|
/// # [0.8, 0.6, 0.2, -1.6]
|
|
|
|
|
/// # ],
|
|
|
|
|
/// # ]))
|
|
|
|
|
/// # );
|
|
|
|
|
/// # }
|
|
|
|
|
/// # }
|
|
|
|
|
/// #
|
|
|
|
|
/// # // Generate the trajectory
|
|
|
|
|
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
|
|
|
|
|
///
|
|
|
|
|
/// // Initialize the hypothesis tests to pass to the CTPC with their
|
|
|
|
|
/// // respective significance level `alpha`
|
|
|
|
|
/// let f = F::new(1e-6);
|
|
|
|
|
/// let chi_sq = ChiSquare::new(1e-4);
|
|
|
|
|
/// // Use the bayesian approach to learn the parameters
|
|
|
|
|
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize CTPC
|
|
|
|
|
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
|
|
|
|
|
///
|
|
|
|
|
/// // Learn the structure of the network from the generated trajectory
|
|
|
|
|
/// let net = ctpc.fit_transform(net, &data);
|
|
|
|
|
/// #
|
|
|
|
|
/// # // Compare the generated network with the original one
|
|
|
|
|
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
|
|
|
|
|
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
|
|
|
|
|
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
|
|
|
|
|
/// ```
|
|
|
|
|
pub struct CTPC<P: ParameterLearning> { |
|
|
|
|
parameter_learning: P, |
|
|
|
|
Ftest: F, |
|
|
|
|