Added tests for the learning of parameters using uniform graph and parameters generators as complementary to their handcrafted version

pull/85/head
Meliurwen 2 years ago
parent e08d12ac1f
commit 430033afdb
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 203
      reCTBN/tests/parameter_learning.rs

@ -6,6 +6,7 @@ use reCTBN::process::ctbn::*;
use reCTBN::process::NetworkProcess; use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::*; use reCTBN::parameter_learning::*;
use reCTBN::params; use reCTBN::params;
use reCTBN::params::Params::DiscreteStatesContinousTime;
use reCTBN::tools::*; use reCTBN::tools::*;
use utils::*; use utils::*;
@ -66,18 +67,78 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
)); ));
} }
fn generate_nodes(
net: &mut CtbnNetwork,
nodes_cardinality: usize,
nodes_domain_cardinality: usize
) {
for node_label in 0..nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
}
fn learn_binary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 2);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test] #[test]
fn learn_binary_cim_MLE() { fn learn_binary_cim_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_binary_cim(mle); learn_binary_cim(mle);
} }
#[test]
fn learn_binary_cim_MLE_gen() {
let mle = MLE {};
learn_binary_cim_gen(mle);
}
#[test] #[test]
fn learn_binary_cim_BA() { fn learn_binary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim(ba); learn_binary_cim(ba);
} }
#[test]
fn learn_binary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim_gen(ba);
}
fn learn_ternary_cim<T: ParameterLearning>(pl: T) { fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -155,18 +216,63 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
)); ));
} }
fn learn_ternary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
4.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test] #[test]
fn learn_ternary_cim_MLE() { fn learn_ternary_cim_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_ternary_cim(mle); learn_ternary_cim(mle);
} }
#[test]
fn learn_ternary_cim_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_gen(mle);
}
#[test] #[test]
fn learn_ternary_cim_BA() { fn learn_ternary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim(ba); learn_ternary_cim(ba);
} }
#[test]
fn learn_ternary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_gen(ba);
}
fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) { fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -234,18 +340,63 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
)); ));
} }
fn learn_ternary_cim_no_parents_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(0) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 0, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test] #[test]
fn learn_ternary_cim_no_parents_MLE() { fn learn_ternary_cim_no_parents_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_ternary_cim_no_parents(mle); learn_ternary_cim_no_parents(mle);
} }
#[test]
fn learn_ternary_cim_no_parents_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_no_parents_gen(mle);
}
#[test] #[test]
fn learn_ternary_cim_no_parents_BA() { fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents(ba); learn_ternary_cim_no_parents(ba);
} }
#[test]
fn learn_ternary_cim_no_parents_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents_gen(ba);
}
fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) { fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -432,14 +583,66 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
)); ));
} }
fn learn_mixed_discrete_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
net.add_edge(0, 2);
net.add_edge(1, 2);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..8.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(2) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 2, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.2
)
);
}
#[test] #[test]
fn learn_mixed_discrete_cim_MLE() { fn learn_mixed_discrete_cim_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_mixed_discrete_cim(mle); learn_mixed_discrete_cim(mle);
} }
#[test]
fn learn_mixed_discrete_cim_MLE_gen() {
let mle = MLE {};
learn_mixed_discrete_cim_gen(mle);
}
#[test] #[test]
fn learn_mixed_discrete_cim_BA() { fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim(ba); learn_mixed_discrete_cim(ba);
} }
#[test]
fn learn_mixed_discrete_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim_gen(ba);
}

Loading…
Cancel
Save