From 4d3f9518e4137e911d55cc0f723e839e8f391752 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 4 Jan 2023 12:14:36 +0100 Subject: [PATCH] CTPC parallelization at the nodes level with `rayon` --- reCTBN/Cargo.toml | 1 + reCTBN/src/parameter_learning.rs | 2 +- reCTBN/src/process.rs | 2 +- .../src/structure_learning/constraint_based_algorithm.rs | 8 +++++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index fdac697..4749b23 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -14,6 +14,7 @@ enum_dispatch = "~0.3" statrs = "~0.16" rand_chacha = "~0.3" itertools = "~0.10" +rayon = "~1.6" [dev-dependencies] approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index fcc47c4..ff6a7b9 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -8,7 +8,7 @@ use ndarray::prelude::*; use crate::params::*; use crate::{process, tools::Dataset}; -pub trait ParameterLearning { +pub trait ParameterLearning: Sync { fn fit( &self, net: &T, diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index dc297bc..45c5e0a 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -21,7 +21,7 @@ pub type NetworkProcessState = Vec; /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait NetworkProcess { +pub trait NetworkProcess: Sync { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 6d54fe7..634c144 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,6 +1,8 @@ //! Module containing constraint based algorithms like CTPC and Hiton. use itertools::Itertools; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::ParallelExtend; use std::collections::BTreeSet; use std::usize; @@ -41,7 +43,8 @@ impl StructureLearningAlgorithm for CTPC

{ net.initialize_adj_matrix(); - for child_node in net.get_node_indices() { + let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![]; + learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { let mut cache = Cache::new(&self.parameter_learning); let mut candidate_parent_set: BTreeSet = net .get_node_indices() @@ -82,6 +85,9 @@ impl StructureLearningAlgorithm for CTPC

{ candidate_parent_set = candidate_parent_set_TMP; separation_set_size += 1; } + (child_node, candidate_parent_set) + })); + for (child_node, candidate_parent_set) in learned_parent_sets { for parent_node in candidate_parent_set.iter() { net.add_edge(*parent_node, child_node); } -- 2.36.3