Merge pull request #42 from AlessandroBregoli/9-score-based-algorithm
9 score based algorithmpull/43/head
commit
c2d44e7332
@ -0,0 +1,10 @@ |
|||||||
|
pub mod score_function; |
||||||
|
pub mod score_based_algorithm; |
||||||
|
use crate::network; |
||||||
|
use crate::tools; |
||||||
|
|
||||||
|
pub trait StructureLearningAlgorithm { |
||||||
|
fn fit_transform<T, >(&self, net: T, dataset: &tools::Dataset) -> T |
||||||
|
where |
||||||
|
T: network::Network; |
||||||
|
} |
@ -0,0 +1,83 @@ |
|||||||
|
use crate::network; |
||||||
|
use crate::structure_learning::score_function::ScoreFunction; |
||||||
|
use crate::structure_learning::StructureLearningAlgorithm; |
||||||
|
use crate::tools; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
pub struct HillClimbing<S: ScoreFunction> { |
||||||
|
score_function: S, |
||||||
|
max_parent_set: Option<usize>, |
||||||
|
} |
||||||
|
|
||||||
|
impl<S: ScoreFunction> HillClimbing<S> { |
||||||
|
pub fn init(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> { |
||||||
|
HillClimbing { |
||||||
|
score_function, |
||||||
|
max_parent_set, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { |
||||||
|
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T |
||||||
|
where |
||||||
|
T: network::Network, |
||||||
|
{ |
||||||
|
//Check the coherence between dataset and network
|
||||||
|
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { |
||||||
|
panic!("Dataset and Network must have the same number of variables.") |
||||||
|
} |
||||||
|
|
||||||
|
//Make the network mutable.
|
||||||
|
let mut net = net; |
||||||
|
//Check if the max_parent_set constraint is present.
|
||||||
|
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); |
||||||
|
//Reset the adj matrix
|
||||||
|
net.initialize_adj_matrix(); |
||||||
|
//Iterate over each node to learn their parent set.
|
||||||
|
for node in net.get_node_indices() { |
||||||
|
//Initialize an empty parent set.
|
||||||
|
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); |
||||||
|
//Compute the score for the empty parent set
|
||||||
|
let mut current_score = self.score_function.call(&net, node, &parent_set, dataset); |
||||||
|
//Set the old score to -\infty.
|
||||||
|
let mut old_score = f64::NEG_INFINITY; |
||||||
|
//Iterate until convergence
|
||||||
|
while current_score > old_score { |
||||||
|
//Save the current_score.
|
||||||
|
old_score = current_score; |
||||||
|
//Iterate over each node.
|
||||||
|
for parent in net.get_node_indices() { |
||||||
|
//Continue if the parent and the node are the same.
|
||||||
|
if parent == node { |
||||||
|
continue; |
||||||
|
} |
||||||
|
//Try to remove parent from the parent_set.
|
||||||
|
let is_removed = parent_set.remove(&parent); |
||||||
|
//If parent was not in the parent_set add it.
|
||||||
|
if !is_removed && parent_set.len() < max_parent_set { |
||||||
|
parent_set.insert(parent); |
||||||
|
} |
||||||
|
//Compute the score with the modified parent_set.
|
||||||
|
let tmp_score = self.score_function.call(&net, node, &parent_set, dataset); |
||||||
|
//If tmp_score is worst than current_score revert the change to the parent set
|
||||||
|
if tmp_score < current_score { |
||||||
|
if is_removed { |
||||||
|
parent_set.insert(parent); |
||||||
|
} else { |
||||||
|
parent_set.remove(&parent); |
||||||
|
} |
||||||
|
} |
||||||
|
//Otherwise save the computed score as current_score
|
||||||
|
else { |
||||||
|
current_score = tmp_score; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
//Apply the learned parent_set to the network struct.
|
||||||
|
parent_set.iter().for_each(|p| net.add_edge(*p, node)); |
||||||
|
} |
||||||
|
|
||||||
|
return net; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,136 @@ |
|||||||
|
use crate::network; |
||||||
|
use crate::parameter_learning; |
||||||
|
use crate::params; |
||||||
|
use crate::tools; |
||||||
|
use ndarray::prelude::*; |
||||||
|
use statrs::function::gamma; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
pub trait ScoreFunction { |
||||||
|
fn call<T>( |
||||||
|
&self, |
||||||
|
net: &T, |
||||||
|
node: usize, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
dataset: &tools::Dataset, |
||||||
|
) -> f64 |
||||||
|
where |
||||||
|
T: network::Network; |
||||||
|
} |
||||||
|
|
||||||
|
pub struct LogLikelihood { |
||||||
|
alpha: usize, |
||||||
|
tau: f64, |
||||||
|
} |
||||||
|
|
||||||
|
impl LogLikelihood { |
||||||
|
pub fn init(alpha: usize, tau: f64) -> LogLikelihood { |
||||||
|
|
||||||
|
//Tau must be >=0.0
|
||||||
|
if tau < 0.0 { |
||||||
|
panic!("tau must be >=0.0"); |
||||||
|
} |
||||||
|
LogLikelihood { alpha, tau } |
||||||
|
} |
||||||
|
|
||||||
|
fn compute_score<T>( |
||||||
|
&self, |
||||||
|
net: &T, |
||||||
|
node: usize, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
dataset: &tools::Dataset, |
||||||
|
) -> (f64, Array3<usize>) |
||||||
|
where |
||||||
|
T: network::Network, |
||||||
|
{
|
||||||
|
//Identify the type of node used
|
||||||
|
match &net.get_node(node).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(_params) => { |
||||||
|
//Compute the sufficient statistics M (number of transistions) and T (residence
|
||||||
|
//time)
|
||||||
|
let (M, T) = |
||||||
|
parameter_learning::sufficient_statistics(net, dataset, node, parent_set); |
||||||
|
|
||||||
|
//Scale alpha accordingly to the size of the parent set
|
||||||
|
let alpha = self.alpha as f64 / M.shape()[0] as f64; |
||||||
|
//Scale tau accordingly to the size of the parent set
|
||||||
|
let tau = self.tau / M.shape()[0] as f64; |
||||||
|
|
||||||
|
//Compute the log likelihood for q
|
||||||
|
let log_ll_q:f64 = M |
||||||
|
.sum_axis(Axis(2)) |
||||||
|
.iter() |
||||||
|
.zip(T.iter()) |
||||||
|
.map(|(m, t)| { |
||||||
|
gamma::ln_gamma(alpha + *m as f64 + 1.0) |
||||||
|
+ (alpha + 1.0) * f64::ln(tau) |
||||||
|
- gamma::ln_gamma(alpha + 1.0) |
||||||
|
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t) |
||||||
|
}) |
||||||
|
.sum(); |
||||||
|
|
||||||
|
//Compute the log likelihood for theta
|
||||||
|
let log_ll_theta: f64 = M.outer_iter() |
||||||
|
.map(|x| x.outer_iter() |
||||||
|
.map(|y| gamma::ln_gamma(alpha)
|
||||||
|
- gamma::ln_gamma(alpha + y.sum() as f64) |
||||||
|
+ y.iter().map(|z|
|
||||||
|
gamma::ln_gamma(alpha + *z as f64)
|
||||||
|
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum(); |
||||||
|
(log_ll_theta + log_ll_q, M) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
impl ScoreFunction for LogLikelihood { |
||||||
|
fn call<T>( |
||||||
|
&self, |
||||||
|
net: &T, |
||||||
|
node: usize, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
dataset: &tools::Dataset, |
||||||
|
) -> f64 |
||||||
|
where |
||||||
|
T: network::Network, |
||||||
|
{ |
||||||
|
self.compute_score(net, node, parent_set, dataset).0 |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub struct BIC { |
||||||
|
ll: LogLikelihood |
||||||
|
} |
||||||
|
|
||||||
|
impl BIC { |
||||||
|
pub fn init(alpha: usize, tau: f64) -> BIC { |
||||||
|
BIC { |
||||||
|
ll: LogLikelihood::init(alpha, tau) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl ScoreFunction for BIC { |
||||||
|
fn call<T>( |
||||||
|
&self, |
||||||
|
net: &T, |
||||||
|
node: usize, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
dataset: &tools::Dataset, |
||||||
|
) -> f64 |
||||||
|
where |
||||||
|
T: network::Network { |
||||||
|
//Compute the log-likelihood
|
||||||
|
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); |
||||||
|
//Compute the number of parameters
|
||||||
|
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); |
||||||
|
//TODO: Optimize this
|
||||||
|
//Compute the sample size
|
||||||
|
let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); |
||||||
|
//Compute BIC
|
||||||
|
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,260 @@ |
|||||||
|
|
||||||
|
mod utils; |
||||||
|
use utils::*; |
||||||
|
|
||||||
|
use rustyCTBN::ctbn::*; |
||||||
|
use rustyCTBN::network::Network; |
||||||
|
use rustyCTBN::tools::*; |
||||||
|
use rustyCTBN::structure_learning::score_function::*; |
||||||
|
use rustyCTBN::structure_learning::score_based_algorithm::*; |
||||||
|
use rustyCTBN::structure_learning::StructureLearningAlgorithm; |
||||||
|
use ndarray::{arr1, arr2, arr3}; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
use rustyCTBN::params; |
||||||
|
|
||||||
|
|
||||||
|
#[macro_use] |
||||||
|
extern crate approx; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_score_test() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let trj = Trajectory::init( |
||||||
|
arr1(&[0.0,0.1,0.3]), |
||||||
|
arr2(&[[0],[1],[1]])); |
||||||
|
|
||||||
|
let dataset = Dataset::init(vec![trj]); |
||||||
|
|
||||||
|
let ll = LogLikelihood::init(1, 1.0); |
||||||
|
|
||||||
|
assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_bic() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let trj = Trajectory::init( |
||||||
|
arr1(&[0.0,0.1,0.3]), |
||||||
|
arr2(&[[0],[1],[1]])); |
||||||
|
|
||||||
|
let dataset = Dataset::init(vec![trj]); |
||||||
|
let bic = BIC::init(1, 1.0); |
||||||
|
|
||||||
|
assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm> (sl: T) { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
|
||||||
|
[1.5, -2.0, 0.5], |
||||||
|
[0.4, 0.6, -1.0]]]))); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||||
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||||
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||||
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||||
|
]))); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); |
||||||
|
|
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let _n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||||
|
.unwrap(); |
||||||
|
let net = sl.fit_transform(net, &data); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
#[should_panic] |
||||||
|
pub fn check_compatibility_between_dataset_and_network_hill_climbing() { |
||||||
|
let ll = LogLikelihood::init(1, 1.0); |
||||||
|
let hl = HillClimbing::init(ll, None); |
||||||
|
check_compatibility_between_dataset_and_network(hl); |
||||||
|
} |
||||||
|
|
||||||
|
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
|
||||||
|
[1.5, -2.0, 0.5], |
||||||
|
[0.4, 0.6, -1.0]]]))); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||||
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||||
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||||
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||||
|
]))); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); |
||||||
|
|
||||||
|
let net = sl.fit_transform(net, &data); |
||||||
|
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); |
||||||
|
assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { |
||||||
|
let ll = LogLikelihood::init(1, 1.0); |
||||||
|
let hl = HillClimbing::init(ll, None); |
||||||
|
learn_ternary_net_2_nodes(hl); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { |
||||||
|
let bic = BIC::init(1, 1.0); |
||||||
|
let hl = HillClimbing::init(bic, None); |
||||||
|
learn_ternary_net_2_nodes(hl); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let n3 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n3"),4)) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
net.add_edge(n1, n3); |
||||||
|
net.add_edge(n2, n3); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
|
||||||
|
[1.5, -2.0, 0.5], |
||||||
|
[0.4, 0.6, -1.0]]]))); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||||
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||||
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||||
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||||
|
]))); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
match &mut net.get_node_mut(n3).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||||
|
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], |
||||||
|
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], |
||||||
|
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], |
||||||
|
|
||||||
|
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], |
||||||
|
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], |
||||||
|
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], |
||||||
|
|
||||||
|
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], |
||||||
|
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], |
||||||
|
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], |
||||||
|
]))); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); |
||||||
|
return (net, data); |
||||||
|
} |
||||||
|
|
||||||
|
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) { |
||||||
|
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||||
|
let net = sl.fit_transform(net, &data); |
||||||
|
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||||
|
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||||
|
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { |
||||||
|
let ll = LogLikelihood::init(1, 1.0); |
||||||
|
let hl = HillClimbing::init(ll, None); |
||||||
|
learn_mixed_discrete_net_3_nodes(hl); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { |
||||||
|
let bic = BIC::init(1, 1.0); |
||||||
|
let hl = HillClimbing::init(bic, None); |
||||||
|
learn_mixed_discrete_net_3_nodes(hl); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm> (sl: T) { |
||||||
|
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||||
|
let net = sl.fit_transform(net, &data); |
||||||
|
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||||
|
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||||
|
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { |
||||||
|
let ll = LogLikelihood::init(1, 1.0); |
||||||
|
let hl = HillClimbing::init(ll, Some(1)); |
||||||
|
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { |
||||||
|
let bic = BIC::init(1, 1.0); |
||||||
|
let hl = HillClimbing::init(bic, Some(1)); |
||||||
|
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); |
||||||
|
} |
Loading…
Reference in new issue