parent
394970adca
commit
a4b0a406f4
@ -0,0 +1,10 @@ |
|||||||
|
pub mod score_function; |
||||||
|
pub mod score_based_algorithm; |
||||||
|
use crate::network; |
||||||
|
use crate::tools; |
||||||
|
|
||||||
|
pub trait StructureLearningAlgorithm { |
||||||
|
fn call<T, >(&self, net: T, dataset: &tools::Dataset) -> T |
||||||
|
where |
||||||
|
T: network::Network; |
||||||
|
} |
@ -0,0 +1,61 @@ |
|||||||
|
use crate::params; |
||||||
|
use crate::structure_learning::score_function::ScoreFunction; |
||||||
|
use crate::structure_learning::StructureLearningAlgorithm; |
||||||
|
use crate::tools; |
||||||
|
use crate::{network, parameter_learning}; |
||||||
|
use ndarray::prelude::*; |
||||||
|
use rand::prelude::*; |
||||||
|
use rand_chacha::ChaCha8Rng; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
pub struct HillClimbing<S: ScoreFunction> { |
||||||
|
score_function: S, |
||||||
|
} |
||||||
|
|
||||||
|
impl<S: ScoreFunction> HillClimbing<S> { |
||||||
|
pub fn init(score_function: S) -> HillClimbing<S> { |
||||||
|
HillClimbing { score_function } |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { |
||||||
|
fn call<T>(&self, net: T, dataset: &tools::Dataset) -> T |
||||||
|
where |
||||||
|
T: network::Network, |
||||||
|
{ |
||||||
|
let mut net = net; |
||||||
|
net.initialize_adj_matrix(); |
||||||
|
for node in net.get_node_indices() { |
||||||
|
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); |
||||||
|
let mut current_ll = self.score_function.call(&net, node, &parent_set, dataset); |
||||||
|
let mut old_ll = f64::NEG_INFINITY; |
||||||
|
while current_ll > old_ll { |
||||||
|
old_ll = current_ll; |
||||||
|
for parent in net.get_node_indices() { |
||||||
|
if parent == node { |
||||||
|
continue; |
||||||
|
} |
||||||
|
let is_removed = parent_set.remove(&parent); |
||||||
|
if !is_removed { |
||||||
|
parent_set.insert(parent); |
||||||
|
} |
||||||
|
|
||||||
|
let tmp_ll = self.score_function.call(&net, node, &parent_set, dataset); |
||||||
|
|
||||||
|
if tmp_ll < current_ll { |
||||||
|
if is_removed { |
||||||
|
parent_set.insert(parent); |
||||||
|
} else { |
||||||
|
parent_set.remove(&parent); |
||||||
|
} |
||||||
|
} else { |
||||||
|
current_ll = tmp_ll; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
parent_set.iter().for_each(|p| net.add_edge(*p, node)); |
||||||
|
} |
||||||
|
|
||||||
|
return net; |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue