CTPC parallelization at the nodes level with `rayon`

pull/79/head
Meliurwen 2 years ago
parent 867bf02934
commit 4d3f9518e4
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 1
      reCTBN/Cargo.toml
  2. 2
      reCTBN/src/parameter_learning.rs
  3. 2
      reCTBN/src/process.rs
  4. 8
      reCTBN/src/structure_learning/constraint_based_algorithm.rs

@ -14,6 +14,7 @@ enum_dispatch = "~0.3"
statrs = "~0.16"
rand_chacha = "~0.3"
itertools = "~0.10"
rayon = "~1.6"
[dev-dependencies]
approx = { package = "approx", version = "~0.5" }

@ -8,7 +8,7 @@ use ndarray::prelude::*;
use crate::params::*;
use crate::{process, tools::Dataset};
pub trait ParameterLearning {
pub trait ParameterLearning: Sync {
fn fit<T: process::NetworkProcess>(
&self,
net: &T,

@ -21,7 +21,7 @@ pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN).
pub trait NetworkProcess {
pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.

@ -1,6 +1,8 @@
//! Module containing constraint based algorithms like CTPC and Hiton.
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
use std::collections::BTreeSet;
use std::usize;
@ -41,7 +43,8 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
net.initialize_adj_matrix();
for child_node in net.get_node_indices() {
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| {
let mut cache = Cache::new(&self.parameter_learning);
let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices()
@ -82,6 +85,9 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
candidate_parent_set = candidate_parent_set_TMP;
separation_set_size += 1;
}
(child_node, candidate_parent_set)
}));
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}

Loading…
Cancel
Save