Defined the `compare_matrices` function for the F-test

pull/79/head
Meliurwen 2 years ago
parent 713b8a8013
commit ec72a6a2f9
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 58
      reCTBN/src/structure_learning/hypothesis_test.rs

@ -3,7 +3,7 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::{Array3, Axis}; use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use crate::params::*; use crate::params::*;
use crate::{network, parameter_learning}; use crate::{network, parameter_learning};
@ -37,7 +37,61 @@ pub struct ChiSquare {
alpha: f64, alpha: f64,
} }
pub struct F {} pub struct F {
alpha: f64,
}
impl F {
pub fn new(alpha: f64) -> F {
F { alpha }
}
pub fn compare_matrices(
&self,
i: usize,
M1: &Array3<usize>,
cim_1: &Array3<f64>,
j: usize,
M2: &Array3<usize>,
cim_2: &Array3<f64>,
) -> bool {
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
let cim_1 = cim_1.index_axis(Axis(0), i);
let cim_2 = cim_2.index_axis(Axis(0), j);
let r1 = M1.sum_axis(Axis(1));
let r2 = M2.sum_axis(Axis(1));
let q1 = cim_1.diag();
let q2 = cim_2.diag();
for idx in 0..r1.shape()[0] {
let s = q2[idx] / q1[idx];
let F = FisherSnedecor::new(r1[idx], r2[idx]);
let lim_sx = F.as_ref().expect("REASON").cdf(self.alpha / 2.0);
let lim_dx = F.as_ref().expect("REASON").cdf(1.0 - (self.alpha / 2.0));
if s < lim_sx || s > lim_dx {
return false;
}
}
true
}
}
impl HypothesisTest for F {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P>,
) -> bool
where
T: network::Network,
P: parameter_learning::ParameterLearning,
{
true
}
}
impl ChiSquare { impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare { pub fn new(alpha: f64) -> ChiSquare {

Loading…
Cancel
Save