Hill Climbing + Simple test

pull/42/head
AlessandroBregoli 3 years ago
parent 394970adca
commit a4b0a406f4
  1. 10
      src/structure_learning/mod.rs
  2. 61
      src/structure_learning/score_based_algorithm.rs
  3. 6
      src/structure_learning/score_function.rs
  4. 49
      tests/structure_learning.rs

@ -0,0 +1,10 @@
pub mod score_function;
pub mod score_based_algorithm;
use crate::network;
use crate::tools;
pub trait StructureLearningAlgorithm {
fn call<T, >(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network;
}

@ -0,0 +1,61 @@
use crate::params;
use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools;
use crate::{network, parameter_learning};
use ndarray::prelude::*;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use std::collections::BTreeSet;
pub struct HillClimbing<S: ScoreFunction> {
score_function: S,
}
impl<S: ScoreFunction> HillClimbing<S> {
pub fn init(score_function: S) -> HillClimbing<S> {
HillClimbing { score_function }
}
}
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn call<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network,
{
let mut net = net;
net.initialize_adj_matrix();
for node in net.get_node_indices() {
let mut parent_set: BTreeSet<usize> = BTreeSet::new();
let mut current_ll = self.score_function.call(&net, node, &parent_set, dataset);
let mut old_ll = f64::NEG_INFINITY;
while current_ll > old_ll {
old_ll = current_ll;
for parent in net.get_node_indices() {
if parent == node {
continue;
}
let is_removed = parent_set.remove(&parent);
if !is_removed {
parent_set.insert(parent);
}
let tmp_ll = self.score_function.call(&net, node, &parent_set, dataset);
if tmp_ll < current_ll {
if is_removed {
parent_set.insert(parent);
} else {
parent_set.remove(&parent);
}
} else {
current_ll = tmp_ll;
}
}
}
parent_set.iter().for_each(|p| net.add_edge(*p, node));
}
return net;
}
}

@ -6,12 +6,6 @@ use ndarray::prelude::*;
use statrs::function::gamma; use statrs::function::gamma;
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait StructureLearning {
fn fit<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network;
}
pub trait ScoreFunction { pub trait ScoreFunction {
fn call<T>( fn call<T>(
&self, &self,

@ -5,9 +5,12 @@ use utils::*;
use rustyCTBN::ctbn::*; use rustyCTBN::ctbn::*;
use rustyCTBN::network::Network; use rustyCTBN::network::Network;
use rustyCTBN::tools::*; use rustyCTBN::tools::*;
use rustyCTBN::structure_learning::*; use rustyCTBN::structure_learning::score_function::*;
use ndarray::{arr1, arr2}; use rustyCTBN::structure_learning::score_based_algorithm::*;
use rustyCTBN::structure_learning::StructureLearningAlgorithm;
use ndarray::{arr1, arr2, arr3};
use std::collections::BTreeSet; use std::collections::BTreeSet;
use rustyCTBN::params;
#[macro_use] #[macro_use]
@ -53,3 +56,45 @@ fn simple_bic() {
assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
} }
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])));
}
}
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let net = sl.call(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing() {
let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic);
learn_ternary_net_2_nodes(hl);
}

Loading…
Cancel
Save