diff --git a/Cargo.toml b/Cargo.toml index 4779b47..56d0452 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] - -ndarray = {version="*", features=["approx"]} +ndarray = {version="*", features=["approx-0_5"]} thiserror = "*" rand = "*" bimap = "*" diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 5270d9e..19c0e4c 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -153,3 +153,20 @@ impl ParameterLearning for BayesianApproach { return (CIM, M, T); } } + + +pub struct Cache { + parameter_learning: P, +} + +impl Cache

{ + pub fn fit( + &mut self, + net: &T, + dataset: &tools::Dataset, + node: usize, + parent_set: Option>, + ) -> (Array3, Array3, Array2) { + self.parameter_learning.fit(net, dataset, node, parent_set) + } +} diff --git a/src/structure_learning.rs b/src/structure_learning.rs index 8ba91df..b7db7ed 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -1,5 +1,7 @@ pub mod score_function; pub mod score_based_algorithm; +pub mod constraint_based_algorithm; +pub mod hypothesis_test; use crate::network; use crate::tools; diff --git a/src/structure_learning/constraint_based_algorithm.rs b/src/structure_learning/constraint_based_algorithm.rs new file mode 100644 index 0000000..0d8b655 --- /dev/null +++ b/src/structure_learning/constraint_based_algorithm.rs @@ -0,0 +1,5 @@ + +//pub struct CTPC { +// +//} + diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs new file mode 100644 index 0000000..fc5c86f --- /dev/null +++ b/src/structure_learning/hypothesis_test.rs @@ -0,0 +1,101 @@ +use ndarray::Array2; +use ndarray::Array3; +use ndarray::Axis; + +use crate::network; +use crate::parameter_learning; +use std::collections::BTreeSet; + +pub trait HypothesisTest { + + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: parameter_learning::Cache

+ ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning; + +} + + +pub struct ChiSquare { + pub alpha: f64, +} + +pub struct F { + +} + +impl ChiSquare { + pub fn compare_matrices( + &self, i: usize, + M1: &Array3, + j: usize, + M2: &Array3 + ) -> bool { + // Bregoli, A., Scutari, M. and Stella, F., 2021. + // A constraint-based algorithm for the structural learning of + // continuous-time Bayesian networks. + // International Journal of Approximate Reasoning, 138, pp.105-122. + // + // M = M M = M + // 1 xx'|s 2 xx'|y,s + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + // __________________ + // / === + // / \ M + // / / xx'|s + // / === + // / x'ϵVal /X \ + // / \ i/ 1 + //K = / ------------------ L = - + // / === K + // / \ M + // / / xx'|y,s + // / === + // / x'ϵVal /X \ + // \ / \ i/ + // \/ + let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); + let K = K.mapv(f64::sqrt); + // Reshape to column vector. + let K = { + let n = K.len(); + K.into_shape((n, 1)).unwrap() + }; + let L = 1.0 / &K; + // ===== + // \ K . M - L . M + // \ 2 1 + // / --------------- + // / M + M + // ===== 2 1 + // x'ϵVal /X \ + // \ i/ + let X_2 = (( K * &M2 - L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1)).sum_axis(Axis(1)); + println!("X_2: {:?}", X_2); + true + } +} + +impl HypothesisTest for ChiSquare { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: parameter_learning::Cache

+ ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning { + todo!() + } +} diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index c91f508..be9c8d5 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -7,6 +7,7 @@ use reCTBN::network::Network; use reCTBN::params; use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm}; +use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::tools::*; use std::collections::BTreeSet; @@ -315,3 +316,25 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() let hl = HillClimbing::new(bic, Some(1)); learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } + +#[test] +pub fn chi_square_compare_matrices () { + let i: usize = 1; + let M1 = arr3(&[ + [[ 1, 2, 3], + [ 4, 5, 6]], + [[ 22, 12, 90], + [3, 20, 40]], + [[ 1, 2, 3], + [ 4, 5, 6]], + [[ 7, 8, 9], + [10, 11, 12]] + ]); + let j: usize = 1; + let M2 = arr3(&[[[ 1, 2, 3], // -- 2 rows \_ + [ 4, 5, 6]], + [[ 7, 8, 9], + [10, 11, 12]]]); + let chi_sq = ChiSquare {alpha: 0.5}; + chi_sq.compare_matrices( i, &M1, j, &M2); +}