Fixed alpha and tau in the Bayesian approach #30

Merged
meliurwen merged 1 commits from 28-bug-bayesian-approach into dev 3 years ago
  1. 12
      src/parameter_learning.rs
  2. 16
      tests/parameter_learning.rs

@ -114,8 +114,8 @@ impl ParameterLearning for MLE {
} }
pub struct BayesianApproach { pub struct BayesianApproach {
pub default_alpha: usize, pub alpha: usize,
pub default_tau: f64 pub tau: f64
} }
impl ParameterLearning for BayesianApproach { impl ParameterLearning for BayesianApproach {
@ -135,13 +135,15 @@ impl ParameterLearning for BayesianApproach {
}; };
let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
M.mapv_inplace(|x|{x + self.default_alpha});
T.mapv_inplace(|x|{x + self.default_tau}); let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64;
let tau: f64 = self.tau as f64 / M.shape()[0] as f64;
//Compute the CIM as M[i,x,y]/T[i,x] //Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T))); .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau))));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);

@ -60,8 +60,8 @@ fn learn_binary_cim_MLE() {
#[test] #[test]
fn learn_binary_cim_BA() { fn learn_binary_cim_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach{
default_alpha: 1, alpha: 1,
default_tau: 1.0}; tau: 1.0};
learn_binary_cim(ba); learn_binary_cim(ba);
} }
@ -115,8 +115,8 @@ fn learn_ternary_cim_MLE() {
#[test] #[test]
fn learn_ternary_cim_BA() { fn learn_ternary_cim_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach{
default_alpha: 1, alpha: 1,
default_tau: 1.0}; tau: 1.0};
learn_ternary_cim(ba); learn_ternary_cim(ba);
} }
@ -168,8 +168,8 @@ fn learn_ternary_cim_no_parents_MLE() {
#[test] #[test]
fn learn_ternary_cim_no_parents_BA() { fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach{
default_alpha: 1, alpha: 1,
default_tau: 1.0}; tau: 1.0};
learn_ternary_cim_no_parents(ba); learn_ternary_cim_no_parents(ba);
} }
@ -257,7 +257,7 @@ fn learn_mixed_discrete_cim_MLE() {
#[test] #[test]
fn learn_mixed_discrete_cim_BA() { fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach{
default_alpha: 1, alpha: 1,
default_tau: 1.0}; tau: 1.0};
learn_mixed_discrete_cim(ba); learn_mixed_discrete_cim(ba);
} }

Loading…
Cancel
Save