Changed trajectory_generator to use generics instead of Box for the network object

pull/19/head
Alessandro Bregoli 3 years ago
parent f87900fdbd
commit ef04e04d1e
  1. 4
      src/tools.rs
  2. 6
      tests/parameter_learning.rs
  3. 2
      tests/tools.rs

@ -13,8 +13,8 @@ pub struct Dataset {
pub trajectories: Vec<Trajectory>, pub trajectories: Vec<Trajectory>,
} }
pub fn trajectory_generator( pub fn trajectory_generator<T: network::Network>(
net: Box<&dyn network::Network>, net: &T,
n_trajectories: u64, n_trajectories: u64,
t_end: f64, t_end: f64,
) -> Dataset { ) -> Dataset {

@ -41,7 +41,7 @@ fn learn_binary_cim_MLE() {
} }
} }
let data = trajectory_generator(Box::new(&net), 100, 100.0); let data = trajectory_generator(&net, 100, 100.0);
let mle = MLE{}; let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
@ -82,7 +82,7 @@ fn learn_ternary_cim_MLE() {
} }
} }
let data = trajectory_generator(Box::new(&net), 100, 200.0); let data = trajectory_generator(&net, 100, 200.0);
let mle = MLE{}; let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
@ -123,7 +123,7 @@ fn learn_ternary_cim_MLE_no_parents() {
} }
} }
let data = trajectory_generator(Box::new(&net), 100, 200.0); let data = trajectory_generator(&net, 100, 200.0);
let mle = MLE{}; let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 0, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);

@ -36,7 +36,7 @@ fn run_sampling() {
} }
} }
let data = trajectory_generator(Box::new(&net), 4, 1.0); let data = trajectory_generator(&net, 4, 1.0);
assert_eq!(4, data.trajectories.len()); assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);

Loading…
Cancel
Save