From 4a7c34af1793b33913eec7f92a52930927e83a83 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 13 Apr 2022 11:49:04 +0200 Subject: [PATCH] BIC --- src/structure_learning.rs | 55 +++++++++++++++++++++++++++++++++---- tests/structure_learning.rs | 22 +++++++++++++-- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/src/structure_learning.rs b/src/structure_learning.rs index f4c369b..4de13ae 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -13,7 +13,7 @@ pub trait StructureLearning { } pub trait ScoreFunction { - fn compute_score( + fn call( &self, net: &T, node: usize, @@ -36,16 +36,14 @@ impl LogLikelihood { } LogLikelihood { alpha, tau } } -} -impl ScoreFunction for LogLikelihood { fn compute_score( &self, net: &T, node: usize, parent_set: &BTreeSet, dataset: &tools::Dataset, - ) -> f64 + ) -> (f64, Array3) where T: network::Network, { @@ -75,8 +73,55 @@ impl ScoreFunction for LogLikelihood { + y.iter().map(|z| gamma::ln_gamma(alpha + *z as f64) - gamma::ln_gamma(alpha)).sum::()).sum::()).sum(); - log_ll_theta + log_ll_q + (log_ll_theta + log_ll_q, M) } } } + + + +} + +impl ScoreFunction for LogLikelihood { + fn call( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network, + { + self.compute_score(net, node, parent_set, dataset).0 + } +} + +pub struct BIC { + ll: LogLikelihood +} + +impl BIC { + pub fn init(alpha: usize, tau: f64) -> BIC { + BIC { + ll: LogLikelihood::init(alpha, tau) + } + } +} + +impl ScoreFunction for BIC { + fn call( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network { + let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); + let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); + let sample_size = M.sum(); + ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 + } } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 95b95fa..f3633b5 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -29,9 +29,27 @@ fn simple_log_likelihood() { let ll = LogLikelihood::init(1, 1.0); - assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); + assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); - +} + + +#[test] +fn simple_bic() { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .unwrap(); + + let trj = Trajectory{ + time: arr1(&[0.0,0.1,0.3]), + events: arr2(&[[0],[1],[1]])}; + + let dataset = Dataset{ + trajectories: vec![trj]}; + + let ll = BIC::init(1, 1.0); + assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); }