Implemented basic cache

pull/79/head
Meliurwen 2 years ago
parent 19856195c3
commit ea3e406bf1
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 17
      reCTBN/src/parameter_learning.rs
  2. 4
      reCTBN/tests/structure_learning.rs

@ -1,6 +1,6 @@
//! Module containing methods used to learn the parameters.
use std::collections::BTreeSet;
use std::collections::{BTreeSet,HashMap};
use ndarray::prelude::*;
@ -165,13 +165,15 @@ impl ParameterLearning for BayesianApproach {
}
}
// TODO: Move to constraint_based_algorithm.rs
pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P,
cache_persistent: HashMap<Option<BTreeSet<usize>>, Params>,
}
impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache { parameter_learning }
Cache { parameter_learning, cache_persistent: HashMap::new() }
}
pub fn fit<T: process::NetworkProcess>(
&mut self,
@ -180,6 +182,15 @@ impl<'a, P: ParameterLearning> Cache<'a, P> {
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
self.parameter_learning.fit(net, dataset, node, parent_set)
match self.cache_persistent.get(&parent_set) {
// TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone());
self.cache_persistent.insert(parent_set, params.clone());
params
}
}
}
}

@ -473,9 +473,11 @@ pub fn chi_square_call() {
let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache));
}
@ -493,9 +495,11 @@ pub fn f_call() {
assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache));
}

Loading…
Cancel
Save