From d6f0fb9623b16187bf4341707797e1b094e6eeba Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sun, 29 Jan 2023 16:44:13 +0100 Subject: [PATCH] WIP: implementing `UniformParametersGenerator` --- reCTBN/src/tools.rs | 92 ++++++++++++++++++++++++++++++++++++++++++- reCTBN/tests/tools.rs | 47 ++++++++++++++++++++-- 2 files changed, 135 insertions(+), 4 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 9222239..7c438d5 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,9 +1,12 @@ //! Contains commonly used methods used across the crate. -use ndarray::prelude::*; +use std::ops::{DivAssign, MulAssign, Range}; + +use ndarray::{Array, Array1, Array2, Array3, Axis}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; +use crate::params::ParamsTrait; use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -151,3 +154,90 @@ impl RandomGraphGenerator for UniformGraphGenerator { } } } + +pub trait RandomParametersGenerator { + fn new(interval: Range, seed: Option) -> Self; + fn generate_parameters(&mut self, net: &mut T); +} + +pub struct UniformParametersGenerator { + interval: Range, + rng: ChaCha8Rng, +} + +impl RandomParametersGenerator for UniformParametersGenerator { + fn new(interval: Range, seed: Option) -> UniformParametersGenerator { + if interval.start < 0.0 || interval.end < 0.0 { + panic!( + "Interval must be entirely less or equal than 0, got {}..{}.", + interval.start, interval.end + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + UniformParametersGenerator { interval, rng } + } + fn generate_parameters(&mut self, net: &mut T) { + for node in net.get_node_indices() { + let parent_set = net.get_parent_set(node); + let parent_set_state_space_cardinality: usize = parent_set + .iter() + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + println!( + "parent_set_state_space_cardinality = {}", + parent_set_state_space_cardinality + ); + let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); + println!("node_domain_cardinality = {}", node_domain_cardinality); + let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64) + ..=(self.interval.end / node_domain_cardinality as f64); + println!("cim_single_param_range = {:?}", cim_single_param_range); + + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + + //let diagonal = cim.axis_iter(Axis(0)); + cim.axis_iter_mut(Axis(0)) + .for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}"))); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + let sum_axis = x.sum_axis(Axis(0)); + //let division = 1.0 / &sum_axis; + x.div_assign(&sum_axis); + println!("{}", x); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag); + println!("{}", x); + x.diag_mut().assign(&-diag) + }); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.diag().iter().for_each(|x| println!("{x}"))); + + println!("Sum Axis"); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}"))); + println!("Matrices"); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.iter().for_each(|x| println!("{}", x))); + //.any(|x| x.diag().iter().any(|x| x >= &0.0)) + + //println!("{:?}", diagonal); + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(cim)); + } + } + } + } +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 9a96959..e91cf04 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,3 +1,5 @@ +use std::ops::Range; + use ndarray::{arr1, arr2, arr3}; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; @@ -85,20 +87,27 @@ fn dataset_wrong_shape() { #[test] #[should_panic] -fn uniform_random_generator_wrong_density() { +fn uniform_graph_generator_wrong_density_1() { let density = 2.1; let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } #[test] -fn uniform_random_generator_right_densities() { +#[should_panic] +fn uniform_graph_generator_wrong_density_2() { + let density = -0.5; + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); +} + +#[test] +fn uniform_graph_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } } #[test] -fn uniform_random_generator_generate_graph() { +fn uniform_graph_generator_generate_graph() { let mut net = CtbnNetwork::new(); for node_label in 0..100 { net.add_node( @@ -122,3 +131,35 @@ fn uniform_random_generator_generate_graph() { // expect the number of edges to be somewhere around the expected value. assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); } + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_1() { + let interval: Range = -2.0..-5.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); +} + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_2() { + let interval: Range = -1.0..0.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); +} + +#[test] +fn uniform_parameters_generator_right_densities() { + let mut net = CtbnNetwork::new(); + for node_label in 0..3 { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + 9, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + structure_generator.generate_graph(&mut net); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120)); + cim_generator.generate_parameters(&mut net); +}