Bug fix of sampling + implementation of Bayesian Approach

pull/19/head
AlessandroBregoli 3 years ago
parent 5c816ebba7
commit f87900fdbd
  1. 2
      Cargo.toml
  2. 41
      src/parameter_learning.rs
  3. 10
      src/params.rs
  4. 124
      src/tools.rs
  5. 55
      tests/parameter_learning.rs
  6. 2
      tests/tools.rs

@ -7,7 +7,7 @@ edition = "2021"
[dependencies]
ndarray = "*"
ndarray = {version="*", features=["approx"]}
thiserror = "*"
rand = "*"
bimap = "*"

@ -112,3 +112,44 @@ impl ParameterLearning for MLE {
return (CIM, M, T);
}
}
pub struct BayesianApproach {
default_alpha: usize,
default_tau: f64
}
impl ParameterLearning for BayesianApproach {
fn fit(
&self,
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node),
};
let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
M.mapv_inplace(|x|{x + self.default_alpha});
T.mapv_inplace(|x|{x + self.default_tau});
//Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut()
.zip(tmp_diag_sum.outer_iter())
.for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag);
});
return (CIM, M, T);
}
}

@ -102,7 +102,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..1.0);
let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda)
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
@ -119,15 +119,17 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..1.0);
let urand: f64 = rng.gen_range(0.0..=1.0);
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
(0, 0.0),
|mut acc, ele| {
if &acc.1 + ele < x && ele > &0.0 {
acc.1 += x;
if &acc.1 + ele < urand && ele > &0.0 {
acc.0 += 1;
}
if ele > &0.0 {
acc.1 += ele;
}
acc
},
);

@ -1,22 +1,25 @@
use ndarray::prelude::*;
use crate::network;
use crate::node;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*;
pub struct Trajectory {
pub time: Array1<f64>,
pub events: Array2<usize>
pub events: Array2<usize>,
}
pub struct Dataset {
pub trajectories: Vec<Trajectory>
pub trajectories: Vec<Trajectory>,
}
pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
let mut dataset = Dataset{
trajectories: Vec::new()
pub fn trajectory_generator(
net: Box<&dyn network::Network>,
n_trajectories: u64,
t_end: f64,
) -> Dataset {
let mut dataset = Dataset {
trajectories: Vec::new(),
};
let node_idx: Vec<_> = net.get_node_indices().collect();
@ -24,72 +27,93 @@ pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64
let mut t = 0.0;
let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Array1<usize>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
net.get_node(*x).params.get_random_state_uniform()
}).collect();
let mut next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect();
events.push(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
}).collect());
let mut current_state: Vec<params::StateType> = node_idx
.iter()
.map(|x| net.get_node(*x).params.get_random_state_uniform())
.collect();
let mut next_transitions: Vec<Option<f64>> =
(0..node_idx.len()).map(|_| Option::None).collect();
events.push(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
time.push(t.clone());
while t < t_end {
for (idx, val) in next_transitions.iter_mut().enumerate(){
for (idx, val) in next_transitions.iter_mut().enumerate() {
if let None = val {
*val = Some(net.get_node(idx).params
.get_random_residence_time(net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state)).unwrap() + t);
*val = Some(
net.get_node(idx)
.params
.get_random_residence_time(
net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state),
)
.unwrap()
+ t,
);
}
};
}
let next_node_transition = next_transitions
.iter()
.enumerate()
.min_by(|x, y|
x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap().0;
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap()
.0;
if next_transitions[next_node_transition].unwrap() > t_end {
break
break;
}
t = next_transitions[next_node_transition].unwrap().clone();
time.push(t.clone());
current_state[next_node_transition] = net.get_node(next_node_transition).params
.get_random_state(
net.get_node(next_node_transition).params.
state_to_index(
&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state))
.unwrap();
events.push(Array::from_vec(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
}).collect()));
current_state[next_node_transition] = net
.get_node(next_node_transition)
.params
.get_random_state(
net.get_node(next_node_transition)
.params
.state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state),
)
.unwrap();
events.push(Array::from_vec(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
));
next_transitions[next_node_transition] = None;
for child in net.get_children_set(next_node_transition){
for child in net.get_children_set(next_node_transition) {
next_transitions[child] = None
}
}
events.push(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
}).collect());
events.push(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
time.push(t_end.clone());
dataset.trajectories.push(Trajectory {
time: Array::from_vec(time),
events: Array2::from_shape_vec((events.len(), current_state.len()), events.iter().flatten().cloned().collect()).unwrap()
events: Array2::from_shape_vec(
(events.len(), current_state.len()),
events.iter().flatten().cloned().collect(),
)
.unwrap(),
});
}
dataset
}

@ -46,10 +46,10 @@ fn learn_binary_cim_MLE() {
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2);
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]), 0.2));
}
@ -87,9 +87,48 @@ fn learn_ternary_cim_MLE() {
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2);
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2);
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.2));
}
#[test]
fn learn_ternary_cim_MLE_no_parents() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 100, 200.0);
let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]), 0.2));
}

@ -41,3 +41,5 @@ fn run_sampling() {
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);
}

Loading…
Cancel
Save