|
|
@ -1,6 +1,6 @@ |
|
|
|
//! Module containing methods used to learn the parameters.
|
|
|
|
//! Module containing methods used to learn the parameters.
|
|
|
|
|
|
|
|
|
|
|
|
use std::collections::{BTreeSet,HashMap}; |
|
|
|
use std::collections::{BTreeSet, HashMap}; |
|
|
|
|
|
|
|
|
|
|
|
use ndarray::prelude::*; |
|
|
|
use ndarray::prelude::*; |
|
|
|
|
|
|
|
|
|
|
@ -173,7 +173,10 @@ pub struct Cache<'a, P: ParameterLearning> { |
|
|
|
|
|
|
|
|
|
|
|
impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
|
|
impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
|
|
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
|
|
|
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
|
|
|
Cache { parameter_learning, cache_persistent: HashMap::new() } |
|
|
|
Cache { |
|
|
|
|
|
|
|
parameter_learning, |
|
|
|
|
|
|
|
cache_persistent: HashMap::new(), |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
pub fn fit<T: process::NetworkProcess>( |
|
|
|
pub fn fit<T: process::NetworkProcess>( |
|
|
|
&mut self, |
|
|
|
&mut self, |
|
|
@ -187,7 +190,9 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
|
|
// not cloning requires a minor and reasoned refactoring across the library
|
|
|
|
// not cloning requires a minor and reasoned refactoring across the library
|
|
|
|
Some(params) => params.clone(), |
|
|
|
Some(params) => params.clone(), |
|
|
|
None => { |
|
|
|
None => { |
|
|
|
let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone()); |
|
|
|
let params = self |
|
|
|
|
|
|
|
.parameter_learning |
|
|
|
|
|
|
|
.fit(net, dataset, node, parent_set.clone()); |
|
|
|
self.cache_persistent.insert(parent_set, params.clone()); |
|
|
|
self.cache_persistent.insert(parent_set, params.clone()); |
|
|
|
params |
|
|
|
params |
|
|
|
} |
|
|
|
} |
|
|
|