|
|
|
@ -5,20 +5,12 @@ use ndarray::prelude::*; |
|
|
|
|
use ndarray::{concatenate, Slice}; |
|
|
|
|
use std::collections::BTreeSet; |
|
|
|
|
|
|
|
|
|
pub fn MLE( |
|
|
|
|
pub fn sufficient_statistics( |
|
|
|
|
net: Box<&dyn network::Network>, |
|
|
|
|
dataset: &tools::Dataset, |
|
|
|
|
node: usize, |
|
|
|
|
parent_set: Option<BTreeSet<usize>>, |
|
|
|
|
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { |
|
|
|
|
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
|
|
|
|
|
|
|
|
|
|
//Use parent_set from parameter if present. Otherwise use parent_set from network.
|
|
|
|
|
let parent_set = match parent_set { |
|
|
|
|
Some(p) => p, |
|
|
|
|
None => net.get_parent_set(node), |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
parent_set: &BTreeSet<usize> |
|
|
|
|
) -> (Array3<usize>, Array2<f64>) { |
|
|
|
|
//Get the number of values assumable by the node
|
|
|
|
|
let node_domain = net |
|
|
|
|
.get_node(node.clone()) |
|
|
|
@ -70,6 +62,25 @@ pub fn MLE( |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return (M, T); |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn MLE( |
|
|
|
|
net: Box<&dyn network::Network>, |
|
|
|
|
dataset: &tools::Dataset, |
|
|
|
|
node: usize, |
|
|
|
|
parent_set: Option<BTreeSet<usize>>, |
|
|
|
|
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { |
|
|
|
|
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
|
|
|
|
|
|
|
|
|
|
//Use parent_set from parameter if present. Otherwise use parent_set from network.
|
|
|
|
|
let parent_set = match parent_set { |
|
|
|
|
Some(p) => p, |
|
|
|
|
None => net.get_parent_set(node), |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); |
|
|
|
|
//Compute the CIM as M[i,x,y]/T[i,x]
|
|
|
|
|
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); |
|
|
|
|
CIM.axis_iter_mut(Axis(2)) |
|
|
|
|