Renamed `UniformRandomGenerator` to `UniformGraphGenerator`, replaced `CtbnNetwork` requirement with `NetworkProcess` in `RandomGraphGenerator`, some related tweaks

pull/85/head
Meliurwen 2 years ago
parent 4b994d8a19
commit 434e671f0a
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 15
      reCTBN/src/tools.rs
  2. 10
      reCTBN/tests/tools.rs

@ -4,7 +4,6 @@ use ndarray::prelude::*;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use crate::process::ctbn::CtbnNetwork;
use crate::process::NetworkProcess; use crate::process::NetworkProcess;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{params, process}; use crate::{params, process};
@ -115,16 +114,16 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
pub trait RandomGraphGenerator { pub trait RandomGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> Self; fn new(density: f64, seed: Option<u64>) -> Self;
fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork; fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T);
} }
pub struct UniformRandomGenerator { pub struct UniformGraphGenerator {
density: f64, density: f64,
rng: ChaCha8Rng, rng: ChaCha8Rng,
} }
impl RandomGraphGenerator for UniformRandomGenerator { impl RandomGraphGenerator for UniformGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> UniformRandomGenerator { fn new(density: f64, seed: Option<u64>) -> UniformGraphGenerator {
if density < 0.0 || density > 1.0 { if density < 0.0 || density > 1.0 {
panic!( panic!(
"Density value must be between 1.0 and 0.0, got {}.", "Density value must be between 1.0 and 0.0, got {}.",
@ -135,10 +134,11 @@ impl RandomGraphGenerator for UniformRandomGenerator {
Some(seed) => SeedableRng::seed_from_u64(seed), Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(), None => SeedableRng::from_entropy(),
}; };
UniformRandomGenerator { density, rng } UniformGraphGenerator { density, rng }
} }
fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T) {
net.initialize_adj_matrix();
let last_node_idx = net.get_node_indices().len(); let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx { for parent in 0..last_node_idx {
for child in 0..last_node_idx { for child in 0..last_node_idx {
@ -149,6 +149,5 @@ impl RandomGraphGenerator for UniformRandomGenerator {
} }
} }
} }
net
} }
} }

@ -87,13 +87,13 @@ fn dataset_wrong_shape() {
#[should_panic] #[should_panic]
fn uniform_random_generator_wrong_density() { fn uniform_random_generator_wrong_density() {
let density = 2.1; let density = 2.1;
let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None);
} }
#[test] #[test]
fn uniform_random_generator_right_densities() { fn uniform_random_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] { for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None);
} }
} }
@ -109,7 +109,7 @@ fn uniform_random_generator_generate_graph() {
).unwrap(); ).unwrap();
} }
let density = 1.0/3.0; let density = 1.0/3.0;
let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120));
structure_generator.generate_graph(&mut net); structure_generator.generate_graph(&mut net);
let mut edges = 0; let mut edges = 0;
for node in net.get_node_indices(){ for node in net.get_node_indices(){
@ -117,8 +117,8 @@ fn uniform_random_generator_generate_graph() {
} }
let nodes = net.get_node_indices().len() as f64; let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance
// As the way `generate_graph()` is implemented we can only reasonably // As the way `generate_graph()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value. // expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
} }

Loading…
Cancel
Save