diff --git a/rustfmt.toml b/rustfmt.toml index 3e7fb50..b6f1257 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -33,4 +33,7 @@ newline_style = "Unix" #error_on_unformatted = true # Files to ignore like third party code which is formatted upstream. -#ignore = [] +# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors +ignore = [ + "tests/" +] diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 0409402..5de02d7 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -32,8 +32,14 @@ fn learn_binary_cim(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], ])) ); } @@ -45,7 +51,16 @@ fn learn_binary_cim(pl: T) { }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( - &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), + &arr3(&[ + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], + ]), 0.1 )); } @@ -76,11 +91,13 @@ fn learn_ternary_cim(pl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -90,9 +107,21 @@ fn learn_ternary_cim(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -105,9 +134,21 @@ fn learn_ternary_cim(pl: T) { assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ]), 0.1 )); @@ -139,11 +180,13 @@ fn learn_ternary_cim_no_parents(pl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ] + ])) ); } } @@ -153,9 +196,21 @@ fn learn_ternary_cim_no_parents(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -167,7 +222,13 @@ fn learn_ternary_cim_no_parents(pl: T) { }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( - &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), + &arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ]), 0.1 )); } @@ -204,11 +265,13 @@ fn learn_mixed_discrete_cim(pl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -218,9 +281,21 @@ fn learn_mixed_discrete_cim(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index ee5109e..81a4ed3 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -70,11 +70,13 @@ fn check_compatibility_between_dataset_and_network { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -84,9 +86,21 @@ fn check_compatibility_between_dataset_and_network(sl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -137,9 +153,21 @@ fn learn_ternary_net_2_nodes(sl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -186,11 +214,13 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -200,9 +230,21 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -324,12 +366,30 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() pub fn chi_square_compare_matrices() { let i: usize = 1; let M1 = arr3(&[ - [[0, 2, 3], [4, 0, 6], [7, 8, 0]], - [[0, 12, 90], [3, 0, 40], [6, 40, 0]], - [[0, 2, 3], [4, 0, 6], [44, 66, 0]], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 12, 90], + [ 3, 0, 40], + [ 6, 40, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], ]); let j: usize = 0; - let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); + let M2 = arr3(&[ + [ + [ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0] + ], + ]); let chi_sq = ChiSquare::new(0.1); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -338,12 +398,28 @@ pub fn chi_square_compare_matrices() { pub fn chi_square_compare_matrices_2() { let i: usize = 1; let M1 = arr3(&[ - [[0, 2, 3], [4, 0, 6], [7, 8, 0]], - [[0, 20, 30], [40, 0, 60], [70, 80, 0]], - [[0, 2, 3], [4, 0, 6], [44, 66, 0]], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 20, 30], + [ 40, 0, 60], + [ 70, 80, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], ]); let j: usize = 0; - let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); let chi_sq = ChiSquare::new(0.1); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -352,12 +428,30 @@ pub fn chi_square_compare_matrices_2() { pub fn chi_square_compare_matrices_3() { let i: usize = 1; let M1 = arr3(&[ - [[0, 2, 3], [4, 0, 6], [7, 8, 0]], - [[0, 21, 31], [41, 0, 59], [71, 79, 0]], - [[0, 2, 3], [4, 0, 6], [44, 66, 0]], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 21, 31], + [ 41, 0, 59], + [ 71, 79, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], ]); let j: usize = 0; - let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); + let M2 = arr3(&[ + [ + [ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0] + ], + ]); let chi_sq = ChiSquare::new(0.1); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } diff --git a/tests/tools.rs b/tests/tools.rs index f7435f7..589b04e 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -29,15 +29,26 @@ fn run_sampling() { match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); + param.set_cim(arr3(&[ + [ + [-3.0, 3.0], + [2.0, -2.0] + ], + ])); } } match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { param.set_cim(arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], ])); } }