Renamed `UniformRandomGenerator` to `UniformGraphGenerator`, replaced `CtbnNetwork` requirement with `NetworkProcess` in `RandomGraphGenerator`, some related tweaks

pull/85/head
Meliurwen 2 years ago
parent 4b994d8a19
commit 434e671f0a
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 15
      reCTBN/src/tools.rs
  2. 10
      reCTBN/tests/tools.rs

@ -4,7 +4,6 @@ use ndarray::prelude::*;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::process::ctbn::CtbnNetwork;
use crate::process::NetworkProcess;
use crate::sampling::{ForwardSampler, Sampler};
use crate::{params, process};
@ -115,16 +114,16 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
pub trait RandomGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> Self;
fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork;
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T);
}
pub struct UniformRandomGenerator {
pub struct UniformGraphGenerator {
density: f64,
rng: ChaCha8Rng,
}
impl RandomGraphGenerator for UniformRandomGenerator {
fn new(density: f64, seed: Option<u64>) -> UniformRandomGenerator {
impl RandomGraphGenerator for UniformGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> UniformGraphGenerator {
if density < 0.0 || density > 1.0 {
panic!(
"Density value must be between 1.0 and 0.0, got {}.",
@ -135,10 +134,11 @@ impl RandomGraphGenerator for UniformRandomGenerator {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformRandomGenerator { density, rng }
UniformGraphGenerator { density, rng }
}
fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork {
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T) {
net.initialize_adj_matrix();
let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx {
for child in 0..last_node_idx {
@ -149,6 +149,5 @@ impl RandomGraphGenerator for UniformRandomGenerator {
}
}
}
net
}
}

@ -87,13 +87,13 @@ fn dataset_wrong_shape() {
#[should_panic]
fn uniform_random_generator_wrong_density() {
let density = 2.1;
let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None);
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None);
}
#[test]
fn uniform_random_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None);
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None);
}
}
@ -109,7 +109,7 @@ fn uniform_random_generator_generate_graph() {
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120));
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120));
structure_generator.generate_graph(&mut net);
let mut edges = 0;
for node in net.get_node_indices(){
@ -117,8 +117,8 @@ fn uniform_random_generator_generate_graph() {
}
let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance
let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance
// As the way `generate_graph()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance));
assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
}

Loading…
Cancel
Save