Bug fix of sampling + implementation of Bayesian Approach

pull/19/head
AlessandroBregoli 3 years ago
parent 5c816ebba7
commit f87900fdbd
  1. 2
      Cargo.toml
  2. 41
      src/parameter_learning.rs
  3. 10
      src/params.rs
  4. 110
      src/tools.rs
  5. 55
      tests/parameter_learning.rs
  6. 2
      tests/tools.rs

@ -7,7 +7,7 @@ edition = "2021"
[dependencies] [dependencies]
ndarray = "*" ndarray = {version="*", features=["approx"]}
thiserror = "*" thiserror = "*"
rand = "*" rand = "*"
bimap = "*" bimap = "*"

@ -112,3 +112,44 @@ impl ParameterLearning for MLE {
return (CIM, M, T); return (CIM, M, T);
} }
} }
pub struct BayesianApproach {
default_alpha: usize,
default_tau: f64
}
impl ParameterLearning for BayesianApproach {
fn fit(
&self,
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node),
};
let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
M.mapv_inplace(|x|{x + self.default_alpha});
T.mapv_inplace(|x|{x + self.default_tau});
//Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut()
.zip(tmp_diag_sum.outer_iter())
.for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag);
});
return (CIM, M, T);
}
}

@ -102,7 +102,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
Option::Some(cim) => { Option::Some(cim) => {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..1.0); let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda) Ok(-x.ln() / lambda)
} }
Option::None => Err(ParamsError::ParametersNotInitialized(String::from( Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
@ -119,15 +119,17 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
Option::Some(cim) => { Option::Some(cim) => {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..1.0); let urand: f64 = rng.gen_range(0.0..=1.0);
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold( let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
(0, 0.0), (0, 0.0),
|mut acc, ele| { |mut acc, ele| {
if &acc.1 + ele < x && ele > &0.0 { if &acc.1 + ele < urand && ele > &0.0 {
acc.1 += x;
acc.0 += 1; acc.0 += 1;
} }
if ele > &0.0 {
acc.1 += ele;
}
acc acc
}, },
); );

@ -1,22 +1,25 @@
use ndarray::prelude::*;
use crate::network; use crate::network;
use crate::node; use crate::node;
use crate::params; use crate::params;
use crate::params::ParamsTrait; use crate::params::ParamsTrait;
use ndarray::prelude::*;
pub struct Trajectory { pub struct Trajectory {
pub time: Array1<f64>, pub time: Array1<f64>,
pub events: Array2<usize> pub events: Array2<usize>,
} }
pub struct Dataset { pub struct Dataset {
pub trajectories: Vec<Trajectory> pub trajectories: Vec<Trajectory>,
} }
pub fn trajectory_generator(
pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset { net: Box<&dyn network::Network>,
n_trajectories: u64,
t_end: f64,
) -> Dataset {
let mut dataset = Dataset { let mut dataset = Dataset {
trajectories: Vec::new() trajectories: Vec::new(),
}; };
let node_idx: Vec<_> = net.get_node_indices().collect(); let node_idx: Vec<_> = net.get_node_indices().collect();
@ -24,72 +27,93 @@ pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64
let mut t = 0.0; let mut t = 0.0;
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Array1<usize>> = Vec::new(); let mut events: Vec<Array1<usize>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| { let mut current_state: Vec<params::StateType> = node_idx
net.get_node(*x).params.get_random_state_uniform() .iter()
}).collect(); .map(|x| net.get_node(*x).params.get_random_state_uniform())
let mut next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect(); .collect();
events.push(current_state.iter().map(|x| match x { let mut next_transitions: Vec<Option<f64>> =
params::StateType::Discrete(state) => state.clone() (0..node_idx.len()).map(|_| Option::None).collect();
}).collect()); events.push(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
time.push(t.clone()); time.push(t.clone());
while t < t_end { while t < t_end {
for (idx, val) in next_transitions.iter_mut().enumerate() { for (idx, val) in next_transitions.iter_mut().enumerate() {
if let None = val { if let None = val {
*val = Some(net.get_node(idx).params *val = Some(
.get_random_residence_time(net.get_node(idx).params.state_to_index(&current_state[idx]), net.get_node(idx)
net.get_param_index_network(idx, &current_state)).unwrap() + t); .params
.get_random_residence_time(
net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state),
)
.unwrap()
+ t,
);
}
} }
};
let next_node_transition = next_transitions let next_node_transition = next_transitions
.iter() .iter()
.enumerate() .enumerate()
.min_by(|x, y| .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) .unwrap()
.unwrap().0; .0;
if next_transitions[next_node_transition].unwrap() > t_end { if next_transitions[next_node_transition].unwrap() > t_end {
break break;
} }
t = next_transitions[next_node_transition].unwrap().clone(); t = next_transitions[next_node_transition].unwrap().clone();
time.push(t.clone()); time.push(t.clone());
current_state[next_node_transition] = net.get_node(next_node_transition).params current_state[next_node_transition] = net
.get_node(next_node_transition)
.params
.get_random_state( .get_random_state(
net.get_node(next_node_transition).params. net.get_node(next_node_transition)
state_to_index( .params
&current_state[next_node_transition]), .state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state)) net.get_param_index_network(next_node_transition, &current_state),
)
.unwrap(); .unwrap();
events.push(Array::from_vec(
events.push(Array::from_vec(current_state.iter().map(|x| match x { current_state
params::StateType::Discrete(state) => state.clone() .iter()
}).collect())); .map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
));
next_transitions[next_node_transition] = None; next_transitions[next_node_transition] = None;
for child in net.get_children_set(next_node_transition) { for child in net.get_children_set(next_node_transition) {
next_transitions[child] = None next_transitions[child] = None
} }
} }
events.push(current_state.iter().map(|x| match x { events.push(
params::StateType::Discrete(state) => state.clone() current_state
}).collect()); .iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
time.push(t_end.clone()); time.push(t_end.clone());
dataset.trajectories.push(Trajectory { dataset.trajectories.push(Trajectory {
time: Array::from_vec(time), time: Array::from_vec(time),
events: Array2::from_shape_vec((events.len(), current_state.len()), events.iter().flatten().cloned().collect()).unwrap() events: Array2::from_shape_vec(
(events.len(), current_state.len()),
events.iter().flatten().cloned().collect(),
)
.unwrap(),
}); });
} }
dataset dataset
} }

@ -46,10 +46,10 @@ fn learn_binary_cim_MLE() {
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]); assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert!(CIM.abs_diff_eq(&arr3(&[
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); [[-1.0, 1.0], [4.0, -4.0]],
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2); [[-6.0, 6.0], [2.0, -2.0]],
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2); ]), 0.2));
} }
@ -87,9 +87,48 @@ fn learn_ternary_cim_MLE() {
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]); assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert!(CIM.abs_diff_eq(&arr3(&[
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2); [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2); [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.2));
}
#[test]
fn learn_ternary_cim_MLE_no_parents() {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
} }
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
let data = trajectory_generator(Box::new(&net), 100, 200.0);
let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]), 0.2));
}

@ -41,3 +41,5 @@ fn run_sampling() {
assert_eq!(4, data.trajectories.len()); assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);
} }

Loading…
Cancel
Save