pull/42/head
AlessandroBregoli 3 years ago
parent 808bc0098c
commit 5044a88b6d
  1. 2
      src/structure_learning/mod.rs
  2. 2
      src/structure_learning/score_based_algorithm.rs
  3. 4
      tests/structure_learning.rs

@ -4,7 +4,7 @@ use crate::network;
use crate::tools; use crate::tools;
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn call<T, >(&self, net: T, dataset: &tools::Dataset) -> T fn fit<T, >(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network; T: network::Network;
} }

@ -19,7 +19,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
} }
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn call<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit<T>(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network, T: network::Network,
{ {

@ -87,7 +87,7 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),);
let net = sl.call(net, &data); let net = sl.fit(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
} }
@ -164,7 +164,7 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),);
let net = sl.call(net, &data); let net = sl.fit(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));

Loading…
Cancel
Save