Greatly improved memory consumption in cache, stale data no longer pile up

pull/79/head
Meliurwen 2 years ago
parent a0da3e2fe8
commit 867bf02934
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 56
      reCTBN/src/parameter_learning.rs

@ -1,6 +1,7 @@
//! Module containing methods used to learn the parameters. //! Module containing methods used to learn the parameters.
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use std::mem;
use ndarray::prelude::*; use ndarray::prelude::*;
@ -168,14 +169,18 @@ impl ParameterLearning for BayesianApproach {
// TODO: Move to constraint_based_algorithm.rs // TODO: Move to constraint_based_algorithm.rs
pub struct Cache<'a, P: ParameterLearning> { pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P, parameter_learning: &'a P,
cache_persistent: HashMap<Option<BTreeSet<usize>>, Params>, cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>,
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>,
parent_set_size_small: usize,
} }
impl<'a, P: ParameterLearning> Cache<'a, P> { impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache { Cache {
parameter_learning, parameter_learning,
cache_persistent: HashMap::new(), cache_persistent_small: HashMap::new(),
cache_persistent_big: HashMap::new(),
parent_set_size_small: 0,
} }
} }
pub fn fit<T: process::NetworkProcess>( pub fn fit<T: process::NetworkProcess>(
@ -185,16 +190,43 @@ impl<'a, P: ParameterLearning> Cache<'a, P> {
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
match self.cache_persistent.get(&parent_set) { let parent_set_len = parent_set.as_ref().unwrap().len();
// TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O if parent_set_len > self.parent_set_size_small + 1 {
// not cloning requires a minor and reasoned refactoring across the library //self.cache_persistent_small = self.cache_persistent_big;
Some(params) => params.clone(), mem::swap(
None => { &mut self.cache_persistent_small,
let params = self &mut self.cache_persistent_big,
.parameter_learning );
.fit(net, dataset, node, parent_set.clone()); self.cache_persistent_big = HashMap::new();
self.cache_persistent.insert(parent_set, params.clone()); self.parent_set_size_small += 1;
params }
if parent_set_len > self.parent_set_size_small {
match self.cache_persistent_big.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_big.insert(parent_set, params.clone());
params
}
}
} else {
match self.cache_persistent_small.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_small
.insert(parent_set, params.clone());
params
}
} }
} }
} }

Loading…
Cancel
Save