Conformed all rank-3 tensors to the same notation and now `rustfmt` ignores `tests/`

pull/55/head
Meliurwen 2 years ago
parent 780515707c
commit a1c1448da7
  1. 5
      rustfmt.toml
  2. 119
      tests/parameter_learning.rs
  3. 148
      tests/structure_learning.rs
  4. 17
      tests/tools.rs

@ -33,4 +33,7 @@ newline_style = "Unix"
#error_on_unformatted = true #error_on_unformatted = true
# Files to ignore like third party code which is formatted upstream. # Files to ignore like third party code which is formatted upstream.
#ignore = [] # Ignoring tests is a temporary measure due some issues regarding rank-3 tensors
ignore = [
"tests/"
]

@ -32,8 +32,14 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [
[[-6.0, 6.0], [2.0, -2.0]], [-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
])) ]))
); );
} }
@ -45,7 +51,16 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), &arr3(&[
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]),
0.1 0.1
)); ));
} }
@ -76,11 +91,13 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0], [-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0] [0.4, 0.6, -1.0]
]])) ],
]))
); );
} }
} }
@ -90,9 +107,21 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -105,9 +134,21 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[ &arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]), ]),
0.1 0.1
)); ));
@ -139,11 +180,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0], [-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0] [0.4, 0.6, -1.0]
]])) ]
]))
); );
} }
} }
@ -153,9 +196,21 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -167,7 +222,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), &arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]),
0.1 0.1
)); ));
} }
@ -204,11 +265,13 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0], [-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0] [0.4, 0.6, -1.0]
]])) ],
]))
); );
} }
} }
@ -218,9 +281,21 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }

@ -70,11 +70,13 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0], [-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0] [0.4, 0.6, -1.0]
]])) ],
]))
); );
} }
} }
@ -84,9 +86,21 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -123,11 +137,13 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0], [-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0] [0.4, 0.6, -1.0]
]])) ],
]))
); );
} }
} }
@ -137,9 +153,21 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -186,11 +214,13 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0], [-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0] [0.4, 0.6, -1.0]
]])) ],
]))
); );
} }
} }
@ -200,9 +230,21 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -324,12 +366,30 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
pub fn chi_square_compare_matrices() { pub fn chi_square_compare_matrices() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[0, 2, 3], [4, 0, 6], [7, 8, 0]], [
[[0, 12, 90], [3, 0, 40], [6, 40, 0]], [ 0, 2, 3],
[[0, 2, 3], [4, 0, 6], [44, 66, 0]], [ 4, 0, 6],
[ 7, 8, 0]
],
[
[0, 12, 90],
[ 3, 0, 40],
[ 6, 40, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let M2 = arr3(&[
[
[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
} }
@ -338,12 +398,28 @@ pub fn chi_square_compare_matrices() {
pub fn chi_square_compare_matrices_2() { pub fn chi_square_compare_matrices_2() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[0, 2, 3], [4, 0, 6], [7, 8, 0]], [
[[0, 20, 30], [40, 0, 60], [70, 80, 0]], [ 0, 2, 3],
[[0, 2, 3], [4, 0, 6], [44, 66, 0]], [ 4, 0, 6],
[ 7, 8, 0]
],
[
[0, 20, 30],
[ 40, 0, 60],
[ 70, 80, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
@ -352,12 +428,30 @@ pub fn chi_square_compare_matrices_2() {
pub fn chi_square_compare_matrices_3() { pub fn chi_square_compare_matrices_3() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[0, 2, 3], [4, 0, 6], [7, 8, 0]], [
[[0, 21, 31], [41, 0, 59], [71, 79, 0]], [ 0, 2, 3],
[[0, 2, 3], [4, 0, 6], [44, 66, 0]], [ 4, 0, 6],
[ 7, 8, 0]
],
[
[0, 21, 31],
[ 41, 0, 59],
[ 71, 79, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let M2 = arr3(&[
[
[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }

@ -29,15 +29,26 @@ fn run_sampling() {
match &mut net.get_node_mut(n1) { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); param.set_cim(arr3(&[
[
[-3.0, 3.0],
[2.0, -2.0]
],
]));
} }
} }
match &mut net.get_node_mut(n2) { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [
[[-6.0, 6.0], [2.0, -2.0]], [-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
])); ]));
} }
} }

Loading…
Cancel
Save