From 6104dcc329481a591c62d9a91735fc65ea593bd1 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 5 Apr 2022 16:38:43 +0200 Subject: [PATCH 01/10] In the Bayesian approach alpha and tau are now divided by the number of possible configurations in its parent set --- src/parameter_learning.rs | 12 +++++++----- tests/parameter_learning.rs | 16 ++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 67ea07f..4fe3bdd 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -114,8 +114,8 @@ impl ParameterLearning for MLE { } pub struct BayesianApproach { - pub default_alpha: usize, - pub default_tau: f64 + pub alpha: usize, + pub tau: f64 } impl ParameterLearning for BayesianApproach { @@ -135,13 +135,15 @@ impl ParameterLearning for BayesianApproach { }; let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - M.mapv_inplace(|x|{x + self.default_alpha}); - T.mapv_inplace(|x|{x + self.default_tau}); + + let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; + let tau: f64 = self.tau as f64 / M.shape()[0] as f64; + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); + .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau)))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index d6b8fd2..345b8d1 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -60,8 +60,8 @@ fn learn_binary_cim_MLE() { #[test] fn learn_binary_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_binary_cim(ba); } @@ -115,8 +115,8 @@ fn learn_ternary_cim_MLE() { #[test] fn learn_ternary_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_ternary_cim(ba); } @@ -168,8 +168,8 @@ fn learn_ternary_cim_no_parents_MLE() { #[test] fn learn_ternary_cim_no_parents_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_ternary_cim_no_parents(ba); } @@ -257,7 +257,7 @@ fn learn_mixed_discrete_cim_MLE() { #[test] fn learn_mixed_discrete_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_mixed_discrete_cim(ba); } From bb42365fb81cfad449609b76575b10122d76568a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 6 Apr 2022 11:18:29 +0200 Subject: [PATCH 02/10] Added meta and refactor issue templates --- .github/ISSUE_TEMPLATE/meta_request.md | 26 ++++++++++++++++++++++ .github/ISSUE_TEMPLATE/refactor_request.md | 26 ++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/meta_request.md create mode 100644 .github/ISSUE_TEMPLATE/refactor_request.md diff --git a/.github/ISSUE_TEMPLATE/meta_request.md b/.github/ISSUE_TEMPLATE/meta_request.md new file mode 100644 index 0000000..d80ccde --- /dev/null +++ b/.github/ISSUE_TEMPLATE/meta_request.md @@ -0,0 +1,26 @@ +--- +name: 📑 Meta request +about: Suggest an idea or a change for this same repository +title: '[Meta] ' +labels: 'meta' +assignees: '' + +--- + +## Description + +As a X, I want to Y, so Z. + +## Acceptance Criteria + +* Criteria 1 +* Criteria 2 + +## Checklist + +* [ ] Element 1 +* [ ] Element 2 + +## (Optional) Extra info + +None diff --git a/.github/ISSUE_TEMPLATE/refactor_request.md b/.github/ISSUE_TEMPLATE/refactor_request.md new file mode 100644 index 0000000..503e3f3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/refactor_request.md @@ -0,0 +1,26 @@ +--- +name: ⚙️ Refactor request +about: Suggest a refactor for this project +title: '[Refactor] ' +labels: 'enhancement' +assignees: '' + +--- + +## Description + +As a X, I want to Y, so Z. + +## Acceptance Criteria + +* Criteria 1 +* Criteria 2 + +## Checklist + +* [ ] Element 1 +* [ ] Element 2 + +## (Optional) Extra info + +None From 651148fffdd6c6d1c4d9c2095e912269e881f89d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 13:51:15 +0200 Subject: [PATCH 03/10] Replaced the current RNG with a seedable one (`rand_chacha`) --- Cargo.toml | 2 ++ src/params.rs | 7 ++++--- src/tools.rs | 7 ++++++- tests/parameter_learning.rs | 8 ++++---- tests/params.rs | 6 +++++- tests/tools.rs | 2 +- 6 files changed, 22 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3aa7c53..4cb6c06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" +rand_core = "*" +rand_chacha = "*" [dev-dependencies] approx = "*" diff --git a/src/params.rs b/src/params.rs index 019e281..b418df6 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,8 +1,10 @@ use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; +use rand::rngs::ThreadRng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; +use rand_chacha::ChaCha8Rng; /// Error types for trait Params #[derive(Error, Debug, PartialEq)] @@ -30,7 +32,7 @@ pub trait ParamsTrait { /// Randomly generate a possible state of the node disregarding the state of the node and it's /// parents. - fn get_random_state_uniform(&self) -> StateType; + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType; /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. @@ -137,8 +139,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { self.residence_time = Option::None; } - fn get_random_state_uniform(&self) -> StateType { - let mut rng = rand::thread_rng(); + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } diff --git a/src/tools.rs b/src/tools.rs index 27438f9..4efe085 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -3,6 +3,8 @@ use crate::node; use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_core::SeedableRng; pub struct Trajectory { pub time: Array1, @@ -17,11 +19,14 @@ pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, + seed: u64, ) -> Dataset { let mut dataset = Dataset { trajectories: Vec::new(), }; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let node_idx: Vec<_> = net.get_node_indices().collect(); for _ in 0..n_trajectories { let mut t = 0.0; @@ -29,7 +34,7 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx .iter() - .map(|x| net.get_node(*x).params.get_random_state_uniform()) + .map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng)) .collect(); let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 345b8d1..96b6ce1 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,7 +40,7 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0); + let data = trajectory_generator(&net, 100, 100.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -148,7 +148,7 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0); + let data = trajectory_generator(&net, 300, 300.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/params.rs b/tests/params.rs index cbc7636..23c99fa 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,6 +1,8 @@ use ndarray::prelude::*; use rustyCTBN::params::*; use std::collections::BTreeSet; +use rand_chacha::ChaCha8Rng; +use rand_core::SeedableRng; mod utils; @@ -21,8 +23,10 @@ fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); + let mut rng = ChaCha8Rng::seed_from_u64(123456); + states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state_uniform() { + if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { val } else { panic!() diff --git a/tests/tools.rs b/tests/tools.rs index 257c957..f831ec4 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0); + let data = trajectory_generator(&net, 4, 1.0, 1234,); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); From 185e1756cacc11476cdc11e0d5c6a5740f2c1d2b Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 14:36:38 +0200 Subject: [PATCH 04/10] The residence time generation is now seedable --- src/params.rs | 4 ++-- src/tools.rs | 1 + tests/params.rs | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/params.rs b/src/params.rs index b418df6..963ff8c 100644 --- a/src/params.rs +++ b/src/params.rs @@ -36,7 +36,7 @@ pub trait ParamsTrait { /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. - fn get_random_residence_time(&self, state: usize, u: usize) -> Result; + fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. @@ -143,7 +143,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } - fn get_random_residence_time(&self, state: usize, u: usize) -> Result { + fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { // Generate a random residence time given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates diff --git a/src/tools.rs b/src/tools.rs index 4efe085..acee937 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -56,6 +56,7 @@ pub fn trajectory_generator( .get_random_residence_time( net.get_node(idx).params.state_to_index(¤t_state[idx]), net.get_param_index_network(idx, ¤t_state), + &mut rng, ) .unwrap() + t, diff --git a/tests/params.rs b/tests/params.rs index 23c99fa..f8b1154 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -61,7 +61,9 @@ fn test_random_generation_residence_time() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); + let mut rng = ChaCha8Rng::seed_from_u64(123456); + + states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } From 9316fcee30b68021b26e3ef4385fba143cccb6be Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 14:41:18 +0200 Subject: [PATCH 05/10] The state generation is now seedable --- src/params.rs | 4 ++-- src/tools.rs | 1 + tests/params.rs | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/params.rs b/src/params.rs index 963ff8c..6173d75 100644 --- a/src/params.rs +++ b/src/params.rs @@ -40,7 +40,7 @@ pub trait ParamsTrait { /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. - fn get_random_state(&self, state: usize, u: usize) -> Result; + fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. fn get_reserved_space_as_parent(&self) -> usize; @@ -160,7 +160,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } } - fn get_random_state(&self, state: usize, u: usize) -> Result { + fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { // Generate a random transition given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution diff --git a/src/tools.rs b/src/tools.rs index acee937..858923e 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -84,6 +84,7 @@ pub fn trajectory_generator( .params .state_to_index(¤t_state[next_node_transition]), net.get_param_index_network(next_node_transition, ¤t_state), + &mut rng, ) .unwrap(); diff --git a/tests/params.rs b/tests/params.rs index f8b1154..8ab81c1 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -42,8 +42,10 @@ fn test_random_generation_state() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); + let mut rng = ChaCha8Rng::seed_from_u64(123456); + states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { + if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { val } else { panic!() From 05af0f37c4fdd61f20e753a7a1fe4e092a57c790 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 14:57:18 +0200 Subject: [PATCH 06/10] Added `.vscode` folder to `.gitignore` --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 96ef6c0..c640ca5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.vscode From 79dbd885296c602370a57592847ba08b2b3b7ed8 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 15:35:23 +0200 Subject: [PATCH 07/10] Get rid of the `rand_core` rependency --- Cargo.toml | 1 - src/tools.rs | 2 +- tests/params.rs | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4cb6c06..9941ed6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,6 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" -rand_core = "*" rand_chacha = "*" [dev-dependencies] diff --git a/src/tools.rs b/src/tools.rs index 858923e..8cec2a2 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -4,7 +4,7 @@ use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; use rand_chacha::ChaCha8Rng; -use rand_core::SeedableRng; +use rand_chacha::rand_core::SeedableRng; pub struct Trajectory { pub time: Array1, diff --git a/tests/params.rs b/tests/params.rs index 8ab81c1..255aba6 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -2,7 +2,7 @@ use ndarray::prelude::*; use rustyCTBN::params::*; use std::collections::BTreeSet; use rand_chacha::ChaCha8Rng; -use rand_core::SeedableRng; +use rand_chacha::rand_core::SeedableRng; mod utils; From 79ec08b29af34a99388f398865ecddb183ac00c2 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 16:33:39 +0200 Subject: [PATCH 08/10] Made seed optional in `trajectory_generator` --- src/tools.rs | 4 +++- tests/parameter_learning.rs | 8 ++++---- tests/tools.rs | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/tools.rs b/src/tools.rs index 8cec2a2..2a38d34 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -19,12 +19,14 @@ pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, - seed: u64, + seed: Option, ) -> Dataset { let mut dataset = Dataset { trajectories: Vec::new(), }; + let seed = seed.unwrap_or_else(rand::random); + let mut rng = ChaCha8Rng::seed_from_u64(seed); let node_idx: Vec<_> = net.get_node_indices().collect(); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 96b6ce1..af57291 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,7 +40,7 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0, 1234,); + let data = trajectory_generator(&net, 100, 100.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, 1234,); + let data = trajectory_generator(&net, 100, 200.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -148,7 +148,7 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, 1234,); + let data = trajectory_generator(&net, 100, 200.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0, 1234,); + let data = trajectory_generator(&net, 300, 300.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/tools.rs b/tests/tools.rs index f831ec4..fc9b930 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0, 1234,); + let data = trajectory_generator(&net, 4, 1.0, Some(1234),); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); From 62fcbd466a08d7dead12c48a40c8200f3fa57698 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 12 Apr 2022 09:33:07 +0200 Subject: [PATCH 09/10] Removed `rand::thread_rng` overriding the ChaCha's `rng`, increased the epsilon from 0.2 to 0.3 in the tests --- src/params.rs | 3 --- tests/parameter_learning.rs | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/params.rs b/src/params.rs index 6173d75..f0e5efa 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,7 +1,6 @@ use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; -use rand::rngs::ThreadRng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; use rand_chacha::ChaCha8Rng; @@ -149,7 +148,6 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let x: f64 = rng.gen_range(0.0..=1.0); Ok(-x.ln() / lambda) @@ -166,7 +164,6 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let urand: f64 = rng.gen_range(0.0..=1.0); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index af57291..adff6e8 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -47,7 +47,7 @@ fn learn_binary_cim (pl: T) { assert!(CIM.abs_diff_eq(&arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.2)); + ]), 0.3)); } #[test] @@ -101,7 +101,7 @@ fn learn_ternary_cim (pl: T) { [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.2)); + ]), 0.3)); } @@ -154,7 +154,7 @@ fn learn_ternary_cim_no_parents (pl: T) { assert_eq!(CIM.shape(), [1, 3, 3]); assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.2)); + [0.4, 0.6, -1.0]]]), 0.3)); } @@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.2)); + ]), 0.3)); } #[test] From a350ddc980204218c72879bf28d12b2dddecf825 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 12 Apr 2022 13:40:00 +0200 Subject: [PATCH 10/10] Decreased epsilon to `0.1` with a new seed --- tests/parameter_learning.rs | 16 ++++++++-------- tests/params.rs | 6 +++--- tests/tools.rs | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index adff6e8..15245fd 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,14 +40,14 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0, Some(1234),); + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); assert!(CIM.abs_diff_eq(&arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.3)); + ]), 0.1)); } #[test] @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, Some(1234),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -101,7 +101,7 @@ fn learn_ternary_cim (pl: T) { [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.3)); + ]), 0.1)); } @@ -148,13 +148,13 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, Some(1234),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.3)); + [0.4, 0.6, -1.0]]]), 0.1)); } @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0, Some(1234),); + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); @@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.3)); + ]), 0.1)); } #[test] diff --git a/tests/params.rs b/tests/params.rs index 255aba6..b049d4e 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -23,7 +23,7 @@ fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - let mut rng = ChaCha8Rng::seed_from_u64(123456); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); states.mapv_inplace(|_| { if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { @@ -42,7 +42,7 @@ fn test_random_generation_state() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - let mut rng = ChaCha8Rng::seed_from_u64(123456); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); states.mapv_inplace(|_| { if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { @@ -63,7 +63,7 @@ fn test_random_generation_residence_time() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - let mut rng = ChaCha8Rng::seed_from_u64(123456); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); diff --git a/tests/tools.rs b/tests/tools.rs index fc9b930..76847ef 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0, Some(1234),); + let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);