Refactor of Dataset and Trajectory to ensure some basic properties.

pull/42/head
AlessandroBregoli 3 years ago
parent df12b93d55
commit f49523f35a
  1. 12
      src/parameter_learning.rs
  2. 2
      src/structure_learning/score_function.rs
  3. 55
      src/tools.rs
  4. 25
      tests/structure_learning.rs
  5. 4
      tests/tools.rs

@ -57,12 +57,12 @@ pub fn sufficient_statistics<T:network::Network>(
let mut T: Array2<f64> = Array::zeros((parentset_domain.iter().product(), node_domain)); let mut T: Array2<f64> = Array::zeros((parentset_domain.iter().product(), node_domain));
//Compute the sufficient statistics //Compute the sufficient statistics
for trj in dataset.trajectories.iter() { for trj in dataset.get_trajectories().iter() {
for idx in 0..(trj.time.len() - 1) { for idx in 0..(trj.get_time().len() - 1) {
let t1 = trj.time[idx]; let t1 = trj.get_time()[idx];
let t2 = trj.time[idx + 1]; let t2 = trj.get_time()[idx + 1];
let ev1 = trj.events.row(idx); let ev1 = trj.get_events().row(idx);
let ev2 = trj.events.row(idx + 1); let ev2 = trj.get_events().row(idx + 1);
let idx1 = vector_to_idx.dot(&ev1); let idx1 = vector_to_idx.dot(&ev1);
T[[idx1, ev1[node]]] += t2 - t1; T[[idx1, ev1[node]]] += t2 - t1;

@ -116,7 +116,7 @@ impl ScoreFunction for BIC {
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
//TODO: Optimize this //TODO: Optimize this
let sample_size: usize = dataset.trajectories.iter().map(|x| x.time.len() -1).sum(); let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum();
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
} }
} }

@ -3,16 +3,49 @@ use crate::node;
use crate::params; use crate::params;
use crate::params::ParamsTrait; use crate::params::ParamsTrait;
use ndarray::prelude::*; use ndarray::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng; use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
pub struct Trajectory { pub struct Trajectory {
pub time: Array1<f64>, time: Array1<f64>,
pub events: Array2<usize>, events: Array2<usize>,
}
impl Trajectory {
pub fn init(time: Array1<f64>, events: Array2<usize>) -> Trajectory {
if time.shape()[0] != events.shape()[0] {
panic!("time.shape[0] must be equal to events.shape[0]");
}
Trajectory { time, events }
}
pub fn get_time(&self) -> &Array1<f64> {
&self.time
}
pub fn get_events(&self) -> &Array2<usize> {
&self.events
}
} }
pub struct Dataset { pub struct Dataset {
pub trajectories: Vec<Trajectory>, trajectories: Vec<Trajectory>,
}
impl Dataset {
pub fn init(trajectories: Vec<Trajectory>) -> Dataset {
if trajectories
.iter()
.any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1])
{
panic!("All the trajectories mus represents the same number of variables");
}
Dataset { trajectories }
}
pub fn get_trajectories(&self) -> &Vec<Trajectory> {
&self.trajectories
}
} }
pub fn trajectory_generator<T: network::Network>( pub fn trajectory_generator<T: network::Network>(
@ -21,10 +54,8 @@ pub fn trajectory_generator<T: network::Network>(
t_end: f64, t_end: f64,
seed: Option<u64>, seed: Option<u64>,
) -> Dataset { ) -> Dataset {
let mut dataset = Dataset {
trajectories: Vec::new(),
};
let mut trajectories: Vec<Trajectory> = Vec::new();
let seed = seed.unwrap_or_else(rand::random); let seed = seed.unwrap_or_else(rand::random);
let mut rng = ChaCha8Rng::seed_from_u64(seed); let mut rng = ChaCha8Rng::seed_from_u64(seed);
@ -115,14 +146,14 @@ pub fn trajectory_generator<T: network::Network>(
); );
time.push(t_end.clone()); time.push(t_end.clone());
dataset.trajectories.push(Trajectory { trajectories.push(Trajectory::init(
time: Array::from_vec(time), Array::from_vec(time),
events: Array2::from_shape_vec( Array2::from_shape_vec(
(events.len(), current_state.len()), (events.len(), current_state.len()),
events.iter().flatten().cloned().collect(), events.iter().flatten().cloned().collect(),
) )
.unwrap(), .unwrap(),
}); ));
} }
dataset Dataset::init(trajectories)
} }

@ -17,18 +17,17 @@ use rustyCTBN::params;
extern crate approx; extern crate approx;
#[test] #[test]
fn simple_log_likelihood() { fn simple_score_test() {
let mut net = CtbnNetwork::init(); let mut net = CtbnNetwork::init();
let n1 = net let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) .add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap(); .unwrap();
let trj = Trajectory{ let trj = Trajectory::init(
time: arr1(&[0.0,0.1,0.3]), arr1(&[0.0,0.1,0.3]),
events: arr2(&[[0],[1],[1]])}; arr2(&[[0],[1],[1]]));
let dataset = Dataset{ let dataset = Dataset::init(vec![trj]);
trajectories: vec![trj]};
let ll = LogLikelihood::init(1, 1.0); let ll = LogLikelihood::init(1, 1.0);
@ -44,16 +43,14 @@ fn simple_bic() {
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) .add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap(); .unwrap();
let trj = Trajectory{ let trj = Trajectory::init(
time: arr1(&[0.0,0.1,0.3]), arr1(&[0.0,0.1,0.3]),
events: arr2(&[[0],[1],[1]])}; arr2(&[[0],[1],[1]]));
let dataset = Dataset{ let dataset = Dataset::init(vec![trj]);
trajectories: vec![trj]}; let bic = BIC::init(1, 1.0);
let ll = BIC::init(1, 1.0);
assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
} }

@ -38,8 +38,8 @@ fn run_sampling() {
let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),);
assert_eq!(4, data.trajectories.len()); assert_eq!(4, data.get_trajectories().len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); assert_relative_eq!(1.0, data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len()-1]);
} }

Loading…
Cancel
Save