Renamed `StructureGen` with `UniformRandomGenerator` and defining the new trait `RandomGraphGenerator`

pull/85/head
Meliurwen 2 years ago
parent a077f738ee
commit 4b994d8a19
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 15
      reCTBN/src/tools.rs
  2. 16
      reCTBN/tests/tools.rs

@ -113,13 +113,18 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
Dataset::new(trajectories) Dataset::new(trajectories)
} }
pub struct StructureGen { pub trait RandomGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> Self;
fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork;
}
pub struct UniformRandomGenerator {
density: f64, density: f64,
rng: ChaCha8Rng, rng: ChaCha8Rng,
} }
impl StructureGen { impl RandomGraphGenerator for UniformRandomGenerator {
pub fn new(density: f64, seed: Option<u64>) -> StructureGen { fn new(density: f64, seed: Option<u64>) -> UniformRandomGenerator {
if density < 0.0 || density > 1.0 { if density < 0.0 || density > 1.0 {
panic!( panic!(
"Density value must be between 1.0 and 0.0, got {}.", "Density value must be between 1.0 and 0.0, got {}.",
@ -130,10 +135,10 @@ impl StructureGen {
Some(seed) => SeedableRng::seed_from_u64(seed), Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(), None => SeedableRng::from_entropy(),
}; };
StructureGen { density, rng } UniformRandomGenerator { density, rng }
} }
pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork {
let last_node_idx = net.get_node_indices().len(); let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx { for parent in 0..last_node_idx {
for child in 0..last_node_idx { for child in 0..last_node_idx {

@ -85,20 +85,20 @@ fn dataset_wrong_shape() {
#[test] #[test]
#[should_panic] #[should_panic]
fn structure_gen_wrong_density() { fn uniform_random_generator_wrong_density() {
let density = 2.1; let density = 2.1;
StructureGen::new(density, None); let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None);
} }
#[test] #[test]
fn structure_gen_right_densities() { fn uniform_random_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] { for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
StructureGen::new(density, None); let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None);
} }
} }
#[test] #[test]
fn structure_gen_gen_structure() { fn uniform_random_generator_generate_graph() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
for node_label in 0..100 { for node_label in 0..100 {
net.add_node( net.add_node(
@ -109,8 +109,8 @@ fn structure_gen_gen_structure() {
).unwrap(); ).unwrap();
} }
let density = 1.0/3.0; let density = 1.0/3.0;
let mut structure_generator = StructureGen::new(density, Some(7641630759785120)); let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120));
structure_generator.gen_structure(&mut net); structure_generator.generate_graph(&mut net);
let mut edges = 0; let mut edges = 0;
for node in net.get_node_indices(){ for node in net.get_node_indices(){
edges += net.get_children_set(node).len() edges += net.get_children_set(node).len()
@ -118,7 +118,7 @@ fn structure_gen_gen_structure() {
let nodes = net.get_node_indices().len() as f64; let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance
// As the way `gen_structure()` is implemented we can only reasonably // As the way `generate_graph()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value. // expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance));
} }

Loading…
Cancel
Save