diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 6474155..7534eaf 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -3,7 +3,7 @@ use std::collections::BTreeSet; use ndarray::{Array3, Axis}; -use statrs::distribution::{ChiSquared, ContinuousCDF}; +use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; use crate::{network, parameter_learning}; @@ -37,7 +37,61 @@ pub struct ChiSquare { alpha: f64, } -pub struct F {} +pub struct F { + alpha: f64, +} + +impl F { + pub fn new(alpha: f64) -> F { + F { alpha } + } + + pub fn compare_matrices( + &self, + i: usize, + M1: &Array3, + cim_1: &Array3, + j: usize, + M2: &Array3, + cim_2: &Array3, + ) -> bool { + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + let cim_1 = cim_1.index_axis(Axis(0), i); + let cim_2 = cim_2.index_axis(Axis(0), j); + let r1 = M1.sum_axis(Axis(1)); + let r2 = M2.sum_axis(Axis(1)); + let q1 = cim_1.diag(); + let q2 = cim_2.diag(); + for idx in 0..r1.shape()[0] { + let s = q2[idx] / q1[idx]; + let F = FisherSnedecor::new(r1[idx], r2[idx]); + let lim_sx = F.as_ref().expect("REASON").cdf(self.alpha / 2.0); + let lim_dx = F.as_ref().expect("REASON").cdf(1.0 - (self.alpha / 2.0)); + if s < lim_sx || s > lim_dx { + return false; + } + } + true + } +} + +impl HypothesisTest for F { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: &mut parameter_learning::Cache

, + ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning, + { + true + } +} impl ChiSquare { pub fn new(alpha: f64) -> ChiSquare {