Merge branch '8-feature-constraint-based-structure-learning-algorithm-for-ctbn' into 'dev'

Added chi-squared test rust-tests and constructor for cache
pull/67/head
Meliurwen 2 years ago
commit c247da6bc0
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 6
      reCTBN/src/parameter_learning.rs
  2. 8
      reCTBN/src/structure_learning/hypothesis_test.rs
  3. 21
      reCTBN/tests/structure_learning.rs

@ -169,6 +169,12 @@ pub struct Cache<P: ParameterLearning> {
}
impl<P: ParameterLearning> Cache<P> {
pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache<P> {
Cache {
parameter_learning,
dataset,
}
}
pub fn fit<T: network::Network>(
&mut self,
net: &T,

@ -30,6 +30,8 @@ impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha }
}
// Restituisce true quando le matrici sono molto simili, quindi indipendenti
// false quando sono diverse, quindi dipendenti
pub fn compare_matrices(
&self,
i: usize,
@ -69,6 +71,7 @@ impl ChiSquare {
let n = K.len();
K.into_shape((n, 1)).unwrap()
};
println!("K: {:?}", K);
let L = 1.0 / &K;
// ===== 2
// \ (K . M - L . M)
@ -89,10 +92,13 @@ impl ChiSquare {
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha))
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
println!("test: {:?}", ret);
ret
}
}
// ritorna false quando sono dipendenti e false quando sono indipendenti
impl HypothesisTest for ChiSquare {
fn call<T, P>(
&self,

@ -6,6 +6,8 @@ use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::parameter_learning::Cache;
use reCTBN::params;
use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::score_based_algorithm::*;
@ -455,3 +457,22 @@ pub fn chi_square_compare_matrices_3() {
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
}
#[test]
pub fn chi_square_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(parameter_learning, data);
let chi_sq = ChiSquare::new(0.0001);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache));
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache));
assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache));
}

Loading…
Cancel
Save