|
|
|
@ -4,9 +4,9 @@ use std::f64::EPSILON; |
|
|
|
|
|
|
|
|
|
use approx::AbsDiffEq; |
|
|
|
|
use ndarray::arr3; |
|
|
|
|
use reCTBN::process::{ctbn::*, ctmp::*}; |
|
|
|
|
use reCTBN::process::NetworkProcess; |
|
|
|
|
use reCTBN::params::{self, ParamsTrait}; |
|
|
|
|
use reCTBN::process::NetworkProcess; |
|
|
|
|
use reCTBN::process::{ctbn::*, ctmp::*}; |
|
|
|
|
use utils::generate_discrete_time_continous_node; |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
@ -160,6 +160,226 @@ fn simple_amalgamation() { |
|
|
|
|
unreachable!(); |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn chain_amalgamation() { |
|
|
|
|
let mut net = CtbnNetwork::new(); |
|
|
|
|
let n1 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
|
|
|
.unwrap(); |
|
|
|
|
let n2 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
|
|
|
|
.unwrap(); |
|
|
|
|
let n3 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
|
|
|
|
.unwrap(); |
|
|
|
|
|
|
|
|
|
net.add_edge(n1, n2); |
|
|
|
|
net.add_edge(n2, n3); |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n1) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n2) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!( |
|
|
|
|
Ok(()), |
|
|
|
|
param.set_cim(arr3(&[ |
|
|
|
|
[[-0.01, 0.01], [5.0, -5.0]], |
|
|
|
|
[[-5.0, 5.0], [0.01, -0.01]] |
|
|
|
|
])) |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n3) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!( |
|
|
|
|
Ok(()), |
|
|
|
|
param.set_cim(arr3(&[ |
|
|
|
|
[[-0.01, 0.01], [5.0, -5.0]], |
|
|
|
|
[[-5.0, 5.0], [0.01, -0.01]] |
|
|
|
|
])) |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
let ctmp = net.amalgamation(); |
|
|
|
|
|
|
|
|
|
let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { |
|
|
|
|
p.get_cim().as_ref().unwrap() |
|
|
|
|
} else { |
|
|
|
|
unreachable!(); |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
let p_ctmp_handmade = arr3(&[[ |
|
|
|
|
[ |
|
|
|
|
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00, |
|
|
|
|
], |
|
|
|
|
]]); |
|
|
|
|
|
|
|
|
|
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn chainfork_amalgamation() { |
|
|
|
|
let mut net = CtbnNetwork::new(); |
|
|
|
|
let n1 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
|
|
|
.unwrap(); |
|
|
|
|
let n2 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
|
|
|
|
.unwrap(); |
|
|
|
|
let n3 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
|
|
|
|
.unwrap(); |
|
|
|
|
let n4 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2)) |
|
|
|
|
.unwrap(); |
|
|
|
|
|
|
|
|
|
net.add_edge(n1, n3); |
|
|
|
|
net.add_edge(n2, n3); |
|
|
|
|
net.add_edge(n3, n4); |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n1) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n2) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n3) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!( |
|
|
|
|
Ok(()), |
|
|
|
|
param.set_cim(arr3(&[ |
|
|
|
|
[[-0.01, 0.01], [5.0, -5.0]], |
|
|
|
|
[[-0.01, 0.01], [5.0, -5.0]], |
|
|
|
|
[[-0.01, 0.01], [5.0, -5.0]], |
|
|
|
|
[[-5.0, 5.0], [0.01, -0.01]] |
|
|
|
|
])) |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n4) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!( |
|
|
|
|
Ok(()), |
|
|
|
|
param.set_cim(arr3(&[ |
|
|
|
|
[[-0.01, 0.01], [5.0, -5.0]], |
|
|
|
|
[[-5.0, 5.0], [0.01, -0.01]] |
|
|
|
|
])) |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let ctmp = net.amalgamation(); |
|
|
|
|
|
|
|
|
|
let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { |
|
|
|
|
p.get_cim().as_ref().unwrap() |
|
|
|
|
} else { |
|
|
|
|
unreachable!(); |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
let p_ctmp_handmade = arr3(&[[ |
|
|
|
|
[ |
|
|
|
|
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
|
|
|
|
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, |
|
|
|
|
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00, |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01, |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01, |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
|
|
|
|
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
|
|
|
|
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, |
|
|
|
|
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, |
|
|
|
|
], |
|
|
|
|
[ |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, |
|
|
|
|
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00, |
|
|
|
|
], |
|
|
|
|
]]); |
|
|
|
|
|
|
|
|
|
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); |
|
|
|
|
} |
|
|
|
|