Refactored `generate_parameters` moving some code inside `match` statement

pull/85/head
Meliurwen 2 years ago
parent 4884010ea9
commit 0639a755d0
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 10
      reCTBN/src/tools.rs

@ -325,7 +325,9 @@ impl RandomParametersGenerator for UniformParametersGenerator {
.iter() .iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) .map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product(); .product();
let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
let node_domain_cardinality = param.get_reserved_space_as_parent();
let mut cim = Array3::<f64>::from_shape_fn( let mut cim = Array3::<f64>::from_shape_fn(
( (
parent_set_state_space_cardinality, parent_set_state_space_cardinality,
@ -341,13 +343,11 @@ impl RandomParametersGenerator for UniformParametersGenerator {
self.rng.gen_range(self.interval.clone()) self.rng.gen_range(self.interval.clone())
}); });
x.mul_assign(&diag.clone().insert_axis(Axis(1))); x.mul_assign(&diag.clone().insert_axis(Axis(1)));
// Recomputing the diagonal in order to reduce the issues caused by the loss of // Recomputing the diagonal in order to reduce the issues caused by the
// precision when validating the parameters. // loss of precision when validating the parameters.
let diag_sum = -x.sum_axis(Axis(1)); let diag_sum = -x.sum_axis(Axis(1));
x.diag_mut().assign(&diag_sum) x.diag_mut().assign(&diag_sum)
}); });
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim_unchecked(cim); param.set_cim_unchecked(cim);
} }
} }

Loading…
Cancel
Save