diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index c58403a..599b420 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -113,13 +113,18 @@ pub fn trajectory_generator( Dataset::new(trajectories) } -pub struct StructureGen { +pub trait RandomGraphGenerator { + fn new(density: f64, seed: Option) -> Self; + fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork; +} + +pub struct UniformRandomGenerator { density: f64, rng: ChaCha8Rng, } -impl StructureGen { - pub fn new(density: f64, seed: Option) -> StructureGen { +impl RandomGraphGenerator for UniformRandomGenerator { + fn new(density: f64, seed: Option) -> UniformRandomGenerator { if density < 0.0 || density > 1.0 { panic!( "Density value must be between 1.0 and 0.0, got {}.", @@ -130,10 +135,10 @@ impl StructureGen { Some(seed) => SeedableRng::seed_from_u64(seed), None => SeedableRng::from_entropy(), }; - StructureGen { density, rng } + UniformRandomGenerator { density, rng } } - pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { let last_node_idx = net.get_node_indices().len(); for parent in 0..last_node_idx { for child in 0..last_node_idx { diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index ac64f8d..4c32de7 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -85,20 +85,20 @@ fn dataset_wrong_shape() { #[test] #[should_panic] -fn structure_gen_wrong_density() { +fn uniform_random_generator_wrong_density() { let density = 2.1; - StructureGen::new(density, None); + let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); } #[test] -fn structure_gen_right_densities() { +fn uniform_random_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - StructureGen::new(density, None); + let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); } } #[test] -fn structure_gen_gen_structure() { +fn uniform_random_generator_generate_graph() { let mut net = CtbnNetwork::new(); for node_label in 0..100 { net.add_node( @@ -109,8 +109,8 @@ fn structure_gen_gen_structure() { ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator = StructureGen::new(density, Some(7641630759785120)); - structure_generator.gen_structure(&mut net); + let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ edges += net.get_children_set(node).len() @@ -118,7 +118,7 @@ fn structure_gen_gen_structure() { let nodes = net.get_node_indices().len() as f64; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance - // As the way `gen_structure()` is implemented we can only reasonably + // As the way `generate_graph()` is implemented we can only reasonably // expect the number of edges to be somewhere around the expected value. assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); }