Sufficient statistics extracted from MLE function

pull/19/head
AlessandroBregoli 3 years ago
parent 4adfbfa4e4
commit a191bdef1c
  1. 33
      src/parameter_learning.rs

@ -5,20 +5,12 @@ use ndarray::prelude::*;
use ndarray::{concatenate, Slice}; use ndarray::{concatenate, Slice};
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub fn MLE( pub fn sufficient_statistics(
net: Box<&dyn network::Network>, net: Box<&dyn network::Network>,
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: &BTreeSet<usize>
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> (Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node),
};
//Get the number of values assumable by the node //Get the number of values assumable by the node
let node_domain = net let node_domain = net
.get_node(node.clone()) .get_node(node.clone())
@ -70,6 +62,25 @@ pub fn MLE(
} }
} }
return (M, T);
}
pub fn MLE(
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node),
};
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
//Compute the CIM as M[i,x,y]/T[i,x] //Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))

Loading…
Cancel
Save