Refactored `generate_parameters` moving some code inside `match` statement

pull/85/head
Meliurwen 2 years ago
parent 4884010ea9
commit 0639a755d0
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 42
      reCTBN/src/tools.rs

@ -325,29 +325,29 @@ impl RandomParametersGenerator for UniformParametersGenerator {
.iter() .iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) .map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product(); .product();
let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent();
let mut cim = Array3::<f64>::from_shape_fn(
(
parent_set_state_space_cardinality,
node_domain_cardinality,
node_domain_cardinality,
),
|_| self.rng.gen(),
);
cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0);
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone())
});
x.mul_assign(&diag.clone().insert_axis(Axis(1)));
// Recomputing the diagonal in order to reduce the issues caused by the loss of
// precision when validating the parameters.
let diag_sum = -x.sum_axis(Axis(1));
x.diag_mut().assign(&diag_sum)
});
match &mut net.get_node_mut(node) { match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
let node_domain_cardinality = param.get_reserved_space_as_parent();
let mut cim = Array3::<f64>::from_shape_fn(
(
parent_set_state_space_cardinality,
node_domain_cardinality,
node_domain_cardinality,
),
|_| self.rng.gen(),
);
cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0);
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone())
});
x.mul_assign(&diag.clone().insert_axis(Axis(1)));
// Recomputing the diagonal in order to reduce the issues caused by the
// loss of precision when validating the parameters.
let diag_sum = -x.sum_axis(Axis(1));
x.diag_mut().assign(&diag_sum)
});
param.set_cim_unchecked(cim); param.set_cim_unchecked(cim);
} }
} }

Loading…
Cancel
Save