diff --git a/Cargo.toml b/Cargo.toml index 3aa7c53..9941ed6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" +rand_chacha = "*" [dev-dependencies] approx = "*" diff --git a/src/params.rs b/src/params.rs index 019e281..f0e5efa 100644 --- a/src/params.rs +++ b/src/params.rs @@ -3,6 +3,7 @@ use ndarray::prelude::*; use rand::Rng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; +use rand_chacha::ChaCha8Rng; /// Error types for trait Params #[derive(Error, Debug, PartialEq)] @@ -30,15 +31,15 @@ pub trait ParamsTrait { /// Randomly generate a possible state of the node disregarding the state of the node and it's /// parents. - fn get_random_state_uniform(&self) -> StateType; + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType; /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. - fn get_random_residence_time(&self, state: usize, u: usize) -> Result; + fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. - fn get_random_state(&self, state: usize, u: usize) -> Result; + fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. fn get_reserved_space_as_parent(&self) -> usize; @@ -137,18 +138,16 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { self.residence_time = Option::None; } - fn get_random_state_uniform(&self) -> StateType { - let mut rng = rand::thread_rng(); + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } - fn get_random_residence_time(&self, state: usize, u: usize) -> Result { + fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { // Generate a random residence time given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let x: f64 = rng.gen_range(0.0..=1.0); Ok(-x.ln() / lambda) @@ -159,13 +158,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } } - fn get_random_state(&self, state: usize, u: usize) -> Result { + fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { // Generate a random transition given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let urand: f64 = rng.gen_range(0.0..=1.0); diff --git a/src/tools.rs b/src/tools.rs index 27438f9..2a38d34 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -3,6 +3,8 @@ use crate::node; use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_chacha::rand_core::SeedableRng; pub struct Trajectory { pub time: Array1, @@ -17,11 +19,16 @@ pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, + seed: Option, ) -> Dataset { let mut dataset = Dataset { trajectories: Vec::new(), }; + let seed = seed.unwrap_or_else(rand::random); + + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let node_idx: Vec<_> = net.get_node_indices().collect(); for _ in 0..n_trajectories { let mut t = 0.0; @@ -29,7 +36,7 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx .iter() - .map(|x| net.get_node(*x).params.get_random_state_uniform()) + .map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng)) .collect(); let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); @@ -51,6 +58,7 @@ pub fn trajectory_generator( .get_random_residence_time( net.get_node(idx).params.state_to_index(¤t_state[idx]), net.get_param_index_network(idx, ¤t_state), + &mut rng, ) .unwrap() + t, @@ -78,6 +86,7 @@ pub fn trajectory_generator( .params .state_to_index(¤t_state[next_node_transition]), net.get_param_index_network(next_node_transition, ¤t_state), + &mut rng, ) .unwrap(); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 345b8d1..15245fd 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,14 +40,14 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0); + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); assert!(CIM.abs_diff_eq(&arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.2)); + ]), 0.1)); } #[test] @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -101,7 +101,7 @@ fn learn_ternary_cim (pl: T) { [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.2)); + ]), 0.1)); } @@ -148,13 +148,13 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.2)); + [0.4, 0.6, -1.0]]]), 0.1)); } @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0); + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); @@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.2)); + ]), 0.1)); } #[test] diff --git a/tests/params.rs b/tests/params.rs index cbc7636..b049d4e 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,6 +1,8 @@ use ndarray::prelude::*; use rustyCTBN::params::*; use std::collections::BTreeSet; +use rand_chacha::ChaCha8Rng; +use rand_chacha::rand_core::SeedableRng; mod utils; @@ -21,8 +23,10 @@ fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); + states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state_uniform() { + if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { val } else { panic!() @@ -38,8 +42,10 @@ fn test_random_generation_state() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); + states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { + if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { val } else { panic!() @@ -57,7 +63,9 @@ fn test_random_generation_residence_time() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); + + states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } diff --git a/tests/tools.rs b/tests/tools.rs index 257c957..76847ef 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0); + let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);