diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index ff6a7b9..536a9d5 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,7 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::{BTreeSet, HashMap}; -use std::mem; +use std::collections::BTreeSet; use ndarray::prelude::*; @@ -165,69 +164,3 @@ impl ParameterLearning for BayesianApproach { return n; } } - -// TODO: Move to constraint_based_algorithm.rs -pub struct Cache<'a, P: ParameterLearning> { - parameter_learning: &'a P, - cache_persistent_small: HashMap>, Params>, - cache_persistent_big: HashMap>, Params>, - parent_set_size_small: usize, -} - -impl<'a, P: ParameterLearning> Cache<'a, P> { - pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { - parameter_learning, - cache_persistent_small: HashMap::new(), - cache_persistent_big: HashMap::new(), - parent_set_size_small: 0, - } - } - pub fn fit( - &mut self, - net: &T, - dataset: &Dataset, - node: usize, - parent_set: Option>, - ) -> Params { - let parent_set_len = parent_set.as_ref().unwrap().len(); - if parent_set_len > self.parent_set_size_small + 1 { - //self.cache_persistent_small = self.cache_persistent_big; - mem::swap( - &mut self.cache_persistent_small, - &mut self.cache_persistent_big, - ); - self.cache_persistent_big = HashMap::new(); - self.parent_set_size_small += 1; - } - - if parent_set_len > self.parent_set_size_small { - match self.cache_persistent_big.get(&parent_set) { - // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = - self.parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent_big.insert(parent_set, params.clone()); - params - } - } - } else { - match self.cache_persistent_small.get(&parent_set) { - // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = - self.parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent_small - .insert(parent_set, params.clone()); - params - } - } - } - } -} diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 634c144..f49b194 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,17 +1,84 @@ //! Module containing constraint based algorithms like CTPC and Hiton. +use crate::params::Params; use itertools::Itertools; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use rayon::prelude::ParallelExtend; -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; +use std::mem; use std::usize; use super::hypothesis_test::*; -use crate::parameter_learning::{Cache, ParameterLearning}; +use crate::parameter_learning::ParameterLearning; use crate::process; use crate::structure_learning::StructureLearningAlgorithm; use crate::tools::Dataset; +pub struct Cache<'a, P: ParameterLearning> { + parameter_learning: &'a P, + cache_persistent_small: HashMap>, Params>, + cache_persistent_big: HashMap>, Params>, + parent_set_size_small: usize, +} + +impl<'a, P: ParameterLearning> Cache<'a, P> { + pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { + Cache { + parameter_learning, + cache_persistent_small: HashMap::new(), + cache_persistent_big: HashMap::new(), + parent_set_size_small: 0, + } + } + pub fn fit( + &mut self, + net: &T, + dataset: &Dataset, + node: usize, + parent_set: Option>, + ) -> Params { + let parent_set_len = parent_set.as_ref().unwrap().len(); + if parent_set_len > self.parent_set_size_small + 1 { + //self.cache_persistent_small = self.cache_persistent_big; + mem::swap( + &mut self.cache_persistent_small, + &mut self.cache_persistent_big, + ); + self.cache_persistent_big = HashMap::new(); + self.parent_set_size_small += 1; + } + + if parent_set_len > self.parent_set_size_small { + match self.cache_persistent_big.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_big.insert(parent_set, params.clone()); + params + } + } + } else { + match self.cache_persistent_small.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_small + .insert(parent_set, params.clone()); + params + } + } + } + } +} + pub struct CTPC { parameter_learning: P, Ftest: F, diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index dd683ab..4c02929 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,6 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; +use crate::structure_learning::constraint_based_algorithm::Cache; use crate::{parameter_learning, process, tools::Dataset}; pub trait HypothesisTest { @@ -16,7 +17,7 @@ pub trait HypothesisTest { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, @@ -86,7 +87,7 @@ impl HypothesisTest for F { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, @@ -226,7 +227,7 @@ impl HypothesisTest for ChiSquare { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a37f2b3..9a69b45 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -7,7 +7,6 @@ use ndarray::{arr1, arr2, arr3}; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; -use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::constraint_based_algorithm::*;