Implemented `UniformParametersGenerator` and its test

pull/85/head
Meliurwen 2 years ago
parent d6f0fb9623
commit f4e3c98c79
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 40
      reCTBN/src/tools.rs
  2. 21
      reCTBN/tests/tools.rs

@ -179,23 +179,15 @@ impl RandomParametersGenerator for UniformParametersGenerator {
}; };
UniformParametersGenerator { interval, rng } UniformParametersGenerator { interval, rng }
} }
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) { fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) {
for node in net.get_node_indices() { for node in net.get_node_indices() {
let parent_set = net.get_parent_set(node); let parent_set_state_space_cardinality: usize = net
let parent_set_state_space_cardinality: usize = parent_set .get_parent_set(node)
.iter() .iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) .map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product(); .product();
println!(
"parent_set_state_space_cardinality = {}",
parent_set_state_space_cardinality
);
let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent();
println!("node_domain_cardinality = {}", node_domain_cardinality);
let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64)
..=(self.interval.end / node_domain_cardinality as f64);
println!("cim_single_param_range = {:?}", cim_single_param_range);
let mut cim = Array3::<f64>::from_shape_fn( let mut cim = Array3::<f64>::from_shape_fn(
( (
parent_set_state_space_cardinality, parent_set_state_space_cardinality,
@ -204,38 +196,18 @@ impl RandomParametersGenerator for UniformParametersGenerator {
), ),
|_| self.rng.gen(), |_| self.rng.gen(),
); );
//let diagonal = cim.axis_iter(Axis(0));
cim.axis_iter_mut(Axis(0))
.for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}")));
cim.axis_iter_mut(Axis(0)).for_each(|mut x| { cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0); x.diag_mut().fill(0.0);
let sum_axis = x.sum_axis(Axis(0)); x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
//let division = 1.0 / &sum_axis;
x.div_assign(&sum_axis);
println!("{}", x);
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| { let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone()) self.rng.gen_range(self.interval.clone())
}); });
x.mul_assign(&diag); x.mul_assign(&diag.clone().insert_axis(Axis(1)));
println!("{}", x);
x.diag_mut().assign(&-diag) x.diag_mut().assign(&-diag)
}); });
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.diag().iter().for_each(|x| println!("{x}")));
println!("Sum Axis");
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}")));
println!("Matrices");
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.iter().for_each(|x| println!("{}", x)));
//.any(|x| x.diag().iter().any(|x| x >= &0.0))
//println!("{:?}", diagonal);
match &mut net.get_node_mut(node) { match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(cim)); param.set_cim_unchecked(cim);
} }
} }
} }

@ -149,17 +149,30 @@ fn uniform_parameters_generator_wrong_density_2() {
#[test] #[test]
fn uniform_parameters_generator_right_densities() { fn uniform_parameters_generator_right_densities() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
for node_label in 0..3 { let nodes_cardinality = 0..5;
let nodes_domain_cardinality = 9;
for node_label in nodes_cardinality {
net.add_node( net.add_node(
utils::generate_discrete_time_continous_node( utils::generate_discrete_time_continous_node(
node_label.to_string(), node_label.to_string(),
9, nodes_domain_cardinality,
) )
).unwrap(); ).unwrap();
} }
let density = 1.0/3.0; let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, seed);
structure_generator.generate_graph(&mut net); structure_generator.generate_graph(&mut net);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120)); let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed);
cim_generator.generate_parameters(&mut net); cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(param.get_cim().clone().unwrap()));
}
}
}
} }

Loading…
Cancel
Save