pull/42/head
AlessandroBregoli 3 years ago
parent e299acb921
commit 4a7c34af17
  1. 55
      src/structure_learning.rs
  2. 20
      tests/structure_learning.rs

@ -13,7 +13,7 @@ pub trait StructureLearning {
} }
pub trait ScoreFunction { pub trait ScoreFunction {
fn compute_score<T>( fn call<T>(
&self, &self,
net: &T, net: &T,
node: usize, node: usize,
@ -36,16 +36,14 @@ impl LogLikelihood {
} }
LogLikelihood { alpha, tau } LogLikelihood { alpha, tau }
} }
}
impl ScoreFunction for LogLikelihood {
fn compute_score<T>( fn compute_score<T>(
&self, &self,
net: &T, net: &T,
node: usize, node: usize,
parent_set: &BTreeSet<usize>, parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> (f64, Array3<usize>)
where where
T: network::Network, T: network::Network,
{ {
@ -75,8 +73,55 @@ impl ScoreFunction for LogLikelihood {
+ y.iter().map(|z| + y.iter().map(|z|
gamma::ln_gamma(alpha + *z as f64) gamma::ln_gamma(alpha + *z as f64)
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum(); - gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum();
log_ll_theta + log_ll_q (log_ll_theta + log_ll_q, M)
}
}
}
}
impl ScoreFunction for LogLikelihood {
fn call<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> f64
where
T: network::Network,
{
self.compute_score(net, node, parent_set, dataset).0
}
}
pub struct BIC {
ll: LogLikelihood
}
impl BIC {
pub fn init(alpha: usize, tau: f64) -> BIC {
BIC {
ll: LogLikelihood::init(alpha, tau)
} }
} }
} }
impl ScoreFunction for BIC {
fn call<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> f64
where
T: network::Network {
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
let sample_size = M.sum();
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
}
} }

@ -29,9 +29,27 @@ fn simple_log_likelihood() {
let ll = LogLikelihood::init(1, 1.0); let ll = LogLikelihood::init(1, 1.0);
assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
}
#[test]
fn simple_bic() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap();
let trj = Trajectory{
time: arr1(&[0.0,0.1,0.3]),
events: arr2(&[[0],[1],[1]])};
let dataset = Dataset{
trajectories: vec![trj]};
let ll = BIC::init(1, 1.0);
assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
} }

Loading…
Cancel
Save