Merge branch '12-feature-add-seed-for-random-generation' into 'dev'

pull/39/head
Meliurwen 3 years ago
commit 258d016155
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 1
      Cargo.toml
  2. 16
      src/params.rs
  3. 11
      src/tools.rs
  4. 16
      tests/parameter_learning.rs
  5. 14
      tests/params.rs
  6. 2
      tests/tools.rs

@ -12,6 +12,7 @@ thiserror = "*"
rand = "*"
bimap = "*"
enum_dispatch = "*"
rand_chacha = "*"
[dev-dependencies]
approx = "*"

@ -3,6 +3,7 @@ use ndarray::prelude::*;
use rand::Rng;
use std::collections::{BTreeSet, HashMap};
use thiserror::Error;
use rand_chacha::ChaCha8Rng;
/// Error types for trait Params
#[derive(Error, Debug, PartialEq)]
@ -30,15 +31,15 @@ pub trait ParamsTrait {
/// Randomly generate a possible state of the node disregarding the state of the node and it's
/// parents.
fn get_random_state_uniform(&self) -> StateType;
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType;
/// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>;
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set.
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError>;
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize;
@ -137,18 +138,16 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
self.residence_time = Option::None;
}
fn get_random_state_uniform(&self) -> StateType {
let mut rng = rand::thread_rng();
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType {
StateType::Discrete(rng.gen_range(0..(self.domain.len())))
}
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> {
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
match &self.cim {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda)
@ -159,13 +158,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
}
}
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> {
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
match &self.cim {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let urand: f64 = rng.gen_range(0.0..=1.0);

@ -3,6 +3,8 @@ use crate::node;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
pub struct Trajectory {
pub time: Array1<f64>,
@ -17,11 +19,16 @@ pub fn trajectory_generator<T: network::Network>(
net: &T,
n_trajectories: u64,
t_end: f64,
seed: Option<u64>,
) -> Dataset {
let mut dataset = Dataset {
trajectories: Vec::new(),
};
let seed = seed.unwrap_or_else(rand::random);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let node_idx: Vec<_> = net.get_node_indices().collect();
for _ in 0..n_trajectories {
let mut t = 0.0;
@ -29,7 +36,7 @@ pub fn trajectory_generator<T: network::Network>(
let mut events: Vec<Array1<usize>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx
.iter()
.map(|x| net.get_node(*x).params.get_random_state_uniform())
.map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng))
.collect();
let mut next_transitions: Vec<Option<f64>> =
(0..node_idx.len()).map(|_| Option::None).collect();
@ -51,6 +58,7 @@ pub fn trajectory_generator<T: network::Network>(
.get_random_residence_time(
net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state),
&mut rng,
)
.unwrap()
+ t,
@ -78,6 +86,7 @@ pub fn trajectory_generator<T: network::Network>(
.params
.state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state),
&mut rng,
)
.unwrap();

@ -40,14 +40,14 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 100.0);
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]), 0.2));
]), 0.1));
}
#[test]
@ -93,7 +93,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 200.0);
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
@ -101,7 +101,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.2));
]), 0.1));
}
@ -148,13 +148,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 200.0);
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]), 0.2));
[0.4, 0.6, -1.0]]]), 0.1));
}
@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
}
let data = trajectory_generator(&net, 300, 300.0);
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]);
@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
]), 0.2));
]), 0.1));
}
#[test]

@ -1,6 +1,8 @@
use ndarray::prelude::*;
use rustyCTBN::params::*;
use std::collections::BTreeSet;
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
mod utils;
@ -21,8 +23,10 @@ fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) {
val
} else {
panic!()
@ -38,8 +42,10 @@ fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() {
val
} else {
panic!()
@ -57,7 +63,9 @@ fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap());
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap());
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
}

@ -36,7 +36,7 @@ fn run_sampling() {
}
}
let data = trajectory_generator(&net, 4, 1.0);
let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);

Loading…
Cancel
Save