parameter_learning refactor

pull/19/head
AlessandroBregoli 3 years ago
parent a191bdef1c
commit 5c816ebba7
  1. 75
      src/parameter_learning.rs
  2. 8
      tests/parameter_learning.rs

@ -5,6 +5,16 @@ use ndarray::prelude::*;
use ndarray::{concatenate, Slice}; use ndarray::{concatenate, Slice};
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait ParameterLearning{
fn fit(
&self,
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>);
}
pub fn sufficient_statistics( pub fn sufficient_statistics(
net: Box<&dyn network::Network>, net: Box<&dyn network::Network>,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -66,34 +76,39 @@ pub fn sufficient_statistics(
} }
pub fn MLE( pub struct MLE {}
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node),
};
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
//Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut()
.zip(tmp_diag_sum.outer_iter())
.for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag);
});
return (CIM, M, T);
}
impl ParameterLearning for MLE {
fn fit(
&self,
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node),
};
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
//Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut()
.zip(tmp_diag_sum.outer_iter())
.for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag);
});
return (CIM, M, T);
}
}

@ -42,8 +42,8 @@ fn learn_binary_cim_MLE() {
} }
let data = trajectory_generator(Box::new(&net), 100, 100.0); let data = trajectory_generator(Box::new(&net), 100, 100.0);
let mle = MLE{};
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]); assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
@ -83,8 +83,8 @@ fn learn_ternary_cim_MLE() {
} }
let data = trajectory_generator(Box::new(&net), 100, 200.0); let data = trajectory_generator(Box::new(&net), 100, 200.0);
let mle = MLE{};
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]); assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);

Loading…
Cancel
Save