Parallelize score based strucutre learning

pull/87/head
Alessandro Bregoli 2 years ago
parent a104d1fbf9
commit ff235b4b77
  1. 16
      reCTBN/src/structure_learning/score_based_algorithm.rs
  2. 2
      reCTBN/src/structure_learning/score_function.rs

@ -6,6 +6,9 @@ use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools::Dataset}; use crate::{process, tools::Dataset};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,
max_parent_set: Option<usize>, max_parent_set: Option<usize>,
@ -36,8 +39,9 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
//Reset the adj matrix //Reset the adj matrix
net.initialize_adj_matrix(); net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
//Iterate over each node to learn their parent set. //Iterate over each node to learn their parent set.
for node in net.get_node_indices() { learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| {
//Initialize an empty parent set. //Initialize an empty parent set.
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); let mut parent_set: BTreeSet<usize> = BTreeSet::new();
//Compute the score for the empty parent set //Compute the score for the empty parent set
@ -76,10 +80,14 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
} }
} }
} }
//Apply the learned parent_set to the network struct. (node, parent_set)
parent_set.iter().for_each(|p| net.add_edge(*p, node)); }));
}
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
}
return net; return net;
} }
} }

@ -7,7 +7,7 @@ use statrs::function::gamma;
use crate::{parameter_learning, params, process, tools}; use crate::{parameter_learning, params, process, tools};
pub trait ScoreFunction { pub trait ScoreFunction: Sync {
fn call<T>( fn call<T>(
&self, &self,
net: &T, net: &T,

Loading…
Cancel
Save