From e08d12ac1f243511befbc76c0c35c7dc03efd679 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 09:04:35 +0100 Subject: [PATCH] Added tests for structure learning algorithms using uniform graph and parameters generators as complementary to their handcrafted version --- reCTBN/tests/structure_learning.rs | 171 +++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 9a69b45..3d7e230 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -117,6 +117,50 @@ fn check_compatibility_between_dataset_and_network(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); + + let mut net = CtbnNetwork::new(); + let _n1 = net + .add_node( + generate_discrete_time_continous_node(String::from("0"), + 3) + ).unwrap(); + let _net = sl.fit_transform(net, &data); +} + #[test] #[should_panic] pub fn check_compatibility_between_dataset_and_network_hill_climbing() { @@ -125,6 +169,14 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + check_compatibility_between_dataset_and_network_gen(hl); +} + fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -182,6 +234,25 @@ fn learn_ternary_net_2_nodes(sl: T) { assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } +fn learn_ternary_net_2_nodes_gen(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -189,6 +260,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_ternary_net_2_nodes_gen(hl); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -196,6 +274,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_ternary_net_2_nodes_gen(hl); +} + fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { let mut net = CtbnNetwork::new(); let n1 = net @@ -320,6 +405,30 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } +fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); + return (net, data); +} + fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -328,6 +437,14 @@ fn learn_mixed_discrete_net_3_nodes(sl: T) { assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); } +fn learn_mixed_discrete_net_3_nodes_gen(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -335,6 +452,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -342,6 +466,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -350,6 +481,14 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { let ll = LogLikelihood::new(1, 1.0); @@ -357,6 +496,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { let bic = BIC::new(1, 1.0); @@ -364,6 +510,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn chi_square_compare_matrices() { let i: usize = 1; @@ -511,6 +664,15 @@ pub fn learn_ternary_net_2_nodes_ctpc() { learn_ternary_net_2_nodes(ctpc); } +#[test] +pub fn learn_ternary_net_2_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes_gen(ctpc); +} + #[test] fn learn_mixed_discrete_net_3_nodes_ctpc() { let f = F::new(1e-6); @@ -519,3 +681,12 @@ fn learn_mixed_discrete_net_3_nodes_ctpc() { let ctpc = CTPC::new(parameter_learning, f, chi_sq); learn_mixed_discrete_net_3_nodes(ctpc); } + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_mixed_discrete_net_3_nodes_gen(ctpc); +}