Added tests to chi square call function and added a constructor to cache

pull/64/head
Meliurwen 2 years ago
parent 68b2afff59
commit 8953471570
  1. 6
      src/parameter_learning.rs
  2. 8
      src/structure_learning/hypothesis_test.rs
  3. 21
      tests/structure_learning.rs

@ -169,6 +169,12 @@ pub struct Cache<P: ParameterLearning> {
} }
impl<P: ParameterLearning> Cache<P> { impl<P: ParameterLearning> Cache<P> {
pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache<P> {
Cache {
parameter_learning,
dataset,
}
}
pub fn fit<T: network::Network>( pub fn fit<T: network::Network>(
&mut self, &mut self,
net: &T, net: &T,

@ -30,6 +30,8 @@ impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare { pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha } ChiSquare { alpha }
} }
// Restituisce true quando le matrici sono molto simili, quindi indipendenti
// false quando sono diverse, quindi dipendenti
pub fn compare_matrices( pub fn compare_matrices(
&self, &self,
i: usize, i: usize,
@ -69,6 +71,7 @@ impl ChiSquare {
let n = K.len(); let n = K.len();
K.into_shape((n, 1)).unwrap() K.into_shape((n, 1)).unwrap()
}; };
println!("K: {:?}", K);
let L = 1.0 / &K; let L = 1.0 / &K;
// ===== 2 // ===== 2
// \ (K . M - L . M) // \ (K . M - L . M)
@ -89,10 +92,13 @@ impl ChiSquare {
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n); println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)) let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
println!("test: {:?}", ret);
ret
} }
} }
// ritorna false quando sono dipendenti e false quando sono indipendenti
impl HypothesisTest for ChiSquare { impl HypothesisTest for ChiSquare {
fn call<T, P>( fn call<T, P>(
&self, &self,

@ -6,6 +6,8 @@ use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::parameter_learning::Cache;
use reCTBN::params; use reCTBN::params;
use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::score_based_algorithm::*; use reCTBN::structure_learning::score_based_algorithm::*;
@ -455,3 +457,22 @@ pub fn chi_square_compare_matrices_3() {
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
#[test]
pub fn chi_square_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(parameter_learning, data);
let chi_sq = ChiSquare::new(0.0001);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache));
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache));
assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache));
}

Loading…
Cancel
Save