Added `StructureGen` struct for generating the structure of a `CtbnNetwork`

pull/85/head
Meliurwen 2 years ago
parent 49c2c55f61
commit a077f738ee
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 39
      reCTBN/src/tools.rs
  2. 40
      reCTBN/tests/tools.rs

@ -1,7 +1,11 @@
//! Contains commonly used methods used across the crate. //! Contains commonly used methods used across the crate.
use ndarray::prelude::*; use ndarray::prelude::*;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::process::ctbn::CtbnNetwork;
use crate::process::NetworkProcess;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{params, process}; use crate::{params, process};
@ -108,3 +112,38 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
//Return a dataset object with the sampled trajectories. //Return a dataset object with the sampled trajectories.
Dataset::new(trajectories) Dataset::new(trajectories)
} }
pub struct StructureGen {
density: f64,
rng: ChaCha8Rng,
}
impl StructureGen {
pub fn new(density: f64, seed: Option<u64>) -> StructureGen {
if density < 0.0 || density > 1.0 {
panic!(
"Density value must be between 1.0 and 0.0, got {}.",
density
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
StructureGen { density, rng }
}
pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork {
let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx {
for child in 0..last_node_idx {
if parent != child {
if self.rng.gen_bool(self.density) {
net.add_edge(parent, child);
}
}
}
}
net
}
}

@ -82,3 +82,43 @@ fn dataset_wrong_shape() {
let t2 = Trajectory::new(time, events); let t2 = Trajectory::new(time, events);
Dataset::new(vec![t1, t2]); Dataset::new(vec![t1, t2]);
} }
#[test]
#[should_panic]
fn structure_gen_wrong_density() {
let density = 2.1;
StructureGen::new(density, None);
}
#[test]
fn structure_gen_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
StructureGen::new(density, None);
}
}
#[test]
fn structure_gen_gen_structure() {
let mut net = CtbnNetwork::new();
for node_label in 0..100 {
net.add_node(
utils::generate_discrete_time_continous_node(
node_label.to_string(),
2,
)
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator = StructureGen::new(density, Some(7641630759785120));
structure_generator.gen_structure(&mut net);
let mut edges = 0;
for node in net.get_node_indices(){
edges += net.get_children_set(node).len()
}
let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance
// As the way `gen_structure()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance));
}

Loading…
Cancel
Save