From df12b93d559b1a74ab7f056a33e73023f1aa8e93 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 10:56:27 +0200 Subject: [PATCH] Lmit parent set --- src/structure_learning/score_based_algorithm.rs | 8 +++++--- tests/structure_learning.rs | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index 0a23c36..b5590ed 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -6,11 +6,12 @@ use std::collections::BTreeSet; pub struct HillClimbing { score_function: S, + max_parent_set: Option } impl HillClimbing { - pub fn init(score_function: S) -> HillClimbing { - HillClimbing { score_function } + pub fn init(score_function: S, max_parent_set: Option) -> HillClimbing { + HillClimbing { score_function, max_parent_set } } } @@ -20,6 +21,7 @@ impl StructureLearningAlgorithm for HillClimbing { T: network::Network, { let mut net = net; + let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); net.initialize_adj_matrix(); for node in net.get_node_indices() { let mut parent_set: BTreeSet = BTreeSet::new(); @@ -32,7 +34,7 @@ impl StructureLearningAlgorithm for HillClimbing { continue; } let is_removed = parent_set.remove(&parent); - if !is_removed { + if !is_removed && parent_set.len() < max_parent_set { parent_set.insert(parent); } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 25ce1e8..f9c0034 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -96,14 +96,14 @@ fn learn_ternary_net_2_nodes (sl: T) { #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::init(1, 1.0); - let hl = HillClimbing::init(ll); + let hl = HillClimbing::init(ll, None); learn_ternary_net_2_nodes(hl); } #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let bic = BIC::init(1, 1.0); - let hl = HillClimbing::init(bic); + let hl = HillClimbing::init(bic, None); learn_ternary_net_2_nodes(hl); } @@ -175,13 +175,13 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::init(1, 1.0); - let hl = HillClimbing::init(ll); + let hl = HillClimbing::init(ll, None); learn_mixed_discrete_net_3_nodes(hl); } #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { let bic = BIC::init(1, 1.0); - let hl = HillClimbing::init(bic); + let hl = HillClimbing::init(bic, None); learn_mixed_discrete_net_3_nodes(hl); }