diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 67ea07f..4fe3bdd 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -114,8 +114,8 @@ impl ParameterLearning for MLE { } pub struct BayesianApproach { - pub default_alpha: usize, - pub default_tau: f64 + pub alpha: usize, + pub tau: f64 } impl ParameterLearning for BayesianApproach { @@ -135,13 +135,15 @@ impl ParameterLearning for BayesianApproach { }; let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - M.mapv_inplace(|x|{x + self.default_alpha}); - T.mapv_inplace(|x|{x + self.default_tau}); + + let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; + let tau: f64 = self.tau as f64 / M.shape()[0] as f64; + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); + .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau)))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index d6b8fd2..345b8d1 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -60,8 +60,8 @@ fn learn_binary_cim_MLE() { #[test] fn learn_binary_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_binary_cim(ba); } @@ -115,8 +115,8 @@ fn learn_ternary_cim_MLE() { #[test] fn learn_ternary_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_ternary_cim(ba); } @@ -168,8 +168,8 @@ fn learn_ternary_cim_no_parents_MLE() { #[test] fn learn_ternary_cim_no_parents_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_ternary_cim_no_parents(ba); } @@ -257,7 +257,7 @@ fn learn_mixed_discrete_cim_MLE() { #[test] fn learn_mixed_discrete_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_mixed_discrete_cim(ba); }