parameter_learning refactor

pull/19/head
AlessandroBregoli 3 years ago
parent a191bdef1c
commit 5c816ebba7
  1. 21
      src/parameter_learning.rs
  2. 8
      tests/parameter_learning.rs

@ -5,6 +5,16 @@ use ndarray::prelude::*;
use ndarray::{concatenate, Slice}; use ndarray::{concatenate, Slice};
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait ParameterLearning{
fn fit(
&self,
net: Box<&dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>);
}
pub fn sufficient_statistics( pub fn sufficient_statistics(
net: Box<&dyn network::Network>, net: Box<&dyn network::Network>,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -66,12 +76,17 @@ pub fn sufficient_statistics(
} }
pub fn MLE( pub struct MLE {}
impl ParameterLearning for MLE {
fn fit(
&self,
net: Box<&dyn network::Network>, net: Box<&dyn network::Network>,
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network. //Use parent_set from parameter if present. Otherwise use parent_set from network.
@ -95,5 +110,5 @@ pub fn MLE(
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
return (CIM, M, T); return (CIM, M, T);
}
} }

@ -42,8 +42,8 @@ fn learn_binary_cim_MLE() {
} }
let data = trajectory_generator(Box::new(&net), 100, 100.0); let data = trajectory_generator(Box::new(&net), 100, 100.0);
let mle = MLE{};
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]); assert_eq!(CIM.shape(), [2, 2, 2]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);
@ -83,8 +83,8 @@ fn learn_ternary_cim_MLE() {
} }
let data = trajectory_generator(Box::new(&net), 100, 200.0); let data = trajectory_generator(Box::new(&net), 100, 200.0);
let mle = MLE{};
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); let (CIM, M, T) = mle.fit(Box::new(&net), &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]); assert_eq!(CIM.shape(), [3, 3, 3]);
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2);

Loading…
Cancel
Save