diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index b5590ed..7537483 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -20,6 +20,10 @@ impl StructureLearningAlgorithm for HillClimbing { where T: network::Network, { + if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { + panic!("Dataset and Network must have the same number of variables.") + } + let mut net = net; let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); net.initialize_adj_matrix(); diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index d24170a..c3482cc 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -54,6 +54,54 @@ fn simple_bic() { } + + +fn check_compatibility_between_dataset_and_network (sl: T) { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]))); + } + } + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); + + let mut net = CtbnNetwork::init(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let net = sl.fit_transform(net, &data); +} + + +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing() { + let ll = LogLikelihood::init(1, 1.0); + let hl = HillClimbing::init(ll, None); + check_compatibility_between_dataset_and_network(hl); +} + fn learn_ternary_net_2_nodes (sl: T) { let mut net = CtbnNetwork::init(); let n1 = net