|
|
|
@ -1,17 +1,84 @@ |
|
|
|
|
//! Module containing constraint based algorithms like CTPC and Hiton.
|
|
|
|
|
|
|
|
|
|
use crate::params::Params; |
|
|
|
|
use itertools::Itertools; |
|
|
|
|
use rayon::iter::{IntoParallelIterator, ParallelIterator}; |
|
|
|
|
use rayon::prelude::ParallelExtend; |
|
|
|
|
use std::collections::BTreeSet; |
|
|
|
|
use std::collections::{BTreeSet, HashMap}; |
|
|
|
|
use std::mem; |
|
|
|
|
use std::usize; |
|
|
|
|
|
|
|
|
|
use super::hypothesis_test::*; |
|
|
|
|
use crate::parameter_learning::{Cache, ParameterLearning}; |
|
|
|
|
use crate::parameter_learning::ParameterLearning; |
|
|
|
|
use crate::process; |
|
|
|
|
use crate::structure_learning::StructureLearningAlgorithm; |
|
|
|
|
use crate::tools::Dataset; |
|
|
|
|
|
|
|
|
|
pub struct Cache<'a, P: ParameterLearning> { |
|
|
|
|
parameter_learning: &'a P, |
|
|
|
|
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>, |
|
|
|
|
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>, |
|
|
|
|
parent_set_size_small: usize, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
|
|
|
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
|
|
|
|
Cache { |
|
|
|
|
parameter_learning, |
|
|
|
|
cache_persistent_small: HashMap::new(), |
|
|
|
|
cache_persistent_big: HashMap::new(), |
|
|
|
|
parent_set_size_small: 0, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
pub fn fit<T: process::NetworkProcess>( |
|
|
|
|
&mut self, |
|
|
|
|
net: &T, |
|
|
|
|
dataset: &Dataset, |
|
|
|
|
node: usize, |
|
|
|
|
parent_set: Option<BTreeSet<usize>>, |
|
|
|
|
) -> Params { |
|
|
|
|
let parent_set_len = parent_set.as_ref().unwrap().len(); |
|
|
|
|
if parent_set_len > self.parent_set_size_small + 1 { |
|
|
|
|
//self.cache_persistent_small = self.cache_persistent_big;
|
|
|
|
|
mem::swap( |
|
|
|
|
&mut self.cache_persistent_small, |
|
|
|
|
&mut self.cache_persistent_big, |
|
|
|
|
); |
|
|
|
|
self.cache_persistent_big = HashMap::new(); |
|
|
|
|
self.parent_set_size_small += 1; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if parent_set_len > self.parent_set_size_small { |
|
|
|
|
match self.cache_persistent_big.get(&parent_set) { |
|
|
|
|
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
|
|
|
|
// not cloning requires a minor and reasoned refactoring across the library
|
|
|
|
|
Some(params) => params.clone(), |
|
|
|
|
None => { |
|
|
|
|
let params = |
|
|
|
|
self.parameter_learning |
|
|
|
|
.fit(net, dataset, node, parent_set.clone()); |
|
|
|
|
self.cache_persistent_big.insert(parent_set, params.clone()); |
|
|
|
|
params |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
match self.cache_persistent_small.get(&parent_set) { |
|
|
|
|
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
|
|
|
|
// not cloning requires a minor and reasoned refactoring across the library
|
|
|
|
|
Some(params) => params.clone(), |
|
|
|
|
None => { |
|
|
|
|
let params = |
|
|
|
|
self.parameter_learning |
|
|
|
|
.fit(net, dataset, node, parent_set.clone()); |
|
|
|
|
self.cache_persistent_small |
|
|
|
|
.insert(parent_set, params.clone()); |
|
|
|
|
params |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct CTPC<P: ParameterLearning> { |
|
|
|
|
parameter_learning: P, |
|
|
|
|
Ftest: F, |
|
|
|
|