From 394970adca780712cd9ab599c292110381318db5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 13 Apr 2022 14:05:36 +0200 Subject: [PATCH] Test for BIC --- src/structure_learning.rs | 3 ++- tests/structure_learning.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/structure_learning.rs b/src/structure_learning.rs index 4de13ae..ba76b7a 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -121,7 +121,8 @@ impl ScoreFunction for BIC { T: network::Network { let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); - let sample_size = M.sum(); + //TODO: Optimize this + let sample_size: usize = dataset.trajectories.iter().map(|x| x.time.len() -1).sum(); ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 } } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index f3633b5..a9feea9 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -50,6 +50,6 @@ fn simple_bic() { let ll = BIC::init(1, 1.0); - assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); + assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); }