diff --git a/Cargo.toml b/Cargo.toml index 3b1bb3e..3aa7c53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] -ndarray = "*" +ndarray = {version="*", features=["approx"]} thiserror = "*" rand = "*" bimap = "*" diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 1582fce..3a6557b 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -112,3 +112,44 @@ impl ParameterLearning for MLE { return (CIM, M, T); } } + +pub struct BayesianApproach { + default_alpha: usize, + default_tau: f64 +} + +impl ParameterLearning for BayesianApproach { + fn fit( + &self, + net: Box<&dyn network::Network>, + dataset: &tools::Dataset, + node: usize, + parent_set: Option>, + ) -> (Array3, Array3, Array2) { + //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes + + //Use parent_set from parameter if present. Otherwise use parent_set from network. + let parent_set = match parent_set { + Some(p) => p, + None => net.get_parent_set(node), + }; + + let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); + M.mapv_inplace(|x|{x + self.default_alpha}); + T.mapv_inplace(|x|{x + self.default_tau}); + //Compute the CIM as M[i,x,y]/T[i,x] + let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); + CIM.axis_iter_mut(Axis(2)) + .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) + .for_each(|(mut C, m)| C.assign(&(&m/&T))); + + //Set the diagonal of the inner matrices to the the row sum multiplied by -1 + let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); + CIM.outer_iter_mut() + .zip(tmp_diag_sum.outer_iter()) + .for_each(|(mut C, diag)| { + C.diag_mut().assign(&diag); + }); + return (CIM, M, T); + } +} diff --git a/src/params.rs b/src/params.rs index 06164ab..c5a9acf 100644 --- a/src/params.rs +++ b/src/params.rs @@ -102,7 +102,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { Option::Some(cim) => { let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; - let x: f64 = rng.gen_range(0.0..1.0); + let x: f64 = rng.gen_range(0.0..=1.0); Ok(-x.ln() / lambda) } Option::None => Err(ParamsError::ParametersNotInitialized(String::from( @@ -119,15 +119,17 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { Option::Some(cim) => { let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; - let x: f64 = rng.gen_range(0.0..1.0); + let urand: f64 = rng.gen_range(0.0..=1.0); let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold( (0, 0.0), |mut acc, ele| { - if &acc.1 + ele < x && ele > &0.0 { - acc.1 += x; + if &acc.1 + ele < urand && ele > &0.0 { acc.0 += 1; } + if ele > &0.0 { + acc.1 += ele; + } acc }, ); diff --git a/src/tools.rs b/src/tools.rs index 95ee013..a719bb9 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,22 +1,25 @@ -use ndarray::prelude::*; use crate::network; use crate::node; use crate::params; use crate::params::ParamsTrait; +use ndarray::prelude::*; pub struct Trajectory { pub time: Array1, - pub events: Array2 + pub events: Array2, } pub struct Dataset { - pub trajectories: Vec + pub trajectories: Vec, } - -pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset { - let mut dataset = Dataset{ - trajectories: Vec::new() +pub fn trajectory_generator( + net: Box<&dyn network::Network>, + n_trajectories: u64, + t_end: f64, +) -> Dataset { + let mut dataset = Dataset { + trajectories: Vec::new(), }; let node_idx: Vec<_> = net.get_node_indices().collect(); @@ -24,72 +27,93 @@ pub fn trajectory_generator(net: Box<&dyn network::Network>, n_trajectories: u64 let mut t = 0.0; let mut time: Vec = Vec::new(); let mut events: Vec> = Vec::new(); - let mut current_state: Vec = node_idx.iter().map(|x| { - net.get_node(*x).params.get_random_state_uniform() - }).collect(); - let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); - events.push(current_state.iter().map(|x| match x { - params::StateType::Discrete(state) => state.clone() - }).collect()); + let mut current_state: Vec = node_idx + .iter() + .map(|x| net.get_node(*x).params.get_random_state_uniform()) + .collect(); + let mut next_transitions: Vec> = + (0..node_idx.len()).map(|_| Option::None).collect(); + events.push( + current_state + .iter() + .map(|x| match x { + params::StateType::Discrete(state) => state.clone(), + }) + .collect(), + ); time.push(t.clone()); while t < t_end { - for (idx, val) in next_transitions.iter_mut().enumerate(){ + for (idx, val) in next_transitions.iter_mut().enumerate() { if let None = val { - *val = Some(net.get_node(idx).params - .get_random_residence_time(net.get_node(idx).params.state_to_index(¤t_state[idx]), - net.get_param_index_network(idx, ¤t_state)).unwrap() + t); + *val = Some( + net.get_node(idx) + .params + .get_random_residence_time( + net.get_node(idx).params.state_to_index(¤t_state[idx]), + net.get_param_index_network(idx, ¤t_state), + ) + .unwrap() + + t, + ); } - }; + } let next_node_transition = next_transitions .iter() .enumerate() - .min_by(|x, y| - x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) - .unwrap().0; - + .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) + .unwrap() + .0; if next_transitions[next_node_transition].unwrap() > t_end { - break + break; } - t = next_transitions[next_node_transition].unwrap().clone(); time.push(t.clone()); - current_state[next_node_transition] = net.get_node(next_node_transition).params - .get_random_state( - net.get_node(next_node_transition).params. - state_to_index( - ¤t_state[next_node_transition]), - net.get_param_index_network(next_node_transition, ¤t_state)) - .unwrap(); - - - events.push(Array::from_vec(current_state.iter().map(|x| match x { - params::StateType::Discrete(state) => state.clone() - }).collect())); + current_state[next_node_transition] = net + .get_node(next_node_transition) + .params + .get_random_state( + net.get_node(next_node_transition) + .params + .state_to_index(¤t_state[next_node_transition]), + net.get_param_index_network(next_node_transition, ¤t_state), + ) + .unwrap(); + + events.push(Array::from_vec( + current_state + .iter() + .map(|x| match x { + params::StateType::Discrete(state) => state.clone(), + }) + .collect(), + )); next_transitions[next_node_transition] = None; - - for child in net.get_children_set(next_node_transition){ + + for child in net.get_children_set(next_node_transition) { next_transitions[child] = None } - } - events.push(current_state.iter().map(|x| match x { - params::StateType::Discrete(state) => state.clone() - }).collect()); + events.push( + current_state + .iter() + .map(|x| match x { + params::StateType::Discrete(state) => state.clone(), + }) + .collect(), + ); time.push(t_end.clone()); - dataset.trajectories.push(Trajectory { time: Array::from_vec(time), - events: Array2::from_shape_vec((events.len(), current_state.len()), events.iter().flatten().cloned().collect()).unwrap() + events: Array2::from_shape_vec( + (events.len(), current_state.len()), + events.iter().flatten().cloned().collect(), + ) + .unwrap(), }); - - } - dataset } - - diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index a59f0ee..c4b2a67 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -46,10 +46,10 @@ fn learn_binary_cim_MLE() { let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); - assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); - assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); - assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2); - assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2); + assert!(CIM.abs_diff_eq(&arr3(&[ + [[-1.0, 1.0], [4.0, -4.0]], + [[-6.0, 6.0], [2.0, -2.0]], + ]), 0.2)); } @@ -87,9 +87,48 @@ fn learn_ternary_cim_MLE() { let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); - assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); - assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); - assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2); - assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2); + assert!(CIM.abs_diff_eq(&arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]), 0.2)); } +#[test] +fn learn_ternary_cim_MLE_no_parents() { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]])); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])); + } + } + + let data = trajectory_generator(Box::new(&net), 100, 200.0); + let mle = MLE{}; + let (CIM, M, T) = mle.fit(Box::new(&net), &data, 0, None); + print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); + assert_eq!(CIM.shape(), [1, 3, 3]); + assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]]), 0.2)); +} diff --git a/tests/tools.rs b/tests/tools.rs index 74e358b..efeef2e 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -41,3 +41,5 @@ fn run_sampling() { assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); } + +