Renamed `StructureGen` with `UniformRandomGenerator` and defining the new trait `RandomGraphGenerator`

pull/85/head
Meliurwen 2 years ago
parent a077f738ee
commit 4b994d8a19
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 15
      reCTBN/src/tools.rs
  2. 16
      reCTBN/tests/tools.rs

@ -113,13 +113,18 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
Dataset::new(trajectories)
}
pub struct StructureGen {
pub trait RandomGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> Self;
fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork;
}
pub struct UniformRandomGenerator {
density: f64,
rng: ChaCha8Rng,
}
impl StructureGen {
pub fn new(density: f64, seed: Option<u64>) -> StructureGen {
impl RandomGraphGenerator for UniformRandomGenerator {
fn new(density: f64, seed: Option<u64>) -> UniformRandomGenerator {
if density < 0.0 || density > 1.0 {
panic!(
"Density value must be between 1.0 and 0.0, got {}.",
@ -130,10 +135,10 @@ impl StructureGen {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
StructureGen { density, rng }
UniformRandomGenerator { density, rng }
}
pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork {
fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork {
let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx {
for child in 0..last_node_idx {

@ -85,20 +85,20 @@ fn dataset_wrong_shape() {
#[test]
#[should_panic]
fn structure_gen_wrong_density() {
fn uniform_random_generator_wrong_density() {
let density = 2.1;
StructureGen::new(density, None);
let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None);
}
#[test]
fn structure_gen_right_densities() {
fn uniform_random_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
StructureGen::new(density, None);
let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None);
}
}
#[test]
fn structure_gen_gen_structure() {
fn uniform_random_generator_generate_graph() {
let mut net = CtbnNetwork::new();
for node_label in 0..100 {
net.add_node(
@ -109,8 +109,8 @@ fn structure_gen_gen_structure() {
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator = StructureGen::new(density, Some(7641630759785120));
structure_generator.gen_structure(&mut net);
let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120));
structure_generator.generate_graph(&mut net);
let mut edges = 0;
for node in net.get_node_indices(){
edges += net.get_children_set(node).len()
@ -118,7 +118,7 @@ fn structure_gen_gen_structure() {
let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance
// As the way `gen_structure()` is implemented we can only reasonably
// As the way `generate_graph()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance));
}

Loading…
Cancel
Save