@ -121,7 +121,8 @@ impl ScoreFunction for BIC {
T: network::Network {
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
let sample_size = M.sum();
//TODO: Optimize this
let sample_size: usize = dataset.trajectories.iter().map(|x| x.time.len() -1).sum();
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
}
@ -50,6 +50,6 @@ fn simple_bic() {
let ll = BIC::init(1, 1.0);
assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);