From a4b0a406f4d65f83d7a638f990f84a229a68fa54 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 13 Apr 2022 19:42:24 +0200 Subject: [PATCH] Hill Climbing + Simple test --- src/structure_learning/mod.rs | 10 +++ .../score_based_algorithm.rs | 61 +++++++++++++++++++ .../score_function.rs} | 6 -- tests/structure_learning.rs | 49 ++++++++++++++- 4 files changed, 118 insertions(+), 8 deletions(-) create mode 100644 src/structure_learning/mod.rs create mode 100644 src/structure_learning/score_based_algorithm.rs rename src/{structure_learning.rs => structure_learning/score_function.rs} (96%) diff --git a/src/structure_learning/mod.rs b/src/structure_learning/mod.rs new file mode 100644 index 0000000..d72862d --- /dev/null +++ b/src/structure_learning/mod.rs @@ -0,0 +1,10 @@ +pub mod score_function; +pub mod score_based_algorithm; +use crate::network; +use crate::tools; + +pub trait StructureLearningAlgorithm { + fn call(&self, net: T, dataset: &tools::Dataset) -> T + where + T: network::Network; +} diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs new file mode 100644 index 0000000..ed54092 --- /dev/null +++ b/src/structure_learning/score_based_algorithm.rs @@ -0,0 +1,61 @@ +use crate::params; +use crate::structure_learning::score_function::ScoreFunction; +use crate::structure_learning::StructureLearningAlgorithm; +use crate::tools; +use crate::{network, parameter_learning}; +use ndarray::prelude::*; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use std::collections::BTreeSet; + +pub struct HillClimbing { + score_function: S, +} + +impl HillClimbing { + pub fn init(score_function: S) -> HillClimbing { + HillClimbing { score_function } + } +} + +impl StructureLearningAlgorithm for HillClimbing { + fn call(&self, net: T, dataset: &tools::Dataset) -> T + where + T: network::Network, + { + let mut net = net; + net.initialize_adj_matrix(); + for node in net.get_node_indices() { + let mut parent_set: BTreeSet = BTreeSet::new(); + let mut current_ll = self.score_function.call(&net, node, &parent_set, dataset); + let mut old_ll = f64::NEG_INFINITY; + while current_ll > old_ll { + old_ll = current_ll; + for parent in net.get_node_indices() { + if parent == node { + continue; + } + let is_removed = parent_set.remove(&parent); + if !is_removed { + parent_set.insert(parent); + } + + let tmp_ll = self.score_function.call(&net, node, &parent_set, dataset); + + if tmp_ll < current_ll { + if is_removed { + parent_set.insert(parent); + } else { + parent_set.remove(&parent); + } + } else { + current_ll = tmp_ll; + } + } + } + parent_set.iter().for_each(|p| net.add_edge(*p, node)); + } + + return net; + } +} diff --git a/src/structure_learning.rs b/src/structure_learning/score_function.rs similarity index 96% rename from src/structure_learning.rs rename to src/structure_learning/score_function.rs index ba76b7a..06f9fb9 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning/score_function.rs @@ -6,12 +6,6 @@ use ndarray::prelude::*; use statrs::function::gamma; use std::collections::BTreeSet; -pub trait StructureLearning { - fn fit(&self, net: T, dataset: &tools::Dataset) -> T - where - T: network::Network; -} - pub trait ScoreFunction { fn call( &self, diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index a9feea9..e3a43e4 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -5,9 +5,12 @@ use utils::*; use rustyCTBN::ctbn::*; use rustyCTBN::network::Network; use rustyCTBN::tools::*; -use rustyCTBN::structure_learning::*; -use ndarray::{arr1, arr2}; +use rustyCTBN::structure_learning::score_function::*; +use rustyCTBN::structure_learning::score_based_algorithm::*; +use rustyCTBN::structure_learning::StructureLearningAlgorithm; +use ndarray::{arr1, arr2, arr3}; use std::collections::BTreeSet; +use rustyCTBN::params; #[macro_use] @@ -53,3 +56,45 @@ fn simple_bic() { assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); } + +fn learn_ternary_net_2_nodes (sl: T) { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]))); + } + } + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); + + let net = sl.call(net, &data); + assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); + assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); +} + +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing() { + let bic = BIC::init(1, 1.0); + let hl = HillClimbing::init(bic); + learn_ternary_net_2_nodes(hl); +}