From a01a9ef20107983667cc2c30f627c8fcf3662df5 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 31 Jan 2023 13:16:52 +0100 Subject: [PATCH] Recomputing the diagonal when generating parameters to counter the precision loss and increase `f64::EPSILON` calculating its square root instead of multiplying it with the node's `domain_size` --- reCTBN/src/params.rs | 4 +--- reCTBN/src/tools.rs | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 9f63860..dc941e5 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,13 +267,11 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } - let domain_size = domain_size as f64; - // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) + .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 344c66c..e9b9fd8 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -203,7 +203,10 @@ impl RandomParametersGenerator for UniformParametersGenerator { self.rng.gen_range(self.interval.clone()) }); x.mul_assign(&diag.clone().insert_axis(Axis(1))); - x.diag_mut().assign(&-diag) + // Recomputing the diagonal in order to reduce the issues caused by the loss of + // precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) }); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => {