diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 89c19a9..0a48410 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -325,29 +325,29 @@ impl RandomParametersGenerator for UniformParametersGenerator { .iter() .map(|x| net.get_node(*x).get_reserved_space_as_parent()) .product(); - let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); - let mut cim = Array3::::from_shape_fn( - ( - parent_set_state_space_cardinality, - node_domain_cardinality, - node_domain_cardinality, - ), - |_| self.rng.gen(), - ); - cim.axis_iter_mut(Axis(0)).for_each(|mut x| { - x.diag_mut().fill(0.0); - x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); - let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { - self.rng.gen_range(self.interval.clone()) - }); - x.mul_assign(&diag.clone().insert_axis(Axis(1))); - // Recomputing the diagonal in order to reduce the issues caused by the loss of - // precision when validating the parameters. - let diag_sum = -x.sum_axis(Axis(1)); - x.diag_mut().assign(&diag_sum) - }); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { + let node_domain_cardinality = param.get_reserved_space_as_parent(); + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); + // Recomputing the diagonal in order to reduce the issues caused by the + // loss of precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) + }); param.set_cim_unchecked(cim); } }