diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 7c438d5..344c66c 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -179,23 +179,15 @@ impl RandomParametersGenerator for UniformParametersGenerator { }; UniformParametersGenerator { interval, rng } } + fn generate_parameters(&mut self, net: &mut T) { for node in net.get_node_indices() { - let parent_set = net.get_parent_set(node); - let parent_set_state_space_cardinality: usize = parent_set + let parent_set_state_space_cardinality: usize = net + .get_parent_set(node) .iter() .map(|x| net.get_node(*x).get_reserved_space_as_parent()) .product(); - println!( - "parent_set_state_space_cardinality = {}", - parent_set_state_space_cardinality - ); let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); - println!("node_domain_cardinality = {}", node_domain_cardinality); - let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64) - ..=(self.interval.end / node_domain_cardinality as f64); - println!("cim_single_param_range = {:?}", cim_single_param_range); - let mut cim = Array3::::from_shape_fn( ( parent_set_state_space_cardinality, @@ -204,38 +196,18 @@ impl RandomParametersGenerator for UniformParametersGenerator { ), |_| self.rng.gen(), ); - - //let diagonal = cim.axis_iter(Axis(0)); - cim.axis_iter_mut(Axis(0)) - .for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}"))); cim.axis_iter_mut(Axis(0)).for_each(|mut x| { x.diag_mut().fill(0.0); - let sum_axis = x.sum_axis(Axis(0)); - //let division = 1.0 / &sum_axis; - x.div_assign(&sum_axis); - println!("{}", x); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { self.rng.gen_range(self.interval.clone()) }); - x.mul_assign(&diag); - println!("{}", x); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); x.diag_mut().assign(&-diag) }); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.diag().iter().for_each(|x| println!("{x}"))); - - println!("Sum Axis"); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}"))); - println!("Matrices"); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.iter().for_each(|x| println!("{}", x))); - //.any(|x| x.diag().iter().any(|x| x >= &0.0)) - - //println!("{:?}", diagonal); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(cim)); + param.set_cim_unchecked(cim); } } } diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index e91cf04..f04fb2a 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -149,17 +149,30 @@ fn uniform_parameters_generator_wrong_density_2() { #[test] fn uniform_parameters_generator_right_densities() { let mut net = CtbnNetwork::new(); - for node_label in 0..3 { + let nodes_cardinality = 0..5; + let nodes_domain_cardinality = 9; + for node_label in nodes_cardinality { net.add_node( utils::generate_discrete_time_continous_node( node_label.to_string(), - 9, + nodes_domain_cardinality, ) ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, seed); structure_generator.generate_graph(&mut net); - let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120)); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed); cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(param.get_cim().clone().unwrap())); + } + } + } }