Add the possibility to insert a seed for the random generation #33

Merged
meliurwen merged 7 commits from 12-feature-add-seed-for-random-generation into dev 3 years ago
  1. 1
      Cargo.toml
  2. 16
      src/params.rs
  3. 11
      src/tools.rs
  4. 16
      tests/parameter_learning.rs
  5. 14
      tests/params.rs
  6. 2
      tests/tools.rs

@ -12,6 +12,7 @@ thiserror = "*"
rand = "*" rand = "*"
bimap = "*" bimap = "*"
enum_dispatch = "*" enum_dispatch = "*"
rand_chacha = "*"
[dev-dependencies] [dev-dependencies]
approx = "*" approx = "*"

@ -3,6 +3,7 @@ use ndarray::prelude::*;
use rand::Rng; use rand::Rng;
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use thiserror::Error; use thiserror::Error;
use rand_chacha::ChaCha8Rng;
/// Error types for trait Params /// Error types for trait Params
#[derive(Error, Debug, PartialEq)] #[derive(Error, Debug, PartialEq)]
@ -30,15 +31,15 @@ pub trait ParamsTrait {
/// Randomly generate a possible state of the node disregarding the state of the node and it's /// Randomly generate a possible state of the node disregarding the state of the node and it's
/// parents. /// parents.
fn get_random_state_uniform(&self) -> StateType; fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType;
/// Randomly generate a residence time for the given node taking into account the node state /// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set. /// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>; fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state /// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set. /// and its parent set.
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError>; fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs. /// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize; fn get_reserved_space_as_parent(&self) -> usize;
@ -137,18 +138,16 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
self.residence_time = Option::None; self.residence_time = Option::None;
} }
fn get_random_state_uniform(&self) -> StateType { fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType {
let mut rng = rand::thread_rng();
StateType::Discrete(rng.gen_range(0..(self.domain.len()))) StateType::Discrete(rng.gen_range(0..(self.domain.len())))
} }
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> { fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set. // Generate a random residence time given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
match &self.cim { match &self.cim {
Option::Some(cim) => { Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..=1.0); let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda) Ok(-x.ln() / lambda)
@ -159,13 +158,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
} }
} }
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> { fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set. // Generate a random transition given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
match &self.cim { match &self.cim {
Option::Some(cim) => { Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0; let lambda = cim[[u, state, state]] * -1.0;
let urand: f64 = rng.gen_range(0.0..=1.0); let urand: f64 = rng.gen_range(0.0..=1.0);

@ -3,6 +3,8 @@ use crate::node;
use crate::params; use crate::params;
use crate::params::ParamsTrait; use crate::params::ParamsTrait;
use ndarray::prelude::*; use ndarray::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
pub struct Trajectory { pub struct Trajectory {
pub time: Array1<f64>, pub time: Array1<f64>,
@ -17,11 +19,16 @@ pub fn trajectory_generator<T: network::Network>(
net: &T, net: &T,
n_trajectories: u64, n_trajectories: u64,
t_end: f64, t_end: f64,
seed: Option<u64>,
) -> Dataset { ) -> Dataset {
let mut dataset = Dataset { let mut dataset = Dataset {
trajectories: Vec::new(), trajectories: Vec::new(),
}; };
let seed = seed.unwrap_or_else(rand::random);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let node_idx: Vec<_> = net.get_node_indices().collect(); let node_idx: Vec<_> = net.get_node_indices().collect();
for _ in 0..n_trajectories { for _ in 0..n_trajectories {
let mut t = 0.0; let mut t = 0.0;
@ -29,7 +36,7 @@ pub fn trajectory_generator<T: network::Network>(
let mut events: Vec<Array1<usize>> = Vec::new(); let mut events: Vec<Array1<usize>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx let mut current_state: Vec<params::StateType> = node_idx
.iter() .iter()
.map(|x| net.get_node(*x).params.get_random_state_uniform()) .map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng))
.collect(); .collect();
let mut next_transitions: Vec<Option<f64>> = let mut next_transitions: Vec<Option<f64>> =
(0..node_idx.len()).map(|_| Option::None).collect(); (0..node_idx.len()).map(|_| Option::None).collect();
@ -51,6 +58,7 @@ pub fn trajectory_generator<T: network::Network>(
.get_random_residence_time( .get_random_residence_time(
net.get_node(idx).params.state_to_index(&current_state[idx]), net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state), net.get_param_index_network(idx, &current_state),
&mut rng,
) )
.unwrap() .unwrap()
+ t, + t,
@ -78,6 +86,7 @@ pub fn trajectory_generator<T: network::Network>(
.params .params
.state_to_index(&current_state[next_node_transition]), .state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state), net.get_param_index_network(next_node_transition, &current_state),
&mut rng,
) )
.unwrap(); .unwrap();

@ -40,14 +40,14 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
} }
} }
let data = trajectory_generator(&net, 100, 100.0); let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]); assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(&arr3(&[ assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]], [[-6.0, 6.0], [2.0, -2.0]],
]), 0.2)); ]), 0.1));
} }
#[test] #[test]
@ -93,7 +93,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
} }
} }
let data = trajectory_generator(&net, 100, 200.0); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]); assert_eq!(CIM.shape(), [3, 3, 3]);
@ -101,7 +101,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.2)); ]), 0.1));
} }
@ -148,13 +148,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
} }
} }
let data = trajectory_generator(&net, 100, 200.0); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 0, None); let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]); assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]), 0.2)); [0.4, 0.6, -1.0]]]), 0.1));
} }
@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
} }
let data = trajectory_generator(&net, 300, 300.0); let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 2, None); let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]); assert_eq!(CIM.shape(), [9, 4, 4]);
@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
]), 0.2)); ]), 0.1));
} }
#[test] #[test]

@ -1,6 +1,8 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use rustyCTBN::params::*; use rustyCTBN::params::*;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
mod utils; mod utils;
@ -21,8 +23,10 @@ fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000); let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| { states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() { if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) {
val val
} else { } else {
panic!() panic!()
@ -38,8 +42,10 @@ fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000); let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| { states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() {
val val
} else { } else {
panic!() panic!()
@ -57,7 +63,9 @@ fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000); let mut states = Array1::<f64>::zeros(10000);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap());
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
} }

@ -36,7 +36,7 @@ fn run_sampling() {
} }
} }
let data = trajectory_generator(&net, 4, 1.0); let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),);
assert_eq!(4, data.trajectories.len()); assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);

Loading…
Cancel
Save