Merge branch 'ctpc-parallelization' into '8-feature-constraint-based-structure-learning-algorithm-for-ctbn'

Added parallelization to CTPC
pull/81/head
Meliurwen 2 years ago
commit 0e1cca0456
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 1
      reCTBN/Cargo.toml
  2. 2
      reCTBN/src/parameter_learning.rs
  3. 2
      reCTBN/src/process.rs
  4. 8
      reCTBN/src/structure_learning/constraint_based_algorithm.rs

@ -14,6 +14,7 @@ enum_dispatch = "~0.3"
statrs = "~0.16" statrs = "~0.16"
rand_chacha = "~0.3" rand_chacha = "~0.3"
itertools = "~0.10" itertools = "~0.10"
rayon = "~1.6"
[dev-dependencies] [dev-dependencies]
approx = { package = "approx", version = "~0.5" } approx = { package = "approx", version = "~0.5" }

@ -8,7 +8,7 @@ use ndarray::prelude::*;
use crate::params::*; use crate::params::*;
use crate::{process, tools::Dataset}; use crate::{process, tools::Dataset};
pub trait ParameterLearning { pub trait ParameterLearning: Sync {
fn fit<T: process::NetworkProcess>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,

@ -21,7 +21,7 @@ pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN). /// as a CTBN).
pub trait NetworkProcess { pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network. /// Add an **directed edge** between a two nodes of the network.

@ -1,6 +1,8 @@
//! Module containing constraint based algorithms like CTPC and Hiton. //! Module containing constraint based algorithms like CTPC and Hiton.
use itertools::Itertools; use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::usize; use std::usize;
@ -41,7 +43,8 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
net.initialize_adj_matrix(); net.initialize_adj_matrix();
for child_node in net.get_node_indices() { let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| {
let mut cache = Cache::new(&self.parameter_learning); let mut cache = Cache::new(&self.parameter_learning);
let mut candidate_parent_set: BTreeSet<usize> = net let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices() .get_node_indices()
@ -82,6 +85,9 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
candidate_parent_set = candidate_parent_set_TMP; candidate_parent_set = candidate_parent_set_TMP;
separation_set_size += 1; separation_set_size += 1;
} }
(child_node, candidate_parent_set)
}));
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() { for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node); net.add_edge(*parent_node, child_node);
} }

Loading…
Cancel
Save