Merge pull request #55 from AlessandroBregoli/54-refactor-make-the-code-compliant-to-rustfmt

Refactored `src/` and `tests/` files to be compliant to `rustfmt`
pull/58/head
Meliurwen 2 years ago committed by GitHub
commit a5b24e9eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 11
      .github/workflows/build.yml
  2. 5
      rustfmt.toml
  3. 70
      src/ctbn.rs
  4. 7
      src/lib.rs
  5. 17
      src/network.rs
  6. 37
      src/parameter_learning.rs
  7. 42
      src/params.rs
  8. 9
      src/structure_learning.rs
  9. 2
      src/structure_learning/constraint_based_algorithm.rs
  10. 20
      src/structure_learning/hypothesis_test.rs
  11. 6
      src/structure_learning/score_based_algorithm.rs
  12. 50
      src/structure_learning/score_function.rs
  13. 13
      src/tools.rs
  14. 3
      tests/ctbn.rs
  15. 133
      tests/parameter_learning.rs
  16. 3
      tests/params.rs
  17. 147
      tests/structure_learning.rs
  18. 17
      tests/tools.rs
  19. 13
      tests/utils.rs

@ -16,12 +16,20 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup (rust)
- name: Setup Rust stable (default)
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
default: true
components: clippy, rustfmt
- name: Setup Rust nightly
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
default: false
components: rustfmt
- name: Linting (clippy)
uses: actions-rs/clippy-check@v1
with:
@ -30,6 +38,7 @@ jobs:
- name: Formatting (rustfmt)
uses: actions-rs/cargo@v1
with:
toolchain: nightly
command: fmt
args: --all -- --check --verbose
- name: Tests (test)

@ -33,4 +33,7 @@ newline_style = "Unix"
#error_on_unformatted = true
# Files to ignore like third party code which is formatted upstream.
#ignore = []
# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors
ignore = [
"tests/"
]

@ -1,10 +1,9 @@
use ndarray::prelude::*;
use crate::params::{StateType, Params, ParamsTrait};
use crate::network;
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::network;
use crate::params::{Params, ParamsTrait, StateType};
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
///composed by the following elements:
@ -54,23 +53,23 @@ use std::collections::BTreeSet;
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params>
nodes: Vec<Params>,
}
impl CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new()
nodes: Vec::new(),
}
}
}
impl network::Network for CtbnNetwork {
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f()));
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
@ -103,60 +102,61 @@ impl network::Network for CtbnNetwork {
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize {
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
}).0
})
.0
}
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize {
parent_set.iter().fold((0, 1), |mut acc, x| {
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<StateType>,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set
.iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
}).0
})
.0
}
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref()
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| {
if x > &0 {
Some(idx)
} else {
None
}
}).collect()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref()
self.adj_matrix
.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| {
if x > &0 {
Some(idx)
} else {
None
}
}).collect()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
}

@ -2,10 +2,9 @@
#[cfg(test)]
extern crate approx;
pub mod params;
pub mod network;
pub mod ctbn;
pub mod tools;
pub mod network;
pub mod parameter_learning;
pub mod params;
pub mod structure_learning;
pub mod tools;

