parent
aa9ff5e05d
commit
dc53e5167e
@ -1,21 +1,197 @@ |
|||||||
use crate::params::*; |
|
||||||
use crate::network; |
use crate::network; |
||||||
|
use crate::params::*; |
||||||
use crate::tools; |
use crate::tools; |
||||||
use ndarray::prelude::*; |
use ndarray::prelude::*; |
||||||
|
use ndarray::{concatenate, Slice}; |
||||||
use std::collections::BTreeSet; |
use std::collections::BTreeSet; |
||||||
|
|
||||||
pub fn MLE(net: Box<dyn network::Network>,
|
pub fn MLE( |
||||||
dataset: &tools::Dataset,
|
net: Box<&dyn network::Network>, |
||||||
node: usize,
|
dataset: &tools::Dataset, |
||||||
parent_set: Option<BTreeSet<usize>>) { |
node: usize, |
||||||
|
parent_set: Option<BTreeSet<usize>>, |
||||||
|
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { |
||||||
|
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
|
||||||
|
|
||||||
|
//Use parent_set from parameter if present. Otherwise use parent_set from network.
|
||||||
let parent_set = match parent_set { |
let parent_set = match parent_set { |
||||||
Some(p) => p, |
Some(p) => p, |
||||||
None => net.get_parent_set(node) |
None => net.get_parent_set(node), |
||||||
}; |
}; |
||||||
|
|
||||||
|
//Get the number of values assumable by the node
|
||||||
|
let node_domain = net |
||||||
|
.get_node(node.clone()) |
||||||
|
.params |
||||||
|
.get_reserved_space_as_parent(); |
||||||
|
|
||||||
|
//Get the number of values assumable by each parent of the node
|
||||||
|
let parentset_domain: Vec<usize> = parent_set |
||||||
|
.iter() |
||||||
|
.map(|x| { |
||||||
|
net.get_node(x.clone()) |
||||||
|
.params |
||||||
|
.get_reserved_space_as_parent() |
||||||
|
}) |
||||||
|
.collect(); |
||||||
|
|
||||||
|
//Vector used to convert a specific configuration of the parent_set to the corresponding index
|
||||||
|
//for CIM, M and T
|
||||||
|
let mut vector_to_idx: Array1<usize> = Array::zeros(net.get_number_of_nodes()); |
||||||
|
|
||||||
|
parent_set |
||||||
|
.iter() |
||||||
|
.zip(parentset_domain.iter()) |
||||||
|
.fold(1, |acc, (idx, x)| { |
||||||
|
vector_to_idx[*idx] = acc; |
||||||
|
acc * x |
||||||
|
}); |
||||||
|
|
||||||
|
//Number of transition given a specific configuration of the parent set
|
||||||
|
let mut M: Array3<usize> = |
||||||
|
Array::zeros((parentset_domain.iter().product(), node_domain, node_domain)); |
||||||
|
|
||||||
|
//Residence time given a specific configuration of the parent set
|
||||||
|
let mut T: Array2<f64> = Array::zeros((parentset_domain.iter().product(), node_domain)); |
||||||
|
|
||||||
|
//Compute the sufficient statistics
|
||||||
|
for trj in dataset.trajectories.iter() { |
||||||
|
for idx in 0..(trj.time.len() - 1) { |
||||||
|
let t1 = trj.time[idx]; |
||||||
|
let t2 = trj.time[idx + 1]; |
||||||
|
let ev1 = trj.events.row(idx); |
||||||
|
let ev2 = trj.events.row(idx + 1); |
||||||
|
let idx1 = vector_to_idx.dot(&ev1); |
||||||
|
|
||||||
|
T[[idx1, ev1[node]]] += t2 - t1; |
||||||
|
if ev1[node] != ev2[node] { |
||||||
|
M[[idx1, ev1[node], ev2[node]]] += 1; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
//Compute the CIM as M[i,x,y]/T[i,x]
|
||||||
|
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); |
||||||
|
CIM.axis_iter_mut(Axis(2)) |
||||||
|
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) |
||||||
|
.for_each(|(mut C, m)| C.assign(&(&m/&T))); |
||||||
|
|
||||||
|
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
|
||||||
|
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); |
||||||
|
CIM.outer_iter_mut() |
||||||
|
.zip(tmp_diag_sum.outer_iter()) |
||||||
|
.for_each(|(mut C, diag)| { |
||||||
|
C.diag_mut().assign(&diag); |
||||||
|
}); |
||||||
|
return (CIM, M, T); |
||||||
|
} |
||||||
|
|
||||||
|
#[cfg(test)] |
||||||
|
mod tests { |
||||||
|
use super::*; |
||||||
|
use crate::ctbn::*; |
||||||
|
use crate::network::Network; |
||||||
|
use crate::node; |
||||||
|
use crate::params; |
||||||
|
use ndarray::arr3; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
use tools::*; |
||||||
|
|
||||||
|
fn define_binary_node(name: String) -> node::Node { |
||||||
|
let mut domain = BTreeSet::new(); |
||||||
|
domain.insert(String::from("A")); |
||||||
|
domain.insert(String::from("B")); |
||||||
|
let param = params::DiscreteStatesContinousTimeParams::init(domain); |
||||||
|
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name); |
||||||
|
return n; |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
fn define_ternary_node(name: String) -> node::Node { |
||||||
|
let mut domain = BTreeSet::new(); |
||||||
|
domain.insert(String::from("A")); |
||||||
|
domain.insert(String::from("B")); |
||||||
|
domain.insert(String::from("C")); |
||||||
|
let param = params::DiscreteStatesContinousTimeParams::init(domain); |
||||||
|
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name); |
||||||
|
return n; |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn learn_binary_cim_MLE() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(define_binary_node(String::from("n1"))) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(define_binary_node(String::from("n2"))) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[ |
||||||
|
[[-1.0, 1.0], [4.0, -4.0]], |
||||||
|
[[-6.0, 6.0], [2.0, -2.0]], |
||||||
|
])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let data = trajectory_generator(Box::new(&net), 10, 100.0); |
||||||
|
|
||||||
|
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); |
||||||
|
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||||
|
assert_eq!(CIM.shape(), [2, 2, 2]); |
||||||
|
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn learn_ternary_cim_MLE() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(define_ternary_node(String::from("n1"))) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(define_ternary_node(String::from("n2"))) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
|
||||||
|
[1.5, -2.0, 0.5], |
||||||
|
[0.4, 0.6, -1.0]]])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[ |
||||||
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||||
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||||
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||||
|
])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let data = trajectory_generator(Box::new(&net), 100, 200.0); |
||||||
|
|
||||||
|
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); |
||||||
|
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||||
|
assert_eq!(CIM.shape(), [3, 3, 3]); |
||||||
|
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2); |
||||||
|
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2); |
||||||
|
} |
||||||
} |
} |
||||||
|
Loading…
Reference in new issue