Refactored cache laying grounds for its node-centered implementation changing also its signature, propagated this change and refactored CTPC tests

pull/79/head
Meliurwen 2 years ago
parent ea5df7cad6
commit 19856195c3
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 28
      reCTBN/src/parameter_learning.rs
  2. 4
      reCTBN/src/structure_learning.rs
  3. 20
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  4. 23
      reCTBN/src/structure_learning/hypothesis_test.rs
  5. 4
      reCTBN/src/structure_learning/score_based_algorithm.rs
  6. 112
      reCTBN/tests/structure_learning.rs

@ -5,13 +5,13 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::params::*; use crate::params::*;
use crate::{process, tools}; use crate::{process, tools::Dataset};
pub trait ParameterLearning { pub trait ParameterLearning {
fn fit<T: process::NetworkProcess>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params; ) -> Params;
@ -19,7 +19,7 @@ pub trait ParameterLearning {
pub fn sufficient_statistics<T: process::NetworkProcess>( pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: &BTreeSet<usize>, parent_set: &BTreeSet<usize>,
) -> (Array3<usize>, Array2<f64>) { ) -> (Array3<usize>, Array2<f64>) {
@ -76,7 +76,7 @@ impl ParameterLearning for MLE {
fn fit<T: process::NetworkProcess>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
@ -123,7 +123,7 @@ impl ParameterLearning for BayesianApproach {
fn fit<T: process::NetworkProcess>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
@ -165,25 +165,21 @@ impl ParameterLearning for BayesianApproach {
} }
} }
pub struct Cache<P: ParameterLearning> { pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: P, parameter_learning: &'a P,
dataset: tools::Dataset,
} }
impl<P: ParameterLearning> Cache<P> { impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache<P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache { Cache { parameter_learning }
parameter_learning,
dataset,
}
} }
pub fn fit<T: process::NetworkProcess>( pub fn fit<T: process::NetworkProcess>(
&mut self, &mut self,
net: &T, net: &T,
dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
self.parameter_learning self.parameter_learning.fit(net, dataset, node, parent_set)
.fit(net, &self.dataset, node, parent_set)
} }
} }

@ -4,10 +4,10 @@ pub mod constraint_based_algorithm;
pub mod hypothesis_test; pub mod hypothesis_test;
pub mod score_based_algorithm; pub mod score_based_algorithm;
pub mod score_function; pub mod score_function;
use crate::{process, tools}; use crate::{process, tools::Dataset};
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where where
T: process::NetworkProcess; T: process::NetworkProcess;
} }

