Expanded Hypothesis test

pull/47/head
Meliurwen 2 years ago
parent 4b35ae6310
commit 68ada89c04
  1. 4
      src/parameter_learning.rs
  2. 26
      src/structure_learning/hypothesis_test.rs

@ -157,16 +157,16 @@ impl ParameterLearning for BayesianApproach {
pub struct Cache<P: ParameterLearning> { pub struct Cache<P: ParameterLearning> {
parameter_learning: P, parameter_learning: P,
dataset: tools::Dataset,
} }
impl<P: ParameterLearning> Cache<P> { impl<P: ParameterLearning> Cache<P> {
pub fn fit<T:network::Network>( pub fn fit<T:network::Network>(
&mut self, &mut self,
net: &T, net: &T,
dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
self.parameter_learning.fit(net, dataset, node, parent_set) self.parameter_learning.fit(net, &self.dataset, node, parent_set)
} }
} }

@ -5,6 +5,7 @@ use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network; use crate::network;
use crate::parameter_learning; use crate::parameter_learning;
use crate::params::ParamsTrait;
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait HypothesisTest { pub trait HypothesisTest {
@ -15,7 +16,7 @@ pub trait HypothesisTest {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
cache: parameter_learning::Cache<P> cache: &mut parameter_learning::Cache<P>
) -> bool ) -> bool
where where
T: network::Network, T: network::Network,
@ -39,7 +40,8 @@ impl ChiSquare {
} }
} }
pub fn compare_matrices( pub fn compare_matrices(
&self, i: usize, &self,
i: usize,
M1: &Array3<usize>, M1: &Array3<usize>,
j: usize, j: usize,
M2: &Array3<usize> M2: &Array3<usize>
@ -107,11 +109,27 @@ impl HypothesisTest for ChiSquare {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
cache: parameter_learning::Cache<P> cache: &mut parameter_learning::Cache<P>
) -> bool ) -> bool
where where
T: network::Network, T: network::Network,
P: parameter_learning::ParameterLearning { P: parameter_learning::ParameterLearning {
todo!() // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn
// (CIM, M, T)
let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone()));
//
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone()));
// Commentare qui
let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product();
for idx_M_big in 0..M_big.shape()[0] {
let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product;
if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) {
return false;
}
}
return true;
} }
} }

Loading…
Cancel
Save