@ -1,15 +1,16 @@
use std::collections::BTreeSet;
use thiserror::Error;
use crate::params;
use std::collections::BTreeSet;
/// Error types for trait Network
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Error during node insertion")]
NodeInsertionError(String)
NodeInsertionError(String),
}
///Network
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network {
@ -26,13 +27,17 @@ pub trait Network {
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize;
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*.
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<params::StateType>, parent_set: &BTreeSet<usize>) -> usize;
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -1,9 +1,10 @@
use crate::network;
use crate::params::*;
use crate::tools;
use ndarray::prelude::*;
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::params::*;
use crate::{network, tools};
pub trait ParameterLearning {
fn fit<T: network::Network>(
&self,
@ -18,20 +19,15 @@ pub fn sufficient_statistics<T:network::Network>(
net: &T,
dataset: &tools::Dataset,
node: usize,
parent_set: &BTreeSet<usize>
parent_set: &BTreeSet<usize>,
) -> (Array3<usize>, Array2<f64>) {
//Get the number of values assumable by the node
let node_domain = net
.get_node(node.clone())
.get_reserved_space_as_parent();
let node_domain = net.get_node(node.clone()).get_reserved_space_as_parent();
//Get the number of values assumable by each parent of the node
let parentset_domain: Vec<usize> = parent_set
.iter()
.map(|x| {
net.get_node(x.clone())
.get_reserved_space_as_parent()
})
.map(|x| net.get_node(x.clone()).get_reserved_space_as_parent())
.collect();
//Vector used to convert a specific configuration of the parent_set to the corresponding index
@ -70,13 +66,11 @@ pub fn sufficient_statistics<T:network::Network>(
}
return (M, T);
}
pub struct MLE {}
impl ParameterLearning for MLE {
fn fit<T: network::Network>(
&self,
net: &T,
@ -84,7 +78,6 @@ impl ParameterLearning for MLE {
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
//Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set {
Some(p) => p,
@ -106,8 +99,6 @@ impl ParameterLearning for MLE {
C.diag_mut().assign(&diag);
});
let mut n: Params = net.get_node(node).clone();
match n {
@ -115,8 +106,6 @@ impl ParameterLearning for MLE {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
@ -125,7 +114,7 @@ impl ParameterLearning for MLE {
pub struct BayesianApproach {
pub alpha: usize,
pub tau: f64
pub tau: f64,
}
impl ParameterLearning for BayesianApproach {
@ -161,8 +150,6 @@ impl ParameterLearning for BayesianApproach {
C.diag_mut().assign(&diag);
});
let mut n: Params = net.get_node(node).clone();
match n {
@ -170,15 +157,12 @@ impl ParameterLearning for BayesianApproach {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
}
}
pub struct Cache<P: ParameterLearning> {
parameter_learning: P,
dataset: tools::Dataset,
@ -191,6 +175,7 @@ impl<P: ParameterLearning> Cache<P> {
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
self.parameter_learning.fit(net, &self.dataset, node, parent_set)
self.parameter_learning
.fit(net, &self.dataset, node, parent_set)
}
}

@ -1,9 +1,10 @@
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch;
use ndarray::prelude::*;
use rand::Rng;
use std::collections::{BTreeSet};
use thiserror::Error;
use rand_chacha::ChaCha8Rng;
use thiserror::Error;
/// Error types for trait Params
#[derive(Error, Debug, PartialEq)]
@ -35,11 +36,21 @@ pub trait ParamsTrait {
/// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError>;
fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set.
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError>;
fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize;
@ -113,7 +124,6 @@ impl DiscreteStatesContinousTimeParams {
}
}
///Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim);
@ -124,7 +134,6 @@ impl DiscreteStatesContinousTimeParams {
&self.transitions
}
///Setter function for transitions
pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions);
@ -135,12 +144,10 @@ impl DiscreteStatesContinousTimeParams {
&self.residence_time
}
///Setter function for residence_time
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time);
}
}
impl ParamsTrait for DiscreteStatesContinousTimeParams {
@ -154,7 +161,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
StateType::Discrete(rng.gen_range(0..(self.domain.len())))
}
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError> {
fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
@ -170,7 +182,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
}
}
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError> {
fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
@ -246,7 +263,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
}
// Check if each row sum up to 0
if cim.sum_axis(Axis(2)).iter()
if cim
.sum_axis(Axis(2))
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0)
{
return Err(ParamsError::InvalidCIM(String::from(
@ -260,5 +279,4 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn get_label(&self) -> &String {
&self.label
}
}

@ -1,12 +1,11 @@
pub mod score_function;
pub mod score_based_algorithm;
pub mod constraint_based_algorithm;
pub mod hypothesis_test;
use crate::network;
use crate::tools;
pub mod score_based_algorithm;
pub mod score_function;
use crate::{network, tools};
pub trait StructureLearningAlgorithm {
fn fit_transform<T, >(&self, net: T, dataset: &tools::Dataset) -> T
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: network::Network;
}

@ -1,11 +1,10 @@
use ndarray::Array3;
use ndarray::Axis;
use std::collections::BTreeSet;
use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network;
use crate::parameter_learning;
use crate::params::*;
use std::collections::BTreeSet;
use crate::{network, parameter_learning};
pub trait HypothesisTest {
fn call<T, P>(
@ -111,14 +110,14 @@ impl HypothesisTest for ChiSquare {
// di dimensione nxn
// (CIM, M, T)
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node
Params::DiscreteStatesContinousTime(node) => node,
};
//
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node
Params::DiscreteStatesContinousTime(node) => node,
};
// Commentare qui
let partial_cardinality_product: usize = extended_separation_set
@ -132,7 +131,12 @@ impl HypothesisTest for ChiSquare {
/ (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product;
if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) {
if !self.compare_matrices(
idx_M_small,
P_small.get_transitions().as_ref().unwrap(),
idx_M_big,
P_big.get_transitions().as_ref().unwrap(),
) {
return false;
}
}

@ -1,8 +1,8 @@
use crate::network;
use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools;
use std::collections::BTreeSet;
use crate::{network, tools};
pub struct HillClimbing<S: ScoreFunction> {
score_function: S,

@ -1,10 +1,9 @@
use crate::network;
use crate::parameter_learning;
use crate::params;
use crate::tools;
use std::collections::BTreeSet;
use ndarray::prelude::*;
use statrs::function::gamma;
use std::collections::BTreeSet;
use crate::{network, parameter_learning, params, tools};
pub trait ScoreFunction {
fn call<T>(
@ -25,7 +24,6 @@ pub struct LogLikelihood {
impl LogLikelihood {
pub fn new(alpha: usize, tau: f64) -> LogLikelihood {
//Tau must be >=0.0
if tau < 0.0 {
panic!("tau must be >=0.0");
@ -62,28 +60,33 @@ impl LogLikelihood {
.iter()
.zip(T.iter())
.map(|(m, t)| {
gamma::ln_gamma(alpha + *m as f64 + 1.0)
+ (alpha + 1.0) * f64::ln(tau)
gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau)
- gamma::ln_gamma(alpha + 1.0)
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t)
})
.sum();
//Compute the log likelihood for theta
let log_ll_theta: f64 = M.outer_iter()
.map(|x| x.outer_iter()
.map(|y| gamma::ln_gamma(alpha)
- gamma::ln_gamma(alpha + y.sum() as f64)
+ y.iter().map(|z|
let log_ll_theta: f64 = M
.outer_iter()
.map(|x| {
x.outer_iter()
.map(|y| {
gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64)
+ y.iter()
.map(|z| {
gamma::ln_gamma(alpha + *z as f64)
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum();
- gamma::ln_gamma(alpha)
})
.sum::<f64>()
})
.sum::<f64>()
})
.sum();
(log_ll_theta + log_ll_q, M)
}
}
}
}
impl ScoreFunction for LogLikelihood {
@ -102,13 +105,13 @@ impl ScoreFunction for LogLikelihood {
}
pub struct BIC {
ll: LogLikelihood
ll: LogLikelihood,
}
impl BIC {
pub fn new(alpha: usize, tau: f64) -> BIC {
BIC {
ll: LogLikelihood::new(alpha, tau)
ll: LogLikelihood::new(alpha, tau),
}
}
}
@ -122,14 +125,19 @@ impl ScoreFunction for BIC {
dataset: &tools::Dataset,
) -> f64
where
T: network::Network {
T: network::Network,
{
//Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
//Compute the number of parameters
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
//TODO: Optimize this
//Compute the sample size
let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum();
let sample_size: usize = dataset
.get_trajectories()
.iter()
.map(|x| x.get_time().len() - 1)
.sum();
//Compute BIC
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
}

@ -1,10 +1,10 @@
use crate::network;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*;
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait;
use crate::{network, params};
pub struct Trajectory {
time: Array1<f64>,
events: Array2<usize>,
@ -35,7 +35,6 @@ pub struct Dataset {
impl Dataset {
pub fn new(trajectories: Vec<Trajectory>) -> Dataset {
//All the trajectories in the same dataset must represent the same process. For this reason
//each trajectory must represent the same number of variables.
if trajectories
@ -58,7 +57,6 @@ pub fn trajectory_generator<T: network::Network>(
t_end: f64,
seed: Option<u64>,
) -> Dataset {
//Tmp growing vector containing generated trajectories.
let mut trajectories: Vec<Trajectory> = Vec::new();
@ -67,7 +65,7 @@ pub fn trajectory_generator<T: network::Network>(
//If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed),
//Otherwise create a new random generator using the method `from_entropy`
None => SeedableRng::from_entropy()
None => SeedableRng::from_entropy(),
};
//Each iteration generate one trajectory
@ -78,7 +76,8 @@ pub fn trajectory_generator<T: network::Network>(
let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform
//distribution.
let mut current_state: Vec<params::StateType> = net.get_node_indices()
let mut current_state: Vec<params::StateType> = net
.get_node_indices()
.map(|x| net.get_node(x).get_random_state_uniform(&mut rng))
.collect();
//History of all the configurations of the process variables.

@ -1,8 +1,9 @@
mod utils;
use std::collections::BTreeSet;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::params::{self, ParamsTrait};
use std::collections::BTreeSet;
use utils::generate_discrete_time_continous_node;
#[test]

@ -1,13 +1,13 @@
#![allow(non_snake_case)]
mod utils;
use utils::*;
use ndarray::arr3;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::parameter_learning::*;
use reCTBN::{params, tools::*};
use reCTBN::params;
use reCTBN::tools::*;
use utils::*;
extern crate approx;
@ -32,8 +32,14 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]))
);
}
@ -41,11 +47,20 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]),
&arr3(&[
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]),
0.1
));
}
@ -76,11 +91,13 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
]]))
],
]))
);
}
}
@ -90,9 +107,21 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
@ -100,14 +129,26 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]),
0.1
));
@ -139,11 +180,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
]]))
]
]))
);
}
}
@ -153,9 +196,21 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
@ -163,11 +218,17 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 0, None) {
params::Params::DiscreteStatesContinousTime(p) => p
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]),
&arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]),
0.1
));
}
@ -204,11 +265,13 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
]]))
],
]))
);
}
}
@ -218,9 +281,21 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
@ -292,7 +367,7 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 2, None) {
params::Params::DiscreteStatesContinousTime(p) => p
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(

@ -1,5 +1,6 @@
use ndarray::prelude::*;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
use reCTBN::params::{ParamsTrait, *};
mod utils;

@ -1,17 +1,18 @@
#![allow(non_snake_case)]
mod utils;
use utils::*;
use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::params;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm};
use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::score_based_algorithm::*;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::StructureLearningAlgorithm;
use reCTBN::tools::*;
use std::collections::BTreeSet;
use utils::*;
#[macro_use]
extern crate approx;
@ -69,11 +70,13 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
]]))
],
]))
);
}
}
@ -83,9 +86,21 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
@ -122,11 +137,13 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
]]))
],
]))
);
}
}
@ -136,9 +153,21 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
@ -185,11 +214,13 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
]]))
],
]))
);
}
}
@ -199,9 +230,21 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
@ -323,21 +366,29 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
pub fn chi_square_compare_matrices() {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 12, 90],
[ 7, 8, 0]
],
[
[0, 12, 90],
[ 3, 0, 40],
[ 6, 40, 0]],
[[ 0, 2, 3],
[ 6, 40, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
[ 44, 66, 0]
],
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[
[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
@ -347,15 +398,21 @@ pub fn chi_square_compare_matrices () {
pub fn chi_square_compare_matrices_2() {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 20, 30],
[ 7, 8, 0]
],
[
[0, 20, 30],
[ 40, 0, 60],
[ 70, 80, 0]],
[[ 0, 2, 3],
[ 70, 80, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
[ 44, 66, 0]
],
]);
let j: usize = 0;
let M2 = arr3(&[
@ -371,21 +428,29 @@ pub fn chi_square_compare_matrices_2 () {
pub fn chi_square_compare_matrices_3() {
let i: usize = 1;
let M1 = arr3(&[
[[ 0, 2, 3],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]],
[[0, 21, 31],
[ 7, 8, 0]
],
[
[0, 21, 31],
[ 41, 0, 59],
[ 71, 79, 0]],
[[ 0, 2, 3],
[ 71, 79, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
[ 44, 66, 0]
],
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[
[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2));

@ -29,15 +29,26 @@ fn run_sampling() {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]));
param.set_cim(arr3(&[
[
[-3.0, 3.0],
[2.0, -2.0]
],
]));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]));
}
}

@ -1,12 +1,19 @@
use reCTBN::params;
use std::collections::BTreeSet;
use reCTBN::params;
#[allow(dead_code)]
pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params {
params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality))
params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(
label,
cardinality,
))
}
pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
pub fn generate_discrete_time_continous_params(
label: String,
cardinality: usize,
) -> params::DiscreteStatesContinousTimeParams {
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::new(label, domain)
}

Loading…
Cancel
Save