@ -6,27 +6,28 @@ use std::usize;
use super::hypothesis_test::*; use super::hypothesis_test::*;
use crate::parameter_learning::{Cache, ParameterLearning}; use crate::parameter_learning::{Cache, ParameterLearning};
use crate::process;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools}; use crate::tools::Dataset;
pub struct CTPC<P: ParameterLearning> { pub struct CTPC<P: ParameterLearning> {
parameter_learning: P,
Ftest: F, Ftest: F,
Chi2test: ChiSquare, Chi2test: ChiSquare,
cache: Cache<P>,
} }
impl<P: ParameterLearning> CTPC<P> { impl<P: ParameterLearning> CTPC<P> {
pub fn new(Ftest: F, Chi2test: ChiSquare, cache: Cache<P>) -> CTPC<P> { pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> {
CTPC { CTPC {
Chi2test, parameter_learning,
Ftest, Ftest,
cache, Chi2test,
} }
} }
} }
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> { impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where where
T: process::NetworkProcess, T: process::NetworkProcess,
{ {
@ -41,6 +42,7 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
net.initialize_adj_matrix(); net.initialize_adj_matrix();
for child_node in net.get_node_indices() { for child_node in net.get_node_indices() {
let mut cache = Cache::new(&self.parameter_learning);
let mut candidate_parent_set: BTreeSet<usize> = net let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices() .get_node_indices()
.into_iter() .into_iter()
@ -62,13 +64,15 @@ impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
child_node, child_node,
*parent_node, *parent_node,
&separation_set, &separation_set,
&mut self.cache, dataset,
&mut cache,
) && self.Chi2test.call( ) && self.Chi2test.call(
&net, &net,
child_node, child_node,
*parent_node, *parent_node,
&separation_set, &separation_set,
&mut self.cache, dataset,
&mut cache,
) { ) {
candidate_parent_set_TMP.remove(parent_node); candidate_parent_set_TMP.remove(parent_node);
break; break;

@ -6,7 +6,7 @@ use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use crate::params::*; use crate::params::*;
use crate::{parameter_learning, process}; use crate::{parameter_learning, process, tools::Dataset};
pub trait HypothesisTest { pub trait HypothesisTest {
fn call<T, P>( fn call<T, P>(
@ -15,6 +15,7 @@ pub trait HypothesisTest {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut parameter_learning::Cache<P>, cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
@ -84,19 +85,25 @@ impl HypothesisTest for F {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut parameter_learning::Cache<P>, cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
T: process::NetworkProcess, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning, P: parameter_learning::ParameterLearning,
{ {
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
let partial_cardinality_product: usize = extended_separation_set let partial_cardinality_product: usize = extended_separation_set
@ -218,6 +225,7 @@ impl HypothesisTest for ChiSquare {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut parameter_learning::Cache<P>, cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
@ -227,14 +235,19 @@ impl HypothesisTest for ChiSquare {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn // di dimensione nxn
// (CIM, M, T) // (CIM, M, T)
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
// //
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
// Commentare qui // Commentare qui

@ -4,7 +4,7 @@ use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools}; use crate::{process, tools::Dataset};
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,
@ -21,7 +21,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
} }
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&mut self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where where
T: process::NetworkProcess, T: process::NetworkProcess,
{ {

@ -59,7 +59,7 @@ fn simple_bic() {
); );
} }
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(mut sl: T) { fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
@ -126,7 +126,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
check_compatibility_between_dataset_and_network(hl); check_compatibility_between_dataset_and_network(hl);
} }
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(mut sl: T) { fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
@ -321,7 +321,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
return (net, data); return (net, data);
} }
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(mut sl: T) { fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); assert_eq!(BTreeSet::new(), net.get_parent_set(0));
@ -343,7 +343,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
learn_mixed_discrete_net_3_nodes(hl); learn_mixed_discrete_net_3_nodes(hl);
} }
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(mut sl: T) { fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); assert_eq!(BTreeSet::new(), net.get_parent_set(0));
@ -393,7 +393,7 @@ pub fn chi_square_compare_matrices() {
[ 700, 800, 0] [ 700, 800, 0]
], ],
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(1e-4);
assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
} }
@ -423,7 +423,7 @@ pub fn chi_square_compare_matrices_2() {
[ 400, 0, 600], [ 400, 0, 600],
[ 700, 800, 0]] [ 700, 800, 0]]
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
@ -455,7 +455,7 @@ pub fn chi_square_compare_matrices_3() {
[ 700, 800, 0] [ 700, 800, 0]
], ],
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
@ -469,14 +469,14 @@ pub fn chi_square_call() {
let N1: usize = 0; let N1: usize = 0;
let mut separation_set = BTreeSet::new(); let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(parameter_learning, data); let mut cache = Cache::new(&parameter_learning);
let chi_sq = ChiSquare::new(0.0001); let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache)); assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache));
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache)); assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache)); assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1); separation_set.insert(N1);
assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache));
} }
#[test] #[test]
@ -488,91 +488,31 @@ pub fn f_call() {
let N1: usize = 0; let N1: usize = 0;
let mut separation_set = BTreeSet::new(); let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(parameter_learning, data); let mut cache = Cache::new(&parameter_learning);
let f = F::new(0.000001); let f = F::new(1e-6);
assert!(f.call(&net, N1, N3, &separation_set, &mut cache)); assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache));
assert!(!f.call(&net, N3, N1, &separation_set, &mut cache)); assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!f.call(&net, N3, N2, &separation_set, &mut cache)); assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1); separation_set.insert(N1);
assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache));
} }
#[test] #[test]
pub fn learn_ternary_net_2_nodes_ctpc() { pub fn learn_ternary_net_2_nodes_ctpc() {
let mut net = CtbnNetwork::new(); let f = F::new(1e-6);
let n1 = net let chi_sq = ChiSquare::new(1e-4);
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
}
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let f = F::new(0.000001);
let chi_sq = ChiSquare::new(0.0001);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let cache = Cache::new(parameter_learning, data.clone()); let ctpc = CTPC::new(parameter_learning, f, chi_sq);
let mut ctpc = CTPC::new(f, chi_sq, cache); learn_ternary_net_2_nodes(ctpc);
let net = ctpc.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
} }
#[test] #[test]
fn learn_mixed_discrete_net_3_nodes_ctpc() { fn learn_mixed_discrete_net_3_nodes_ctpc() {
let (_, data) = get_mixed_discrete_net_3_nodes_with_data(); let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let f = F::new(1e-24);
let chi_sq = ChiSquare::new(1e-24);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let cache = Cache::new(parameter_learning, data); let ctpc = CTPC::new(parameter_learning, f, chi_sq);
let ctpc = CTPC::new(f, chi_sq, cache);
learn_mixed_discrete_net_3_nodes(ctpc); learn_mixed_discrete_net_3_nodes(ctpc);
} }

Loading…
Cancel
Save