diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 47a067d..c58403a 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,7 +1,11 @@ //! Contains commonly used methods used across the crate. use ndarray::prelude::*; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use crate::process::ctbn::CtbnNetwork; +use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -108,3 +112,38 @@ pub fn trajectory_generator( //Return a dataset object with the sampled trajectories. Dataset::new(trajectories) } + +pub struct StructureGen { + density: f64, + rng: ChaCha8Rng, +} + +impl StructureGen { + pub fn new(density: f64, seed: Option) -> StructureGen { + if density < 0.0 || density > 1.0 { + panic!( + "Density value must be between 1.0 and 0.0, got {}.", + density + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + StructureGen { density, rng } + } + + pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + let last_node_idx = net.get_node_indices().len(); + for parent in 0..last_node_idx { + for child in 0..last_node_idx { + if parent != child { + if self.rng.gen_bool(self.density) { + net.add_edge(parent, child); + } + } + } + } + net + } +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 806faef..ac64f8d 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -82,3 +82,43 @@ fn dataset_wrong_shape() { let t2 = Trajectory::new(time, events); Dataset::new(vec![t1, t2]); } + +#[test] +#[should_panic] +fn structure_gen_wrong_density() { + let density = 2.1; + StructureGen::new(density, None); +} + +#[test] +fn structure_gen_right_densities() { + for density in [1.0, 0.75, 0.5, 0.25, 0.0] { + StructureGen::new(density, None); + } +} + +#[test] +fn structure_gen_gen_structure() { + let mut net = CtbnNetwork::new(); + for node_label in 0..100 { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + 2, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator = StructureGen::new(density, Some(7641630759785120)); + structure_generator.gen_structure(&mut net); + let mut edges = 0; + for node in net.get_node_indices(){ + edges += net.get_children_set(node).len() + } + let nodes = net.get_node_indices().len() as f64; + let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; + let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance + // As the way `gen_structure()` is implemented we can only reasonably + // expect the number of edges to be somewhere around the expected value. + assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); +}