|
|
@ -5,6 +5,16 @@ use ndarray::prelude::*; |
|
|
|
use ndarray::{concatenate, Slice}; |
|
|
|
use ndarray::{concatenate, Slice}; |
|
|
|
use std::collections::BTreeSet; |
|
|
|
use std::collections::BTreeSet; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub trait ParameterLearning{ |
|
|
|
|
|
|
|
fn fit( |
|
|
|
|
|
|
|
&self, |
|
|
|
|
|
|
|
net: Box<&dyn network::Network>, |
|
|
|
|
|
|
|
dataset: &tools::Dataset, |
|
|
|
|
|
|
|
node: usize, |
|
|
|
|
|
|
|
parent_set: Option<BTreeSet<usize>>, |
|
|
|
|
|
|
|
) -> (Array3<f64>, Array3<usize>, Array2<f64>); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
pub fn sufficient_statistics( |
|
|
|
pub fn sufficient_statistics( |
|
|
|
net: Box<&dyn network::Network>, |
|
|
|
net: Box<&dyn network::Network>, |
|
|
|
dataset: &tools::Dataset, |
|
|
|
dataset: &tools::Dataset, |
|
|
@ -66,34 +76,39 @@ pub fn sufficient_statistics( |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
pub fn MLE( |
|
|
|
pub struct MLE {} |
|
|
|
net: Box<&dyn network::Network>, |
|
|
|
|
|
|
|
dataset: &tools::Dataset, |
|
|
|
|
|
|
|
node: usize, |
|
|
|
|
|
|
|
parent_set: Option<BTreeSet<usize>>, |
|
|
|
|
|
|
|
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { |
|
|
|
|
|
|
|
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//Use parent_set from parameter if present. Otherwise use parent_set from network.
|
|
|
|
|
|
|
|
let parent_set = match parent_set { |
|
|
|
|
|
|
|
Some(p) => p, |
|
|
|
|
|
|
|
None => net.get_parent_set(node), |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); |
|
|
|
|
|
|
|
//Compute the CIM as M[i,x,y]/T[i,x]
|
|
|
|
|
|
|
|
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); |
|
|
|
|
|
|
|
CIM.axis_iter_mut(Axis(2)) |
|
|
|
|
|
|
|
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) |
|
|
|
|
|
|
|
.for_each(|(mut C, m)| C.assign(&(&m/&T))); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
|
|
|
|
|
|
|
|
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); |
|
|
|
|
|
|
|
CIM.outer_iter_mut() |
|
|
|
|
|
|
|
.zip(tmp_diag_sum.outer_iter()) |
|
|
|
|
|
|
|
.for_each(|(mut C, diag)| { |
|
|
|
|
|
|
|
C.diag_mut().assign(&diag); |
|
|
|
|
|
|
|
}); |
|
|
|
|
|
|
|
return (CIM, M, T); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl ParameterLearning for MLE { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn fit( |
|
|
|
|
|
|
|
&self, |
|
|
|
|
|
|
|
net: Box<&dyn network::Network>, |
|
|
|
|
|
|
|
dataset: &tools::Dataset, |
|
|
|
|
|
|
|
node: usize, |
|
|
|
|
|
|
|
parent_set: Option<BTreeSet<usize>>, |
|
|
|
|
|
|
|
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { |
|
|
|
|
|
|
|
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//Use parent_set from parameter if present. Otherwise use parent_set from network.
|
|
|
|
|
|
|
|
let parent_set = match parent_set { |
|
|
|
|
|
|
|
Some(p) => p, |
|
|
|
|
|
|
|
None => net.get_parent_set(node), |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); |
|
|
|
|
|
|
|
//Compute the CIM as M[i,x,y]/T[i,x]
|
|
|
|
|
|
|
|
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); |
|
|
|
|
|
|
|
CIM.axis_iter_mut(Axis(2)) |
|
|
|
|
|
|
|
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) |
|
|
|
|
|
|
|
.for_each(|(mut C, m)| C.assign(&(&m/&T))); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
|
|
|
|
|
|
|
|
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); |
|
|
|
|
|
|
|
CIM.outer_iter_mut() |
|
|
|
|
|
|
|
.zip(tmp_diag_sum.outer_iter()) |
|
|
|
|
|
|
|
.for_each(|(mut C, diag)| { |
|
|
|
|
|
|
|
C.diag_mut().assign(&diag); |
|
|
|
|
|
|
|
}); |
|
|
|
|
|
|
|
return (CIM, M, T); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|