Changed the return type for ParameterLearning::fit from tuple to Param

pull/48/head
AlessandroBregoli 2 years ago
parent e091cc4d2e
commit f5756f71d3
  1. 41
      src/parameter_learning.rs
  2. 15
      src/params.rs
  3. 51
      src/structure_learning/hypothesis_test.rs
  4. 36
      tests/parameter_learning.rs

@ -12,7 +12,7 @@ pub trait ParameterLearning{
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>);
) -> Params;
}
pub fn sufficient_statistics<T:network::Network>(
@ -84,8 +84,7 @@ impl ParameterLearning for MLE {
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
) -> Params {
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
@ -107,7 +106,21 @@ impl ParameterLearning for MLE {
.for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag);
});
return (CIM, M, T);
let mut n: Params = net.get_node(node).clone();
match n {
Params::DiscreteStatesContinousTime(ref mut dsct) => {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
}
}
@ -123,7 +136,7 @@ impl ParameterLearning for BayesianApproach {
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
) -> Params {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network.
@ -150,7 +163,21 @@ impl ParameterLearning for BayesianApproach {
.for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag);
});
return (CIM, M, T);
let mut n: Params = net.get_node(node).clone();
match n {
Params::DiscreteStatesContinousTime(ref mut dsct) => {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
}
}
@ -166,7 +193,7 @@ impl<P: ParameterLearning> Cache<P> {
net: &T,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) {
) -> Params {
self.parameter_learning.fit(net, &self.dataset, node, parent_set)
}
}

@ -55,7 +55,8 @@ pub trait ParamsTrait {
}
/// The Params enum is the core element for building different types of nodes. The goal is to
/// define all the supported type of parameters.
/// define all the supported type of Parameters
#[derive(Clone)]
#[enum_dispatch]
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
@ -72,11 +73,12 @@ pub enum Params {
/// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
#[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams {
label: String,
domain: BTreeSet<String>,
cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>,
transitions: Option<Array3<usize>>,
residence_time: Option<Array2<f64>>,
}
@ -112,14 +114,19 @@ impl DiscreteStatesContinousTimeParams {
}
///Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim);
}
///Getter function for transitions
pub fn get_transitions(&self) -> &Option<Array3<u64>> {
pub fn get_transitions(&self) -> &Option<Array3<usize>> {
&self.transitions
}
///Setter function for transitions
pub fn set_transitions(&mut self, transitions: Array3<u64>) {
pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions);
}

@ -5,46 +5,39 @@ use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network;
use crate::parameter_learning;
use crate::params::ParamsTrait;
use crate::params::*;
use std::collections::BTreeSet;
pub trait HypothesisTest {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P>
cache: &mut parameter_learning::Cache<P>,
) -> bool
where
T: network::Network,
P: parameter_learning::ParameterLearning;
}
pub struct ChiSquare {
alpha: f64,
}
pub struct F {
}
pub struct F {}
impl ChiSquare {
pub fn new( alpha: f64) -> ChiSquare {
ChiSquare {
alpha
}
pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha }
}
pub fn compare_matrices(
&self,
i: usize,
M1: &Array3<usize>,
j: usize,
M2: &Array3<usize>
M2: &Array3<usize>,
) -> bool {
// Bregoli, A., Scutari, M. and Stella, F., 2021.
// A constraint-based algorithm for the structural learning of
@ -87,7 +80,7 @@ impl ChiSquare {
// ===== 2 1
// x'ϵVal /X \
// \ i/
let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1);
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1);
println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1));
@ -109,24 +102,38 @@ impl HypothesisTest for ChiSquare {
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P>
cache: &mut parameter_learning::Cache<P>,
) -> bool
where
T: network::Network,
P: parameter_learning::ParameterLearning {
P: parameter_learning::ParameterLearning,
{
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn
// (CIM, M, T)
let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone()));
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){
Params::DiscreteStatesContinousTime(node) => node
};
//
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone()));
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){
Params::DiscreteStatesContinousTime(node) => node
};
// Commentare qui
let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product();
for idx_M_big in 0..M_big.shape()[0] {
let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product;
if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) {
let partial_cardinality_product: usize = extended_separation_set
.iter()
.take_while(|x| **x != parent_node)
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ (idx_M_big
/ (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product;
if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) {
return false;
}
}

@ -40,10 +40,11 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
}
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(
let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]),
0.1
));
@ -98,10 +99,11 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
}
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert!(CIM.abs_diff_eq(
let p = match pl.fit(&net, &data, 1, None){
params::Params::DiscreteStatesContinousTime(p) => p
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
@ -160,10 +162,11 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
}
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(
let p = match pl.fit(&net, &data, 0, None){
params::Params::DiscreteStatesContinousTime(p) => p
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]),
0.1
));
@ -288,10 +291,11 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
}
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]);
assert!(CIM.abs_diff_eq(
let p = match pl.fit(&net, &data, 2, None){
params::Params::DiscreteStatesContinousTime(p) => p
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[
[
[-1.0, 0.5, 0.3, 0.2],

Loading…
Cancel
Save