From 19856195c39e5428e906fbb0b7ab0ecb8e9e6394 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 21 Dec 2022 11:41:26 +0100 Subject: [PATCH] Refactored cache laying grounds for its node-centered implementation changing also its signature, propagated this change and refactored CTPC tests --- reCTBN/src/parameter_learning.rs | 28 ++--- reCTBN/src/structure_learning.rs | 4 +- .../constraint_based_algorithm.rs | 20 ++-- .../src/structure_learning/hypothesis_test.rs | 23 +++- .../score_based_algorithm.rs | 4 +- reCTBN/tests/structure_learning.rs | 112 ++++-------------- 6 files changed, 72 insertions(+), 119 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 2aa518c..f8a7664 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,13 +5,13 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub trait ParameterLearning { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params; @@ -19,7 +19,7 @@ pub trait ParameterLearning { pub fn sufficient_statistics( net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: &BTreeSet, ) -> (Array3, Array2) { @@ -76,7 +76,7 @@ impl ParameterLearning for MLE { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { @@ -123,7 +123,7 @@ impl ParameterLearning for BayesianApproach { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { @@ -165,25 +165,21 @@ impl ParameterLearning for BayesianApproach { } } -pub struct Cache { - parameter_learning: P, - dataset: tools::Dataset, +pub struct Cache<'a, P: ParameterLearning> { + parameter_learning: &'a P, } -impl Cache

{ - pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache

{ - Cache { - parameter_learning, - dataset, - } +impl<'a, P: ParameterLearning> Cache<'a, P> { + pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { + Cache { parameter_learning } } pub fn fit( &mut self, net: &T, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning - .fit(net, &self.dataset, node, parent_set) + self.parameter_learning.fit(net, dataset, node, parent_set) } } diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index d119ab2..a4c6ea1 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub trait StructureLearningAlgorithm { - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 8949aa5..6d54fe7 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -6,27 +6,28 @@ use std::usize; use super::hypothesis_test::*; use crate::parameter_learning::{Cache, ParameterLearning}; +use crate::process; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{process, tools}; +use crate::tools::Dataset; pub struct CTPC { + parameter_learning: P, Ftest: F, Chi2test: ChiSquare, - cache: Cache

, } impl CTPC

{ - pub fn new(Ftest: F, Chi2test: ChiSquare, cache: Cache

) -> CTPC

{ + pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC

{ CTPC { - Chi2test, + parameter_learning, Ftest, - cache, + Chi2test, } } } impl StructureLearningAlgorithm for CTPC

{ - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess, { @@ -41,6 +42,7 @@ impl StructureLearningAlgorithm for CTPC

{ net.initialize_adj_matrix(); for child_node in net.get_node_indices() { + let mut cache = Cache::new(&self.parameter_learning); let mut candidate_parent_set: BTreeSet = net .get_node_indices() .into_iter() @@ -62,13 +64,15 @@ impl StructureLearningAlgorithm for CTPC

{ child_node, *parent_node, &separation_set, - &mut self.cache, + dataset, + &mut cache, ) && self.Chi2test.call( &net, child_node, *parent_node, &separation_set, - &mut self.cache, + dataset, + &mut cache, ) { candidate_parent_set_TMP.remove(parent_node); break; diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index dd3bbf7..dd683ab 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; -use crate::{parameter_learning, process}; +use crate::{parameter_learning, process, tools::Dataset}; pub trait HypothesisTest { fn call( @@ -15,6 +15,7 @@ pub trait HypothesisTest { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where @@ -84,19 +85,25 @@ impl HypothesisTest for F { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { Params::DiscreteStatesContinousTime(node) => node, }; let partial_cardinality_product: usize = extended_separation_set @@ -218,6 +225,7 @@ impl HypothesisTest for ChiSquare { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where @@ -227,14 +235,19 @@ impl HypothesisTest for ChiSquare { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // di dimensione nxn // (CIM, M, T) - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { Params::DiscreteStatesContinousTime(node) => node, }; // Commentare qui diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index d59f0c1..d65ea88 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub struct HillClimbing { score_function: S, @@ -21,7 +21,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess, { diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index c0deffd..6f97c9d 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -59,7 +59,7 @@ fn simple_bic() { ); } -fn check_compatibility_between_dataset_and_network(mut sl: T) { +fn check_compatibility_between_dataset_and_network(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -126,7 +126,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes(mut sl: T) { +fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -321,7 +321,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } -fn learn_mixed_discrete_net_3_nodes(mut sl: T) { +fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -343,7 +343,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(mut sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -393,7 +393,7 @@ pub fn chi_square_compare_matrices() { [ 700, 800, 0] ], ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -423,7 +423,7 @@ pub fn chi_square_compare_matrices_2() { [ 400, 0, 600], [ 700, 800, 0]] ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -455,7 +455,7 @@ pub fn chi_square_compare_matrices_3() { [ 700, 800, 0] ], ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -469,14 +469,14 @@ pub fn chi_square_call() { let N1: usize = 0; let mut separation_set = BTreeSet::new(); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let mut cache = Cache::new(parameter_learning, data); - let chi_sq = ChiSquare::new(0.0001); + let mut cache = Cache::new(¶meter_learning); + let chi_sq = ChiSquare::new(1e-4); - assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache)); - assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache)); - assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); + assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); - assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); } #[test] @@ -488,91 +488,31 @@ pub fn f_call() { let N1: usize = 0; let mut separation_set = BTreeSet::new(); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let mut cache = Cache::new(parameter_learning, data); - let f = F::new(0.000001); + let mut cache = Cache::new(¶meter_learning); + let f = F::new(1e-6); - assert!(f.call(&net, N1, N3, &separation_set, &mut cache)); - assert!(!f.call(&net, N3, N1, &separation_set, &mut cache)); - assert!(!f.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); + assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); - assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); + assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); } #[test] pub fn learn_ternary_net_2_nodes_ctpc() { - let mut net = CtbnNetwork::new(); - let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) - .unwrap(); - let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) - .unwrap(); - net.add_edge(n1, n2); - - match &mut net.get_node_mut(n1) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(arr3(&[ - [ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ], - ])) - ); - } - } - - match &mut net.get_node_mut(n2) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(arr3(&[ - [ - [-1.0, 0.5, 0.5], - [3.0, -4.0, 1.0], - [0.9, 0.1, -1.0] - ], - [ - [-6.0, 2.0, 4.0], - [1.5, -2.0, 0.5], - [3.0, 1.0, -4.0] - ], - [ - [-1.0, 0.1, 0.9], - [2.0, -2.5, 0.5], - [0.9, 0.1, -1.0] - ], - ])) - ); - } - } - - let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); - - let f = F::new(0.000001); - let chi_sq = ChiSquare::new(0.0001); + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let cache = Cache::new(parameter_learning, data.clone()); - let mut ctpc = CTPC::new(f, chi_sq, cache); - - - let net = ctpc.fit_transform(net, &data); - assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); - assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes(ctpc); } #[test] fn learn_mixed_discrete_net_3_nodes_ctpc() { - let (_, data) = get_mixed_discrete_net_3_nodes_with_data(); - - let f = F::new(1e-24); - let chi_sq = ChiSquare::new(1e-24); + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let cache = Cache::new(parameter_learning, data); - let ctpc = CTPC::new(f, chi_sq, cache); - + let ctpc = CTPC::new(parameter_learning, f, chi_sq); learn_mixed_discrete_net_3_nodes(ctpc); }