From ea3e406bf14e40c622ffc40e02abb18e55d1e30a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 29 Dec 2022 23:03:03 +0100 Subject: [PATCH] Implemented basic cache --- reCTBN/src/parameter_learning.rs | 17 ++++++++++++++--- reCTBN/tests/structure_learning.rs | 4 ++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index f8a7664..021b100 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::BTreeSet; +use std::collections::{BTreeSet,HashMap}; use ndarray::prelude::*; @@ -165,13 +165,15 @@ impl ParameterLearning for BayesianApproach { } } +// TODO: Move to constraint_based_algorithm.rs pub struct Cache<'a, P: ParameterLearning> { parameter_learning: &'a P, + cache_persistent: HashMap>, Params>, } impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { parameter_learning } + Cache { parameter_learning, cache_persistent: HashMap::new() } } pub fn fit( &mut self, @@ -180,6 +182,15 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning.fit(net, dataset, node, parent_set) + match self.cache_persistent.get(&parent_set) { + // TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone()); + self.cache_persistent.insert(parent_set, params.clone()); + params + } + } } } diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 6f97c9d..a37f2b3 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -473,9 +473,11 @@ pub fn chi_square_call() { let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); } @@ -493,9 +495,11 @@ pub fn f_call() { assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); }