Made seed optional in `trajectory_generator`

pull/33/head
Meliurwen 3 years ago
parent 79dbd88529
commit 79ec08b29a
  1. 4
      src/tools.rs
  2. 8
      tests/parameter_learning.rs
  3. 2
      tests/tools.rs

@ -19,12 +19,14 @@ pub fn trajectory_generator<T: network::Network>(
net: &T,
n_trajectories: u64,
t_end: f64,
seed: u64,
seed: Option<u64>,
) -> Dataset {
let mut dataset = Dataset {
trajectories: Vec::new(),
};
let seed = seed.unwrap_or_else(rand::random);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let node_idx: Vec<_> = net.get_node_indices().collect();

@ -40,7 +40,7 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 100.0, 1234,);
let data = trajectory_generator(&net, 100, 100.0, Some(1234),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
@ -93,7 +93,7 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 200.0, 1234,);
let data = trajectory_generator(&net, 100, 200.0, Some(1234),);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
@ -148,7 +148,7 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
}
}
let data = trajectory_generator(&net, 100, 200.0, 1234,);
let data = trajectory_generator(&net, 100, 200.0, Some(1234),);
let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
}
let data = trajectory_generator(&net, 300, 300.0, 1234,);
let data = trajectory_generator(&net, 300, 300.0, Some(1234),);
let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]);

@ -36,7 +36,7 @@ fn run_sampling() {
}
}
let data = trajectory_generator(&net, 4, 1.0, 1234,);
let data = trajectory_generator(&net, 4, 1.0, Some(1234),);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);

Loading…
Cancel
Save