Changed parameter_learning module to use generics instead of Box for the network object

pull/19/head
Alessandro Bregoli 3 years ago
parent ef04e04d1e
commit cc8071ca07
  1. 16
      src/parameter_learning.rs
  2. 6
      tests/parameter_learning.rs

@ -6,17 +6,17 @@ use ndarray::{concatenate, Slice};
use std::collections::BTreeSet;
pub trait ParameterLearning{
fn fit(
fn fit<T:network::Network>(
&self,
net: Box<&dyn network::Network>,
net: &T,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>);
}
pub fn sufficient_statistics(
net: Box<&dyn network::Network>,
pub fn sufficient_statistics<T:network::Network>(
net: &T,
dataset: &tools::Dataset,
node: usize,
parent_set: &BTreeSet<usize>
@ -80,9 +80,9 @@ pub struct MLE {}
impl ParameterLearning for MLE {
fn fit(
fn fit<T: network::Network>(
&self,
net: Box<&dyn network::Network>,
net: &T,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
@ -119,9 +119,9 @@ pub struct BayesianApproach {
}
impl ParameterLearning for BayesianApproach {
fn fit(
fn fit<T: network::Network>(
&self,
net: Box<&dyn network::Network>,
net: &T,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,

@ -43,7 +43,7 @@ fn learn_binary_cim_MLE() {
let data = trajectory_generator(&net, 100, 100.0);
let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
let (CIM, M, T) = mle.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(&arr3(&[
@ -84,7 +84,7 @@ fn learn_ternary_cim_MLE() {
let data = trajectory_generator(&net, 100, 200.0);
let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
let (CIM, M, T) = mle.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[
@ -125,7 +125,7 @@ fn learn_ternary_cim_MLE_no_parents() {
let data = trajectory_generator(&net, 100, 200.0);
let mle = MLE{};
let (CIM, M, T) = mle.fit(Box::new(&net), &data, 0, None);
let (CIM, M, T) = mle.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],

Loading…
Cancel
Save