diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 73193ca..fcc47c4 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,7 @@ //! Module containing methods used to learn the parameters. use std::collections::{BTreeSet, HashMap}; +use std::mem; use ndarray::prelude::*; @@ -168,14 +169,18 @@ impl ParameterLearning for BayesianApproach { // TODO: Move to constraint_based_algorithm.rs pub struct Cache<'a, P: ParameterLearning> { parameter_learning: &'a P, - cache_persistent: HashMap>, Params>, + cache_persistent_small: HashMap>, Params>, + cache_persistent_big: HashMap>, Params>, + parent_set_size_small: usize, } impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { Cache { parameter_learning, - cache_persistent: HashMap::new(), + cache_persistent_small: HashMap::new(), + cache_persistent_big: HashMap::new(), + parent_set_size_small: 0, } } pub fn fit( @@ -185,16 +190,43 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { node: usize, parent_set: Option>, ) -> Params { - match self.cache_persistent.get(&parent_set) { - // TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = self - .parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent.insert(parent_set, params.clone()); - params + let parent_set_len = parent_set.as_ref().unwrap().len(); + if parent_set_len > self.parent_set_size_small + 1 { + //self.cache_persistent_small = self.cache_persistent_big; + mem::swap( + &mut self.cache_persistent_small, + &mut self.cache_persistent_big, + ); + self.cache_persistent_big = HashMap::new(); + self.parent_set_size_small += 1; + } + + if parent_set_len > self.parent_set_size_small { + match self.cache_persistent_big.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_big.insert(parent_set, params.clone()); + params + } + } + } else { + match self.cache_persistent_small.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_small + .insert(parent_set, params.clone()); + params + } } } }