From 8ca93c931b14a90718312ba383403fed97cd0c2f Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 10:04:09 +0200 Subject: [PATCH] Refactor --- src/{structure_learning/mod.rs => structure_learning.rs} | 2 +- src/structure_learning/score_based_algorithm.rs | 8 ++------ tests/structure_learning.rs | 4 ++-- 3 files changed, 5 insertions(+), 9 deletions(-) rename src/{structure_learning/mod.rs => structure_learning.rs} (70%) diff --git a/src/structure_learning/mod.rs b/src/structure_learning.rs similarity index 70% rename from src/structure_learning/mod.rs rename to src/structure_learning.rs index a335101..8ba91df 100644 --- a/src/structure_learning/mod.rs +++ b/src/structure_learning.rs @@ -4,7 +4,7 @@ use crate::network; use crate::tools; pub trait StructureLearningAlgorithm { - fn fit(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network; } diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index 63620fe..0a23c36 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -1,11 +1,7 @@ -use crate::params; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; use crate::tools; -use crate::{network, parameter_learning}; -use ndarray::prelude::*; -use rand::prelude::*; -use rand_chacha::ChaCha8Rng; +use crate::network; use std::collections::BTreeSet; pub struct HillClimbing { @@ -19,7 +15,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network, { diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 4ce89a3..25ce1e8 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -87,7 +87,7 @@ fn learn_ternary_net_2_nodes (sl: T) { let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); - let net = sl.fit(net, &data); + let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } @@ -164,7 +164,7 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); - let net = sl.fit(net, &data); + let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));