From 8953471570a37695915e4d7a58ac97723198b953 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 5 Aug 2022 21:21:45 +0200 Subject: [PATCH] Added tests to chi square call function and added a constructor to cache --- src/parameter_learning.rs | 6 ++++++ src/structure_learning/hypothesis_test.rs | 8 +++++++- tests/structure_learning.rs | 21 +++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 10f0257..bdb5d4a 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -169,6 +169,12 @@ pub struct Cache { } impl Cache

{ + pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache

{ + Cache { + parameter_learning, + dataset, + } + } pub fn fit( &mut self, net: &T, diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index 5ddcc51..4f2ce18 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -30,6 +30,8 @@ impl ChiSquare { pub fn new(alpha: f64) -> ChiSquare { ChiSquare { alpha } } + // Restituisce true quando le matrici sono molto simili, quindi indipendenti + // false quando sono diverse, quindi dipendenti pub fn compare_matrices( &self, i: usize, @@ -69,6 +71,7 @@ impl ChiSquare { let n = K.len(); K.into_shape((n, 1)).unwrap() }; + println!("K: {:?}", K); let L = 1.0 / &K; // ===== 2 // \ (K . M - L . M) @@ -89,10 +92,13 @@ impl ChiSquare { let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); println!("CHI^2: {:?}", n); println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); - X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)) + let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); + println!("test: {:?}", ret); + ret } } +// ritorna false quando sono dipendenti e false quando sono indipendenti impl HypothesisTest for ChiSquare { fn call( &self, diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 81a4ed3..a1667c2 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -6,6 +6,8 @@ use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; use reCTBN::ctbn::*; use reCTBN::network::Network; +use reCTBN::parameter_learning::BayesianApproach; +use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::score_based_algorithm::*; @@ -455,3 +457,22 @@ pub fn chi_square_compare_matrices_3() { let chi_sq = ChiSquare::new(0.1); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } + + +#[test] +pub fn chi_square_call() { + + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let N3: usize = 2; + let N2: usize = 1; + let N1: usize = 0; + let separation_set = BTreeSet::new(); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let mut cache = Cache::new(parameter_learning, data); + let chi_sq = ChiSquare::new(0.0001); + + assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache)); + assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache)); + assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); +} -- 2.36.3