Implemented `UniformParametersGenerator` and its test

pull/85/head
Meliurwen 2 years ago
parent d6f0fb9623
commit f4e3c98c79
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 40
      reCTBN/src/tools.rs
  2. 21
      reCTBN/tests/tools.rs

@ -179,23 +179,15 @@ impl RandomParametersGenerator for UniformParametersGenerator {
};
UniformParametersGenerator { interval, rng }
}
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) {
for node in net.get_node_indices() {
let parent_set = net.get_parent_set(node);
let parent_set_state_space_cardinality: usize = parent_set
let parent_set_state_space_cardinality: usize = net
.get_parent_set(node)
.iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
println!(
"parent_set_state_space_cardinality = {}",
parent_set_state_space_cardinality
);
let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent();
println!("node_domain_cardinality = {}", node_domain_cardinality);
let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64)
..=(self.interval.end / node_domain_cardinality as f64);
println!("cim_single_param_range = {:?}", cim_single_param_range);
let mut cim = Array3::<f64>::from_shape_fn(
(
parent_set_state_space_cardinality,
@ -204,38 +196,18 @@ impl RandomParametersGenerator for UniformParametersGenerator {
),
|_| self.rng.gen(),
);
//let diagonal = cim.axis_iter(Axis(0));
cim.axis_iter_mut(Axis(0))
.for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}")));
cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0);
let sum_axis = x.sum_axis(Axis(0));
//let division = 1.0 / &sum_axis;
x.div_assign(&sum_axis);
println!("{}", x);
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone())
});
x.mul_assign(&diag);
println!("{}", x);
x.mul_assign(&diag.clone().insert_axis(Axis(1)));
x.diag_mut().assign(&-diag)
});
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.diag().iter().for_each(|x| println!("{x}")));
println!("Sum Axis");
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}")));
println!("Matrices");
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.iter().for_each(|x| println!("{}", x)));
//.any(|x| x.diag().iter().any(|x| x >= &0.0))
//println!("{:?}", diagonal);
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(cim));
param.set_cim_unchecked(cim);
}
}
}

@ -149,17 +149,30 @@ fn uniform_parameters_generator_wrong_density_2() {
#[test]
fn uniform_parameters_generator_right_densities() {
let mut net = CtbnNetwork::new();
for node_label in 0..3 {
let nodes_cardinality = 0..5;
let nodes_domain_cardinality = 9;
for node_label in nodes_cardinality {
net.add_node(
utils::generate_discrete_time_continous_node(
node_label.to_string(),
9,
nodes_domain_cardinality,
)
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120));
let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, seed);
structure_generator.generate_graph(&mut net);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120));
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed);
cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(param.get_cim().clone().unwrap()));
}
}
}
}

Loading…
Cancel
Save