Implemented matrices comparison function in chi square

pull/47/head
Meliurwen 2 years ago
parent 2605bf3816
commit 4b35ae6310
  1. 2
      Cargo.toml
  2. 28
      src/structure_learning/hypothesis_test.rs
  3. 80
      tests/structure_learning.rs

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
ndarray = {version="*", features=["approx-0_5"]} ndarray = {version="*", features=["approx"]}
thiserror = "*" thiserror = "*"
rand = "*" rand = "*"
bimap = "*" bimap = "*"

@ -1,6 +1,7 @@
use ndarray::Array2; use ndarray::Array2;
use ndarray::Array3; use ndarray::Array3;
use ndarray::Axis; use ndarray::Axis;
use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network; use crate::network;
use crate::parameter_learning; use crate::parameter_learning;
@ -24,7 +25,7 @@ pub trait HypothesisTest {
pub struct ChiSquare { pub struct ChiSquare {
pub alpha: f64, alpha: f64,
} }
pub struct F { pub struct F {
@ -32,6 +33,11 @@ pub struct F {
} }
impl ChiSquare { impl ChiSquare {
pub fn new( alpha: f64) -> ChiSquare {
ChiSquare {
alpha
}
}
pub fn compare_matrices( pub fn compare_matrices(
&self, i: usize, &self, i: usize,
M1: &Array3<usize>, M1: &Array3<usize>,
@ -42,6 +48,7 @@ impl ChiSquare {
// A constraint-based algorithm for the structural learning of // A constraint-based algorithm for the structural learning of
// continuous-time Bayesian networks. // continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122. // International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
// //
// M = M M = M // M = M M = M
// 1 xx'|s 2 xx'|y,s // 1 xx'|s 2 xx'|y,s
@ -70,17 +77,26 @@ impl ChiSquare {
K.into_shape((n, 1)).unwrap() K.into_shape((n, 1)).unwrap()
}; };
let L = 1.0 / &K; let L = 1.0 / &K;
// ===== // ===== 2
// \ K . M - L . M // \ (K . M - L . M)
// \ 2 1 // \ 2 1
// / --------------- // / ---------------
// / M + M // / M + M
// ===== 2 1 // ===== 2 1
// x'ϵVal /X \ // x'ϵVal /X \
// \ i/ // \ i/
let X_2 = (( K * &M2 - L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1)).sum_axis(Axis(1)); let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("X_2: {:?}", X_2); println!("M1: {:?}", M1);
true println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1));
println!("K*M2: {:?}", (K * &M2));
println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha))
} }
} }

@ -321,20 +321,70 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
pub fn chi_square_compare_matrices () { pub fn chi_square_compare_matrices () {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[ 1, 2, 3], [[ 0, 2, 3],
[ 4, 5, 6]], [ 4, 0, 6],
[[ 22, 12, 90], [ 7, 8, 0]],
[3, 20, 40]], [[0, 12, 90],
[[ 1, 2, 3], [ 3, 0, 40],
[ 4, 5, 6]], [ 6, 40, 0]],
[[ 7, 8, 9], [[ 0, 2, 3],
[10, 11, 12]] [ 4, 0, 6],
[ 44, 66, 0]]
]); ]);
let j: usize = 1; let j: usize = 0;
let M2 = arr3(&[[[ 1, 2, 3], // -- 2 rows \_ let M2 = arr3(&[
[ 4, 5, 6]], [[ 0, 200, 300],
[[ 7, 8, 9], [ 400, 0, 600],
[10, 11, 12]]]); [ 700, 800, 0]]
let chi_sq = ChiSquare {alpha: 0.5}; ]);
chi_sq.compare_matrices( i, &M1, j, &M2); let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices( i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_2 () {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 20, 30],
[ 40, 0, 60],
[ 70, 80, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_3 () {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 21, 31],
[ 41, 0, 59],
[ 71, 79, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2));
} }

Loading…
Cancel
Save