Merge branch '82-feature-create-doctest-for-ctpc' into 'dev'

Added doctest for CTPC, some docstrings for F-test and removed some comments in `structure_learning/hypothesis_test.rs`
88-feature-add-benchmarks
Meliurwen 2 years ago
commit 49c2c55f61
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 184
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  2. 66
      reCTBN/src/structure_learning/hypothesis_test.rs

@ -79,6 +79,190 @@ impl<'a, P: ParameterLearning> Cache<'a, P> {
} }
} }
/// Continuous-Time Peter Clark algorithm.
///
/// A method to learn the structure of the network.
///
/// # Arguments
///
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// use reCTBN::parameter_learning::BayesianApproach;
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
/// #
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("A"));
/// # domain.insert(String::from("B"));
/// # domain.insert(String::from("C"));
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
/// # //Create the node n1 using the parameters
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("D"));
/// # domain.insert(String::from("E"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("G"));
/// # domain.insert(String::from("H"));
/// # domain.insert(String::from("I"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # // Initialize a ctbn
/// # let mut net = CtbnNetwork::new();
/// #
/// # // Add the nodes and their edges
/// # let n1 = net.add_node(n1).unwrap();
/// # let n2 = net.add_node(n2).unwrap();
/// # let n3 = net.add_node(n3).unwrap();
/// # net.add_edge(n1, n2);
/// # net.add_edge(n1, n3);
/// # net.add_edge(n2, n3);
/// #
/// # match &mut net.get_node_mut(n1) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-3.0, 2.0, 1.0],
/// # [1.5, -2.0, 0.5],
/// # [0.4, 0.6, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n2) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.5],
/// # [3.0, -4.0, 1.0],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 4.0],
/// # [1.5, -2.0, 0.5],
/// # [3.0, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.0, 0.1, 0.9],
/// # [2.0, -2.5, 0.5],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n3) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.3, 0.2],
/// # [0.5, -4.0, 2.5, 1.0],
/// # [2.5, 0.5, -4.0, 1.0],
/// # [0.7, 0.2, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 3.0, 1.0],
/// # [1.5, -3.0, 0.5, 1.0],
/// # [2.0, 1.3, -5.0, 1.7],
/// # [2.5, 0.5, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.3, 0.1, 0.9],
/// # [1.4, -4.0, 0.5, 2.1],
/// # [1.0, 1.5, -3.0, 0.5],
/// # [0.4, 0.3, 0.1, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.7, 0.3],
/// # [1.3, -5.9, 2.7, 1.9],
/// # [2.0, 1.5, -4.0, 0.5],
/// # [0.2, 0.7, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 1.0, 2.0, 3.0],
/// # [0.5, -3.0, 1.0, 1.5],
/// # [1.4, 2.1, -4.3, 0.8],
/// # [0.5, 1.0, 2.5, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.9, 0.3, 0.1],
/// # [0.1, -1.3, 0.2, 1.0],
/// # [0.5, 1.0, -3.0, 1.5],
/// # [0.1, 0.4, 0.3, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.6, 0.4],
/// # [2.6, -7.1, 1.4, 3.1],
/// # [5.0, 1.0, -8.0, 2.0],
/// # [1.4, 0.4, 0.2, -2.0]
/// # ],
/// # [
/// # [-3.0, 1.0, 1.5, 0.5],
/// # [3.0, -6.0, 1.0, 2.0],
/// # [0.3, 0.5, -1.9, 1.1],
/// # [5.0, 1.0, 2.0, -8.0]
/// # ],
/// # [
/// # [-2.6, 0.6, 0.2, 1.8],
/// # [2.0, -6.0, 3.0, 1.0],
/// # [0.1, 0.5, -1.3, 0.7],
/// # [0.8, 0.6, 0.2, -1.6]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # // Generate the trajectory
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
///
/// // Initialize the hypothesis tests to pass to the CTPC with their
/// // respective significance level `alpha`
/// let f = F::new(1e-6);
/// let chi_sq = ChiSquare::new(1e-4);
/// // Use the bayesian approach to learn the parameters
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
///
/// //Initialize CTPC
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
///
/// // Learn the structure of the network from the generated trajectory
/// let net = ctpc.fit_transform(net, &data);
/// #
/// # // Compare the generated network with the original one
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
/// ```
pub struct CTPC<P: ParameterLearning> { pub struct CTPC<P: ParameterLearning> {
parameter_learning: P, parameter_learning: P,
Ftest: F, Ftest: F,

@ -39,6 +39,17 @@ pub struct ChiSquare {
alpha: f64, alpha: f64,
} }
/// Does the F-test.
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct F { pub struct F {
alpha: f64, alpha: f64,
} }
@ -48,6 +59,20 @@ impl F {
F { alpha } F { alpha }
} }
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices( pub fn compare_matrices(
&self, &self,
i: usize, i: usize,
@ -164,26 +189,8 @@ impl ChiSquare {
// continuous-time Bayesian networks. // continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122. // International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
//
// M = M M = M
// 1 xx'|s 2 xx'|y,s
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
// __________________
// / ===
// / \ M
// / / xx'|s
// / ===
// / x'ϵVal /X \
// / \ i/ 1
//K = / ------------------ L = -
// / === K
// / \ M
// / / xx'|y,s
// / ===
// / x'ϵVal /X \
// \ / \ i/
// \/
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
let K = K.mapv(f64::sqrt); let K = K.mapv(f64::sqrt);
// Reshape to column vector. // Reshape to column vector.
@ -191,34 +198,16 @@ impl ChiSquare {
let n = K.len(); let n = K.len();
K.into_shape((n, 1)).unwrap() K.into_shape((n, 1)).unwrap()
}; };
//println!("K: {:?}", K);
let L = 1.0 / &K; let L = 1.0 / &K;
// ===== 2
// \ (K . M - L . M)
// \ 2 1
// / ---------------
// / M + M
// ===== 2 1
// x'ϵVal /X \
// \ i/
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
//println!("M1: {:?}", M1);
//println!("M2: {:?}", M2);
//println!("L*M1: {:?}", (L * &M1));
//println!("K*M2: {:?}", (K * &M2));
//println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0); X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1)); let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
//println!("CHI^2: {:?}", n);
//println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
//println!("test: {:?}", ret);
ret ret
} }
} }
// ritorna false quando sono dipendenti e false quando sono indipendenti
impl HypothesisTest for ChiSquare { impl HypothesisTest for ChiSquare {
fn call<T, P>( fn call<T, P>(
&self, &self,
@ -233,13 +222,9 @@ impl HypothesisTest for ChiSquare {
T: process::NetworkProcess, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning, P: parameter_learning::ParameterLearning,
{ {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn
// (CIM, M, T)
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
//
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
@ -251,7 +236,6 @@ impl HypothesisTest for ChiSquare {
) { ) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
// Commentare qui
let partial_cardinality_product: usize = extended_separation_set let partial_cardinality_product: usize = extended_separation_set
.iter() .iter()
.take_while(|x| **x != parent_node) .take_while(|x| **x != parent_node)

Loading…
Cancel
Save