Added `itertools` a WIP version of CTPC and some hacky and temporary modifications

pull/79/head
Meliurwen 2 years ago
parent 6e90458418
commit 6d952f8c07
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 1
      reCTBN/Cargo.toml
  2. 2
      reCTBN/src/structure_learning.rs
  3. 44
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  4. 2
      reCTBN/src/structure_learning/score_based_algorithm.rs
  5. 8
      reCTBN/tests/structure_learning.rs

@ -13,6 +13,7 @@ bimap = "~0.6"
enum_dispatch = "~0.3" enum_dispatch = "~0.3"
statrs = "~0.16" statrs = "~0.16"
rand_chacha = "~0.3" rand_chacha = "~0.3"
itertools = "~0.10"
[dev-dependencies] [dev-dependencies]
approx = { package = "approx", version = "~0.5" } approx = { package = "approx", version = "~0.5" }

@ -7,7 +7,7 @@ pub mod score_function;
use crate::{process, tools}; use crate::{process, tools};
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T
where where
T: process::NetworkProcess; T: process::NetworkProcess;
} }

@ -1,15 +1,18 @@
//! Module containing constraint based algorithms like CTPC and Hiton. //! Module containing constraint based algorithms like CTPC and Hiton.
use itertools::Itertools;
use std::collections::BTreeSet;
use std::usize;
use super::hypothesis_test::*; use super::hypothesis_test::*;
use crate::parameter_learning::{Cache, ParameterLearning};
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools}; use crate::{process, tools};
use crate::parameter_learning::{Cache, ParameterLearning};
pub struct CTPC<P: ParameterLearning> { pub struct CTPC<P: ParameterLearning> {
Ftest: F, Ftest: F,
Chi2test: ChiSquare, Chi2test: ChiSquare,
cache: Cache<P>, cache: Cache<P>,
} }
impl<P: ParameterLearning> CTPC<P> { impl<P: ParameterLearning> CTPC<P> {
@ -23,7 +26,7 @@ impl<P: ParameterLearning> CTPC<P> {
} }
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> { impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T
where where
T: process::NetworkProcess, T: process::NetworkProcess,
{ {
@ -34,6 +37,41 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
//Make the network mutable. //Make the network mutable.
let mut net = net; let mut net = net;
net.initialize_adj_matrix();
for child_node in net.get_node_indices() {
let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices()
.into_iter()
.filter(|x| x != &child_node)
.collect();
let mut b = 0;
while b < candidate_parent_set.len() {
for parent_node in candidate_parent_set.iter() {
for separation_set in candidate_parent_set
.iter()
.filter(|x| x != &parent_node)
.map(|x| *x)
.combinations(b)
{
let separation_set = separation_set.into_iter().collect();
if self.Ftest.call(
&net,
child_node,
*parent_node,
&separation_set,
&mut self.cache,
) && self.Chi2test.call(&net, child_node, *parent_node, &separation_set, &mut self.cache) {
candidate_parent_set.remove(&parent_node);
break;
}
}
}
b = b + 1;
}
}
net net
} }
} }

@ -21,7 +21,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
} }
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T
where where
T: process::NetworkProcess, T: process::NetworkProcess,
{ {

@ -58,7 +58,7 @@ fn simple_bic() {
); );
} }
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) { fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(mut sl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
@ -125,7 +125,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
check_compatibility_between_dataset_and_network(hl); check_compatibility_between_dataset_and_network(hl);
} }
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) { fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(mut sl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
@ -320,7 +320,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
return (net, data); return (net, data);
} }
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) { fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(mut sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); assert_eq!(BTreeSet::new(), net.get_parent_set(0));
@ -342,7 +342,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
learn_mixed_discrete_net_3_nodes(hl); learn_mixed_discrete_net_3_nodes(hl);
} }
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) { fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(mut sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); assert_eq!(BTreeSet::new(), net.get_parent_set(0));

Loading…
Cancel
Save