|
|
|
@ -5,6 +5,7 @@ use statrs::distribution::{ChiSquared, ContinuousCDF}; |
|
|
|
|
|
|
|
|
|
use crate::network; |
|
|
|
|
use crate::parameter_learning; |
|
|
|
|
use crate::params::ParamsTrait; |
|
|
|
|
use std::collections::BTreeSet; |
|
|
|
|
|
|
|
|
|
pub trait HypothesisTest { |
|
|
|
@ -15,7 +16,7 @@ pub trait HypothesisTest { |
|
|
|
|
child_node: usize, |
|
|
|
|
parent_node: usize, |
|
|
|
|
separation_set: &BTreeSet<usize>, |
|
|
|
|
cache: parameter_learning::Cache<P> |
|
|
|
|
cache: &mut parameter_learning::Cache<P> |
|
|
|
|
) -> bool |
|
|
|
|
where |
|
|
|
|
T: network::Network, |
|
|
|
@ -39,7 +40,8 @@ impl ChiSquare { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
pub fn compare_matrices( |
|
|
|
|
&self, i: usize, |
|
|
|
|
&self, |
|
|
|
|
i: usize, |
|
|
|
|
M1: &Array3<usize>, |
|
|
|
|
j: usize, |
|
|
|
|
M2: &Array3<usize> |
|
|
|
@ -107,11 +109,27 @@ impl HypothesisTest for ChiSquare { |
|
|
|
|
child_node: usize, |
|
|
|
|
parent_node: usize, |
|
|
|
|
separation_set: &BTreeSet<usize>, |
|
|
|
|
cache: parameter_learning::Cache<P> |
|
|
|
|
cache: &mut parameter_learning::Cache<P> |
|
|
|
|
) -> bool |
|
|
|
|
where |
|
|
|
|
T: network::Network, |
|
|
|
|
P: parameter_learning::ParameterLearning { |
|
|
|
|
todo!() |
|
|
|
|
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
|
|
|
|
|
// di dimensione nxn
|
|
|
|
|
// (CIM, M, T)
|
|
|
|
|
let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone())); |
|
|
|
|
//
|
|
|
|
|
let mut extended_separation_set = separation_set.clone(); |
|
|
|
|
extended_separation_set.insert(parent_node); |
|
|
|
|
let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone())); |
|
|
|
|
// Commentare qui
|
|
|
|
|
let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product(); |
|
|
|
|
for idx_M_big in 0..M_big.shape()[0] { |
|
|
|
|
let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product; |
|
|
|
|
if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) { |
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|