From ec72a6a2f9e6981da7407a2e7e4a66b2345d003c Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 28 Oct 2022 16:02:12 +0200 Subject: [PATCH 01/18] Defined the `compare_matrices` function for the F-test --- .../src/structure_learning/hypothesis_test.rs | 58 ++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 6474155..7534eaf 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -3,7 +3,7 @@ use std::collections::BTreeSet; use ndarray::{Array3, Axis}; -use statrs::distribution::{ChiSquared, ContinuousCDF}; +use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; use crate::{network, parameter_learning}; @@ -37,7 +37,61 @@ pub struct ChiSquare { alpha: f64, } -pub struct F {} +pub struct F { + alpha: f64, +} + +impl F { + pub fn new(alpha: f64) -> F { + F { alpha } + } + + pub fn compare_matrices( + &self, + i: usize, + M1: &Array3, + cim_1: &Array3, + j: usize, + M2: &Array3, + cim_2: &Array3, + ) -> bool { + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + let cim_1 = cim_1.index_axis(Axis(0), i); + let cim_2 = cim_2.index_axis(Axis(0), j); + let r1 = M1.sum_axis(Axis(1)); + let r2 = M2.sum_axis(Axis(1)); + let q1 = cim_1.diag(); + let q2 = cim_2.diag(); + for idx in 0..r1.shape()[0] { + let s = q2[idx] / q1[idx]; + let F = FisherSnedecor::new(r1[idx], r2[idx]); + let lim_sx = F.as_ref().expect("REASON").cdf(self.alpha / 2.0); + let lim_dx = F.as_ref().expect("REASON").cdf(1.0 - (self.alpha / 2.0)); + if s < lim_sx || s > lim_dx { + return false; + } + } + true + } +} + +impl HypothesisTest for F { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: &mut parameter_learning::Cache

, + ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning, + { + true + } +} impl ChiSquare { pub fn new(alpha: f64) -> ChiSquare { From c08f4e1985edf949049b09eceba059e7de0cb1f1 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 10 Nov 2022 14:10:17 +0100 Subject: [PATCH 02/18] Added F call function --- .../src/structure_learning/hypothesis_test.rs | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 7534eaf..75c0eac 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -89,7 +89,38 @@ impl HypothesisTest for F { T: network::Network, P: parameter_learning::ParameterLearning, { - true + let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let mut extended_separation_set = separation_set.clone(); + extended_separation_set.insert(parent_node); + + let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let partial_cardinality_product: usize = extended_separation_set + .iter() + .take_while(|x| **x != parent_node) + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { + let idx_M_small: usize = idx_M_big % partial_cardinality_product + + (idx_M_big + / (partial_cardinality_product + * net.get_node(parent_node).get_reserved_space_as_parent())) + * partial_cardinality_product; + if !self.compare_matrices( + idx_M_small, + P_small.get_transitions().as_ref().unwrap(), + P_small.get_cim().as_ref().unwrap(), + idx_M_big, + P_big.get_transitions().as_ref().unwrap(), + P_big.get_cim().as_ref().unwrap(), + ) { + return false; + } + } + return true; } } From 9fbdf25149f6ebe7fcaded6e609e711783053736 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 16 Nov 2022 10:40:19 +0100 Subject: [PATCH 03/18] Fixed `chi_square_call` test, the test was passing, but only for pure chance --- reCTBN/tests/structure_learning.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a1667c2..a8cf3c6 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -108,7 +108,7 @@ fn check_compatibility_between_dataset_and_network Date: Wed, 16 Nov 2022 10:46:55 +0100 Subject: [PATCH 04/18] Slight optimization of `F::compare_matrices` --- reCTBN/src/structure_learning/hypothesis_test.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 75c0eac..9f7a518 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -65,9 +65,10 @@ impl F { let q2 = cim_2.diag(); for idx in 0..r1.shape()[0] { let s = q2[idx] / q1[idx]; - let F = FisherSnedecor::new(r1[idx], r2[idx]); - let lim_sx = F.as_ref().expect("REASON").cdf(self.alpha / 2.0); - let lim_dx = F.as_ref().expect("REASON").cdf(1.0 - (self.alpha / 2.0)); + let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap(); + let s = F.cdf(s); + let lim_sx = self.alpha / 2.0; + let lim_dx = 1.0 - (self.alpha / 2.0); if s < lim_sx || s > lim_dx { return false; } From 3a0151a9f62a85da46d916af0f35c5309b5be9e5 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 16 Nov 2022 12:44:06 +0100 Subject: [PATCH 05/18] Added test for F-test call function --- reCTBN/tests/structure_learning.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a8cf3c6..5c1db80 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -477,3 +477,23 @@ pub fn chi_square_call() { separation_set.insert(N1); assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); } + +#[test] +pub fn f_call() { + + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let N3: usize = 2; + let N2: usize = 1; + let N1: usize = 0; + let mut separation_set = BTreeSet::new(); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let mut cache = Cache::new(parameter_learning, data); + let f = F::new(0.000001); + + + assert!(f.call(&net, N1, N3, &separation_set, &mut cache)); + assert!(!f.call(&net, N3, N1, &separation_set, &mut cache)); + assert!(!f.call(&net, N3, N2, &separation_set, &mut cache)); + separation_set.insert(N1); + assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); +} From cac19b17565e06bb192298d7fc715c8c7701897b Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 25 Nov 2022 10:19:05 +0100 Subject: [PATCH 06/18] Aligned F-test to the new changes --- reCTBN/src/structure_learning/hypothesis_test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index aa37cfa..dd3bbf7 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -87,7 +87,7 @@ impl HypothesisTest for F { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { From 6e90458418c8d3d6b9c9107faa8ec4ef77a238e6 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 28 Nov 2022 11:08:20 +0100 Subject: [PATCH 07/18] Laying grounds for CTPC --- .../constraint_based_algorithm.rs | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 670c8ed..d931f78 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,5 +1,39 @@ //! Module containing constraint based algorithms like CTPC and Hiton. -//pub struct CTPC { -// -//} +use super::hypothesis_test::*; +use crate::structure_learning::StructureLearningAlgorithm; +use crate::{process, tools}; +use crate::parameter_learning::{Cache, ParameterLearning}; + +pub struct CTPC { + Ftest: F, + Chi2test: ChiSquare, + cache: Cache

, + +} + +impl CTPC

{ + pub fn new(Ftest: F, Chi2test: ChiSquare, cache: Cache

) -> CTPC

{ + CTPC { + Chi2test, + Ftest, + cache, + } + } +} + +impl StructureLearningAlgorithm for CTPC

{ + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + where + T: process::NetworkProcess, + { + //Check the coherence between dataset and network + if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { + panic!("Dataset and Network must have the same number of variables.") + } + + //Make the network mutable. + let mut net = net; + net + } +} From 6d952f8c0741faf827ac725e328f1b6e8b4d5b8a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 08:52:40 +0100 Subject: [PATCH 08/18] Added `itertools` a WIP version of CTPC and some hacky and temporary modifications --- reCTBN/Cargo.toml | 1 + reCTBN/src/structure_learning.rs | 2 +- .../constraint_based_algorithm.rs | 44 +++++++++++++++++-- .../score_based_algorithm.rs | 2 +- reCTBN/tests/structure_learning.rs | 8 ++-- 5 files changed, 48 insertions(+), 9 deletions(-) diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index b0a691b..fdac697 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -13,6 +13,7 @@ bimap = "~0.6" enum_dispatch = "~0.3" statrs = "~0.16" rand_chacha = "~0.3" +itertools = "~0.10" [dev-dependencies] approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index b272e22..d119ab2 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -7,7 +7,7 @@ pub mod score_function; use crate::{process, tools}; pub trait StructureLearningAlgorithm { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index d931f78..6fd5b79 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,15 +1,18 @@ //! Module containing constraint based algorithms like CTPC and Hiton. +use itertools::Itertools; +use std::collections::BTreeSet; +use std::usize; + use super::hypothesis_test::*; +use crate::parameter_learning::{Cache, ParameterLearning}; use crate::structure_learning::StructureLearningAlgorithm; use crate::{process, tools}; -use crate::parameter_learning::{Cache, ParameterLearning}; pub struct CTPC { Ftest: F, Chi2test: ChiSquare, cache: Cache

, - } impl CTPC

{ @@ -23,7 +26,7 @@ impl CTPC

{ } impl StructureLearningAlgorithm for CTPC

{ - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess, { @@ -34,6 +37,41 @@ impl StructureLearningAlgorithm for CTPC

{ //Make the network mutable. let mut net = net; + + net.initialize_adj_matrix(); + + for child_node in net.get_node_indices() { + let mut candidate_parent_set: BTreeSet = net + .get_node_indices() + .into_iter() + .filter(|x| x != &child_node) + .collect(); + let mut b = 0; + while b < candidate_parent_set.len() { + for parent_node in candidate_parent_set.iter() { + for separation_set in candidate_parent_set + .iter() + .filter(|x| x != &parent_node) + .map(|x| *x) + .combinations(b) + { + let separation_set = separation_set.into_iter().collect(); + if self.Ftest.call( + &net, + child_node, + *parent_node, + &separation_set, + &mut self.cache, + ) && self.Chi2test.call(&net, child_node, *parent_node, &separation_set, &mut self.cache) { + candidate_parent_set.remove(&parent_node); + break; + } + } + } + b = b + 1; + } + } + net } } diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 16e9056..d59f0c1 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -21,7 +21,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess, { diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 6134510..4bf9027 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -58,7 +58,7 @@ fn simple_bic() { ); } -fn check_compatibility_between_dataset_and_network(sl: T) { +fn check_compatibility_between_dataset_and_network(mut sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -125,7 +125,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes(sl: T) { +fn learn_ternary_net_2_nodes(mut sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -320,7 +320,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } -fn learn_mixed_discrete_net_3_nodes(sl: T) { +fn learn_mixed_discrete_net_3_nodes(mut sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -342,7 +342,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(mut sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); From 6d42d8a805c493aad5dcbdd63a4fad4bb638646c Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 12:54:08 +0100 Subject: [PATCH 09/18] Solved issue with `candidate_parent_set` variable in CTPC and added loop to fill the adjacency matrix --- .../constraint_based_algorithm.rs | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 6fd5b79..d94c793 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -48,6 +48,7 @@ impl StructureLearningAlgorithm for CTPC

{ .collect(); let mut b = 0; while b < candidate_parent_set.len() { + let mut not_parent_node: usize = child_node; for parent_node in candidate_parent_set.iter() { for separation_set in candidate_parent_set .iter() @@ -62,16 +63,30 @@ impl StructureLearningAlgorithm for CTPC

{ *parent_node, &separation_set, &mut self.cache, - ) && self.Chi2test.call(&net, child_node, *parent_node, &separation_set, &mut self.cache) { - candidate_parent_set.remove(&parent_node); + ) && self.Chi2test.call( + &net, + child_node, + *parent_node, + &separation_set, + &mut self.cache, + ) { + not_parent_node = parent_node.clone(); break; } } + if not_parent_node != child_node { + break; + } + } + if not_parent_node != child_node { + candidate_parent_set.remove(¬_parent_node); } b = b + 1; } + for parent_node in candidate_parent_set.iter() { + net.add_edge(*parent_node, child_node); + } } - net } } From 8d0f9db289b8453bed413d2ba70ec1693b0c376d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 17:21:53 +0100 Subject: [PATCH 10/18] WIP: Added tests for CTPC --- reCTBN/tests/structure_learning.rs | 79 ++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 4bf9027..c0deffd 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -10,6 +10,7 @@ use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; +use reCTBN::structure_learning::constraint_based_algorithm::*; use reCTBN::structure_learning::score_based_algorithm::*; use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::StructureLearningAlgorithm; @@ -497,3 +498,81 @@ pub fn f_call() { separation_set.insert(N1); assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); } + +#[test] +pub fn learn_ternary_net_2_nodes_ctpc() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let f = F::new(0.000001); + let chi_sq = ChiSquare::new(0.0001); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let cache = Cache::new(parameter_learning, data.clone()); + let mut ctpc = CTPC::new(f, chi_sq, cache); + + + let net = ctpc.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); + assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); +} + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc() { + let (_, data) = get_mixed_discrete_net_3_nodes_with_data(); + + let f = F::new(1e-24); + let chi_sq = ChiSquare::new(1e-24); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let cache = Cache::new(parameter_learning, data); + let ctpc = CTPC::new(f, chi_sq, cache); + + learn_mixed_discrete_net_3_nodes(ctpc); +} From 468ebf09cc330c2448b90ed27e090e1613336045 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 17:24:23 +0100 Subject: [PATCH 11/18] WIP: Added `#[derive(Clone)]` to `Dataset` and `Trajectory` --- reCTBN/src/tools.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index ecfeff9..47a067d 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -5,6 +5,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; +#[derive(Clone)] pub struct Trajectory { time: Array1, events: Array2, @@ -29,6 +30,7 @@ impl Trajectory { } } +#[derive(Clone)] pub struct Dataset { trajectories: Vec, } From ea5df7cad6485742905f2a0297a95cfc6cf2f801 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 20 Dec 2022 12:36:28 +0100 Subject: [PATCH 12/18] Solved another issue with `candidate_parent_set` variable in CTPC --- .../constraint_based_algorithm.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index d94c793..8949aa5 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -46,15 +46,15 @@ impl StructureLearningAlgorithm for CTPC

{ .into_iter() .filter(|x| x != &child_node) .collect(); - let mut b = 0; - while b < candidate_parent_set.len() { - let mut not_parent_node: usize = child_node; + let mut separation_set_size = 0; + while separation_set_size < candidate_parent_set.len() { + let mut candidate_parent_set_TMP = candidate_parent_set.clone(); for parent_node in candidate_parent_set.iter() { for separation_set in candidate_parent_set .iter() .filter(|x| x != &parent_node) .map(|x| *x) - .combinations(b) + .combinations(separation_set_size) { let separation_set = separation_set.into_iter().collect(); if self.Ftest.call( @@ -70,18 +70,13 @@ impl StructureLearningAlgorithm for CTPC

{ &separation_set, &mut self.cache, ) { - not_parent_node = parent_node.clone(); + candidate_parent_set_TMP.remove(parent_node); break; } } - if not_parent_node != child_node { - break; - } - } - if not_parent_node != child_node { - candidate_parent_set.remove(¬_parent_node); } - b = b + 1; + candidate_parent_set = candidate_parent_set_TMP; + separation_set_size += 1; } for parent_node in candidate_parent_set.iter() { net.add_edge(*parent_node, child_node); From 19856195c39e5428e906fbb0b7ab0ecb8e9e6394 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 21 Dec 2022 11:41:26 +0100 Subject: [PATCH 13/18] Refactored cache laying grounds for its node-centered implementation changing also its signature, propagated this change and refactored CTPC tests --- reCTBN/src/parameter_learning.rs | 28 ++--- reCTBN/src/structure_learning.rs | 4 +- .../constraint_based_algorithm.rs | 20 ++-- .../src/structure_learning/hypothesis_test.rs | 23 +++- .../score_based_algorithm.rs | 4 +- reCTBN/tests/structure_learning.rs | 112 ++++-------------- 6 files changed, 72 insertions(+), 119 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 2aa518c..f8a7664 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,13 +5,13 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub trait ParameterLearning { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params; @@ -19,7 +19,7 @@ pub trait ParameterLearning { pub fn sufficient_statistics( net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: &BTreeSet, ) -> (Array3, Array2) { @@ -76,7 +76,7 @@ impl ParameterLearning for MLE { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { @@ -123,7 +123,7 @@ impl ParameterLearning for BayesianApproach { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { @@ -165,25 +165,21 @@ impl ParameterLearning for BayesianApproach { } } -pub struct Cache { - parameter_learning: P, - dataset: tools::Dataset, +pub struct Cache<'a, P: ParameterLearning> { + parameter_learning: &'a P, } -impl Cache

{ - pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache

{ - Cache { - parameter_learning, - dataset, - } +impl<'a, P: ParameterLearning> Cache<'a, P> { + pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { + Cache { parameter_learning } } pub fn fit( &mut self, net: &T, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning - .fit(net, &self.dataset, node, parent_set) + self.parameter_learning.fit(net, dataset, node, parent_set) } } diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index d119ab2..a4c6ea1 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub trait StructureLearningAlgorithm { - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 8949aa5..6d54fe7 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -6,27 +6,28 @@ use std::usize; use super::hypothesis_test::*; use crate::parameter_learning::{Cache, ParameterLearning}; +use crate::process; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{process, tools}; +use crate::tools::Dataset; pub struct CTPC { + parameter_learning: P, Ftest: F, Chi2test: ChiSquare, - cache: Cache

, } impl CTPC

{ - pub fn new(Ftest: F, Chi2test: ChiSquare, cache: Cache

) -> CTPC

{ + pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC

{ CTPC { - Chi2test, + parameter_learning, Ftest, - cache, + Chi2test, } } } impl StructureLearningAlgorithm for CTPC

{ - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess, { @@ -41,6 +42,7 @@ impl StructureLearningAlgorithm for CTPC

{ net.initialize_adj_matrix(); for child_node in net.get_node_indices() { + let mut cache = Cache::new(&self.parameter_learning); let mut candidate_parent_set: BTreeSet = net .get_node_indices() .into_iter() @@ -62,13 +64,15 @@ impl StructureLearningAlgorithm for CTPC

{ child_node, *parent_node, &separation_set, - &mut self.cache, + dataset, + &mut cache, ) && self.Chi2test.call( &net, child_node, *parent_node, &separation_set, - &mut self.cache, + dataset, + &mut cache, ) { candidate_parent_set_TMP.remove(parent_node); break; diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index dd3bbf7..dd683ab 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; -use crate::{parameter_learning, process}; +use crate::{parameter_learning, process, tools::Dataset}; pub trait HypothesisTest { fn call( @@ -15,6 +15,7 @@ pub trait HypothesisTest { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where @@ -84,19 +85,25 @@ impl HypothesisTest for F { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { Params::DiscreteStatesContinousTime(node) => node, }; let partial_cardinality_product: usize = extended_separation_set @@ -218,6 +225,7 @@ impl HypothesisTest for ChiSquare { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where @@ -227,14 +235,19 @@ impl HypothesisTest for ChiSquare { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // di dimensione nxn // (CIM, M, T) - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { Params::DiscreteStatesContinousTime(node) => node, }; // Commentare qui diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index d59f0c1..d65ea88 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub struct HillClimbing { score_function: S, @@ -21,7 +21,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess, { diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index c0deffd..6f97c9d 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -59,7 +59,7 @@ fn simple_bic() { ); } -fn check_compatibility_between_dataset_and_network(mut sl: T) { +fn check_compatibility_between_dataset_and_network(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -126,7 +126,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes(mut sl: T) { +fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -321,7 +321,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } -fn learn_mixed_discrete_net_3_nodes(mut sl: T) { +fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -343,7 +343,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(mut sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -393,7 +393,7 @@ pub fn chi_square_compare_matrices() { [ 700, 800, 0] ], ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -423,7 +423,7 @@ pub fn chi_square_compare_matrices_2() { [ 400, 0, 600], [ 700, 800, 0]] ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -455,7 +455,7 @@ pub fn chi_square_compare_matrices_3() { [ 700, 800, 0] ], ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -469,14 +469,14 @@ pub fn chi_square_call() { let N1: usize = 0; let mut separation_set = BTreeSet::new(); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let mut cache = Cache::new(parameter_learning, data); - let chi_sq = ChiSquare::new(0.0001); + let mut cache = Cache::new(¶meter_learning); + let chi_sq = ChiSquare::new(1e-4); - assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache)); - assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache)); - assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); + assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); - assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); } #[test] @@ -488,91 +488,31 @@ pub fn f_call() { let N1: usize = 0; let mut separation_set = BTreeSet::new(); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let mut cache = Cache::new(parameter_learning, data); - let f = F::new(0.000001); + let mut cache = Cache::new(¶meter_learning); + let f = F::new(1e-6); - assert!(f.call(&net, N1, N3, &separation_set, &mut cache)); - assert!(!f.call(&net, N3, N1, &separation_set, &mut cache)); - assert!(!f.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); + assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); - assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); + assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); } #[test] pub fn learn_ternary_net_2_nodes_ctpc() { - let mut net = CtbnNetwork::new(); - let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) - .unwrap(); - let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) - .unwrap(); - net.add_edge(n1, n2); - - match &mut net.get_node_mut(n1) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(arr3(&[ - [ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ], - ])) - ); - } - } - - match &mut net.get_node_mut(n2) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(arr3(&[ - [ - [-1.0, 0.5, 0.5], - [3.0, -4.0, 1.0], - [0.9, 0.1, -1.0] - ], - [ - [-6.0, 2.0, 4.0], - [1.5, -2.0, 0.5], - [3.0, 1.0, -4.0] - ], - [ - [-1.0, 0.1, 0.9], - [2.0, -2.5, 0.5], - [0.9, 0.1, -1.0] - ], - ])) - ); - } - } - - let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); - - let f = F::new(0.000001); - let chi_sq = ChiSquare::new(0.0001); + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let cache = Cache::new(parameter_learning, data.clone()); - let mut ctpc = CTPC::new(f, chi_sq, cache); - - - let net = ctpc.fit_transform(net, &data); - assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); - assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes(ctpc); } #[test] fn learn_mixed_discrete_net_3_nodes_ctpc() { - let (_, data) = get_mixed_discrete_net_3_nodes_with_data(); - - let f = F::new(1e-24); - let chi_sq = ChiSquare::new(1e-24); + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let cache = Cache::new(parameter_learning, data); - let ctpc = CTPC::new(f, chi_sq, cache); - + let ctpc = CTPC::new(parameter_learning, f, chi_sq); learn_mixed_discrete_net_3_nodes(ctpc); } From ea3e406bf14e40c622ffc40e02abb18e55d1e30a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 29 Dec 2022 23:03:03 +0100 Subject: [PATCH 14/18] Implemented basic cache --- reCTBN/src/parameter_learning.rs | 17 ++++++++++++++--- reCTBN/tests/structure_learning.rs | 4 ++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index f8a7664..021b100 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::BTreeSet; +use std::collections::{BTreeSet,HashMap}; use ndarray::prelude::*; @@ -165,13 +165,15 @@ impl ParameterLearning for BayesianApproach { } } +// TODO: Move to constraint_based_algorithm.rs pub struct Cache<'a, P: ParameterLearning> { parameter_learning: &'a P, + cache_persistent: HashMap>, Params>, } impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { parameter_learning } + Cache { parameter_learning, cache_persistent: HashMap::new() } } pub fn fit( &mut self, @@ -180,6 +182,15 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning.fit(net, dataset, node, parent_set) + match self.cache_persistent.get(&parent_set) { + // TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone()); + self.cache_persistent.insert(parent_set, params.clone()); + params + } + } } } diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 6f97c9d..a37f2b3 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -473,9 +473,11 @@ pub fn chi_square_call() { let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); } @@ -493,9 +495,11 @@ pub fn f_call() { assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); } From a0da3e2fe8fb8e3dde02ca2c375e2826623d4ee0 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 30 Dec 2022 17:47:32 +0100 Subject: [PATCH 15/18] Fixed formatting issue --- reCTBN/src/parameter_learning.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 021b100..73193ca 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::{BTreeSet,HashMap}; +use std::collections::{BTreeSet, HashMap}; use ndarray::prelude::*; @@ -173,7 +173,10 @@ pub struct Cache<'a, P: ParameterLearning> { impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { parameter_learning, cache_persistent: HashMap::new() } + Cache { + parameter_learning, + cache_persistent: HashMap::new(), + } } pub fn fit( &mut self, @@ -187,7 +190,9 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { // not cloning requires a minor and reasoned refactoring across the library Some(params) => params.clone(), None => { - let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone()); + let params = self + .parameter_learning + .fit(net, dataset, node, parent_set.clone()); self.cache_persistent.insert(parent_set, params.clone()); params } From 867bf029345855637ea6a608cf9b3ae58d0937eb Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 30 Dec 2022 17:55:26 +0100 Subject: [PATCH 16/18] Greatly improved memory consumption in cache, stale data no longer pile up --- reCTBN/src/parameter_learning.rs | 56 +++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 73193ca..fcc47c4 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,7 @@ //! Module containing methods used to learn the parameters. use std::collections::{BTreeSet, HashMap}; +use std::mem; use ndarray::prelude::*; @@ -168,14 +169,18 @@ impl ParameterLearning for BayesianApproach { // TODO: Move to constraint_based_algorithm.rs pub struct Cache<'a, P: ParameterLearning> { parameter_learning: &'a P, - cache_persistent: HashMap>, Params>, + cache_persistent_small: HashMap>, Params>, + cache_persistent_big: HashMap>, Params>, + parent_set_size_small: usize, } impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { Cache { parameter_learning, - cache_persistent: HashMap::new(), + cache_persistent_small: HashMap::new(), + cache_persistent_big: HashMap::new(), + parent_set_size_small: 0, } } pub fn fit( @@ -185,16 +190,43 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { node: usize, parent_set: Option>, ) -> Params { - match self.cache_persistent.get(&parent_set) { - // TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = self - .parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent.insert(parent_set, params.clone()); - params + let parent_set_len = parent_set.as_ref().unwrap().len(); + if parent_set_len > self.parent_set_size_small + 1 { + //self.cache_persistent_small = self.cache_persistent_big; + mem::swap( + &mut self.cache_persistent_small, + &mut self.cache_persistent_big, + ); + self.cache_persistent_big = HashMap::new(); + self.parent_set_size_small += 1; + } + + if parent_set_len > self.parent_set_size_small { + match self.cache_persistent_big.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_big.insert(parent_set, params.clone()); + params + } + } + } else { + match self.cache_persistent_small.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_small + .insert(parent_set, params.clone()); + params + } } } } From 4d3f9518e4137e911d55cc0f723e839e8f391752 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 4 Jan 2023 12:14:36 +0100 Subject: [PATCH 17/18] CTPC parallelization at the nodes level with `rayon` --- reCTBN/Cargo.toml | 1 + reCTBN/src/parameter_learning.rs | 2 +- reCTBN/src/process.rs | 2 +- .../src/structure_learning/constraint_based_algorithm.rs | 8 +++++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index fdac697..4749b23 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -14,6 +14,7 @@ enum_dispatch = "~0.3" statrs = "~0.16" rand_chacha = "~0.3" itertools = "~0.10" +rayon = "~1.6" [dev-dependencies] approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index fcc47c4..ff6a7b9 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -8,7 +8,7 @@ use ndarray::prelude::*; use crate::params::*; use crate::{process, tools::Dataset}; -pub trait ParameterLearning { +pub trait ParameterLearning: Sync { fn fit( &self, net: &T, diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index dc297bc..45c5e0a 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -21,7 +21,7 @@ pub type NetworkProcessState = Vec; /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait NetworkProcess { +pub trait NetworkProcess: Sync { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 6d54fe7..634c144 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,6 +1,8 @@ //! Module containing constraint based algorithms like CTPC and Hiton. use itertools::Itertools; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::ParallelExtend; use std::collections::BTreeSet; use std::usize; @@ -41,7 +43,8 @@ impl StructureLearningAlgorithm for CTPC

{ net.initialize_adj_matrix(); - for child_node in net.get_node_indices() { + let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![]; + learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { let mut cache = Cache::new(&self.parameter_learning); let mut candidate_parent_set: BTreeSet = net .get_node_indices() @@ -82,6 +85,9 @@ impl StructureLearningAlgorithm for CTPC

{ candidate_parent_set = candidate_parent_set_TMP; separation_set_size += 1; } + (child_node, candidate_parent_set) + })); + for (child_node, candidate_parent_set) in learned_parent_sets { for parent_node in candidate_parent_set.iter() { net.add_edge(*parent_node, child_node); } From 5632833963ed3f27514c02bb42c711a79cc06b74 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 5 Jan 2023 10:53:59 +0100 Subject: [PATCH 18/18] Moved `Cache` to `constraint_based_algorithm.rs` --- reCTBN/src/parameter_learning.rs | 69 +----------------- .../constraint_based_algorithm.rs | 71 ++++++++++++++++++- .../src/structure_learning/hypothesis_test.rs | 7 +- reCTBN/tests/structure_learning.rs | 1 - 4 files changed, 74 insertions(+), 74 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index ff6a7b9..536a9d5 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,7 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::{BTreeSet, HashMap}; -use std::mem; +use std::collections::BTreeSet; use ndarray::prelude::*; @@ -165,69 +164,3 @@ impl ParameterLearning for BayesianApproach { return n; } } - -// TODO: Move to constraint_based_algorithm.rs -pub struct Cache<'a, P: ParameterLearning> { - parameter_learning: &'a P, - cache_persistent_small: HashMap>, Params>, - cache_persistent_big: HashMap>, Params>, - parent_set_size_small: usize, -} - -impl<'a, P: ParameterLearning> Cache<'a, P> { - pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { - parameter_learning, - cache_persistent_small: HashMap::new(), - cache_persistent_big: HashMap::new(), - parent_set_size_small: 0, - } - } - pub fn fit( - &mut self, - net: &T, - dataset: &Dataset, - node: usize, - parent_set: Option>, - ) -> Params { - let parent_set_len = parent_set.as_ref().unwrap().len(); - if parent_set_len > self.parent_set_size_small + 1 { - //self.cache_persistent_small = self.cache_persistent_big; - mem::swap( - &mut self.cache_persistent_small, - &mut self.cache_persistent_big, - ); - self.cache_persistent_big = HashMap::new(); - self.parent_set_size_small += 1; - } - - if parent_set_len > self.parent_set_size_small { - match self.cache_persistent_big.get(&parent_set) { - // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = - self.parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent_big.insert(parent_set, params.clone()); - params - } - } - } else { - match self.cache_persistent_small.get(&parent_set) { - // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = - self.parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent_small - .insert(parent_set, params.clone()); - params - } - } - } - } -} diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 634c144..f49b194 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,17 +1,84 @@ //! Module containing constraint based algorithms like CTPC and Hiton. +use crate::params::Params; use itertools::Itertools; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use rayon::prelude::ParallelExtend; -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; +use std::mem; use std::usize; use super::hypothesis_test::*; -use crate::parameter_learning::{Cache, ParameterLearning}; +use crate::parameter_learning::ParameterLearning; use crate::process; use crate::structure_learning::StructureLearningAlgorithm; use crate::tools::Dataset; +pub struct Cache<'a, P: ParameterLearning> { + parameter_learning: &'a P, + cache_persistent_small: HashMap>, Params>, + cache_persistent_big: HashMap>, Params>, + parent_set_size_small: usize, +} + +impl<'a, P: ParameterLearning> Cache<'a, P> { + pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { + Cache { + parameter_learning, + cache_persistent_small: HashMap::new(), + cache_persistent_big: HashMap::new(), + parent_set_size_small: 0, + } + } + pub fn fit( + &mut self, + net: &T, + dataset: &Dataset, + node: usize, + parent_set: Option>, + ) -> Params { + let parent_set_len = parent_set.as_ref().unwrap().len(); + if parent_set_len > self.parent_set_size_small + 1 { + //self.cache_persistent_small = self.cache_persistent_big; + mem::swap( + &mut self.cache_persistent_small, + &mut self.cache_persistent_big, + ); + self.cache_persistent_big = HashMap::new(); + self.parent_set_size_small += 1; + } + + if parent_set_len > self.parent_set_size_small { + match self.cache_persistent_big.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_big.insert(parent_set, params.clone()); + params + } + } + } else { + match self.cache_persistent_small.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_small + .insert(parent_set, params.clone()); + params + } + } + } + } +} + pub struct CTPC { parameter_learning: P, Ftest: F, diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index dd683ab..4c02929 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,6 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; +use crate::structure_learning::constraint_based_algorithm::Cache; use crate::{parameter_learning, process, tools::Dataset}; pub trait HypothesisTest { @@ -16,7 +17,7 @@ pub trait HypothesisTest { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, @@ -86,7 +87,7 @@ impl HypothesisTest for F { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, @@ -226,7 +227,7 @@ impl HypothesisTest for ChiSquare { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a37f2b3..9a69b45 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -7,7 +7,6 @@ use ndarray::{arr1, arr2, arr3}; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; -use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::constraint_based_algorithm::*;