From 62fcbd466a08d7dead12c48a40c8200f3fa57698 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 12 Apr 2022 09:33:07 +0200 Subject: [PATCH] Removed `rand::thread_rng` overriding the ChaCha's `rng`, increased the epsilon from 0.2 to 0.3 in the tests --- src/params.rs | 3 --- tests/parameter_learning.rs | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/params.rs b/src/params.rs index 6173d75..f0e5efa 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,7 +1,6 @@ use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; -use rand::rngs::ThreadRng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; use rand_chacha::ChaCha8Rng; @@ -149,7 +148,6 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let x: f64 = rng.gen_range(0.0..=1.0); Ok(-x.ln() / lambda) @@ -166,7 +164,6 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let urand: f64 = rng.gen_range(0.0..=1.0); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index af57291..adff6e8 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -47,7 +47,7 @@ fn learn_binary_cim (pl: T) { assert!(CIM.abs_diff_eq(&arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.2)); + ]), 0.3)); } #[test] @@ -101,7 +101,7 @@ fn learn_ternary_cim (pl: T) { [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.2)); + ]), 0.3)); } @@ -154,7 +154,7 @@ fn learn_ternary_cim_no_parents (pl: T) { assert_eq!(CIM.shape(), [1, 3, 3]); assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.2)); + [0.4, 0.6, -1.0]]]), 0.3)); } @@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.2)); + ]), 0.3)); } #[test]