Lmit parent set

pull/42/head
AlessandroBregoli 3 years ago
parent 8ca93c931b
commit df12b93d55
  1. 8
      src/structure_learning/score_based_algorithm.rs
  2. 8
      tests/structure_learning.rs

@ -6,11 +6,12 @@ use std::collections::BTreeSet;
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,
max_parent_set: Option<usize>
} }
impl<S: ScoreFunction> HillClimbing<S> { impl<S: ScoreFunction> HillClimbing<S> {
pub fn init(score_function: S) -> HillClimbing<S> { pub fn init(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> {
HillClimbing { score_function } HillClimbing { score_function, max_parent_set }
} }
} }
@ -20,6 +21,7 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
T: network::Network, T: network::Network,
{ {
let mut net = net; let mut net = net;
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
net.initialize_adj_matrix(); net.initialize_adj_matrix();
for node in net.get_node_indices() { for node in net.get_node_indices() {
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); let mut parent_set: BTreeSet<usize> = BTreeSet::new();
@ -32,7 +34,7 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
continue; continue;
} }
let is_removed = parent_set.remove(&parent); let is_removed = parent_set.remove(&parent);
if !is_removed { if !is_removed && parent_set.len() < max_parent_set {
parent_set.insert(parent); parent_set.insert(parent);
} }

@ -96,14 +96,14 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
#[test] #[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::init(1, 1.0); let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll); let hl = HillClimbing::init(ll, None);
learn_ternary_net_2_nodes(hl); learn_ternary_net_2_nodes(hl);
} }
#[test] #[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
let bic = BIC::init(1, 1.0); let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic); let hl = HillClimbing::init(bic, None);
learn_ternary_net_2_nodes(hl); learn_ternary_net_2_nodes(hl);
} }
@ -175,13 +175,13 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::init(1, 1.0); let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll); let hl = HillClimbing::init(ll, None);
learn_mixed_discrete_net_3_nodes(hl); learn_mixed_discrete_net_3_nodes(hl);
} }
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
let bic = BIC::init(1, 1.0); let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic); let hl = HillClimbing::init(bic, None);
learn_mixed_discrete_net_3_nodes(hl); learn_mixed_discrete_net_3_nodes(hl);
} }

Loading…
Cancel
Save