From f5756f71d31e86aeddaa179991b7d615f84b5ebf Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Fri, 29 Jul 2022 10:36:53 +0200 Subject: [PATCH] Changed the return type for ParameterLearning::fit from tuple to Param --- src/parameter_learning.rs | 41 +++++++++++++++--- src/params.rs | 15 +++++-- src/structure_learning/hypothesis_test.rs | 53 +++++++++++++---------- tests/parameter_learning.rs | 36 ++++++++------- 4 files changed, 95 insertions(+), 50 deletions(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 6fff9d1..bf5b96a 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -12,7 +12,7 @@ pub trait ParameterLearning{ dataset: &tools::Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2); + ) -> Params; } pub fn sufficient_statistics( @@ -84,8 +84,7 @@ impl ParameterLearning for MLE { dataset: &tools::Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { - //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes + ) -> Params { //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { @@ -107,7 +106,21 @@ impl ParameterLearning for MLE { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - return (CIM, M, T); + + + + let mut n: Params = net.get_node(node).clone(); + + match n { + Params::DiscreteStatesContinousTime(ref mut dsct) => { + dsct.set_cim_unchecked(CIM); + dsct.set_transitions(M); + dsct.set_residence_time(T); + + + } + }; + return n; } } @@ -123,7 +136,7 @@ impl ParameterLearning for BayesianApproach { dataset: &tools::Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { + ) -> Params { //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes //Use parent_set from parameter if present. Otherwise use parent_set from network. @@ -150,7 +163,21 @@ impl ParameterLearning for BayesianApproach { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - return (CIM, M, T); + + + + let mut n: Params = net.get_node(node).clone(); + + match n { + Params::DiscreteStatesContinousTime(ref mut dsct) => { + dsct.set_cim_unchecked(CIM); + dsct.set_transitions(M); + dsct.set_residence_time(T); + + + } + }; + return n; } } @@ -166,7 +193,7 @@ impl Cache

{ net: &T, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { + ) -> Params { self.parameter_learning.fit(net, &self.dataset, node, parent_set) } } diff --git a/src/params.rs b/src/params.rs index e632b1b..d9f307f 100644 --- a/src/params.rs +++ b/src/params.rs @@ -55,7 +55,8 @@ pub trait ParamsTrait { } /// The Params enum is the core element for building different types of nodes. The goal is to -/// define all the supported type of parameters. +/// define all the supported type of Parameters +#[derive(Clone)] #[enum_dispatch] pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), @@ -72,11 +73,12 @@ pub enum Params { /// realization of the parent set /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set +#[derive(Clone)] pub struct DiscreteStatesContinousTimeParams { label: String, domain: BTreeSet, cim: Option>, - transitions: Option>, + transitions: Option>, residence_time: Option>, } @@ -112,14 +114,19 @@ impl DiscreteStatesContinousTimeParams { } + ///Unchecked version of the setter function for CIM. + pub fn set_cim_unchecked(&mut self, cim: Array3) { + self.cim = Some(cim); + } + ///Getter function for transitions - pub fn get_transitions(&self) -> &Option> { + pub fn get_transitions(&self) -> &Option> { &self.transitions } ///Setter function for transitions - pub fn set_transitions(&mut self, transitions: Array3) { + pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); } diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index 86500e5..eb6b570 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -5,46 +5,39 @@ use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::network; use crate::parameter_learning; -use crate::params::ParamsTrait; +use crate::params::*; use std::collections::BTreeSet; pub trait HypothesisTest { - fn call( &self, net: &T, child_node: usize, parent_node: usize, separation_set: &BTreeSet, - cache: &mut parameter_learning::Cache

+ cache: &mut parameter_learning::Cache

, ) -> bool where T: network::Network, P: parameter_learning::ParameterLearning; - } - pub struct ChiSquare { alpha: f64, } -pub struct F { - -} +pub struct F {} impl ChiSquare { - pub fn new( alpha: f64) -> ChiSquare { - ChiSquare { - alpha - } + pub fn new(alpha: f64) -> ChiSquare { + ChiSquare { alpha } } pub fn compare_matrices( &self, i: usize, M1: &Array3, j: usize, - M2: &Array3 + M2: &Array3, ) -> bool { // Bregoli, A., Scutari, M. and Stella, F., 2021. // A constraint-based algorithm for the structural learning of @@ -87,7 +80,7 @@ impl ChiSquare { // ===== 2 1 // x'ϵVal /X \ // \ i/ - let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1); + let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); println!("M1: {:?}", M1); println!("M2: {:?}", M2); println!("L*M1: {:?}", (L * &M1)); @@ -109,24 +102,38 @@ impl HypothesisTest for ChiSquare { child_node: usize, parent_node: usize, separation_set: &BTreeSet, - cache: &mut parameter_learning::Cache

+ cache: &mut parameter_learning::Cache

, ) -> bool where T: network::Network, - P: parameter_learning::ParameterLearning { + P: parameter_learning::ParameterLearning, + { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // di dimensione nxn // (CIM, M, T) - let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone())); - // + let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){ + Params::DiscreteStatesContinousTime(node) => node + }; + // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone())); + + let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){ + Params::DiscreteStatesContinousTime(node) => node + }; // Commentare qui - let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product(); - for idx_M_big in 0..M_big.shape()[0] { - let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product; - if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) { + let partial_cardinality_product: usize = extended_separation_set + .iter() + .take_while(|x| **x != parent_node) + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { + let idx_M_small: usize = idx_M_big % partial_cardinality_product + + (idx_M_big + / (partial_cardinality_product + * net.get_node(parent_node).get_reserved_space_as_parent())) + * partial_cardinality_product; + if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) { return false; } } diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index cd980d0..1ce5d51 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,10 +40,11 @@ fn learn_binary_cim(pl: T) { } let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 1, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [2, 2, 2]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 1, None) { + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), 0.1 )); @@ -98,10 +99,11 @@ fn learn_ternary_cim(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 1, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [3, 3, 3]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 1, None){ + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], @@ -160,10 +162,11 @@ fn learn_ternary_cim_no_parents(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 0, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [1, 3, 3]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 0, None){ + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), 0.1 )); @@ -288,10 +291,11 @@ fn learn_mixed_discrete_cim(pl: T) { } let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 2, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [9, 4, 4]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 2, None){ + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[ [ [-1.0, 0.5, 0.3, 0.2],