Moved `Cache` to `constraint_based_algorithm.rs`

pull/81/head
Meliurwen 2 years ago
parent 0e1cca0456
commit 5632833963
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 69
      reCTBN/src/parameter_learning.rs
  2. 71
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  3. 7
      reCTBN/src/structure_learning/hypothesis_test.rs
  4. 1
      reCTBN/tests/structure_learning.rs

@ -1,7 +1,6 @@
//! Module containing methods used to learn the parameters.
use std::collections::{BTreeSet, HashMap};
use std::mem;
use std::collections::BTreeSet;
use ndarray::prelude::*;
@ -165,69 +164,3 @@ impl ParameterLearning for BayesianApproach {
return n;
}
}
// TODO: Move to constraint_based_algorithm.rs
pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P,
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>,
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>,
parent_set_size_small: usize,
}
impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache {
parameter_learning,
cache_persistent_small: HashMap::new(),
cache_persistent_big: HashMap::new(),
parent_set_size_small: 0,
}
}
pub fn fit<T: process::NetworkProcess>(
&mut self,
net: &T,
dataset: &Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
let parent_set_len = parent_set.as_ref().unwrap().len();
if parent_set_len > self.parent_set_size_small + 1 {
//self.cache_persistent_small = self.cache_persistent_big;
mem::swap(
&mut self.cache_persistent_small,
&mut self.cache_persistent_big,
);
self.cache_persistent_big = HashMap::new();
self.parent_set_size_small += 1;
}
if parent_set_len > self.parent_set_size_small {
match self.cache_persistent_big.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_big.insert(parent_set, params.clone());
params
}
}
} else {
match self.cache_persistent_small.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_small
.insert(parent_set, params.clone());
params
}
}
}
}
}

@ -1,17 +1,84 @@
//! Module containing constraint based algorithms like CTPC and Hiton.
use crate::params::Params;
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
use std::collections::BTreeSet;
use std::collections::{BTreeSet, HashMap};
use std::mem;
use std::usize;
use super::hypothesis_test::*;
use crate::parameter_learning::{Cache, ParameterLearning};
use crate::parameter_learning::ParameterLearning;
use crate::process;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools::Dataset;
pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P,
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>,
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>,
parent_set_size_small: usize,
}
impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache {
parameter_learning,
cache_persistent_small: HashMap::new(),
cache_persistent_big: HashMap::new(),
parent_set_size_small: 0,
}
}
pub fn fit<T: process::NetworkProcess>(
&mut self,
net: &T,
dataset: &Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
let parent_set_len = parent_set.as_ref().unwrap().len();
if parent_set_len > self.parent_set_size_small + 1 {
//self.cache_persistent_small = self.cache_persistent_big;
mem::swap(
&mut self.cache_persistent_small,
&mut self.cache_persistent_big,
);
self.cache_persistent_big = HashMap::new();
self.parent_set_size_small += 1;
}
if parent_set_len > self.parent_set_size_small {
match self.cache_persistent_big.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_big.insert(parent_set, params.clone());
params
}
}
} else {
match self.cache_persistent_small.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_small
.insert(parent_set, params.clone());
params
}
}
}
}
}
pub struct CTPC<P: ParameterLearning> {
parameter_learning: P,
Ftest: F,

@ -6,6 +6,7 @@ use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use crate::params::*;
use crate::structure_learning::constraint_based_algorithm::Cache;
use crate::{parameter_learning, process, tools::Dataset};
pub trait HypothesisTest {
@ -16,7 +17,7 @@ pub trait HypothesisTest {
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut parameter_learning::Cache<P>,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,
@ -86,7 +87,7 @@ impl HypothesisTest for F {
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut parameter_learning::Cache<P>,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,
@ -226,7 +227,7 @@ impl HypothesisTest for ChiSquare {
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut parameter_learning::Cache<P>,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,

@ -7,7 +7,6 @@ use ndarray::{arr1, arr2, arr3};
use reCTBN::process::ctbn::*;
use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::parameter_learning::Cache;
use reCTBN::params;
use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::constraint_based_algorithm::*;

Loading…
Cancel
Save