Implemented part of matrices comparison in chi square

pull/47/head
Meliurwen 3 years ago
parent 88ad3eba1b
commit 2605bf3816
  1. 3
      Cargo.toml
  2. 17
      src/parameter_learning.rs
  3. 2
      src/structure_learning.rs
  4. 5
      src/structure_learning/constraint_based_algorithm.rs
  5. 101
      src/structure_learning/hypothesis_test.rs
  6. 23
      tests/structure_learning.rs

@ -6,8 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ndarray = {version="*", features=["approx"]}
ndarray = {version="*", features=["approx-0_5"]}
thiserror = "*"
rand = "*"
bimap = "*"

@ -153,3 +153,20 @@ impl ParameterLearning for BayesianApproach {
return (CIM, M, T);
}
}
pub struct Cache<P: ParameterLearning> {
parameter_learning: P,
}
impl<P: ParameterLearning> Cache<P> {
pub fn fit<T:network::Network>(
&mut self,
net: &T,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
self.parameter_learning.fit(net, dataset, node, parent_set)
}
}

@ -1,5 +1,7 @@
pub mod score_function;
pub mod score_based_algorithm;
pub mod constraint_based_algorithm;
pub mod hypothesis_test;
use crate::network;
use crate::tools;

@ -0,0 +1,101 @@
use ndarray::Array2;
use ndarray::Array3;
use ndarray::Axis;
use crate::network;
use crate::parameter_learning;
use std::collections::BTreeSet;
pub trait HypothesisTest {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
cache: parameter_learning::Cache<P>
) -> bool
where
T: network::Network,
P: parameter_learning::ParameterLearning;
}
pub struct ChiSquare {
pub alpha: f64,
}
pub struct F {
}
impl ChiSquare {
pub fn compare_matrices(
&self, i: usize,
M1: &Array3<usize>,
j: usize,
M2: &Array3<usize>
) -> bool {
// Bregoli, A., Scutari, M. and Stella, F., 2021.
// A constraint-based algorithm for the structural learning of
// continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122.
//
// M = M M = M
// 1 xx'|s 2 xx'|y,s
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
// __________________
// / ===
// / \ M
// / / xx'|s
// / ===
// / x'ϵVal /X \
// / \ i/ 1
//K = / ------------------ L = -
// / === K
// / \ M
// / / xx'|y,s
// / ===
// / x'ϵVal /X \
// \ / \ i/
// \/
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
let K = K.mapv(f64::sqrt);
// Reshape to column vector.
let K = {
let n = K.len();
K.into_shape((n, 1)).unwrap()
};
let L = 1.0 / &K;
// =====
// \ K . M - L . M
// \ 2 1
// / ---------------
// / M + M
// ===== 2 1
// x'ϵVal /X \
// \ i/
let X_2 = (( K * &M2 - L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1)).sum_axis(Axis(1));
println!("X_2: {:?}", X_2);
true
}
}
impl HypothesisTest for ChiSquare {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
cache: parameter_learning::Cache<P>
) -> bool
where
T: network::Network,
P: parameter_learning::ParameterLearning {
todo!()
}
}

@ -7,6 +7,7 @@ use reCTBN::network::Network;
use reCTBN::params;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm};
use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::tools::*;
use std::collections::BTreeSet;
@ -315,3 +316,25 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
let hl = HillClimbing::new(bic, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}
#[test]
pub fn chi_square_compare_matrices () {
let i: usize = 1;
let M1 = arr3(&[
[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 22, 12, 90],
[3, 20, 40]],
[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]
]);
let j: usize = 1;
let M2 = arr3(&[[[ 1, 2, 3], // -- 2 rows \_
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]]);
let chi_sq = ChiSquare {alpha: 0.5};
chi_sq.compare_matrices( i, &M1, j, &M2);
}

Loading…
Cancel
Save