|
|
@ -6,6 +6,9 @@ use crate::structure_learning::score_function::ScoreFunction; |
|
|
|
use crate::structure_learning::StructureLearningAlgorithm; |
|
|
|
use crate::structure_learning::StructureLearningAlgorithm; |
|
|
|
use crate::{process, tools::Dataset}; |
|
|
|
use crate::{process, tools::Dataset}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use rayon::iter::{IntoParallelIterator, ParallelIterator}; |
|
|
|
|
|
|
|
use rayon::prelude::ParallelExtend; |
|
|
|
|
|
|
|
|
|
|
|
pub struct HillClimbing<S: ScoreFunction> { |
|
|
|
pub struct HillClimbing<S: ScoreFunction> { |
|
|
|
score_function: S, |
|
|
|
score_function: S, |
|
|
|
max_parent_set: Option<usize>, |
|
|
|
max_parent_set: Option<usize>, |
|
|
@ -36,8 +39,9 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { |
|
|
|
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); |
|
|
|
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); |
|
|
|
//Reset the adj matrix
|
|
|
|
//Reset the adj matrix
|
|
|
|
net.initialize_adj_matrix(); |
|
|
|
net.initialize_adj_matrix(); |
|
|
|
|
|
|
|
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![]; |
|
|
|
//Iterate over each node to learn their parent set.
|
|
|
|
//Iterate over each node to learn their parent set.
|
|
|
|
for node in net.get_node_indices() { |
|
|
|
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| { |
|
|
|
//Initialize an empty parent set.
|
|
|
|
//Initialize an empty parent set.
|
|
|
|
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); |
|
|
|
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); |
|
|
|
//Compute the score for the empty parent set
|
|
|
|
//Compute the score for the empty parent set
|
|
|
@ -76,10 +80,14 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
//Apply the learned parent_set to the network struct.
|
|
|
|
(node, parent_set) |
|
|
|
parent_set.iter().for_each(|p| net.add_edge(*p, node)); |
|
|
|
})); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (child_node, candidate_parent_set) in learned_parent_sets { |
|
|
|
|
|
|
|
for parent_node in candidate_parent_set.iter() { |
|
|
|
|
|
|
|
net.add_edge(*parent_node, child_node); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return net; |
|
|
|
return net; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|