Added more tests

pull/73/head
AlessandroBregoli 2 years ago
parent 7c3cba50d4
commit 4a7a8c5fba
  1. 4
      reCTBN/src/params.rs
  2. 8
      reCTBN/src/process/ctbn.rs
  3. 226
      reCTBN/tests/ctbn.rs

@ -267,11 +267,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
))); )));
} }
let domain_size = domain_size as f64;
// Check if each row sum up to 0 // Check if each row sum up to 0
if cim if cim
.sum_axis(Axis(2)) .sum_axis(Axis(2))
.iter() .iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size)
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",

@ -71,6 +71,11 @@ impl CtbnNetwork {
} }
} }
///Transform the **CTBN** into a **CTMP**
///
/// # Return
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess { pub fn amalgamation(&self) -> CtmpProcess {
for v in self.nodes.iter() { for v in self.nodes.iter() {
match v { match v {
@ -123,8 +128,7 @@ impl CtbnNetwork {
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
); );
println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); println!("{:?}", amalgamated_cim);
amalgamated_param.set_cim(amalgamated_cim).unwrap(); amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new(); let mut ctmp = CtmpProcess::new();

@ -4,9 +4,9 @@ use std::f64::EPSILON;
use approx::AbsDiffEq; use approx::AbsDiffEq;
use ndarray::arr3; use ndarray::arr3;
use reCTBN::process::{ctbn::*, ctmp::*};
use reCTBN::process::NetworkProcess;
use reCTBN::params::{self, ParamsTrait}; use reCTBN::params::{self, ParamsTrait};
use reCTBN::process::NetworkProcess;
use reCTBN::process::{ctbn::*, ctmp::*};
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
#[test] #[test]
@ -160,6 +160,226 @@ fn simple_amalgamation() {
unreachable!(); unreachable!();
}; };
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
} }
#[test]
fn chain_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) {
p.get_cim().as_ref().unwrap()
} else {
unreachable!();
};
let p_ctmp_handmade = arr3(&[[
[
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}
#[test]
fn chainfork_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let n4 = net
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2))
.unwrap();
net.add_edge(n1, n3);
net.add_edge(n2, n3);
net.add_edge(n3, n4);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n4) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) {
p.get_cim().as_ref().unwrap()
} else {
unreachable!();
};
let p_ctmp_handmade = arr3(&[[
[
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}

Loading…
Cancel
Save