Merge pull request #42 from AlessandroBregoli/9-score-based-algorithm
9 score based algorithmpull/43/head
commit
c2d44e7332
@ -0,0 +1,10 @@ |
||||
pub mod score_function; |
||||
pub mod score_based_algorithm; |
||||
use crate::network; |
||||
use crate::tools; |
||||
|
||||
pub trait StructureLearningAlgorithm { |
||||
fn fit_transform<T, >(&self, net: T, dataset: &tools::Dataset) -> T |
||||
where |
||||
T: network::Network; |
||||
} |
@ -0,0 +1,83 @@ |
||||
use crate::network; |
||||
use crate::structure_learning::score_function::ScoreFunction; |
||||
use crate::structure_learning::StructureLearningAlgorithm; |
||||
use crate::tools; |
||||
use std::collections::BTreeSet; |
||||
|
||||
pub struct HillClimbing<S: ScoreFunction> { |
||||
score_function: S, |
||||
max_parent_set: Option<usize>, |
||||
} |
||||
|
||||
impl<S: ScoreFunction> HillClimbing<S> { |
||||
pub fn init(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> { |
||||
HillClimbing { |
||||
score_function, |
||||
max_parent_set, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { |
||||
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T |
||||
where |
||||
T: network::Network, |
||||
{ |
||||
//Check the coherence between dataset and network
|
||||
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { |
||||
panic!("Dataset and Network must have the same number of variables.") |
||||
} |
||||
|
||||
//Make the network mutable.
|
||||
let mut net = net; |
||||
//Check if the max_parent_set constraint is present.
|
||||
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); |
||||
//Reset the adj matrix
|
||||
net.initialize_adj_matrix(); |
||||
//Iterate over each node to learn their parent set.
|
||||
for node in net.get_node_indices() { |
||||
//Initialize an empty parent set.
|
||||
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); |
||||
//Compute the score for the empty parent set
|
||||
let mut current_score = self.score_function.call(&net, node, &parent_set, dataset); |
||||
//Set the old score to -\infty.
|
||||
let mut old_score = f64::NEG_INFINITY; |
||||
//Iterate until convergence
|
||||
while current_score > old_score { |
||||
//Save the current_score.
|
||||
old_score = current_score; |
||||
//Iterate over each node.
|
||||
for parent in net.get_node_indices() { |
||||
//Continue if the parent and the node are the same.
|
||||
if parent == node { |
||||
continue; |
||||
} |
||||
//Try to remove parent from the parent_set.
|
||||
let is_removed = parent_set.remove(&parent); |
||||
//If parent was not in the parent_set add it.
|
||||
if !is_removed && parent_set.len() < max_parent_set { |
||||
parent_set.insert(parent); |
||||
} |
||||
//Compute the score with the modified parent_set.
|
||||
let tmp_score = self.score_function.call(&net, node, &parent_set, dataset); |
||||
//If tmp_score is worst than current_score revert the change to the parent set
|
||||
if tmp_score < current_score { |
||||
if is_removed { |
||||
parent_set.insert(parent); |
||||
} else { |
||||
parent_set.remove(&parent); |
||||
} |
||||
} |
||||
//Otherwise save the computed score as current_score
|
||||
else { |
||||
current_score = tmp_score; |
||||
} |
||||
} |
||||
} |
||||
//Apply the learned parent_set to the network struct.
|
||||
parent_set.iter().for_each(|p| net.add_edge(*p, node)); |
||||
} |
||||
|
||||
return net; |
||||
} |
||||
} |
@ -0,0 +1,136 @@ |
||||
use crate::network; |
||||
use crate::parameter_learning; |
||||
use crate::params; |
||||
use crate::tools; |
||||
use ndarray::prelude::*; |
||||
use statrs::function::gamma; |
||||
use std::collections::BTreeSet; |
||||
|
||||
pub trait ScoreFunction { |
||||
fn call<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: network::Network; |
||||
} |
||||
|
||||
pub struct LogLikelihood { |
||||
alpha: usize, |
||||
tau: f64, |
||||
} |
||||
|
||||
impl LogLikelihood { |
||||
pub fn init(alpha: usize, tau: f64) -> LogLikelihood { |
||||
|
||||
//Tau must be >=0.0
|
||||
if tau < 0.0 { |
||||
panic!("tau must be >=0.0"); |
||||
} |
||||
LogLikelihood { alpha, tau } |
||||
} |
||||
|
||||
fn compute_score<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> (f64, Array3<usize>) |
||||
where |
||||
T: network::Network, |
||||
{
|
||||
//Identify the type of node used
|
||||
match &net.get_node(node).params { |
||||
params::Params::DiscreteStatesContinousTime(_params) => { |
||||
//Compute the sufficient statistics M (number of transistions) and T (residence
|
||||
//time)
|
||||
let (M, T) = |
||||
parameter_learning::sufficient_statistics(net, dataset, node, parent_set); |
||||
|
||||
//Scale alpha accordingly to the size of the parent set
|
||||
let alpha = self.alpha as f64 / M.shape()[0] as f64; |
||||
//Scale tau accordingly to the size of the parent set
|
||||
let tau = self.tau / M.shape()[0] as f64; |
||||
|
||||
//Compute the log likelihood for q
|
||||
let log_ll_q:f64 = M |
||||
.sum_axis(Axis(2)) |
||||
.iter() |
||||
.zip(T.iter()) |
||||
.map(|(m, t)| { |
||||
gamma::ln_gamma(alpha + *m as f64 + 1.0) |
||||
+ (alpha + 1.0) * f64::ln(tau) |
||||
- gamma::ln_gamma(alpha + 1.0) |
||||
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t) |
||||
}) |
||||
.sum(); |
||||
|
||||
//Compute the log likelihood for theta
|
||||
let log_ll_theta: f64 = M.outer_iter() |
||||
.map(|x| x.outer_iter() |
||||
.map(|y| gamma::ln_gamma(alpha)
|
||||
- gamma::ln_gamma(alpha + y.sum() as f64) |
||||
+ y.iter().map(|z|
|
||||
gamma::ln_gamma(alpha + *z as f64)
|
||||
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum(); |
||||
(log_ll_theta + log_ll_q, M) |
||||
} |
||||
} |
||||
} |
||||
|
||||
|
||||
|
||||
} |
||||
|
||||
impl ScoreFunction for LogLikelihood { |
||||
fn call<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: network::Network, |
||||
{ |
||||
self.compute_score(net, node, parent_set, dataset).0 |
||||
} |
||||
} |
||||
|
||||
pub struct BIC { |
||||
ll: LogLikelihood |
||||
} |
||||
|
||||
impl BIC { |
||||
pub fn init(alpha: usize, tau: f64) -> BIC { |
||||
BIC { |
||||
ll: LogLikelihood::init(alpha, tau) |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl ScoreFunction for BIC { |
||||
fn call<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: network::Network { |
||||
//Compute the log-likelihood
|
||||
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); |
||||
//Compute the number of parameters
|
||||
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); |
||||
//TODO: Optimize this
|
||||
//Compute the sample size
|
||||
let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); |
||||
//Compute BIC
|
||||
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 |
||||
} |
||||
} |
@ -0,0 +1,260 @@ |
||||
|
||||
mod utils; |
||||
use utils::*; |
||||
|
||||
use rustyCTBN::ctbn::*; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::tools::*; |
||||
use rustyCTBN::structure_learning::score_function::*; |
||||
use rustyCTBN::structure_learning::score_based_algorithm::*; |
||||
use rustyCTBN::structure_learning::StructureLearningAlgorithm; |
||||
use ndarray::{arr1, arr2, arr3}; |
||||
use std::collections::BTreeSet; |
||||
use rustyCTBN::params; |
||||
|
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
#[test] |
||||
fn simple_score_test() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||
.unwrap(); |
||||
|
||||
let trj = Trajectory::init( |
||||
arr1(&[0.0,0.1,0.3]), |
||||
arr2(&[[0],[1],[1]])); |
||||
|
||||
let dataset = Dataset::init(vec![trj]); |
||||
|
||||
let ll = LogLikelihood::init(1, 1.0); |
||||
|
||||
assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); |
||||
|
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn simple_bic() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||
.unwrap(); |
||||
|
||||
let trj = Trajectory::init( |
||||
arr1(&[0.0,0.1,0.3]), |
||||
arr2(&[[0],[1],[1]])); |
||||
|
||||
let dataset = Dataset::init(vec![trj]); |
||||
let bic = BIC::init(1, 1.0); |
||||
|
||||
assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); |
||||
|
||||
} |
||||
|
||||
|
||||
|
||||
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm> (sl: T) { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]]))); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
]))); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); |
||||
|
||||
let mut net = CtbnNetwork::init(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let net = sl.fit_transform(net, &data); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
pub fn check_compatibility_between_dataset_and_network_hill_climbing() { |
||||
let ll = LogLikelihood::init(1, 1.0); |
||||
let hl = HillClimbing::init(ll, None); |
||||
check_compatibility_between_dataset_and_network(hl); |
||||
} |
||||
|
||||
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]]))); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
]))); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); |
||||
|
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { |
||||
let ll = LogLikelihood::init(1, 1.0); |
||||
let hl = HillClimbing::init(ll, None); |
||||
learn_ternary_net_2_nodes(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { |
||||
let bic = BIC::init(1, 1.0); |
||||
let hl = HillClimbing::init(bic, None); |
||||
learn_ternary_net_2_nodes(hl); |
||||
} |
||||
|
||||
|
||||
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||
.unwrap(); |
||||
|
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"),4)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n1, n3); |
||||
net.add_edge(n2, n3); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]]))); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
]))); |
||||
} |
||||
} |
||||
|
||||
|
||||
match &mut net.get_node_mut(n3).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[ |
||||
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], |
||||
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], |
||||
|
||||
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], |
||||
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], |
||||
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], |
||||
|
||||
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], |
||||
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], |
||||
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], |
||||
]))); |
||||
} |
||||
} |
||||
|
||||
|
||||
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); |
||||
return (net, data); |
||||
} |
||||
|
||||
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) { |
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { |
||||
let ll = LogLikelihood::init(1, 1.0); |
||||
let hl = HillClimbing::init(ll, None); |
||||
learn_mixed_discrete_net_3_nodes(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { |
||||
let bic = BIC::init(1, 1.0); |
||||
let hl = HillClimbing::init(bic, None); |
||||
learn_mixed_discrete_net_3_nodes(hl); |
||||
} |
||||
|
||||
|
||||
|
||||
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm> (sl: T) { |
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { |
||||
let ll = LogLikelihood::init(1, 1.0); |
||||
let hl = HillClimbing::init(ll, Some(1)); |
||||
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { |
||||
let bic = BIC::init(1, 1.0); |
||||
let hl = HillClimbing::init(bic, Some(1)); |
||||
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); |
||||
} |
Loading…
Reference in new issue