pull/19/head
AlessandroBregoli 3 years ago
parent aa9ff5e05d
commit dc53e5167e
  1. 40
      src/ctbn.rs
  2. 9
      src/network.rs
  3. 188
      src/parameter_learning.rs
  4. 10
      src/params.rs
  5. 12
      src/tools.rs

@ -97,6 +97,10 @@ impl network::Network for CtbnNetwork {
0..self.nodes.len()
}
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &node::Node{
&self.nodes[node_idx]
}
@ -117,6 +121,15 @@ impl network::Network for CtbnNetwork {
}).0
}
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize {
parent_set.iter().fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].params.state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent();
acc
}).0
}
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref()
.unwrap()
@ -229,5 +242,32 @@ mod tests {
params::StateType::Discrete(1),
params::StateType::Discrete(0)]);
assert_eq!(1, idx);
}
#[test]
fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
let n3 = net.add_node(define_binary_node(String::from("n3"))).unwrap();
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1]));
assert_eq!(0, idx);
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1,2]));
assert_eq!(2, idx);
}
}

@ -20,13 +20,20 @@ pub trait Network {
///Get all the indices of the nodes contained inside the network
fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_number_of_nodes(&self) -> usize;
fn get_node(&self, node_idx: usize) -> &node::Node;
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node;
///Compute the index that must be used to access the parameter of a node given a specific
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*.
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<params::StateType>, parent_set: &BTreeSet<usize>) -> usize;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -1,21 +1,197 @@
use crate::params::*;
use crate::network;
use crate::params::*;
use crate::tools;
use ndarray::prelude::*;
use ndarray::{concatenate, Slice};
use std::collections::BTreeSet;
pub fn MLE(net: Box<dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>) {
pub fn MLE(
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node)
None => net.get_parent_set(node),
};
//Get the number of values assumable by the node
let node_domain = net
.get_node(node.clone())
.params
.get_reserved_space_as_parent();
//Get the number of values assumable by each parent of the node
let parentset_domain: Vec<usize> = parent_set
.iter()
.map(|x| {
net.get_node(x.clone())
.params
.get_reserved_space_as_parent()
})
.collect();
//Vector used to convert a specific configuration of the parent_set to the corresponding index
//for CIM, M and T
let mut vector_to_idx: Array1<usize> = Array::zeros(net.get_number_of_nodes());
parent_set
.iter()
.zip(parentset_domain.iter())
.fold(1, |acc, (idx, x)| {
vector_to_idx[*idx] = acc;
acc * x
});
//Number of transition given a specific configuration of the parent set
let mut M: Array3<usize> =
Array::zeros((parentset_domain.iter().product(), node_domain, node_domain));
//Residence time given a specific configuration of the parent set
let mut T: Array2<f64> = Array::zeros((parentset_domain.iter().product(), node_domain));
//Compute the sufficient statistics
for trj in dataset.trajectories.iter() {
for idx in 0..(trj.time.len() - 1) {
let t1 = trj.time[idx];
let t2 = trj.time[idx + 1];
let ev1 = trj.events.row(idx);
let ev2 = trj.events.row(idx + 1);
let idx1 = vector_to_idx.dot(&ev1);
T[[idx1, ev1[node]]] += t2 - t1;
if ev1[node] != ev2[node] {
M[[idx1, ev1[node], ev2[node]]] += 1;
}
}
}
//Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut()
.zip(tmp_diag_sum.outer_iter())
.for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag);
});
return (CIM, M, T);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ctbn::*;
use crate::network::Network;
use crate::node;
use crate::params;
use ndarray::arr3;
use std::collections::BTreeSet;
use tools::*;
fn define_binary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let param = params::DiscreteStatesContinousTimeParams::init(domain);
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}
fn define_ternary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
domain.insert(String::from("C"));
let param = params::DiscreteStatesContinousTimeParams::init(domain);
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}
#[test]
fn learn_binary_cim_MLE() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(define_binary_node(String::from("n1")))
.unwrap();
let n2 = net
.add_node(define_binary_node(String::from("n2")))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 10, 100.0);
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2);
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2);
}
#[test]
fn learn_ternary_cim_MLE() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(define_ternary_node(String::from("n1")))
.unwrap();
let n2 = net
.add_node(define_ternary_node(String::from("n2")))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 100, 200.0);
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2);
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2);
}
}

@ -16,7 +16,7 @@ pub enum ParamsError {
/// Allowed type of states
#[derive(Clone)]
pub enum StateType {
Discrete(u32),
Discrete(usize),
}
/// Parameters
@ -91,7 +91,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn get_random_state_uniform(&self) -> StateType {
let mut rng = rand::thread_rng();
StateType::Discrete(rng.gen_range(0..(self.domain.len() as u32)))
StateType::Discrete(rng.gen_range(0..(self.domain.len())))
}
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> {
@ -138,7 +138,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
next_state.0 + 1
};
Ok(StateType::Discrete(next_state as u32))
Ok(StateType::Discrete(next_state))
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
@ -177,7 +177,7 @@ mod tests {
#[test]
fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000);
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
@ -194,7 +194,7 @@ mod tests {
#[test]
fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<u32>::zeros(10000);
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {

@ -5,16 +5,16 @@ use crate::params;
use crate::params::ParamsTrait;
pub struct Trajectory {
time: Array1<f64>,
events: Array2<u32>
pub time: Array1<f64>,
pub events: Array2<usize>
}
pub struct Dataset {
trajectories: Vec<Trajectory>
pub trajectories: Vec<Trajectory>
}
pub fn trajectory_generator(net: Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
let mut dataset = Dataset{
trajectories: Vec::new()
};
@ -23,7 +23,7 @@ pub fn trajectory_generator(net: Box<dyn network::Network>, n_trajectories: u64,
for _ in 0..n_trajectories {
let mut t = 0.0;
let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Array1<u32>> = Vec::new();
let mut events: Vec<Array1<usize>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
net.get_node(*x).params.get_random_state_uniform()
}).collect();
@ -135,7 +135,7 @@ mod tests {
}
}
let data = trajectory_generator(Box::from(net), 4, 1.0);
let data = trajectory_generator(Box::new(&net), 4, 1.0);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);

Loading…
Cancel
Save