|
|
|
@ -325,29 +325,29 @@ impl RandomParametersGenerator for UniformParametersGenerator { |
|
|
|
|
.iter() |
|
|
|
|
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) |
|
|
|
|
.product(); |
|
|
|
|
let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); |
|
|
|
|
let mut cim = Array3::<f64>::from_shape_fn( |
|
|
|
|
( |
|
|
|
|
parent_set_state_space_cardinality, |
|
|
|
|
node_domain_cardinality, |
|
|
|
|
node_domain_cardinality, |
|
|
|
|
), |
|
|
|
|
|_| self.rng.gen(), |
|
|
|
|
); |
|
|
|
|
cim.axis_iter_mut(Axis(0)).for_each(|mut x| { |
|
|
|
|
x.diag_mut().fill(0.0); |
|
|
|
|
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); |
|
|
|
|
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| { |
|
|
|
|
self.rng.gen_range(self.interval.clone()) |
|
|
|
|
}); |
|
|
|
|
x.mul_assign(&diag.clone().insert_axis(Axis(1))); |
|
|
|
|
// Recomputing the diagonal in order to reduce the issues caused by the loss of
|
|
|
|
|
// precision when validating the parameters.
|
|
|
|
|
let diag_sum = -x.sum_axis(Axis(1)); |
|
|
|
|
x.diag_mut().assign(&diag_sum) |
|
|
|
|
}); |
|
|
|
|
match &mut net.get_node_mut(node) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
let node_domain_cardinality = param.get_reserved_space_as_parent(); |
|
|
|
|
let mut cim = Array3::<f64>::from_shape_fn( |
|
|
|
|
( |
|
|
|
|
parent_set_state_space_cardinality, |
|
|
|
|
node_domain_cardinality, |
|
|
|
|
node_domain_cardinality, |
|
|
|
|
), |
|
|
|
|
|_| self.rng.gen(), |
|
|
|
|
); |
|
|
|
|
cim.axis_iter_mut(Axis(0)).for_each(|mut x| { |
|
|
|
|
x.diag_mut().fill(0.0); |
|
|
|
|
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); |
|
|
|
|
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| { |
|
|
|
|
self.rng.gen_range(self.interval.clone()) |
|
|
|
|
}); |
|
|
|
|
x.mul_assign(&diag.clone().insert_axis(Axis(1))); |
|
|
|
|
// Recomputing the diagonal in order to reduce the issues caused by the
|
|
|
|
|
// loss of precision when validating the parameters.
|
|
|
|
|
let diag_sum = -x.sum_axis(Axis(1)); |
|
|
|
|
x.diag_mut().assign(&diag_sum) |
|
|
|
|
}); |
|
|
|
|
param.set_cim_unchecked(cim); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|