Implemented LL for ctbn

pull/42/head
AlessandroBregoli 3 years ago
parent 2d7e52f8f1
commit 86d2a0b767
  1. 1
      Cargo.toml
  2. 1
      src/lib.rs
  3. 82
      src/structure_learning.rs
  4. 37
      tests/structure_learning.rs

@ -12,6 +12,7 @@ thiserror = "*"
rand = "*"
bimap = "*"
enum_dispatch = "*"
statrs = "*"
[dev-dependencies]
approx = "*"

@ -8,4 +8,5 @@ pub mod network;
pub mod ctbn;
pub mod tools;
pub mod parameter_learning;
pub mod structure_learning;

@ -0,0 +1,82 @@
use crate::network;
use crate::parameter_learning;
use crate::params;
use crate::tools;
use ndarray::prelude::*;
use statrs::function::gamma;
use std::collections::BTreeSet;
pub trait StructureLearning {
fn fit<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network;
}
pub trait ScoreFunction {
fn compute_score<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> f64
where
T: network::Network;
}
pub struct LogLikelihood {
alpha: usize,
tau: f64,
}
impl LogLikelihood {
pub fn init(alpha: usize, tau: f64) -> LogLikelihood {
if tau < 0.0 {
panic!("tau must be >=0.0");
}
LogLikelihood { alpha, tau }
}
}
impl ScoreFunction for LogLikelihood {
fn compute_score<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> f64
where
T: network::Network,
{
match &net.get_node(node).params {
params::Params::DiscreteStatesContinousTime(params) => {
let (M, T) =
parameter_learning::sufficient_statistics(net, dataset, node, parent_set);
let alpha = self.alpha as f64 / M.shape()[0] as f64;
let tau = self.tau / M.shape()[0] as f64;
let log_ll_q:f64 = M
.sum_axis(Axis(2))
.iter()
.zip(T.iter())
.map(|(m, t)| {
gamma::ln_gamma(alpha + *m as f64 + 1.0)
+ (alpha + 1.0) * f64::ln(tau)
- gamma::ln_gamma(alpha + 1.0)
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t)
})
.sum();
let log_ll_theta: f64 = M.outer_iter()
.map(|x| x.outer_iter()
.map(|y| gamma::ln_gamma(alpha)
- gamma::ln_gamma(alpha + y.sum() as f64)
+ y.iter().map(|z|
gamma::ln_gamma(alpha + *z as f64)
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum();
log_ll_theta + log_ll_q
}
}
}
}

@ -0,0 +1,37 @@
mod utils;
use utils::*;
use rustyCTBN::ctbn::*;
use rustyCTBN::network::Network;
use rustyCTBN::tools::*;
use rustyCTBN::structure_learning::*;
use ndarray::{arr1, arr2};
use std::collections::BTreeSet;
#[macro_use]
extern crate approx;
#[test]
fn simple_log_likelihood() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap();
let trj = Trajectory{
time: arr1(&[0.0,0.1,0.3]),
events: arr2(&[[0],[1],[1]])};
let dataset = Dataset{
trajectories: vec![trj]};
let ll = LogLikelihood::init(1, 1.0);
assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
}
Loading…
Cancel
Save