diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 9f63860..dc941e5 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,13 +267,11 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } - let domain_size = domain_size as f64; - // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) + .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 344c66c..e9b9fd8 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -203,7 +203,10 @@ impl RandomParametersGenerator for UniformParametersGenerator { self.rng.gen_range(self.interval.clone()) }); x.mul_assign(&diag.clone().insert_axis(Axis(1))); - x.diag_mut().assign(&-diag) + // Recomputing the diagonal in order to reduce the issues caused by the loss of + // precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) }); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => {