From 7ec56914d91c38f27862b138728d9938651188a2 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 17 Jan 2023 21:43:54 +0100 Subject: [PATCH] Added doctest for CTPC --- .../constraint_based_algorithm.rs | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index f49b194..f9cd820 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -79,6 +79,190 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { } } +/// Continuous-Time Peter Clark algorithm. +/// +/// A method to learn the structure of the network. +/// +/// # Arguments +/// +/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters. +/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test. +/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test. +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::parameter_learning::BayesianApproach; +/// use reCTBN::structure_learning::StructureLearningAlgorithm; +/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare}; +/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC; +/// # +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("A")); +/// # domain.insert(String::from("B")); +/// # domain.insert(String::from("C")); +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain); +/// # //Create the node n1 using the parameters +/// # let n1 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("D")); +/// # domain.insert(String::from("E")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain); +/// # let n2 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("G")); +/// # domain.insert(String::from("H")); +/// # domain.insert(String::from("I")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain); +/// # let n3 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # // Initialize a ctbn +/// # let mut net = CtbnNetwork::new(); +/// # +/// # // Add the nodes and their edges +/// # let n1 = net.add_node(n1).unwrap(); +/// # let n2 = net.add_node(n2).unwrap(); +/// # let n3 = net.add_node(n3).unwrap(); +/// # net.add_edge(n1, n2); +/// # net.add_edge(n1, n3); +/// # net.add_edge(n2, n3); +/// # +/// # match &mut net.get_node_mut(n1) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-3.0, 2.0, 1.0], +/// # [1.5, -2.0, 0.5], +/// # [0.4, 0.6, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n2) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.5], +/// # [3.0, -4.0, 1.0], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 4.0], +/// # [1.5, -2.0, 0.5], +/// # [3.0, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.0, 0.1, 0.9], +/// # [2.0, -2.5, 0.5], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n3) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.3, 0.2], +/// # [0.5, -4.0, 2.5, 1.0], +/// # [2.5, 0.5, -4.0, 1.0], +/// # [0.7, 0.2, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 3.0, 1.0], +/// # [1.5, -3.0, 0.5, 1.0], +/// # [2.0, 1.3, -5.0, 1.7], +/// # [2.5, 0.5, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.3, 0.1, 0.9], +/// # [1.4, -4.0, 0.5, 2.1], +/// # [1.0, 1.5, -3.0, 0.5], +/// # [0.4, 0.3, 0.1, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.7, 0.3], +/// # [1.3, -5.9, 2.7, 1.9], +/// # [2.0, 1.5, -4.0, 0.5], +/// # [0.2, 0.7, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 1.0, 2.0, 3.0], +/// # [0.5, -3.0, 1.0, 1.5], +/// # [1.4, 2.1, -4.3, 0.8], +/// # [0.5, 1.0, 2.5, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.9, 0.3, 0.1], +/// # [0.1, -1.3, 0.2, 1.0], +/// # [0.5, 1.0, -3.0, 1.5], +/// # [0.1, 0.4, 0.3, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.6, 0.4], +/// # [2.6, -7.1, 1.4, 3.1], +/// # [5.0, 1.0, -8.0, 2.0], +/// # [1.4, 0.4, 0.2, -2.0] +/// # ], +/// # [ +/// # [-3.0, 1.0, 1.5, 0.5], +/// # [3.0, -6.0, 1.0, 2.0], +/// # [0.3, 0.5, -1.9, 1.1], +/// # [5.0, 1.0, 2.0, -8.0] +/// # ], +/// # [ +/// # [-2.6, 0.6, 0.2, 1.8], +/// # [2.0, -6.0, 3.0, 1.0], +/// # [0.1, 0.5, -1.3, 0.7], +/// # [0.8, 0.6, 0.2, -1.6] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # // Generate the trajectory +/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873)); +/// +/// // Initialize the hypothesis tests to pass to the CTPC with their +/// // respective significance level `alpha` +/// let f = F::new(1e-6); +/// let chi_sq = ChiSquare::new(1e-4); +/// // Use the bayesian approach to learn the parameters +/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; +/// +/// //Initialize CTPC +/// let ctpc = CTPC::new(parameter_learning, f, chi_sq); +/// +/// // Learn the structure of the network from the generated trajectory +/// let net = ctpc.fit_transform(net, &data); +/// # +/// # // Compare the generated network with the original one +/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); +/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +/// ``` pub struct CTPC { parameter_learning: P, Ftest: F,