Implemented matrices comparison function in chi square

pull/47/head
Meliurwen 2 years ago
parent 2605bf3816
commit 4b35ae6310
  1. 2
      Cargo.toml
  2. 26
      src/structure_learning/hypothesis_test.rs
  3. 80
      tests/structure_learning.rs

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ndarray = {version="*", features=["approx-0_5"]}
ndarray = {version="*", features=["approx"]}
thiserror = "*"
rand = "*"
bimap = "*"

@ -1,6 +1,7 @@
use ndarray::Array2;
use ndarray::Array3;
use ndarray::Axis;
use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network;
use crate::parameter_learning;
@ -24,7 +25,7 @@ pub trait HypothesisTest {
pub struct ChiSquare {
pub alpha: f64,
alpha: f64,
}
pub struct F {
@ -32,6 +33,11 @@ pub struct F {
}
impl ChiSquare {
pub fn new( alpha: f64) -> ChiSquare {
ChiSquare {
alpha
}
}
pub fn compare_matrices(
&self, i: usize,
M1: &Array3<usize>,
@ -42,6 +48,7 @@ impl ChiSquare {
// A constraint-based algorithm for the structural learning of
// continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
//
// M = M M = M
// 1 xx'|s 2 xx'|y,s
@ -70,17 +77,26 @@ impl ChiSquare {
K.into_shape((n, 1)).unwrap()
};
let L = 1.0 / &K;
// =====
// \ K . M - L . M
// ===== 2
// \ (K . M - L . M)
// \ 2 1
// / ---------------
// / M + M
// ===== 2 1
// x'ϵVal /X \
// \ i/
let X_2 = (( K * &M2 - L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1)).sum_axis(Axis(1));
let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1);
println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1));
println!("K*M2: {:?}", (K * &M2));
println!("X_2: {:?}", X_2);
true
X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha))
}
}

@ -321,20 +321,70 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
pub fn chi_square_compare_matrices () {
let i: usize = 1;
let M1 = arr3(&[
[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 22, 12, 90],
[3, 20, 40]],
[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 12, 90],
[ 3, 0, 40],
[ 6, 40, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 1;
let M2 = arr3(&[[[ 1, 2, 3], // -- 2 rows \_
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]]);
let chi_sq = ChiSquare {alpha: 0.5};
chi_sq.compare_matrices( i, &M1, j, &M2);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices( i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_2 () {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 20, 30],
[ 40, 0, 60],
[ 70, 80, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_3 () {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 21, 31],
[ 41, 0, 59],
[ 71, 79, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2));
}

Loading…
Cancel
Save