diff --git a/src/ctbn.rs b/src/ctbn.rs index 6cfb9ba..5e7e276 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -97,6 +97,10 @@ impl network::Network for CtbnNetwork { 0..self.nodes.len() } + fn get_number_of_nodes(&self) -> usize { + self.nodes.len() + } + fn get_node(&self, node_idx: usize) -> &node::Node{ &self.nodes[node_idx] } @@ -117,6 +121,15 @@ impl network::Network for CtbnNetwork { }).0 } + + fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize { + parent_set.iter().fold((0, 1), |mut acc, x| { + acc.0 += self.nodes[*x].params.state_to_index(¤t_state[*x]) * acc.1; + acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent(); + acc + }).0 + } + fn get_parent_set(&self, node: usize) -> BTreeSet { self.adj_matrix.as_ref() .unwrap() @@ -229,5 +242,32 @@ mod tests { params::StateType::Discrete(1), params::StateType::Discrete(0)]); assert_eq!(1, idx); + + } + + + + #[test] + fn compute_index_from_custom_parent_set() { + let mut net = CtbnNetwork::init(); + let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap(); + let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); + let n3 = net.add_node(define_binary_node(String::from("n3"))).unwrap(); + + + let idx = net.get_param_index_from_custom_parent_set(&vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1)], + &BTreeSet::from([1])); + assert_eq!(0, idx); + + + let idx = net.get_param_index_from_custom_parent_set(&vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1)], + &BTreeSet::from([1,2])); + assert_eq!(2, idx); } } diff --git a/src/network.rs b/src/network.rs index 2f36738..3b6ce06 100644 --- a/src/network.rs +++ b/src/network.rs @@ -20,13 +20,20 @@ pub trait Network { ///Get all the indices of the nodes contained inside the network fn get_node_indices(&self) -> std::ops::Range; + fn get_number_of_nodes(&self) -> usize; fn get_node(&self, node_idx: usize) -> &node::Node; fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node; - ///Compute the index that must be used to access the parameter of a node given a specific + ///Compute the index that must be used to access the parameters of a node given a specific ///configuration of the network. Usually, the only values really used in *current_state* are ///the ones in the parent set of the *node*. fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; + + + ///Compute the index that must be used to access the parameters of a node given a specific + ///configuration of the network and a generic parent_set. Usually, the only values really used + ///in *current_state* are the ones in the parent set of the *node*. + fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize; fn get_parent_set(&self, node: usize) -> BTreeSet; fn get_children_set(&self, node: usize) -> BTreeSet; } diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 16c0b8f..cd2dba2 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -1,21 +1,197 @@ -use crate::params::*; use crate::network; +use crate::params::*; use crate::tools; use ndarray::prelude::*; +use ndarray::{concatenate, Slice}; use std::collections::BTreeSet; -pub fn MLE(net: Box, - dataset: &tools::Dataset, - node: usize, - parent_set: Option>) { - +pub fn MLE( + net: Box<&dyn network::Network>, + dataset: &tools::Dataset, + node: usize, + parent_set: Option>, +) -> (Array3, Array3, Array2) { + //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes + + //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { Some(p) => p, - None => net.get_parent_set(node) + None => net.get_parent_set(node), }; + //Get the number of values assumable by the node + let node_domain = net + .get_node(node.clone()) + .params + .get_reserved_space_as_parent(); - + //Get the number of values assumable by each parent of the node + let parentset_domain: Vec = parent_set + .iter() + .map(|x| { + net.get_node(x.clone()) + .params + .get_reserved_space_as_parent() + }) + .collect(); + + //Vector used to convert a specific configuration of the parent_set to the corresponding index + //for CIM, M and T + let mut vector_to_idx: Array1 = Array::zeros(net.get_number_of_nodes()); + parent_set + .iter() + .zip(parentset_domain.iter()) + .fold(1, |acc, (idx, x)| { + vector_to_idx[*idx] = acc; + acc * x + }); + //Number of transition given a specific configuration of the parent set + let mut M: Array3 = + Array::zeros((parentset_domain.iter().product(), node_domain, node_domain)); + + //Residence time given a specific configuration of the parent set + let mut T: Array2 = Array::zeros((parentset_domain.iter().product(), node_domain)); + + //Compute the sufficient statistics + for trj in dataset.trajectories.iter() { + for idx in 0..(trj.time.len() - 1) { + let t1 = trj.time[idx]; + let t2 = trj.time[idx + 1]; + let ev1 = trj.events.row(idx); + let ev2 = trj.events.row(idx + 1); + let idx1 = vector_to_idx.dot(&ev1); + + T[[idx1, ev1[node]]] += t2 - t1; + if ev1[node] != ev2[node] { + M[[idx1, ev1[node], ev2[node]]] += 1; + } + } + } + + //Compute the CIM as M[i,x,y]/T[i,x] + let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); + CIM.axis_iter_mut(Axis(2)) + .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) + .for_each(|(mut C, m)| C.assign(&(&m/&T))); + + //Set the diagonal of the inner matrices to the the row sum multiplied by -1 + let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); + CIM.outer_iter_mut() + .zip(tmp_diag_sum.outer_iter()) + .for_each(|(mut C, diag)| { + C.diag_mut().assign(&diag); + }); + return (CIM, M, T); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ctbn::*; + use crate::network::Network; + use crate::node; + use crate::params; + use ndarray::arr3; + use std::collections::BTreeSet; + use tools::*; + + fn define_binary_node(name: String) -> node::Node { + let mut domain = BTreeSet::new(); + domain.insert(String::from("A")); + domain.insert(String::from("B")); + let param = params::DiscreteStatesContinousTimeParams::init(domain); + let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name); + return n; + } + + + fn define_ternary_node(name: String) -> node::Node { + let mut domain = BTreeSet::new(); + domain.insert(String::from("A")); + domain.insert(String::from("B")); + domain.insert(String::from("C")); + let param = params::DiscreteStatesContinousTimeParams::init(domain); + let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name); + return n; + } + + #[test] + fn learn_binary_cim_MLE() { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(define_binary_node(String::from("n1"))) + .unwrap(); + let n2 = net + .add_node(define_binary_node(String::from("n2"))) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some(arr3(&[ + [[-1.0, 1.0], [4.0, -4.0]], + [[-6.0, 6.0], [2.0, -2.0]], + ])); + } + } + + let data = trajectory_generator(Box::new(&net), 10, 100.0); + + let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); + print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); + assert_eq!(CIM.shape(), [2, 2, 2]); + assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); + assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); + assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2); + assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2); + } + + + #[test] + fn learn_ternary_cim_MLE() { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(define_ternary_node(String::from("n1"))) + .unwrap(); + let n2 = net + .add_node(define_ternary_node(String::from("n2"))) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]])); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])); + } + } + + let data = trajectory_generator(Box::new(&net), 100, 200.0); + + let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); + print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); + assert_eq!(CIM.shape(), [3, 3, 3]); + assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); + assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); + assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2); + assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2); + } } diff --git a/src/params.rs b/src/params.rs index 1c7c7af..029b062 100644 --- a/src/params.rs +++ b/src/params.rs @@ -16,7 +16,7 @@ pub enum ParamsError { /// Allowed type of states #[derive(Clone)] pub enum StateType { - Discrete(u32), + Discrete(usize), } /// Parameters @@ -91,7 +91,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { fn get_random_state_uniform(&self) -> StateType { let mut rng = rand::thread_rng(); - StateType::Discrete(rng.gen_range(0..(self.domain.len() as u32))) + StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } fn get_random_residence_time(&self, state: usize, u: usize) -> Result { @@ -138,7 +138,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { next_state.0 + 1 }; - Ok(StateType::Discrete(next_state as u32)) + Ok(StateType::Discrete(next_state)) } Option::None => Err(ParamsError::ParametersNotInitialized(String::from( "CIM not initialized", @@ -177,7 +177,7 @@ mod tests { #[test] fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); - let mut states = Array1::::zeros(10000); + let mut states = Array1::::zeros(10000); states.mapv_inplace(|_| { if let StateType::Discrete(val) = param.get_random_state_uniform() { @@ -194,7 +194,7 @@ mod tests { #[test] fn test_random_generation_state() { let param = create_ternary_discrete_time_continous_param(); - let mut states = Array1::::zeros(10000); + let mut states = Array1::::zeros(10000); states.mapv_inplace(|_| { if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { diff --git a/src/tools.rs b/src/tools.rs index eaf2ef6..f891a7d 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -5,16 +5,16 @@ use crate::params; use crate::params::ParamsTrait; pub struct Trajectory { - time: Array1, - events: Array2 + pub time: Array1, + pub events: Array2 } pub struct Dataset { - trajectories: Vec + pub trajectories: Vec } -pub fn trajectory_generator(net: Box, n_trajectories: u64, t_end: f64) -> Dataset { +pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset { let mut dataset = Dataset{ trajectories: Vec::new() }; @@ -23,7 +23,7 @@ pub fn trajectory_generator(net: Box, n_trajectories: u64, for _ in 0..n_trajectories { let mut t = 0.0; let mut time: Vec = Vec::new(); - let mut events: Vec> = Vec::new(); + let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx.iter().map(|x| { net.get_node(*x).params.get_random_state_uniform() }).collect(); @@ -135,7 +135,7 @@ mod tests { } } - let data = trajectory_generator(Box::from(net), 4, 1.0); + let data = trajectory_generator(Box::new(&net), 4, 1.0); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);