Added docstrings for the F-test and removed some comments

pull/83/head
Meliurwen 2 years ago
parent 7ec56914d9
commit c2df26c3e6
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 66
      reCTBN/src/structure_learning/hypothesis_test.rs

@ -39,6 +39,17 @@ pub struct ChiSquare {
alpha: f64, alpha: f64,
} }
/// Does the F-test.
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct F { pub struct F {
alpha: f64, alpha: f64,
} }
@ -48,6 +59,20 @@ impl F {
F { alpha } F { alpha }
} }
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices( pub fn compare_matrices(
&self, &self,
i: usize, i: usize,
@ -164,26 +189,8 @@ impl ChiSquare {
// continuous-time Bayesian networks. // continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122. // International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
//
// M = M M = M
// 1 xx'|s 2 xx'|y,s
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
// __________________
// / ===
// / \ M
// / / xx'|s
// / ===
// / x'ϵVal /X \
// / \ i/ 1
//K = / ------------------ L = -
// / === K
// / \ M
// / / xx'|y,s
// / ===
// / x'ϵVal /X \
// \ / \ i/
// \/
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
let K = K.mapv(f64::sqrt); let K = K.mapv(f64::sqrt);
// Reshape to column vector. // Reshape to column vector.
@ -191,34 +198,16 @@ impl ChiSquare {
let n = K.len(); let n = K.len();
K.into_shape((n, 1)).unwrap() K.into_shape((n, 1)).unwrap()
}; };
//println!("K: {:?}", K);
let L = 1.0 / &K; let L = 1.0 / &K;
// ===== 2
// \ (K . M - L . M)
// \ 2 1
// / ---------------
// / M + M
// ===== 2 1
// x'ϵVal /X \
// \ i/
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
//println!("M1: {:?}", M1);
//println!("M2: {:?}", M2);
//println!("L*M1: {:?}", (L * &M1));
//println!("K*M2: {:?}", (K * &M2));
//println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0); X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1)); let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
//println!("CHI^2: {:?}", n);
//println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
//println!("test: {:?}", ret);
ret ret
} }
} }
// ritorna false quando sono dipendenti e false quando sono indipendenti
impl HypothesisTest for ChiSquare { impl HypothesisTest for ChiSquare {
fn call<T, P>( fn call<T, P>(
&self, &self,
@ -233,13 +222,9 @@ impl HypothesisTest for ChiSquare {
T: process::NetworkProcess, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning, P: parameter_learning::ParameterLearning,
{ {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn
// (CIM, M, T)
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
//
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
@ -251,7 +236,6 @@ impl HypothesisTest for ChiSquare {
) { ) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
// Commentare qui
let partial_cardinality_product: usize = extended_separation_set let partial_cardinality_product: usize = extended_separation_set
.iter() .iter()
.take_while(|x| **x != parent_node) .take_while(|x| **x != parent_node)

Loading…
Cancel
Save