WIP: implementing `UniformParametersGenerator`

pull/85/head
Meliurwen 2 years ago
parent 434e671f0a
commit d6f0fb9623
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 92
      reCTBN/src/tools.rs
  2. 47
      reCTBN/tests/tools.rs

@ -1,9 +1,12 @@
//! Contains commonly used methods used across the crate. //! Contains commonly used methods used across the crate.
use ndarray::prelude::*; use std::ops::{DivAssign, MulAssign, Range};
use ndarray::{Array, Array1, Array2, Array3, Axis};
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait;
use crate::process::NetworkProcess; use crate::process::NetworkProcess;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{params, process}; use crate::{params, process};
@ -151,3 +154,90 @@ impl RandomGraphGenerator for UniformGraphGenerator {
} }
} }
} }
pub trait RandomParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> Self;
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T);
}
pub struct UniformParametersGenerator {
interval: Range<f64>,
rng: ChaCha8Rng,
}
impl RandomParametersGenerator for UniformParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> UniformParametersGenerator {
if interval.start < 0.0 || interval.end < 0.0 {
panic!(
"Interval must be entirely less or equal than 0, got {}..{}.",
interval.start, interval.end
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformParametersGenerator { interval, rng }
}
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) {
for node in net.get_node_indices() {
let parent_set = net.get_parent_set(node);
let parent_set_state_space_cardinality: usize = parent_set
.iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
println!(
"parent_set_state_space_cardinality = {}",
parent_set_state_space_cardinality
);
let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent();
println!("node_domain_cardinality = {}", node_domain_cardinality);
let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64)
..=(self.interval.end / node_domain_cardinality as f64);
println!("cim_single_param_range = {:?}", cim_single_param_range);
let mut cim = Array3::<f64>::from_shape_fn(
(
parent_set_state_space_cardinality,
node_domain_cardinality,
node_domain_cardinality,
),
|_| self.rng.gen(),
);
//let diagonal = cim.axis_iter(Axis(0));
cim.axis_iter_mut(Axis(0))
.for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}")));
cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0);
let sum_axis = x.sum_axis(Axis(0));
//let division = 1.0 / &sum_axis;
x.div_assign(&sum_axis);
println!("{}", x);
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone())
});
x.mul_assign(&diag);
println!("{}", x);
x.diag_mut().assign(&-diag)
});
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.diag().iter().for_each(|x| println!("{x}")));
println!("Sum Axis");
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}")));
println!("Matrices");
cim.axis_iter_mut(Axis(0))
.for_each(|x| x.iter().for_each(|x| println!("{}", x)));
//.any(|x| x.diag().iter().any(|x| x >= &0.0))
//println!("{:?}", diagonal);
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(cim));
}
}
}
}
}

@ -1,3 +1,5 @@
use std::ops::Range;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::process::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::process::NetworkProcess; use reCTBN::process::NetworkProcess;
@ -85,20 +87,27 @@ fn dataset_wrong_shape() {
#[test] #[test]
#[should_panic] #[should_panic]
fn uniform_random_generator_wrong_density() { fn uniform_graph_generator_wrong_density_1() {
let density = 2.1; let density = 2.1;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None);
} }
#[test] #[test]
fn uniform_random_generator_right_densities() { #[should_panic]
fn uniform_graph_generator_wrong_density_2() {
let density = -0.5;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None);
}
#[test]
fn uniform_graph_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] { for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None);
} }
} }
#[test] #[test]
fn uniform_random_generator_generate_graph() { fn uniform_graph_generator_generate_graph() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
for node_label in 0..100 { for node_label in 0..100 {
net.add_node( net.add_node(
@ -122,3 +131,35 @@ fn uniform_random_generator_generate_graph() {
// expect the number of edges to be somewhere around the expected value. // expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
} }
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_1() {
let interval: Range<f64> = -2.0..-5.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None);
}
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_2() {
let interval: Range<f64> = -1.0..0.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None);
}
#[test]
fn uniform_parameters_generator_right_densities() {
let mut net = CtbnNetwork::new();
for node_label in 0..3 {
net.add_node(
utils::generate_discrete_time_continous_node(
node_label.to_string(),
9,
)
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120));
structure_generator.generate_graph(&mut net);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120));
cim_generator.generate_parameters(&mut net);
}

Loading…
Cancel
Save