From ef04e04d1e4f4ddeb47f8945146b57f05e66c31e Mon Sep 17 00:00:00 2001 From: Alessandro Bregoli Date: Fri, 18 Mar 2022 07:45:57 +0100 Subject: [PATCH] Changed trajectory_generator to use generics instead of Box for the network object --- src/tools.rs | 4 ++-- tests/parameter_learning.rs | 6 +++--- tests/tools.rs | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tools.rs b/src/tools.rs index a719bb9..27438f9 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -13,8 +13,8 @@ pub struct Dataset { pub trajectories: Vec, } -pub fn trajectory_generator( - net: Box<&dyn network::Network>, +pub fn trajectory_generator( + net: &T, n_trajectories: u64, t_end: f64, ) -> Dataset { diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index c4b2a67..c138f64 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -41,7 +41,7 @@ fn learn_binary_cim_MLE() { } } - let data = trajectory_generator(Box::new(&net), 100, 100.0); + let data = trajectory_generator(&net, 100, 100.0); let mle = MLE{}; let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); @@ -82,7 +82,7 @@ fn learn_ternary_cim_MLE() { } } - let data = trajectory_generator(Box::new(&net), 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0); let mle = MLE{}; let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); @@ -123,7 +123,7 @@ fn learn_ternary_cim_MLE_no_parents() { } } - let data = trajectory_generator(Box::new(&net), 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0); let mle = MLE{}; let (CIM, M, T) = mle.fit(Box::new(&net), &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); diff --git a/tests/tools.rs b/tests/tools.rs index efeef2e..802e2fe 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(Box::new(&net), 4, 1.0); + let data = trajectory_generator(&net, 4, 1.0); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);