diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 36a7e01..1582fce 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -5,6 +5,16 @@ use ndarray::prelude::*; use ndarray::{concatenate, Slice}; use std::collections::BTreeSet; +pub trait ParameterLearning{ + fn fit( + &self, + net: Box<&dyn network::Network>, + dataset: &tools::Dataset, + node: usize, + parent_set: Option>, + ) -> (Array3, Array3, Array2); +} + pub fn sufficient_statistics( net: Box<&dyn network::Network>, dataset: &tools::Dataset, @@ -66,34 +76,39 @@ pub fn sufficient_statistics( } -pub fn MLE( - net: Box<&dyn network::Network>, - dataset: &tools::Dataset, - node: usize, - parent_set: Option>, -) -> (Array3, Array3, Array2) { - //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes +pub struct MLE {} - //Use parent_set from parameter if present. Otherwise use parent_set from network. - let parent_set = match parent_set { - Some(p) => p, - None => net.get_parent_set(node), - }; - - let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - //Compute the CIM as M[i,x,y]/T[i,x] - let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); - CIM.axis_iter_mut(Axis(2)) - .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); +impl ParameterLearning for MLE { - //Set the diagonal of the inner matrices to the the row sum multiplied by -1 - let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); - CIM.outer_iter_mut() - .zip(tmp_diag_sum.outer_iter()) - .for_each(|(mut C, diag)| { - C.diag_mut().assign(&diag); - }); - return (CIM, M, T); -} + fn fit( + &self, + net: Box<&dyn network::Network>, + dataset: &tools::Dataset, + node: usize, + parent_set: Option>, + ) -> (Array3, Array3, Array2) { + //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes + + //Use parent_set from parameter if present. Otherwise use parent_set from network. + let parent_set = match parent_set { + Some(p) => p, + None => net.get_parent_set(node), + }; + + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); + //Compute the CIM as M[i,x,y]/T[i,x] + let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); + CIM.axis_iter_mut(Axis(2)) + .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) + .for_each(|(mut C, m)| C.assign(&(&m/&T))); + //Set the diagonal of the inner matrices to the the row sum multiplied by -1 + let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); + CIM.outer_iter_mut() + .zip(tmp_diag_sum.outer_iter()) + .for_each(|(mut C, diag)| { + C.diag_mut().assign(&diag); + }); + return (CIM, M, T); + } +} diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 26573a8..a59f0ee 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -42,8 +42,8 @@ fn learn_binary_cim_MLE() { } let data = trajectory_generator(Box::new(&net), 100, 100.0); - - let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); + let mle = MLE{}; + let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); @@ -83,8 +83,8 @@ fn learn_ternary_cim_MLE() { } let data = trajectory_generator(Box::new(&net), 100, 200.0); - - let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); + let mle = MLE{}; + let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);