|
|
|
@ -1,6 +1,7 @@ |
|
|
|
|
//! Module containing methods used to learn the parameters.
|
|
|
|
|
|
|
|
|
|
use std::collections::{BTreeSet, HashMap}; |
|
|
|
|
use std::mem; |
|
|
|
|
|
|
|
|
|
use ndarray::prelude::*; |
|
|
|
|
|
|
|
|
@ -168,14 +169,18 @@ impl ParameterLearning for BayesianApproach { |
|
|
|
|
// TODO: Move to constraint_based_algorithm.rs
|
|
|
|
|
pub struct Cache<'a, P: ParameterLearning> { |
|
|
|
|
parameter_learning: &'a P, |
|
|
|
|
cache_persistent: HashMap<Option<BTreeSet<usize>>, Params>, |
|
|
|
|
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>, |
|
|
|
|
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>, |
|
|
|
|
parent_set_size_small: usize, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
|
|
|
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
|
|
|
|
Cache { |
|
|
|
|
parameter_learning, |
|
|
|
|
cache_persistent: HashMap::new(), |
|
|
|
|
cache_persistent_small: HashMap::new(), |
|
|
|
|
cache_persistent_big: HashMap::new(), |
|
|
|
|
parent_set_size_small: 0, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
pub fn fit<T: process::NetworkProcess>( |
|
|
|
@ -185,17 +190,44 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
|
|
|
node: usize, |
|
|
|
|
parent_set: Option<BTreeSet<usize>>, |
|
|
|
|
) -> Params { |
|
|
|
|
match self.cache_persistent.get(&parent_set) { |
|
|
|
|
// TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O
|
|
|
|
|
let parent_set_len = parent_set.as_ref().unwrap().len(); |
|
|
|
|
if parent_set_len > self.parent_set_size_small + 1 { |
|
|
|
|
//self.cache_persistent_small = self.cache_persistent_big;
|
|
|
|
|
mem::swap( |
|
|
|
|
&mut self.cache_persistent_small, |
|
|
|
|
&mut self.cache_persistent_big, |
|
|
|
|
); |
|
|
|
|
self.cache_persistent_big = HashMap::new(); |
|
|
|
|
self.parent_set_size_small += 1; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if parent_set_len > self.parent_set_size_small { |
|
|
|
|
match self.cache_persistent_big.get(&parent_set) { |
|
|
|
|
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
|
|
|
|
// not cloning requires a minor and reasoned refactoring across the library
|
|
|
|
|
Some(params) => params.clone(), |
|
|
|
|
None => { |
|
|
|
|
let params = |
|
|
|
|
self.parameter_learning |
|
|
|
|
.fit(net, dataset, node, parent_set.clone()); |
|
|
|
|
self.cache_persistent_big.insert(parent_set, params.clone()); |
|
|
|
|
params |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
match self.cache_persistent_small.get(&parent_set) { |
|
|
|
|
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
|
|
|
|
// not cloning requires a minor and reasoned refactoring across the library
|
|
|
|
|
Some(params) => params.clone(), |
|
|
|
|
None => { |
|
|
|
|
let params = self |
|
|
|
|
.parameter_learning |
|
|
|
|
let params = |
|
|
|
|
self.parameter_learning |
|
|
|
|
.fit(net, dataset, node, parent_set.clone()); |
|
|
|
|
self.cache_persistent.insert(parent_set, params.clone()); |
|
|
|
|
self.cache_persistent_small |
|
|
|
|
.insert(parent_set, params.clone()); |
|
|
|
|
params |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|