|
|
|
@ -1,6 +1,7 @@ |
|
|
|
|
use ndarray::Array2; |
|
|
|
|
use ndarray::Array3; |
|
|
|
|
use ndarray::Axis; |
|
|
|
|
use statrs::distribution::{ChiSquared, ContinuousCDF}; |
|
|
|
|
|
|
|
|
|
use crate::network; |
|
|
|
|
use crate::parameter_learning; |
|
|
|
@ -24,7 +25,7 @@ pub trait HypothesisTest { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub struct ChiSquare { |
|
|
|
|
pub alpha: f64, |
|
|
|
|
alpha: f64, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct F { |
|
|
|
@ -32,6 +33,11 @@ pub struct F { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl ChiSquare { |
|
|
|
|
pub fn new( alpha: f64) -> ChiSquare { |
|
|
|
|
ChiSquare { |
|
|
|
|
alpha |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
pub fn compare_matrices( |
|
|
|
|
&self, i: usize, |
|
|
|
|
M1: &Array3<usize>, |
|
|
|
@ -42,6 +48,7 @@ impl ChiSquare { |
|
|
|
|
// A constraint-based algorithm for the structural learning of
|
|
|
|
|
// continuous-time Bayesian networks.
|
|
|
|
|
// International Journal of Approximate Reasoning, 138, pp.105-122.
|
|
|
|
|
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
|
|
|
|
|
//
|
|
|
|
|
// M = M M = M
|
|
|
|
|
// 1 xx'|s 2 xx'|y,s
|
|
|
|
@ -70,17 +77,26 @@ impl ChiSquare { |
|
|
|
|
K.into_shape((n, 1)).unwrap() |
|
|
|
|
}; |
|
|
|
|
let L = 1.0 / &K; |
|
|
|
|
// =====
|
|
|
|
|
// \ K . M - L . M
|
|
|
|
|
// ===== 2
|
|
|
|
|
// \ (K . M - L . M)
|
|
|
|
|
// \ 2 1
|
|
|
|
|
// / ---------------
|
|
|
|
|
// / M + M
|
|
|
|
|
// ===== 2 1
|
|
|
|
|
// x'ϵVal /X \
|
|
|
|
|
// \ i/
|
|
|
|
|
let X_2 = (( K * &M2 - L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1)).sum_axis(Axis(1)); |
|
|
|
|
let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1); |
|
|
|
|
println!("M1: {:?}", M1); |
|
|
|
|
println!("M2: {:?}", M2); |
|
|
|
|
println!("L*M1: {:?}", (L * &M1)); |
|
|
|
|
println!("K*M2: {:?}", (K * &M2)); |
|
|
|
|
println!("X_2: {:?}", X_2); |
|
|
|
|
true |
|
|
|
|
X_2.diag_mut().fill(0.0); |
|
|
|
|
let X_2 = X_2.sum_axis(Axis(1)); |
|
|
|
|
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); |
|
|
|
|
println!("CHI^2: {:?}", n); |
|
|
|
|
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); |
|
|
|
|
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|