diff --git a/Cargo.toml b/Cargo.toml index 3aa7c53..4cb6c06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" +rand_core = "*" +rand_chacha = "*" [dev-dependencies] approx = "*" diff --git a/src/params.rs b/src/params.rs index 019e281..b418df6 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,8 +1,10 @@ use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; +use rand::rngs::ThreadRng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; +use rand_chacha::ChaCha8Rng; /// Error types for trait Params #[derive(Error, Debug, PartialEq)] @@ -30,7 +32,7 @@ pub trait ParamsTrait { /// Randomly generate a possible state of the node disregarding the state of the node and it's /// parents. - fn get_random_state_uniform(&self) -> StateType; + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType; /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. @@ -137,8 +139,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { self.residence_time = Option::None; } - fn get_random_state_uniform(&self) -> StateType { - let mut rng = rand::thread_rng(); + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } diff --git a/src/tools.rs b/src/tools.rs index 27438f9..4efe085 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -3,6 +3,8 @@ use crate::node; use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_core::SeedableRng; pub struct Trajectory { pub time: Array1, @@ -17,11 +19,14 @@ pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, + seed: u64, ) -> Dataset { let mut dataset = Dataset { trajectories: Vec::new(), }; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let node_idx: Vec<_> = net.get_node_indices().collect(); for _ in 0..n_trajectories { let mut t = 0.0; @@ -29,7 +34,7 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx .iter() - .map(|x| net.get_node(*x).params.get_random_state_uniform()) + .map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng)) .collect(); let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 345b8d1..96b6ce1 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,7 +40,7 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0); + let data = trajectory_generator(&net, 100, 100.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -148,7 +148,7 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0); + let data = trajectory_generator(&net, 300, 300.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/params.rs b/tests/params.rs index cbc7636..23c99fa 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,6 +1,8 @@ use ndarray::prelude::*; use rustyCTBN::params::*; use std::collections::BTreeSet; +use rand_chacha::ChaCha8Rng; +use rand_core::SeedableRng; mod utils; @@ -21,8 +23,10 @@ fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); + let mut rng = ChaCha8Rng::seed_from_u64(123456); + states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state_uniform() { + if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { val } else { panic!() diff --git a/tests/tools.rs b/tests/tools.rs index 257c957..f831ec4 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0); + let data = trajectory_generator(&net, 4, 1.0, 1234,); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);