Added `itertools` a WIP version of CTPC and some hacky and temporary modifications

pull/79/head
Meliurwen 2 years ago
parent 6e90458418
commit 6d952f8c07
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 1
      reCTBN/Cargo.toml
  2. 2
      reCTBN/src/structure_learning.rs
  3. 44
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  4. 2
      reCTBN/src/structure_learning/score_based_algorithm.rs
  5. 8
      reCTBN/tests/structure_learning.rs

@ -13,6 +13,7 @@ bimap = "~0.6"
enum_dispatch = "~0.3"
statrs = "~0.16"
rand_chacha = "~0.3"
itertools = "~0.10"
[dev-dependencies]
approx = { package = "approx", version = "~0.5" }

@ -7,7 +7,7 @@ pub mod score_function;
use crate::{process, tools};
pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T
where
T: process::NetworkProcess;
}

@ -1,15 +1,18 @@
//! Module containing constraint based algorithms like CTPC and Hiton.
use itertools::Itertools;
use std::collections::BTreeSet;
use std::usize;
use super::hypothesis_test::*;
use crate::parameter_learning::{Cache, ParameterLearning};
use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools};
use crate::parameter_learning::{Cache, ParameterLearning};
pub struct CTPC<P: ParameterLearning> {
Ftest: F,
Chi2test: ChiSquare,
cache: Cache<P>,
}
impl<P: ParameterLearning> CTPC<P> {
@ -23,7 +26,7 @@ impl<P: ParameterLearning> CTPC<P> {
}
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T
where
T: process::NetworkProcess,
{
@ -34,6 +37,41 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
//Make the network mutable.
let mut net = net;
net.initialize_adj_matrix();
for child_node in net.get_node_indices() {
let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices()
.into_iter()
.filter(|x| x != &child_node)
.collect();
let mut b = 0;
while b < candidate_parent_set.len() {
for parent_node in candidate_parent_set.iter() {
for separation_set in candidate_parent_set
.iter()
.filter(|x| x != &parent_node)
.map(|x| *x)
.combinations(b)
{
let separation_set = separation_set.into_iter().collect();
if self.Ftest.call(
&net,
child_node,
*parent_node,
&separation_set,
&mut self.cache,
) && self.Chi2test.call(&net, child_node, *parent_node, &separation_set, &mut self.cache) {
candidate_parent_set.remove(&parent_node);
break;
}
}
}
b = b + 1;
}
}
net
}
}

@ -21,7 +21,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
}
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T
where
T: process::NetworkProcess,
{

@ -58,7 +58,7 @@ fn simple_bic() {
);
}
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) {
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(mut sl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
@ -125,7 +125,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
check_compatibility_between_dataset_and_network(hl);
}
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(mut sl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
@ -320,7 +320,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
return (net, data);
}
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) {
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(mut sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
@ -342,7 +342,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
learn_mixed_discrete_net_3_nodes(hl);
}
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) {
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(mut sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));

Loading…
Cancel
Save