Implemented basic cache

pull/79/head
Meliurwen 2 years ago
parent 19856195c3
commit ea3e406bf1
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 17
      reCTBN/src/parameter_learning.rs
  2. 4
      reCTBN/tests/structure_learning.rs

@ -1,6 +1,6 @@
//! Module containing methods used to learn the parameters. //! Module containing methods used to learn the parameters.
use std::collections::BTreeSet; use std::collections::{BTreeSet,HashMap};
use ndarray::prelude::*; use ndarray::prelude::*;
@ -165,13 +165,15 @@ impl ParameterLearning for BayesianApproach {
} }
} }
// TODO: Move to constraint_based_algorithm.rs
pub struct Cache<'a, P: ParameterLearning> { pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P, parameter_learning: &'a P,
cache_persistent: HashMap<Option<BTreeSet<usize>>, Params>,
} }
impl<'a, P: ParameterLearning> Cache<'a, P> { impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache { parameter_learning } Cache { parameter_learning, cache_persistent: HashMap::new() }
} }
pub fn fit<T: process::NetworkProcess>( pub fn fit<T: process::NetworkProcess>(
&mut self, &mut self,
@ -180,6 +182,15 @@ impl<'a, P: ParameterLearning> Cache<'a, P> {
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
self.parameter_learning.fit(net, dataset, node, parent_set) match self.cache_persistent.get(&parent_set) {
// TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone());
self.cache_persistent.insert(parent_set, params.clone());
params
}
}
} }
} }

@ -473,9 +473,11 @@ pub fn chi_square_call() {
let chi_sq = ChiSquare::new(1e-4); let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1); separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache));
} }
@ -493,9 +495,11 @@ pub fn f_call() {
assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1); separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache));
} }

Loading…
Cancel
Save