Merge pull request #47 from AlessandroBregoli/8-feature-constraint-based-structure-learning-algorithm-for-ctbn

Added chi-square
pull/48/head
Meliurwen 2 years ago committed by GitHub
commit e091cc4d2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      Cargo.toml
  2. 17
      src/parameter_learning.rs
  3. 2
      src/structure_learning.rs
  4. 5
      src/structure_learning/constraint_based_algorithm.rs
  5. 135
      src/structure_learning/hypothesis_test.rs
  6. 73
      tests/structure_learning.rs

@ -6,7 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ndarray = {version="*", features=["approx"]}
thiserror = "*"
rand = "*"

@ -153,3 +153,20 @@ impl ParameterLearning for BayesianApproach {
return (CIM, M, T);
}
}
pub struct Cache<P: ParameterLearning> {
parameter_learning: P,
dataset: tools::Dataset,
}
impl<P: ParameterLearning> Cache<P> {
pub fn fit<T:network::Network>(
&mut self,
net: &T,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
self.parameter_learning.fit(net, &self.dataset, node, parent_set)
}
}

@ -1,5 +1,7 @@
pub mod score_function;
pub mod score_based_algorithm;
pub mod constraint_based_algorithm;
pub mod hypothesis_test;
use crate::network;
use crate::tools;

@ -0,0 +1,135 @@
use ndarray::Array2;
use ndarray::Array3;
use ndarray::Axis;
use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network;
use crate::parameter_learning;
use crate::params::ParamsTrait;
use std::collections::BTreeSet;
pub trait HypothesisTest {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P>
) -> bool
where
T: network::Network,
P: parameter_learning::ParameterLearning;
}
pub struct ChiSquare {
alpha: f64,
}
pub struct F {
}
impl ChiSquare {
pub fn new( alpha: f64) -> ChiSquare {
ChiSquare {
alpha
}
}
pub fn compare_matrices(
&self,
i: usize,
M1: &Array3<usize>,
j: usize,
M2: &Array3<usize>
) -> bool {
// Bregoli, A., Scutari, M. and Stella, F., 2021.
// A constraint-based algorithm for the structural learning of
// continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
//
// M = M M = M
// 1 xx'|s 2 xx'|y,s
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
// __________________
// / ===
// / \ M
// / / xx'|s
// / ===
// / x'ϵVal /X \
// / \ i/ 1
//K = / ------------------ L = -
// / === K
// / \ M
// / / xx'|y,s
// / ===
// / x'ϵVal /X \
// \ / \ i/
// \/
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
let K = K.mapv(f64::sqrt);
// Reshape to column vector.
let K = {
let n = K.len();
K.into_shape((n, 1)).unwrap()
};
let L = 1.0 / &K;
// ===== 2
// \ (K . M - L . M)
// \ 2 1
// / ---------------
// / M + M
// ===== 2 1
// x'ϵVal /X \
// \ i/
let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1);
println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1));
println!("K*M2: {:?}", (K * &M2));
println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha))
}
}
impl HypothesisTest for ChiSquare {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P>
) -> bool
where
T: network::Network,
P: parameter_learning::ParameterLearning {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn
// (CIM, M, T)
let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone()));
//
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone()));
// Commentare qui
let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product();
for idx_M_big in 0..M_big.shape()[0] {
let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product;
if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) {
return false;
}
}
return true;
}
}

@ -7,6 +7,7 @@ use reCTBN::network::Network;
use reCTBN::params;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm};
use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::tools::*;
use std::collections::BTreeSet;
@ -315,3 +316,75 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
let hl = HillClimbing::new(bic, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}
#[test]
pub fn chi_square_compare_matrices () {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 12, 90],
[ 3, 0, 40],
[ 6, 40, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices( i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_2 () {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 20, 30],
[ 40, 0, 60],
[ 70, 80, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_3 () {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 21, 31],
[ 41, 0, 59],
[ 71, 79, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2));
}

Loading…
Cancel
Save