pull/42/head
AlessandroBregoli 3 years ago
parent 5044a88b6d
commit 8ca93c931b
  1. 2
      src/structure_learning.rs
  2. 8
      src/structure_learning/score_based_algorithm.rs
  3. 4
      tests/structure_learning.rs

@ -4,7 +4,7 @@ use crate::network;
use crate::tools;
pub trait StructureLearningAlgorithm {
fn fit<T, >(&self, net: T, dataset: &tools::Dataset) -> T
fn fit_transform<T, >(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network;
}

@ -1,11 +1,7 @@
use crate::params;
use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools;
use crate::{network, parameter_learning};
use ndarray::prelude::*;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use crate::network;
use std::collections::BTreeSet;
pub struct HillClimbing<S: ScoreFunction> {
@ -19,7 +15,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
}
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit<T>(&self, net: T, dataset: &tools::Dataset) -> T
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network,
{

@ -87,7 +87,7 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),);
let net = sl.fit(net, &data);
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
}
@ -164,7 +164,7 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),);
let net = sl.fit(net, &data);
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));

Loading…
Cancel
Save