parameter_learning refactor

pull/19/head
AlessandroBregoli 3 years ago
parent a191bdef1c
commit 5c816ebba7
  1. 19
      src/parameter_learning.rs
  2. 8
      tests/parameter_learning.rs

@ -5,6 +5,16 @@ use ndarray::prelude::*;
use ndarray::{concatenate, Slice}; use ndarray::{concatenate, Slice};
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait ParameterLearning{
fn fit(
&self,
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>);
}
pub fn sufficient_statistics( pub fn sufficient_statistics(
net: Box<&dyn network::Network>, net: Box<&dyn network::Network>,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -66,7 +76,12 @@ pub fn sufficient_statistics(
} }
pub fn MLE( pub struct MLE {}
impl ParameterLearning for MLE {
fn fit(
&self,
net: Box<&dyn network::Network>, net: Box<&dyn network::Network>,
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
@ -96,4 +111,4 @@ pub fn MLE(
}); });
return (CIM, M, T); return (CIM, M, T);
} }
}

@ -42,8 +42,8 @@ fn learn_binary_cim_MLE() {
} }
let data = trajectory_generator(Box::new(&net), 100, 100.0); let data = trajectory_generator(Box::new(&net), 100, 100.0);
let mle = MLE{};
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]); assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
@ -83,8 +83,8 @@ fn learn_ternary_cim_MLE() {
} }
let data = trajectory_generator(Box::new(&net), 100, 200.0); let data = trajectory_generator(Box::new(&net), 100, 200.0);
let mle = MLE{};
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]); assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);

Loading…
Cancel
Save