Implemented amalgamation

pull/73/head
AlessandroBregoli 2 years ago
parent ed5471c7cf
commit 7c3cba50d4
  1. 78
      reCTBN/src/process/ctbn.rs
  2. 36
      reCTBN/tests/ctbn.rs

@ -4,8 +4,11 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
use crate::process; use crate::process;
use crate::params::{Params, ParamsTrait, StateType};
use super::ctmp::CtmpProcess;
use super::NetworkProcess;
/// It represents both the structure and the parameters of a CTBN. /// It represents both the structure and the parameters of a CTBN.
/// ///
@ -67,6 +70,79 @@ impl CtbnNetwork {
nodes: Vec::new(), nodes: Vec::new(),
} }
} }
pub fn amalgamation(&self) -> CtmpProcess {
for v in self.nodes.iter() {
match v {
Params::DiscreteStatesContinousTime(_) => {}
_ => panic!("Unsupported node"),
}
}
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
let state_space = variables_domain.product();
let variables_set = BTreeSet::from_iter(self.get_node_indices());
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space));
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: Vec<StateType> = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
for idx_node in 0..self.nodes.len() {
let p = match self.get_node(idx_node) {
Params::DiscreteStatesContinousTime(p) => p,
};
for next_node_state in 0..variables_domain[idx_node] {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: Vec<StateType> = next_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
&variables_set,
);
amalgamated_cim[[0, idx_current_state, idx_next_state]] +=
p.get_cim().as_ref().unwrap()[[
self.get_param_index_network(idx_node, &current_state_statetype),
current_state[idx_node],
next_node_state,
]];
}
}
}
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new(
"ctmp".to_string(),
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap();
return ctmp;
}
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> {
let mut state = state;
let mut array_state = Array1::zeros(variables_domain.shape()[0]);
for (idx, var) in variables_domain.indexed_iter() {
array_state[idx] = state % var;
state = state / var;
}
return array_state;
}
} }
impl process::NetworkProcess for CtbnNetwork { impl process::NetworkProcess for CtbnNetwork {

@ -1,7 +1,10 @@
mod utils; mod utils;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::f64::EPSILON;
use reCTBN::process::ctbn::*; use approx::AbsDiffEq;
use ndarray::arr3;
use reCTBN::process::{ctbn::*, ctmp::*};
use reCTBN::process::NetworkProcess; use reCTBN::process::NetworkProcess;
use reCTBN::params::{self, ParamsTrait}; use reCTBN::params::{self, ParamsTrait};
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
@ -129,3 +132,34 @@ fn compute_index_from_custom_parent_set() {
); );
assert_eq!(2, idx); assert_eq!(2, idx);
} }
#[test]
fn simple_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.initialize_adj_matrix();
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
let ctmp = net.amalgamation();
let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){
p.get_cim().as_ref().unwrap()
} else {
unreachable!();
};
let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) {
p.get_cim().as_ref().unwrap()
} else {
unreachable!();
};
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
}

Loading…
Cancel
Save