diff --git a/Cargo.toml b/Cargo.toml index 4779b47..553e294 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] - ndarray = {version="*", features=["approx"]} thiserror = "*" rand = "*" diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 5270d9e..6fff9d1 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -153,3 +153,20 @@ impl ParameterLearning for BayesianApproach { return (CIM, M, T); } } + + +pub struct Cache { + parameter_learning: P, + dataset: tools::Dataset, +} + +impl Cache

{ + pub fn fit( + &mut self, + net: &T, + node: usize, + parent_set: Option>, + ) -> (Array3, Array3, Array2) { + self.parameter_learning.fit(net, &self.dataset, node, parent_set) + } +} diff --git a/src/structure_learning.rs b/src/structure_learning.rs index 8ba91df..b7db7ed 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -1,5 +1,7 @@ pub mod score_function; pub mod score_based_algorithm; +pub mod constraint_based_algorithm; +pub mod hypothesis_test; use crate::network; use crate::tools; diff --git a/src/structure_learning/constraint_based_algorithm.rs b/src/structure_learning/constraint_based_algorithm.rs new file mode 100644 index 0000000..0d8b655 --- /dev/null +++ b/src/structure_learning/constraint_based_algorithm.rs @@ -0,0 +1,5 @@ + +//pub struct CTPC { +// +//} + diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs new file mode 100644 index 0000000..86500e5 --- /dev/null +++ b/src/structure_learning/hypothesis_test.rs @@ -0,0 +1,135 @@ +use ndarray::Array2; +use ndarray::Array3; +use ndarray::Axis; +use statrs::distribution::{ChiSquared, ContinuousCDF}; + +use crate::network; +use crate::parameter_learning; +use crate::params::ParamsTrait; +use std::collections::BTreeSet; + +pub trait HypothesisTest { + + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: &mut parameter_learning::Cache

+ ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning; + +} + + +pub struct ChiSquare { + alpha: f64, +} + +pub struct F { + +} + +impl ChiSquare { + pub fn new( alpha: f64) -> ChiSquare { + ChiSquare { + alpha + } + } + pub fn compare_matrices( + &self, + i: usize, + M1: &Array3, + j: usize, + M2: &Array3 + ) -> bool { + // Bregoli, A., Scutari, M. and Stella, F., 2021. + // A constraint-based algorithm for the structural learning of + // continuous-time Bayesian networks. + // International Journal of Approximate Reasoning, 138, pp.105-122. + // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm + // + // M = M M = M + // 1 xx'|s 2 xx'|y,s + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + // __________________ + // / === + // / \ M + // / / xx'|s + // / === + // / x'ϵVal /X \ + // / \ i/ 1 + //K = / ------------------ L = - + // / === K + // / \ M + // / / xx'|y,s + // / === + // / x'ϵVal /X \ + // \ / \ i/ + // \/ + let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); + let K = K.mapv(f64::sqrt); + // Reshape to column vector. + let K = { + let n = K.len(); + K.into_shape((n, 1)).unwrap() + }; + let L = 1.0 / &K; + // ===== 2 + // \ (K . M - L . M) + // \ 2 1 + // / --------------- + // / M + M + // ===== 2 1 + // x'ϵVal /X \ + // \ i/ + let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1); + println!("M1: {:?}", M1); + println!("M2: {:?}", M2); + println!("L*M1: {:?}", (L * &M1)); + println!("K*M2: {:?}", (K * &M2)); + println!("X_2: {:?}", X_2); + X_2.diag_mut().fill(0.0); + let X_2 = X_2.sum_axis(Axis(1)); + let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); + println!("CHI^2: {:?}", n); + println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); + X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)) + } +} + +impl HypothesisTest for ChiSquare { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: &mut parameter_learning::Cache

+ ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning { + // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM + // di dimensione nxn + // (CIM, M, T) + let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone())); + // + let mut extended_separation_set = separation_set.clone(); + extended_separation_set.insert(parent_node); + let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone())); + // Commentare qui + let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product(); + for idx_M_big in 0..M_big.shape()[0] { + let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product; + if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) { + return false; + } + } + return true; + } +} diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index c91f508..2c9645b 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -7,6 +7,7 @@ use reCTBN::network::Network; use reCTBN::params; use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm}; +use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::tools::*; use std::collections::BTreeSet; @@ -315,3 +316,75 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() let hl = HillClimbing::new(bic, Some(1)); learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } + +#[test] +pub fn chi_square_compare_matrices () { + let i: usize = 1; + let M1 = arr3(&[ + [[ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0]], + [[0, 12, 90], + [ 3, 0, 40], + [ 6, 40, 0]], + [[ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0]] + ]); + let j: usize = 0; + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); + let chi_sq = ChiSquare::new(0.1); + assert!(!chi_sq.compare_matrices( i, &M1, j, &M2)); +} + +#[test] +pub fn chi_square_compare_matrices_2 () { + let i: usize = 1; + let M1 = arr3(&[ + [[ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0]], + [[0, 20, 30], + [ 40, 0, 60], + [ 70, 80, 0]], + [[ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0]] + ]); + let j: usize = 0; + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); + let chi_sq = ChiSquare::new(0.1); + assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); +} + +#[test] +pub fn chi_square_compare_matrices_3 () { + let i: usize = 1; + let M1 = arr3(&[ + [[ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0]], + [[0, 21, 31], + [ 41, 0, 59], + [ 71, 79, 0]], + [[ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0]] + ]); + let j: usize = 0; + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); + let chi_sq = ChiSquare::new(0.1); + assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); +}