From 8d0f9db289b8453bed413d2ba70ec1693b0c376d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 17:21:53 +0100 Subject: [PATCH] WIP: Added tests for CTPC --- reCTBN/tests/structure_learning.rs | 79 ++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 4bf9027..c0deffd 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -10,6 +10,7 @@ use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; +use reCTBN::structure_learning::constraint_based_algorithm::*; use reCTBN::structure_learning::score_based_algorithm::*; use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::StructureLearningAlgorithm; @@ -497,3 +498,81 @@ pub fn f_call() { separation_set.insert(N1); assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); } + +#[test] +pub fn learn_ternary_net_2_nodes_ctpc() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let f = F::new(0.000001); + let chi_sq = ChiSquare::new(0.0001); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let cache = Cache::new(parameter_learning, data.clone()); + let mut ctpc = CTPC::new(f, chi_sq, cache); + + + let net = ctpc.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); + assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); +} + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc() { + let (_, data) = get_mixed_discrete_net_3_nodes_with_data(); + + let f = F::new(1e-24); + let chi_sq = ChiSquare::new(1e-24); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let cache = Cache::new(parameter_learning, data); + let ctpc = CTPC::new(f, chi_sq, cache); + + learn_mixed_discrete_net_3_nodes(ctpc); +}