From 86d2a0b7672a4c0aadbfbdae7c371bc611d5667b Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 6 Apr 2022 07:38:14 +0200 Subject: [PATCH] Implemented LL for ctbn --- Cargo.toml | 1 + src/lib.rs | 1 + src/structure_learning.rs | 82 +++++++++++++++++++++++++++++++++++++ tests/structure_learning.rs | 37 +++++++++++++++++ 4 files changed, 121 insertions(+) create mode 100644 src/structure_learning.rs create mode 100644 tests/structure_learning.rs diff --git a/Cargo.toml b/Cargo.toml index 3aa7c53..37f87e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" +statrs = "*" [dev-dependencies] approx = "*" diff --git a/src/lib.rs b/src/lib.rs index 65e4b11..ec12261 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,4 +8,5 @@ pub mod network; pub mod ctbn; pub mod tools; pub mod parameter_learning; +pub mod structure_learning; diff --git a/src/structure_learning.rs b/src/structure_learning.rs new file mode 100644 index 0000000..f4c369b --- /dev/null +++ b/src/structure_learning.rs @@ -0,0 +1,82 @@ +use crate::network; +use crate::parameter_learning; +use crate::params; +use crate::tools; +use ndarray::prelude::*; +use statrs::function::gamma; +use std::collections::BTreeSet; + +pub trait StructureLearning { + fn fit(&self, net: T, dataset: &tools::Dataset) -> T + where + T: network::Network; +} + +pub trait ScoreFunction { + fn compute_score( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network; +} + +pub struct LogLikelihood { + alpha: usize, + tau: f64, +} + +impl LogLikelihood { + pub fn init(alpha: usize, tau: f64) -> LogLikelihood { + if tau < 0.0 { + panic!("tau must be >=0.0"); + } + LogLikelihood { alpha, tau } + } +} + +impl ScoreFunction for LogLikelihood { + fn compute_score( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network, + { + match &net.get_node(node).params { + params::Params::DiscreteStatesContinousTime(params) => { + let (M, T) = + parameter_learning::sufficient_statistics(net, dataset, node, parent_set); + let alpha = self.alpha as f64 / M.shape()[0] as f64; + let tau = self.tau / M.shape()[0] as f64; + + let log_ll_q:f64 = M + .sum_axis(Axis(2)) + .iter() + .zip(T.iter()) + .map(|(m, t)| { + gamma::ln_gamma(alpha + *m as f64 + 1.0) + + (alpha + 1.0) * f64::ln(tau) + - gamma::ln_gamma(alpha + 1.0) + - (alpha + *m as f64 + 1.0) * f64::ln(tau + t) + }) + .sum(); + + let log_ll_theta: f64 = M.outer_iter() + .map(|x| x.outer_iter() + .map(|y| gamma::ln_gamma(alpha) + - gamma::ln_gamma(alpha + y.sum() as f64) + + y.iter().map(|z| + gamma::ln_gamma(alpha + *z as f64) + - gamma::ln_gamma(alpha)).sum::()).sum::()).sum(); + log_ll_theta + log_ll_q + } + } + } +} diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs new file mode 100644 index 0000000..95b95fa --- /dev/null +++ b/tests/structure_learning.rs @@ -0,0 +1,37 @@ + +mod utils; +use utils::*; + +use rustyCTBN::ctbn::*; +use rustyCTBN::network::Network; +use rustyCTBN::tools::*; +use rustyCTBN::structure_learning::*; +use ndarray::{arr1, arr2}; +use std::collections::BTreeSet; + + +#[macro_use] +extern crate approx; + +#[test] +fn simple_log_likelihood() { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .unwrap(); + + let trj = Trajectory{ + time: arr1(&[0.0,0.1,0.3]), + events: arr2(&[[0],[1],[1]])}; + + let dataset = Dataset{ + trajectories: vec![trj]}; + + let ll = LogLikelihood::init(1, 1.0); + + assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); + + + + +}