From 7c3cba50d4afb08c1087711ef8fba12a2351ad54 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 11:14:41 +0100 Subject: [PATCH] Implemented amalgamation --- reCTBN/src/process/ctbn.rs | 78 +++++++++++++++++++++++++++++++++++++- reCTBN/tests/ctbn.rs | 36 +++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c59d99d..3852c50 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,8 +4,11 @@ use std::collections::BTreeSet; use ndarray::prelude::*; +use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; use crate::process; -use crate::params::{Params, ParamsTrait, StateType}; + +use super::ctmp::CtmpProcess; +use super::NetworkProcess; /// It represents both the structure and the parameters of a CTBN. /// @@ -67,6 +70,79 @@ impl CtbnNetwork { nodes: Vec::new(), } } + + pub fn amalgamation(&self) -> CtmpProcess { + for v in self.nodes.iter() { + match v { + Params::DiscreteStatesContinousTime(_) => {} + _ => panic!("Unsupported node"), + } + } + + let variables_domain = + Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); + + let state_space = variables_domain.product(); + let variables_set = BTreeSet::from_iter(self.get_node_indices()); + let mut amalgamated_cim: Array3 = Array::zeros((1, state_space, state_space)); + + for idx_current_state in 0..state_space { + let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); + let current_state_statetype: Vec = current_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + for idx_node in 0..self.nodes.len() { + let p = match self.get_node(idx_node) { + Params::DiscreteStatesContinousTime(p) => p, + }; + for next_node_state in 0..variables_domain[idx_node] { + let mut next_state = current_state.clone(); + next_state[idx_node] = next_node_state; + + let next_state_statetype: Vec = next_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + let idx_next_state = self.get_param_index_from_custom_parent_set( + &next_state_statetype, + &variables_set, + ); + amalgamated_cim[[0, idx_current_state, idx_next_state]] += + p.get_cim().as_ref().unwrap()[[ + self.get_param_index_network(idx_node, ¤t_state_statetype), + current_state[idx_node], + next_node_state, + ]]; + } + } + } + + let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( + "ctmp".to_string(), + BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), + ); + + println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); + + amalgamated_param.set_cim(amalgamated_cim).unwrap(); + + let mut ctmp = CtmpProcess::new(); + + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap(); + return ctmp; + } + + pub fn idx_to_state(variables_domain: &Array1, state: usize) -> Array1 { + let mut state = state; + let mut array_state = Array1::zeros(variables_domain.shape()[0]); + for (idx, var) in variables_domain.indexed_iter() { + array_state[idx] = state % var; + state = state / var; + } + + return array_state; + } } impl process::NetworkProcess for CtbnNetwork { diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 0ad0fc4..fc17a94 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,7 +1,10 @@ mod utils; use std::collections::BTreeSet; +use std::f64::EPSILON; -use reCTBN::process::ctbn::*; +use approx::AbsDiffEq; +use ndarray::arr3; +use reCTBN::process::{ctbn::*, ctmp::*}; use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; @@ -129,3 +132,34 @@ fn compute_index_from_custom_parent_set() { ); assert_eq!(2, idx); } + +#[test] +fn simple_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + net.initialize_adj_matrix(); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); + } + } + + let ctmp = net.amalgamation(); + let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){ + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + + assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); +}