Changed the return type for ParameterLearning::fit from tuple to Param #48

Merged
AlessandroBregoli merged 1 commits from 46-feature-parameter-learning-generalization into dev 2 years ago
  1. 41
      src/parameter_learning.rs
  2. 15
      src/params.rs
  3. 51
      src/structure_learning/hypothesis_test.rs
  4. 36
      tests/parameter_learning.rs

@ -12,7 +12,7 @@ pub trait ParameterLearning{
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>); ) -> Params;
} }
pub fn sufficient_statistics<T:network::Network>( pub fn sufficient_statistics<T:network::Network>(
@ -84,8 +84,7 @@ impl ParameterLearning for MLE {
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> Params {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network. //Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set { let parent_set = match parent_set {
@ -107,7 +106,21 @@ impl ParameterLearning for MLE {
.for_each(|(mut C, diag)| { .for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
return (CIM, M, T);
let mut n: Params = net.get_node(node).clone();
match n {
Params::DiscreteStatesContinousTime(ref mut dsct) => {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
} }
} }
@ -123,7 +136,7 @@ impl ParameterLearning for BayesianApproach {
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> Params {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network. //Use parent_set from parameter if present. Otherwise use parent_set from network.
@ -150,7 +163,21 @@ impl ParameterLearning for BayesianApproach {
.for_each(|(mut C, diag)| { .for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
return (CIM, M, T);
let mut n: Params = net.get_node(node).clone();
match n {
Params::DiscreteStatesContinousTime(ref mut dsct) => {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
} }
} }
@ -166,7 +193,7 @@ impl<P: ParameterLearning> Cache<P> {
net: &T, net: &T,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> Params {
self.parameter_learning.fit(net, &self.dataset, node, parent_set) self.parameter_learning.fit(net, &self.dataset, node, parent_set)
} }
} }

@ -55,7 +55,8 @@ pub trait ParamsTrait {
} }
/// The Params enum is the core element for building different types of nodes. The goal is to /// The Params enum is the core element for building different types of nodes. The goal is to
/// define all the supported type of parameters. /// define all the supported type of Parameters
#[derive(Clone)]
#[enum_dispatch] #[enum_dispatch]
pub enum Params { pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
@ -72,11 +73,12 @@ pub enum Params {
/// realization of the parent set /// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific /// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set /// realization of the parent set
#[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams { pub struct DiscreteStatesContinousTimeParams {
label: String, label: String,
domain: BTreeSet<String>, domain: BTreeSet<String>,
cim: Option<Array3<f64>>, cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>, transitions: Option<Array3<usize>>,
residence_time: Option<Array2<f64>>, residence_time: Option<Array2<f64>>,
} }
@ -112,14 +114,19 @@ impl DiscreteStatesContinousTimeParams {
} }
///Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim);
}
///Getter function for transitions ///Getter function for transitions
pub fn get_transitions(&self) -> &Option<Array3<u64>> { pub fn get_transitions(&self) -> &Option<Array3<usize>> {
&self.transitions &self.transitions
} }
///Setter function for transitions ///Setter function for transitions
pub fn set_transitions(&mut self, transitions: Array3<u64>) { pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions); self.transitions = Some(transitions);
} }

@ -5,46 +5,39 @@ use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network; use crate::network;
use crate::parameter_learning; use crate::parameter_learning;
use crate::params::ParamsTrait; use crate::params::*;
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait HypothesisTest { pub trait HypothesisTest {
fn call<T, P>( fn call<T, P>(
&self, &self,
net: &T, net: &T,
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P> cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: network::Network,
P: parameter_learning::ParameterLearning; P: parameter_learning::ParameterLearning;
} }
pub struct ChiSquare { pub struct ChiSquare {
alpha: f64, alpha: f64,
} }
pub struct F { pub struct F {}
}
impl ChiSquare { impl ChiSquare {
pub fn new( alpha: f64) -> ChiSquare { pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { ChiSquare { alpha }
alpha
}
} }
pub fn compare_matrices( pub fn compare_matrices(
&self, &self,
i: usize, i: usize,
M1: &Array3<usize>, M1: &Array3<usize>,
j: usize, j: usize,
M2: &Array3<usize> M2: &Array3<usize>,
) -> bool { ) -> bool {
// Bregoli, A., Scutari, M. and Stella, F., 2021. // Bregoli, A., Scutari, M. and Stella, F., 2021.
// A constraint-based algorithm for the structural learning of // A constraint-based algorithm for the structural learning of
@ -87,7 +80,7 @@ impl ChiSquare {
// ===== 2 1 // ===== 2 1
// x'ϵVal /X \ // x'ϵVal /X \
// \ i/ // \ i/
let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1); let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1); println!("M1: {:?}", M1);
println!("M2: {:?}", M2); println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1)); println!("L*M1: {:?}", (L * &M1));
@ -109,24 +102,38 @@ impl HypothesisTest for ChiSquare {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P> cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: network::Network,
P: parameter_learning::ParameterLearning { P: parameter_learning::ParameterLearning,
{
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn // di dimensione nxn
// (CIM, M, T) // (CIM, M, T)
let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone())); let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){
Params::DiscreteStatesContinousTime(node) => node
};
// //
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone()));
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){
Params::DiscreteStatesContinousTime(node) => node
};
// Commentare qui // Commentare qui
let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product(); let partial_cardinality_product: usize = extended_separation_set
for idx_M_big in 0..M_big.shape()[0] { .iter()
let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product; .take_while(|x| **x != parent_node)
if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) { .map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ (idx_M_big
/ (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product;
if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) {
return false; return false;
} }
} }

@ -40,10 +40,11 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let p = match pl.fit(&net, &data, 1, None) {
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); params::Params::DiscreteStatesContinousTime(p) => p
assert_eq!(CIM.shape(), [2, 2, 2]); };
assert!(CIM.abs_diff_eq( assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]),
0.1 0.1
)); ));
@ -98,10 +99,11 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let p = match pl.fit(&net, &data, 1, None){
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); params::Params::DiscreteStatesContinousTime(p) => p
assert_eq!(CIM.shape(), [3, 3, 3]); };
assert!(CIM.abs_diff_eq( assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[ &arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
@ -160,10 +162,11 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 0, None); let p = match pl.fit(&net, &data, 0, None){
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); params::Params::DiscreteStatesContinousTime(p) => p
assert_eq!(CIM.shape(), [1, 3, 3]); };
assert!(CIM.abs_diff_eq( assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]),
0.1 0.1
)); ));
@ -288,10 +291,11 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 2, None); let p = match pl.fit(&net, &data, 2, None){
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); params::Params::DiscreteStatesContinousTime(p) => p
assert_eq!(CIM.shape(), [9, 4, 4]); };
assert!(CIM.abs_diff_eq( assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[ &arr3(&[
[ [
[-1.0, 0.5, 0.3, 0.2], [-1.0, 0.5, 0.3, 0.2],

Loading…
Cancel
Save