From 7ec56914d91c38f27862b138728d9938651188a2 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 17 Jan 2023 21:43:54 +0100 Subject: [PATCH 1/2] Added doctest for CTPC --- .../constraint_based_algorithm.rs | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index f49b194..f9cd820 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -79,6 +79,190 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { } } +/// Continuous-Time Peter Clark algorithm. +/// +/// A method to learn the structure of the network. +/// +/// # Arguments +/// +/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters. +/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test. +/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test. +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::parameter_learning::BayesianApproach; +/// use reCTBN::structure_learning::StructureLearningAlgorithm; +/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare}; +/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC; +/// # +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("A")); +/// # domain.insert(String::from("B")); +/// # domain.insert(String::from("C")); +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain); +/// # //Create the node n1 using the parameters +/// # let n1 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("D")); +/// # domain.insert(String::from("E")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain); +/// # let n2 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("G")); +/// # domain.insert(String::from("H")); +/// # domain.insert(String::from("I")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain); +/// # let n3 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # // Initialize a ctbn +/// # let mut net = CtbnNetwork::new(); +/// # +/// # // Add the nodes and their edges +/// # let n1 = net.add_node(n1).unwrap(); +/// # let n2 = net.add_node(n2).unwrap(); +/// # let n3 = net.add_node(n3).unwrap(); +/// # net.add_edge(n1, n2); +/// # net.add_edge(n1, n3); +/// # net.add_edge(n2, n3); +/// # +/// # match &mut net.get_node_mut(n1) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-3.0, 2.0, 1.0], +/// # [1.5, -2.0, 0.5], +/// # [0.4, 0.6, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n2) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.5], +/// # [3.0, -4.0, 1.0], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 4.0], +/// # [1.5, -2.0, 0.5], +/// # [3.0, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.0, 0.1, 0.9], +/// # [2.0, -2.5, 0.5], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n3) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.3, 0.2], +/// # [0.5, -4.0, 2.5, 1.0], +/// # [2.5, 0.5, -4.0, 1.0], +/// # [0.7, 0.2, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 3.0, 1.0], +/// # [1.5, -3.0, 0.5, 1.0], +/// # [2.0, 1.3, -5.0, 1.7], +/// # [2.5, 0.5, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.3, 0.1, 0.9], +/// # [1.4, -4.0, 0.5, 2.1], +/// # [1.0, 1.5, -3.0, 0.5], +/// # [0.4, 0.3, 0.1, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.7, 0.3], +/// # [1.3, -5.9, 2.7, 1.9], +/// # [2.0, 1.5, -4.0, 0.5], +/// # [0.2, 0.7, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 1.0, 2.0, 3.0], +/// # [0.5, -3.0, 1.0, 1.5], +/// # [1.4, 2.1, -4.3, 0.8], +/// # [0.5, 1.0, 2.5, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.9, 0.3, 0.1], +/// # [0.1, -1.3, 0.2, 1.0], +/// # [0.5, 1.0, -3.0, 1.5], +/// # [0.1, 0.4, 0.3, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.6, 0.4], +/// # [2.6, -7.1, 1.4, 3.1], +/// # [5.0, 1.0, -8.0, 2.0], +/// # [1.4, 0.4, 0.2, -2.0] +/// # ], +/// # [ +/// # [-3.0, 1.0, 1.5, 0.5], +/// # [3.0, -6.0, 1.0, 2.0], +/// # [0.3, 0.5, -1.9, 1.1], +/// # [5.0, 1.0, 2.0, -8.0] +/// # ], +/// # [ +/// # [-2.6, 0.6, 0.2, 1.8], +/// # [2.0, -6.0, 3.0, 1.0], +/// # [0.1, 0.5, -1.3, 0.7], +/// # [0.8, 0.6, 0.2, -1.6] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # // Generate the trajectory +/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873)); +/// +/// // Initialize the hypothesis tests to pass to the CTPC with their +/// // respective significance level `alpha` +/// let f = F::new(1e-6); +/// let chi_sq = ChiSquare::new(1e-4); +/// // Use the bayesian approach to learn the parameters +/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; +/// +/// //Initialize CTPC +/// let ctpc = CTPC::new(parameter_learning, f, chi_sq); +/// +/// // Learn the structure of the network from the generated trajectory +/// let net = ctpc.fit_transform(net, &data); +/// # +/// # // Compare the generated network with the original one +/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); +/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +/// ``` pub struct CTPC { parameter_learning: P, Ftest: F, From c2df26c3e6835d87f8ea2a47fefad29cffe26fba Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 18 Jan 2023 14:18:24 +0100 Subject: [PATCH 2/2] Added docstrings for the F-test and removed some comments --- .../src/structure_learning/hypothesis_test.rs | 66 +++++++------------ 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4c02929..311ec47 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -39,6 +39,17 @@ pub struct ChiSquare { alpha: f64, } +/// Does the F-test. +/// +/// Used to determine if a difference between two sets of data is due to chance, or if it is due to +/// a relationship (dependence) between the variables. +/// +/// # Arguments +/// +/// * `alpha` - is the significance level, the probability to reject a true null hypothesis; +/// in other words is the risk of concluding that an association between the variables exists +/// when there is no actual association. + pub struct F { alpha: f64, } @@ -48,6 +59,20 @@ impl F { F { alpha } } + /// Compare two matrices extracted from two 3rd-orer tensors. + /// + /// # Arguments + /// + /// * `i` - Position of the matrix of `M1` to compare with `M2`. + /// * `M1` - 3rd-order tensor 1. + /// * `j` - Position of the matrix of `M2` to compare with `M1`. + /// * `M2` - 3rd-order tensor 2. + /// + /// # Returns + /// + /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**. + /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**. + pub fn compare_matrices( &self, i: usize, @@ -164,26 +189,8 @@ impl ChiSquare { // continuous-time Bayesian networks. // International Journal of Approximate Reasoning, 138, pp.105-122. // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm - // - // M = M M = M - // 1 xx'|s 2 xx'|y,s let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); - // __________________ - // / === - // / \ M - // / / xx'|s - // / === - // / x'ϵVal /X \ - // / \ i/ 1 - //K = / ------------------ L = - - // / === K - // / \ M - // / / xx'|y,s - // / === - // / x'ϵVal /X \ - // \ / \ i/ - // \/ let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); let K = K.mapv(f64::sqrt); // Reshape to column vector. @@ -191,34 +198,16 @@ impl ChiSquare { let n = K.len(); K.into_shape((n, 1)).unwrap() }; - //println!("K: {:?}", K); let L = 1.0 / &K; - // ===== 2 - // \ (K . M - L . M) - // \ 2 1 - // / --------------- - // / M + M - // ===== 2 1 - // x'ϵVal /X \ - // \ i/ let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); - //println!("M1: {:?}", M1); - //println!("M2: {:?}", M2); - //println!("L*M1: {:?}", (L * &M1)); - //println!("K*M2: {:?}", (K * &M2)); - //println!("X_2: {:?}", X_2); X_2.diag_mut().fill(0.0); let X_2 = X_2.sum_axis(Axis(1)); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); - //println!("CHI^2: {:?}", n); - //println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); - //println!("test: {:?}", ret); ret } } -// ritorna false quando sono dipendenti e false quando sono indipendenti impl HypothesisTest for ChiSquare { fn call( &self, @@ -233,13 +222,9 @@ impl HypothesisTest for ChiSquare { T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { - // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM - // di dimensione nxn - // (CIM, M, T) let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; - // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); @@ -251,7 +236,6 @@ impl HypothesisTest for ChiSquare { ) { Params::DiscreteStatesContinousTime(node) => node, }; - // Commentare qui let partial_cardinality_product: usize = extended_separation_set .iter() .take_while(|x| **x != parent_node)