parent
2d7e52f8f1
commit
86d2a0b767
@ -0,0 +1,82 @@ |
|||||||
|
use crate::network; |
||||||
|
use crate::parameter_learning; |
||||||
|
use crate::params; |
||||||
|
use crate::tools; |
||||||
|
use ndarray::prelude::*; |
||||||
|
use statrs::function::gamma; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
pub trait StructureLearning { |
||||||
|
fn fit<T>(&self, net: T, dataset: &tools::Dataset) -> T |
||||||
|
where |
||||||
|
T: network::Network; |
||||||
|
} |
||||||
|
|
||||||
|
pub trait ScoreFunction { |
||||||
|
fn compute_score<T>( |
||||||
|
&self, |
||||||
|
net: &T, |
||||||
|
node: usize, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
dataset: &tools::Dataset, |
||||||
|
) -> f64 |
||||||
|
where |
||||||
|
T: network::Network; |
||||||
|
} |
||||||
|
|
||||||
|
pub struct LogLikelihood { |
||||||
|
alpha: usize, |
||||||
|
tau: f64, |
||||||
|
} |
||||||
|
|
||||||
|
impl LogLikelihood { |
||||||
|
pub fn init(alpha: usize, tau: f64) -> LogLikelihood { |
||||||
|
if tau < 0.0 { |
||||||
|
panic!("tau must be >=0.0"); |
||||||
|
} |
||||||
|
LogLikelihood { alpha, tau } |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl ScoreFunction for LogLikelihood { |
||||||
|
fn compute_score<T>( |
||||||
|
&self, |
||||||
|
net: &T, |
||||||
|
node: usize, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
dataset: &tools::Dataset, |
||||||
|
) -> f64 |
||||||
|
where |
||||||
|
T: network::Network, |
||||||
|
{ |
||||||
|
match &net.get_node(node).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(params) => { |
||||||
|
let (M, T) = |
||||||
|
parameter_learning::sufficient_statistics(net, dataset, node, parent_set); |
||||||
|
let alpha = self.alpha as f64 / M.shape()[0] as f64; |
||||||
|
let tau = self.tau / M.shape()[0] as f64; |
||||||
|
|
||||||
|
let log_ll_q:f64 = M |
||||||
|
.sum_axis(Axis(2)) |
||||||
|
.iter() |
||||||
|
.zip(T.iter()) |
||||||
|
.map(|(m, t)| { |
||||||
|
gamma::ln_gamma(alpha + *m as f64 + 1.0) |
||||||
|
+ (alpha + 1.0) * f64::ln(tau) |
||||||
|
- gamma::ln_gamma(alpha + 1.0) |
||||||
|
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t) |
||||||
|
}) |
||||||
|
.sum(); |
||||||
|
|
||||||
|
let log_ll_theta: f64 = M.outer_iter() |
||||||
|
.map(|x| x.outer_iter() |
||||||
|
.map(|y| gamma::ln_gamma(alpha)
|
||||||
|
- gamma::ln_gamma(alpha + y.sum() as f64) |
||||||
|
+ y.iter().map(|z|
|
||||||
|
gamma::ln_gamma(alpha + *z as f64)
|
||||||
|
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum(); |
||||||
|
log_ll_theta + log_ll_q |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,37 @@ |
|||||||
|
|
||||||
|
mod utils; |
||||||
|
use utils::*; |
||||||
|
|
||||||
|
use rustyCTBN::ctbn::*; |
||||||
|
use rustyCTBN::network::Network; |
||||||
|
use rustyCTBN::tools::*; |
||||||
|
use rustyCTBN::structure_learning::*; |
||||||
|
use ndarray::{arr1, arr2}; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
|
||||||
|
#[macro_use] |
||||||
|
extern crate approx; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_log_likelihood() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let trj = Trajectory{ |
||||||
|
time: arr1(&[0.0,0.1,0.3]), |
||||||
|
events: arr2(&[[0],[1],[1]])}; |
||||||
|
|
||||||
|
let dataset = Dataset{ |
||||||
|
trajectories: vec![trj]}; |
||||||
|
|
||||||
|
let ll = LogLikelihood::init(1, 1.0); |
||||||
|
|
||||||
|
assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
} |
Loading…
Reference in new issue