diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index b0a691b..fdac697 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -13,6 +13,7 @@ bimap = "~0.6" enum_dispatch = "~0.3" statrs = "~0.16" rand_chacha = "~0.3" +itertools = "~0.10" [dev-dependencies] approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index b272e22..d119ab2 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -7,7 +7,7 @@ pub mod score_function; use crate::{process, tools}; pub trait StructureLearningAlgorithm { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index d931f78..6fd5b79 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,15 +1,18 @@ //! Module containing constraint based algorithms like CTPC and Hiton. +use itertools::Itertools; +use std::collections::BTreeSet; +use std::usize; + use super::hypothesis_test::*; +use crate::parameter_learning::{Cache, ParameterLearning}; use crate::structure_learning::StructureLearningAlgorithm; use crate::{process, tools}; -use crate::parameter_learning::{Cache, ParameterLearning}; pub struct CTPC { Ftest: F, Chi2test: ChiSquare, cache: Cache

, - } impl CTPC

{ @@ -23,7 +26,7 @@ impl CTPC

{ } impl StructureLearningAlgorithm for CTPC

{ - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess, { @@ -34,6 +37,41 @@ impl StructureLearningAlgorithm for CTPC

{ //Make the network mutable. let mut net = net; + + net.initialize_adj_matrix(); + + for child_node in net.get_node_indices() { + let mut candidate_parent_set: BTreeSet = net + .get_node_indices() + .into_iter() + .filter(|x| x != &child_node) + .collect(); + let mut b = 0; + while b < candidate_parent_set.len() { + for parent_node in candidate_parent_set.iter() { + for separation_set in candidate_parent_set + .iter() + .filter(|x| x != &parent_node) + .map(|x| *x) + .combinations(b) + { + let separation_set = separation_set.into_iter().collect(); + if self.Ftest.call( + &net, + child_node, + *parent_node, + &separation_set, + &mut self.cache, + ) && self.Chi2test.call(&net, child_node, *parent_node, &separation_set, &mut self.cache) { + candidate_parent_set.remove(&parent_node); + break; + } + } + } + b = b + 1; + } + } + net } } diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 16e9056..d59f0c1 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -21,7 +21,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess, { diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 6134510..4bf9027 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -58,7 +58,7 @@ fn simple_bic() { ); } -fn check_compatibility_between_dataset_and_network(sl: T) { +fn check_compatibility_between_dataset_and_network(mut sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -125,7 +125,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes(sl: T) { +fn learn_ternary_net_2_nodes(mut sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -320,7 +320,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } -fn learn_mixed_discrete_net_3_nodes(sl: T) { +fn learn_mixed_discrete_net_3_nodes(mut sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -342,7 +342,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(mut sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0));