|
|
|
@ -1,7 +1,11 @@ |
|
|
|
|
//! Contains commonly used methods used across the crate.
|
|
|
|
|
|
|
|
|
|
use ndarray::prelude::*; |
|
|
|
|
use rand::{Rng, SeedableRng}; |
|
|
|
|
use rand_chacha::ChaCha8Rng; |
|
|
|
|
|
|
|
|
|
use crate::process::ctbn::CtbnNetwork; |
|
|
|
|
use crate::process::NetworkProcess; |
|
|
|
|
use crate::sampling::{ForwardSampler, Sampler}; |
|
|
|
|
use crate::{params, process}; |
|
|
|
|
|
|
|
|
@ -108,3 +112,38 @@ pub fn trajectory_generator<T: process::NetworkProcess>( |
|
|
|
|
//Return a dataset object with the sampled trajectories.
|
|
|
|
|
Dataset::new(trajectories) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct StructureGen { |
|
|
|
|
density: f64, |
|
|
|
|
rng: ChaCha8Rng, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl StructureGen { |
|
|
|
|
pub fn new(density: f64, seed: Option<u64>) -> StructureGen { |
|
|
|
|
if density < 0.0 || density > 1.0 { |
|
|
|
|
panic!( |
|
|
|
|
"Density value must be between 1.0 and 0.0, got {}.", |
|
|
|
|
density |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
let rng: ChaCha8Rng = match seed { |
|
|
|
|
Some(seed) => SeedableRng::seed_from_u64(seed), |
|
|
|
|
None => SeedableRng::from_entropy(), |
|
|
|
|
}; |
|
|
|
|
StructureGen { density, rng } |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { |
|
|
|
|
let last_node_idx = net.get_node_indices().len(); |
|
|
|
|
for parent in 0..last_node_idx { |
|
|
|
|
for child in 0..last_node_idx { |
|
|
|
|
if parent != child { |
|
|
|
|
if self.rng.gen_bool(self.density) { |
|
|
|
|
net.add_edge(parent, child); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
net |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|