pull/42/head
AlessandroBregoli 3 years ago
parent 808bc0098c
commit 5044a88b6d
  1. 2
      src/structure_learning/mod.rs
  2. 2
      src/structure_learning/score_based_algorithm.rs
  3. 4
      tests/structure_learning.rs

@ -4,7 +4,7 @@ use crate::network;
use crate::tools;
pub trait StructureLearningAlgorithm {
fn call<T, >(&self, net: T, dataset: &tools::Dataset) -> T
fn fit<T, >(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network;
}

@ -19,7 +19,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
}
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn call<T>(&self, net: T, dataset: &tools::Dataset) -> T
fn fit<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network,
{

@ -87,7 +87,7 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),);
let net = sl.call(net, &data);
let net = sl.fit(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
}
@ -164,7 +164,7 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),);
let net = sl.call(net, &data);
let net = sl.fit(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));

Loading…
Cancel
Save