|
|
|
@ -54,6 +54,54 @@ fn simple_bic() { |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm> (sl: T) { |
|
|
|
|
let mut net = CtbnNetwork::init(); |
|
|
|
|
let n1 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
|
|
|
|
.unwrap(); |
|
|
|
|
let n2 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
|
|
|
|
.unwrap(); |
|
|
|
|
net.add_edge(n1, n2); |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n1).params { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
|
|
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
|
[0.4, 0.6, -1.0]]]))); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n2).params { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
|
|
|
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
|
|
|
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
|
|
|
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
|
|
|
|
]))); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); |
|
|
|
|
|
|
|
|
|
let mut net = CtbnNetwork::init(); |
|
|
|
|
let _n1 = net |
|
|
|
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
|
|
|
|
.unwrap(); |
|
|
|
|
let net = sl.fit_transform(net, &data); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
#[should_panic] |
|
|
|
|
pub fn check_compatibility_between_dataset_and_network_hill_climbing() { |
|
|
|
|
let ll = LogLikelihood::init(1, 1.0); |
|
|
|
|
let hl = HillClimbing::init(ll, None); |
|
|
|
|
check_compatibility_between_dataset_and_network(hl); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) { |
|
|
|
|
let mut net = CtbnNetwork::init(); |
|
|
|
|
let n1 = net |
|
|
|
|