diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 599b420..9222239 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -4,7 +4,6 @@ use ndarray::prelude::*; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use crate::process::ctbn::CtbnNetwork; use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -115,16 +114,16 @@ pub fn trajectory_generator( pub trait RandomGraphGenerator { fn new(density: f64, seed: Option) -> Self; - fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork; + fn generate_graph(&mut self, net: &mut T); } -pub struct UniformRandomGenerator { +pub struct UniformGraphGenerator { density: f64, rng: ChaCha8Rng, } -impl RandomGraphGenerator for UniformRandomGenerator { - fn new(density: f64, seed: Option) -> UniformRandomGenerator { +impl RandomGraphGenerator for UniformGraphGenerator { + fn new(density: f64, seed: Option) -> UniformGraphGenerator { if density < 0.0 || density > 1.0 { panic!( "Density value must be between 1.0 and 0.0, got {}.", @@ -135,10 +134,11 @@ impl RandomGraphGenerator for UniformRandomGenerator { Some(seed) => SeedableRng::seed_from_u64(seed), None => SeedableRng::from_entropy(), }; - UniformRandomGenerator { density, rng } + UniformGraphGenerator { density, rng } } - fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + fn generate_graph(&mut self, net: &mut T) { + net.initialize_adj_matrix(); let last_node_idx = net.get_node_indices().len(); for parent in 0..last_node_idx { for child in 0..last_node_idx { @@ -149,6 +149,5 @@ impl RandomGraphGenerator for UniformRandomGenerator { } } } - net } } diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 4c32de7..9a96959 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -87,13 +87,13 @@ fn dataset_wrong_shape() { #[should_panic] fn uniform_random_generator_wrong_density() { let density = 2.1; - let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } #[test] fn uniform_random_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } } @@ -109,7 +109,7 @@ fn uniform_random_generator_generate_graph() { ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ @@ -117,8 +117,8 @@ fn uniform_random_generator_generate_graph() { } let nodes = net.get_node_indices().len() as f64; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; - let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance + let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance // As the way `generate_graph()` is implemented we can only reasonably // expect the number of edges to be somewhere around the expected value. - assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); + assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); }