Replaced the current RNG with a seedable one (`rand_chacha`)

pull/33/head
Meliurwen 3 years ago
parent cd89edd6d4
commit 651148fffd
  1. 2
      Cargo.toml
  2. 7
      src/params.rs
  3. 7
      src/tools.rs
  4. 8
      tests/parameter_learning.rs
  5. 6
      tests/params.rs
  6. 2
      tests/tools.rs

@ -12,6 +12,8 @@ thiserror = "*"
rand = "*"
bimap = "*"
enum_dispatch = "*"
rand_core = "*"
rand_chacha = "*"
[dev-dependencies]
approx = "*"

@ -1,8 +1,10 @@
use enum_dispatch::enum_dispatch;
use ndarray::prelude::*;
use rand::Rng;
use rand::rngs::ThreadRng;
use std::collections::{BTreeSet, HashMap};
use thiserror::Error;
use rand_chacha::ChaCha8Rng;
/// Error types for trait Params
#[derive(Error, Debug, PartialEq)]
@ -30,7 +32,7 @@ pub trait ParamsTrait {
/// Randomly generate a possible state of the node disregarding the state of the node and it's
/// parents.
fn get_random_state_uniform(&self) -> StateType;
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType;
/// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set.
@ -137,8 +139,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
self.residence_time = Option::None;
}
fn get_random_state_uniform(&self) -> StateType {
let mut rng = rand::thread_rng();
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType {
StateType::Discrete(rng.gen_range(0..(self.domain.len())))
}

@ -3,6 +3,8 @@ use crate::node;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;
pub struct Trajectory {
pub time: Array1<f64>,
@ -17,11 +19,14 @@ pub fn trajectory_generator<T: network::Network>(
net: &T,
n_trajectories: u64,
t_end: f64,
seed: u64,
) -> Dataset {
let mut dataset = Dataset {
trajectories: Vec::new(),
};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let node_idx: Vec<_> = net.get_node_indices().collect();
for _ in 0..n_trajectories {
let mut t = 0.0;
@ -29,7 +34,7 @@ pub fn trajectory_generator<T: network::Network>(
let mut events: Vec<Array1<usize>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx
.iter()
.map(|x| net.get_node(*x).params.get_random_state_uniform())
.map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng))
.collect();
let mut next_transitions: Vec<Option<f64>> =
(0..node_idx.len()).map(|_| Option::None).collect();

@ -40,7 +40,7 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 100.0);
let data = trajectory_generator(&net, 100, 100.0, 1234,);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
@ -93,7 +93,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 200.0);
let data = trajectory_generator(&net, 100, 200.0, 1234,);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
@ -148,7 +148,7 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 200.0);
let data = trajectory_generator(&net, 100, 200.0, 1234,);
let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
}
let data = trajectory_generator(&net, 300, 300.0);
let data = trajectory_generator(&net, 300, 300.0, 1234,);
let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]);

@ -1,6 +1,8 @@
use ndarray::prelude::*;
use rustyCTBN::params::*;
use std::collections::BTreeSet;
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;
mod utils;
@ -21,8 +23,10 @@ fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(123456);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) {
val
} else {
panic!()

@ -36,7 +36,7 @@ fn run_sampling() {
}
}
let data = trajectory_generator(&net, 4, 1.0);
let data = trajectory_generator(&net, 4, 1.0, 1234,);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);

Loading…
Cancel
Save