parent
2d7e52f8f1
commit
86d2a0b767
@ -0,0 +1,82 @@ |
||||
use crate::network; |
||||
use crate::parameter_learning; |
||||
use crate::params; |
||||
use crate::tools; |
||||
use ndarray::prelude::*; |
||||
use statrs::function::gamma; |
||||
use std::collections::BTreeSet; |
||||
|
||||
pub trait StructureLearning { |
||||
fn fit<T>(&self, net: T, dataset: &tools::Dataset) -> T |
||||
where |
||||
T: network::Network; |
||||
} |
||||
|
||||
pub trait ScoreFunction { |
||||
fn compute_score<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: network::Network; |
||||
} |
||||
|
||||
pub struct LogLikelihood { |
||||
alpha: usize, |
||||
tau: f64, |
||||
} |
||||
|
||||
impl LogLikelihood { |
||||
pub fn init(alpha: usize, tau: f64) -> LogLikelihood { |
||||
if tau < 0.0 { |
||||
panic!("tau must be >=0.0"); |
||||
} |
||||
LogLikelihood { alpha, tau } |
||||
} |
||||
} |
||||
|
||||
impl ScoreFunction for LogLikelihood { |
||||
fn compute_score<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: network::Network, |
||||
{ |
||||
match &net.get_node(node).params { |
||||
params::Params::DiscreteStatesContinousTime(params) => { |
||||
let (M, T) = |
||||
parameter_learning::sufficient_statistics(net, dataset, node, parent_set); |
||||
let alpha = self.alpha as f64 / M.shape()[0] as f64; |
||||
let tau = self.tau / M.shape()[0] as f64; |
||||
|
||||
let log_ll_q:f64 = M |
||||
.sum_axis(Axis(2)) |
||||
.iter() |
||||
.zip(T.iter()) |
||||
.map(|(m, t)| { |
||||
gamma::ln_gamma(alpha + *m as f64 + 1.0) |
||||
+ (alpha + 1.0) * f64::ln(tau) |
||||
- gamma::ln_gamma(alpha + 1.0) |
||||
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t) |
||||
}) |
||||
.sum(); |
||||
|
||||
let log_ll_theta: f64 = M.outer_iter() |
||||
.map(|x| x.outer_iter() |
||||
.map(|y| gamma::ln_gamma(alpha)
|
||||
- gamma::ln_gamma(alpha + y.sum() as f64) |
||||
+ y.iter().map(|z|
|
||||
gamma::ln_gamma(alpha + *z as f64)
|
||||
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum(); |
||||
log_ll_theta + log_ll_q |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,37 @@ |
||||
|
||||
mod utils; |
||||
use utils::*; |
||||
|
||||
use rustyCTBN::ctbn::*; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::tools::*; |
||||
use rustyCTBN::structure_learning::*; |
||||
use ndarray::{arr1, arr2}; |
||||
use std::collections::BTreeSet; |
||||
|
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
#[test] |
||||
fn simple_log_likelihood() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||
.unwrap(); |
||||
|
||||
let trj = Trajectory{ |
||||
time: arr1(&[0.0,0.1,0.3]), |
||||
events: arr2(&[[0],[1],[1]])}; |
||||
|
||||
let dataset = Dataset{ |
||||
trajectories: vec![trj]}; |
||||
|
||||
let ll = LogLikelihood::init(1, 1.0); |
||||
|
||||
assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); |
||||
|
||||
|
||||
|
||||
|
||||
} |
Loading…
Reference in new issue