From 430033afdb17a239ada1ccad16f3c32e3ce48234 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 11:20:13 +0100 Subject: [PATCH] Added tests for the learning of parameters using uniform graph and parameters generators as complementary to their handcrafted version --- reCTBN/tests/parameter_learning.rs | 203 +++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 2cbc185..0a09a2a 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -6,6 +6,7 @@ use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; +use reCTBN::params::Params::DiscreteStatesContinousTime; use reCTBN::tools::*; use utils::*; @@ -66,18 +67,78 @@ fn learn_binary_cim(pl: T) { )); } +fn generate_nodes( + net: &mut CtbnNetwork, + nodes_cardinality: usize, + nodes_domain_cardinality: usize +) { + for node_label in 0..nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } +} + +fn learn_binary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 2); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_binary_cim_MLE() { let mle = MLE {}; learn_binary_cim(mle); } +#[test] +fn learn_binary_cim_MLE_gen() { + let mle = MLE {}; + learn_binary_cim_gen(mle); +} + #[test] fn learn_binary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_binary_cim(ba); } +#[test] +fn learn_binary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_binary_cim_gen(ba); +} + fn learn_ternary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -155,18 +216,63 @@ fn learn_ternary_cim(pl: T) { )); } +fn learn_ternary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 4.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_MLE() { let mle = MLE {}; learn_ternary_cim(mle); } +#[test] +fn learn_ternary_cim_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_gen(mle); +} + #[test] fn learn_ternary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim(ba); } +#[test] +fn learn_ternary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_gen(ba); +} + fn learn_ternary_cim_no_parents(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -234,18 +340,63 @@ fn learn_ternary_cim_no_parents(pl: T) { )); } +fn learn_ternary_cim_no_parents_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(0) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 0, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_no_parents_MLE() { let mle = MLE {}; learn_ternary_cim_no_parents(mle); } +#[test] +fn learn_ternary_cim_no_parents_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_no_parents_gen(mle); +} + #[test] fn learn_ternary_cim_no_parents_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim_no_parents(ba); } +#[test] +fn learn_ternary_cim_no_parents_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_no_parents_gen(ba); +} + fn learn_mixed_discrete_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -432,14 +583,66 @@ fn learn_mixed_discrete_cim(pl: T) { )); } +fn learn_mixed_discrete_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..8.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(2) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 2, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.2 + ) + ); +} + #[test] fn learn_mixed_discrete_cim_MLE() { let mle = MLE {}; learn_mixed_discrete_cim(mle); } +#[test] +fn learn_mixed_discrete_cim_MLE_gen() { + let mle = MLE {}; + learn_mixed_discrete_cim_gen(mle); +} + #[test] fn learn_mixed_discrete_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_mixed_discrete_cim(ba); } + +#[test] +fn learn_mixed_discrete_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_mixed_discrete_cim_gen(ba); +}