Lmit parent set

pull/42/head
AlessandroBregoli 3 years ago
parent 8ca93c931b
commit df12b93d55
  1. 8
      src/structure_learning/score_based_algorithm.rs
  2. 8
      tests/structure_learning.rs

@ -6,11 +6,12 @@ use std::collections::BTreeSet;
pub struct HillClimbing<S: ScoreFunction> {
score_function: S,
max_parent_set: Option<usize>
}
impl<S: ScoreFunction> HillClimbing<S> {
pub fn init(score_function: S) -> HillClimbing<S> {
HillClimbing { score_function }
pub fn init(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> {
HillClimbing { score_function, max_parent_set }
}
}
@ -20,6 +21,7 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
T: network::Network,
{
let mut net = net;
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
net.initialize_adj_matrix();
for node in net.get_node_indices() {
let mut parent_set: BTreeSet<usize> = BTreeSet::new();
@ -32,7 +34,7 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
continue;
}
let is_removed = parent_set.remove(&parent);
if !is_removed {
if !is_removed && parent_set.len() < max_parent_set {
parent_set.insert(parent);
}

@ -96,14 +96,14 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll);
let hl = HillClimbing::init(ll, None);
learn_ternary_net_2_nodes(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic);
let hl = HillClimbing::init(bic, None);
learn_ternary_net_2_nodes(hl);
}
@ -175,13 +175,13 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll);
let hl = HillClimbing::init(ll, None);
learn_mixed_discrete_net_3_nodes(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic);
let hl = HillClimbing::init(bic, None);
learn_mixed_discrete_net_3_nodes(hl);
}

Loading…
Cancel
Save