Decreased epsilon to `0.1` with a new seed

pull/33/head
Meliurwen 3 years ago
parent 62fcbd466a
commit a350ddc980
  1. 16
      tests/parameter_learning.rs
  2. 6
      tests/params.rs
  3. 2
      tests/tools.rs

@ -40,14 +40,14 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
} }
} }
let data = trajectory_generator(&net, 100, 100.0, Some(1234),); let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]); assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(&arr3(&[ assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]], [[-6.0, 6.0], [2.0, -2.0]],
]), 0.3)); ]), 0.1));
} }
#[test] #[test]
@ -93,7 +93,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
} }
} }
let data = trajectory_generator(&net, 100, 200.0, Some(1234),); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]); assert_eq!(CIM.shape(), [3, 3, 3]);
@ -101,7 +101,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.3)); ]), 0.1));
} }
@ -148,13 +148,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
} }
} }
let data = trajectory_generator(&net, 100, 200.0, Some(1234),); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 0, None); let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]); assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]), 0.3)); [0.4, 0.6, -1.0]]]), 0.1));
} }
@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
} }
let data = trajectory_generator(&net, 300, 300.0, Some(1234),); let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 2, None); let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]); assert_eq!(CIM.shape(), [9, 4, 4]);
@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
]), 0.3)); ]), 0.1));
} }
#[test] #[test]

@ -23,7 +23,7 @@ fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000); let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(123456); let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| { states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) {
@ -42,7 +42,7 @@ fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000); let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(123456); let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| { states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() {
@ -63,7 +63,7 @@ fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000); let mut states = Array1::<f64>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(123456); let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap());

@ -36,7 +36,7 @@ fn run_sampling() {
} }
} }
let data = trajectory_generator(&net, 4, 1.0, Some(1234),); let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),);
assert_eq!(4, data.trajectories.len()); assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);

Loading…
Cancel
Save