Recomputing the diagonal when generating parameters to counter the precision loss and increase `f64::EPSILON` calculating its square root instead of multiplying it with the node's `domain_size`

pull/85/head
Meliurwen 2 years ago
parent 097dc25030
commit a01a9ef201
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 4
      reCTBN/src/params.rs
  2. 5
      reCTBN/src/tools.rs

@ -267,13 +267,11 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
))); )));
} }
let domain_size = domain_size as f64;
// Check if each row sum up to 0 // Check if each row sum up to 0
if cim if cim
.sum_axis(Axis(2)) .sum_axis(Axis(2))
.iter() .iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",

@ -203,7 +203,10 @@ impl RandomParametersGenerator for UniformParametersGenerator {
self.rng.gen_range(self.interval.clone()) self.rng.gen_range(self.interval.clone())
}); });
x.mul_assign(&diag.clone().insert_axis(Axis(1))); x.mul_assign(&diag.clone().insert_axis(Axis(1)));
x.diag_mut().assign(&-diag) // Recomputing the diagonal in order to reduce the issues caused by the loss of
// precision when validating the parameters.
let diag_sum = -x.sum_axis(Axis(1));
x.diag_mut().assign(&diag_sum)
}); });
match &mut net.get_node_mut(node) { match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {

Loading…
Cancel
Save