Merge pull request #55 from AlessandroBregoli/54-refactor-make-the-code-compliant-to-rustfmt

Refactored `src/` and `tests/` files to be compliant to `rustfmt`
pull/58/head
Meliurwen 2 years ago committed by GitHub
commit a5b24e9eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 11
      .github/workflows/build.yml
  2. 5
      rustfmt.toml
  3. 100
      src/ctbn.rs
  4. 7
      src/lib.rs
  5. 19
      src/network.rs
  6. 51
      src/parameter_learning.rs
  7. 46
      src/params.rs
  8. 9
      src/structure_learning.rs
  9. 2
      src/structure_learning/constraint_based_algorithm.rs
  10. 24
      src/structure_learning/hypothesis_test.rs
  11. 6
      src/structure_learning/score_based_algorithm.rs
  12. 56
      src/structure_learning/score_function.rs
  13. 13
      src/tools.rs
  14. 3
      tests/ctbn.rs
  15. 157
      tests/parameter_learning.rs
  16. 3
      tests/params.rs
  17. 199
      tests/structure_learning.rs
  18. 17
      tests/tools.rs
  19. 13
      tests/utils.rs

@ -16,12 +16,20 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup (rust) - name: Setup Rust stable (default)
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
profile: minimal profile: minimal
toolchain: stable toolchain: stable
default: true
components: clippy, rustfmt components: clippy, rustfmt
- name: Setup Rust nightly
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
default: false
components: rustfmt
- name: Linting (clippy) - name: Linting (clippy)
uses: actions-rs/clippy-check@v1 uses: actions-rs/clippy-check@v1
with: with:
@ -30,6 +38,7 @@ jobs:
- name: Formatting (rustfmt) - name: Formatting (rustfmt)
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:
toolchain: nightly
command: fmt command: fmt
args: --all -- --check --verbose args: --all -- --check --verbose
- name: Tests (test) - name: Tests (test)

@ -33,4 +33,7 @@ newline_style = "Unix"
#error_on_unformatted = true #error_on_unformatted = true
# Files to ignore like third party code which is formatted upstream. # Files to ignore like third party code which is formatted upstream.
#ignore = [] # Ignoring tests is a temporary measure due some issues regarding rank-3 tensors
ignore = [
"tests/"
]

@ -1,10 +1,9 @@
use ndarray::prelude::*;
use crate::params::{StateType, Params, ParamsTrait};
use crate::network;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::network;
use crate::params::{Params, ParamsTrait, StateType};
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is ///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
///composed by the following elements: ///composed by the following elements:
@ -54,30 +53,30 @@ use std::collections::BTreeSet;
/// ``` /// ```
pub struct CtbnNetwork { pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>, adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params> nodes: Vec<Params>,
} }
impl CtbnNetwork { impl CtbnNetwork {
pub fn new() -> CtbnNetwork { pub fn new() -> CtbnNetwork {
CtbnNetwork { CtbnNetwork {
adj_matrix: None, adj_matrix: None,
nodes: Vec::new() nodes: Vec::new(),
} }
} }
} }
impl network::Network for CtbnNetwork { impl network::Network for CtbnNetwork {
fn initialize_adj_matrix(&mut self) { fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f())); self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
} }
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.reset_params(); n.reset_params();
self.adj_matrix = Option::None; self.adj_matrix = Option::None;
self.nodes.push(n); self.nodes.push(n);
Ok(self.nodes.len() -1) Ok(self.nodes.len() - 1)
} }
fn add_edge(&mut self, parent: usize, child: usize) { fn add_edge(&mut self, parent: usize, child: usize) {
@ -91,7 +90,7 @@ impl network::Network for CtbnNetwork {
} }
} }
fn get_node_indices(&self) -> std::ops::Range<usize>{ fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len() 0..self.nodes.len()
} }
@ -99,64 +98,65 @@ impl network::Network for CtbnNetwork {
self.nodes.len() self.nodes.len()
} }
fn get_node(&self, node_idx: usize) -> &Params{ fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx] &self.nodes[node_idx]
} }
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{
&mut self.nodes[node_idx] &mut self.nodes[node_idx]
} }
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize {
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{ self.adj_matrix
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { .as_ref()
if x.1 > &0 { .unwrap()
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1; .column(node)
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); .iter()
} .enumerate()
acc .fold((0, 1), |mut acc, x| {
}).0 if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
} }
fn get_param_index_from_custom_parent_set(
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize { &self,
parent_set.iter().fold((0, 1), |mut acc, x| { current_state: &Vec<StateType>,
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1; parent_set: &BTreeSet<usize>,
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); ) -> usize {
acc parent_set
}).0 .iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
} }
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref() self.adj_matrix
.as_ref()
.unwrap() .unwrap()
.column(node) .column(node)
.iter() .iter()
.enumerate() .enumerate()
.filter_map(|(idx, x)| { .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
if x > &0 { .collect()
Some(idx)
} else {
None
}
}).collect()
} }
fn get_children_set(&self, node: usize) -> BTreeSet<usize>{ fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref() self.adj_matrix
.as_ref()
.unwrap() .unwrap()
.row(node) .row(node)
.iter() .iter()
.enumerate() .enumerate()
.filter_map(|(idx, x)| { .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
if x > &0 { .collect()
Some(idx)
} else {
None
}
}).collect()
} }
} }

@ -2,10 +2,9 @@
#[cfg(test)] #[cfg(test)]
extern crate approx; extern crate approx;
pub mod params;
pub mod network;
pub mod ctbn; pub mod ctbn;
pub mod tools; pub mod network;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params;
pub mod structure_learning; pub mod structure_learning;
pub mod tools;

@ -1,20 +1,21 @@
use std::collections::BTreeSet;
use thiserror::Error; use thiserror::Error;
use crate::params; use crate::params;
use std::collections::BTreeSet;
/// Error types for trait Network /// Error types for trait Network
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum NetworkError { pub enum NetworkError {
#[error("Error during node insertion")] #[error("Error during node insertion")]
NodeInsertionError(String) NodeInsertionError(String),
} }
///Network ///Network
///The Network trait define the required methods for a structure used as pgm (such as ctbn). ///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network { pub trait Network {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize); fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network ///Get all the indices of the nodes contained inside the network
@ -26,13 +27,17 @@ pub trait Network {
///Compute the index that must be used to access the parameters of a node given a specific ///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are ///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*. ///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize; fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize;
///Compute the index that must be used to access the parameters of a node given a specific ///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used ///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*. ///in *current_state* are the ones in the parent set of the *node*.
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<params::StateType>, parent_set: &BTreeSet<usize>) -> usize; fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
} }

@ -1,11 +1,12 @@
use crate::network;
use crate::params::*;
use crate::tools;
use ndarray::prelude::*;
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait ParameterLearning{ use ndarray::prelude::*;
fn fit<T:network::Network>(
use crate::params::*;
use crate::{network, tools};
pub trait ParameterLearning {
fn fit<T: network::Network>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -14,24 +15,19 @@ pub trait ParameterLearning{
) -> Params; ) -> Params;
} }
pub fn sufficient_statistics<T:network::Network>( pub fn sufficient_statistics<T: network::Network>(
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: &BTreeSet<usize> parent_set: &BTreeSet<usize>,
) -> (Array3<usize>, Array2<f64>) { ) -> (Array3<usize>, Array2<f64>) {
//Get the number of values assumable by the node //Get the number of values assumable by the node
let node_domain = net let node_domain = net.get_node(node.clone()).get_reserved_space_as_parent();
.get_node(node.clone())
.get_reserved_space_as_parent();
//Get the number of values assumable by each parent of the node //Get the number of values assumable by each parent of the node
let parentset_domain: Vec<usize> = parent_set let parentset_domain: Vec<usize> = parent_set
.iter() .iter()
.map(|x| { .map(|x| net.get_node(x.clone()).get_reserved_space_as_parent())
net.get_node(x.clone())
.get_reserved_space_as_parent()
})
.collect(); .collect();
//Vector used to convert a specific configuration of the parent_set to the corresponding index //Vector used to convert a specific configuration of the parent_set to the corresponding index
@ -70,13 +66,11 @@ pub fn sufficient_statistics<T:network::Network>(
} }
return (M, T); return (M, T);
} }
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
fn fit<T: network::Network>( fn fit<T: network::Network>(
&self, &self,
net: &T, net: &T,
@ -84,7 +78,6 @@ impl ParameterLearning for MLE {
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
//Use parent_set from parameter if present. Otherwise use parent_set from network. //Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set { let parent_set = match parent_set {
Some(p) => p, Some(p) => p,
@ -96,7 +89,7 @@ impl ParameterLearning for MLE {
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T))); .for_each(|(mut C, m)| C.assign(&(&m / &T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
@ -106,8 +99,6 @@ impl ParameterLearning for MLE {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
let mut n: Params = net.get_node(node).clone(); let mut n: Params = net.get_node(node).clone();
match n { match n {
@ -115,8 +106,6 @@ impl ParameterLearning for MLE {
dsct.set_cim_unchecked(CIM); dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M); dsct.set_transitions(M);
dsct.set_residence_time(T); dsct.set_residence_time(T);
} }
}; };
return n; return n;
@ -125,7 +114,7 @@ impl ParameterLearning for MLE {
pub struct BayesianApproach { pub struct BayesianApproach {
pub alpha: usize, pub alpha: usize,
pub tau: f64 pub tau: f64,
} }
impl ParameterLearning for BayesianApproach { impl ParameterLearning for BayesianApproach {
@ -151,7 +140,7 @@ impl ParameterLearning for BayesianApproach {
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau)))); .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau))));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
@ -161,8 +150,6 @@ impl ParameterLearning for BayesianApproach {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
let mut n: Params = net.get_node(node).clone(); let mut n: Params = net.get_node(node).clone();
match n { match n {
@ -170,27 +157,25 @@ impl ParameterLearning for BayesianApproach {
dsct.set_cim_unchecked(CIM); dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M); dsct.set_transitions(M);
dsct.set_residence_time(T); dsct.set_residence_time(T);
} }
}; };
return n; return n;
} }
} }
pub struct Cache<P: ParameterLearning> { pub struct Cache<P: ParameterLearning> {
parameter_learning: P, parameter_learning: P,
dataset: tools::Dataset, dataset: tools::Dataset,
} }
impl<P: ParameterLearning> Cache<P> { impl<P: ParameterLearning> Cache<P> {
pub fn fit<T:network::Network>( pub fn fit<T: network::Network>(
&mut self, &mut self,
net: &T, net: &T,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
self.parameter_learning.fit(net, &self.dataset, node, parent_set) self.parameter_learning
.fit(net, &self.dataset, node, parent_set)
} }
} }

@ -1,9 +1,10 @@
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch; use enum_dispatch::enum_dispatch;
use ndarray::prelude::*; use ndarray::prelude::*;
use rand::Rng; use rand::Rng;
use std::collections::{BTreeSet};
use thiserror::Error;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use thiserror::Error;
/// Error types for trait Params /// Error types for trait Params
#[derive(Error, Debug, PartialEq)] #[derive(Error, Debug, PartialEq)]
@ -35,11 +36,21 @@ pub trait ParamsTrait {
/// Randomly generate a residence time for the given node taking into account the node state /// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set. /// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError>; fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state /// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set. /// and its parent set.
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError>; fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs. /// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize; fn get_reserved_space_as_parent(&self) -> usize;
@ -102,7 +113,7 @@ impl DiscreteStatesContinousTimeParams {
///This function check if the cim is valid using the validate_params method. ///This function check if the cim is valid using the validate_params method.
///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(())
///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError>{ pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
self.cim = Some(cim); self.cim = Some(cim);
match self.validate_params() { match self.validate_params() {
Ok(()) => Ok(()), Ok(()) => Ok(()),
@ -113,7 +124,6 @@ impl DiscreteStatesContinousTimeParams {
} }
} }
///Unchecked version of the setter function for CIM. ///Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) { pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim); self.cim = Some(cim);
@ -124,7 +134,6 @@ impl DiscreteStatesContinousTimeParams {
&self.transitions &self.transitions
} }
///Setter function for transitions ///Setter function for transitions
pub fn set_transitions(&mut self, transitions: Array3<usize>) { pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions); self.transitions = Some(transitions);
@ -135,12 +144,10 @@ impl DiscreteStatesContinousTimeParams {
&self.residence_time &self.residence_time
} }
///Setter function for residence_time ///Setter function for residence_time
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) { pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time); self.residence_time = Some(residence_time);
} }
} }
impl ParamsTrait for DiscreteStatesContinousTimeParams { impl ParamsTrait for DiscreteStatesContinousTimeParams {
@ -154,7 +161,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
StateType::Discrete(rng.gen_range(0..(self.domain.len()))) StateType::Discrete(rng.gen_range(0..(self.domain.len())))
} }
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError> { fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set. // Generate a random residence time given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
@ -170,7 +182,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
} }
} }
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError> { fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set. // Generate a random transition given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
@ -246,7 +263,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
} }
// Check if each row sum up to 0 // Check if each row sum up to 0
if cim.sum_axis(Axis(2)).iter() if cim
.sum_axis(Axis(2))
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0)
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
@ -257,8 +276,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
return Ok(()); return Ok(());
} }
fn get_label(&self) -> &String { fn get_label(&self) -> &String {
&self.label &self.label
} }
} }

@ -1,12 +1,11 @@
pub mod score_function;
pub mod score_based_algorithm;
pub mod constraint_based_algorithm; pub mod constraint_based_algorithm;
pub mod hypothesis_test; pub mod hypothesis_test;
use crate::network; pub mod score_based_algorithm;
use crate::tools; pub mod score_function;
use crate::{network, tools};
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn fit_transform<T, >(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network; T: network::Network;
} }

@ -1,5 +1,3 @@
//pub struct CTPC { //pub struct CTPC {
// //
//} //}

@ -1,11 +1,10 @@
use ndarray::Array3; use std::collections::BTreeSet;
use ndarray::Axis;
use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF}; use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network;
use crate::parameter_learning;
use crate::params::*; use crate::params::*;
use std::collections::BTreeSet; use crate::{network, parameter_learning};
pub trait HypothesisTest { pub trait HypothesisTest {
fn call<T, P>( fn call<T, P>(
@ -110,15 +109,15 @@ impl HypothesisTest for ChiSquare {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn // di dimensione nxn
// (CIM, M, T) // (CIM, M, T)
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){ let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node Params::DiscreteStatesContinousTime(node) => node,
}; };
// //
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){ let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node Params::DiscreteStatesContinousTime(node) => node,
}; };
// Commentare qui // Commentare qui
let partial_cardinality_product: usize = extended_separation_set let partial_cardinality_product: usize = extended_separation_set
@ -132,7 +131,12 @@ impl HypothesisTest for ChiSquare {
/ (partial_cardinality_product / (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent())) * net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product; * partial_cardinality_product;
if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) { if !self.compare_matrices(
idx_M_small,
P_small.get_transitions().as_ref().unwrap(),
idx_M_big,
P_big.get_transitions().as_ref().unwrap(),
) {
return false; return false;
} }
} }

@ -1,8 +1,8 @@
use crate::network; use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools; use crate::{network, tools};
use std::collections::BTreeSet;
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,

@ -1,10 +1,9 @@
use crate::network; use std::collections::BTreeSet;
use crate::parameter_learning;
use crate::params;
use crate::tools;
use ndarray::prelude::*; use ndarray::prelude::*;
use statrs::function::gamma; use statrs::function::gamma;
use std::collections::BTreeSet;
use crate::{network, parameter_learning, params, tools};
pub trait ScoreFunction { pub trait ScoreFunction {
fn call<T>( fn call<T>(
@ -25,7 +24,6 @@ pub struct LogLikelihood {
impl LogLikelihood { impl LogLikelihood {
pub fn new(alpha: usize, tau: f64) -> LogLikelihood { pub fn new(alpha: usize, tau: f64) -> LogLikelihood {
//Tau must be >=0.0 //Tau must be >=0.0
if tau < 0.0 { if tau < 0.0 {
panic!("tau must be >=0.0"); panic!("tau must be >=0.0");
@ -44,7 +42,7 @@ impl LogLikelihood {
T: network::Network, T: network::Network,
{ {
//Identify the type of node used //Identify the type of node used
match &net.get_node(node){ match &net.get_node(node) {
params::Params::DiscreteStatesContinousTime(_params) => { params::Params::DiscreteStatesContinousTime(_params) => {
//Compute the sufficient statistics M (number of transistions) and T (residence //Compute the sufficient statistics M (number of transistions) and T (residence
//time) //time)
@ -57,33 +55,38 @@ impl LogLikelihood {
let tau = self.tau / M.shape()[0] as f64; let tau = self.tau / M.shape()[0] as f64;
//Compute the log likelihood for q //Compute the log likelihood for q
let log_ll_q:f64 = M let log_ll_q: f64 = M
.sum_axis(Axis(2)) .sum_axis(Axis(2))
.iter() .iter()
.zip(T.iter()) .zip(T.iter())
.map(|(m, t)| { .map(|(m, t)| {
gamma::ln_gamma(alpha + *m as f64 + 1.0) gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau)
+ (alpha + 1.0) * f64::ln(tau)
- gamma::ln_gamma(alpha + 1.0) - gamma::ln_gamma(alpha + 1.0)
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t) - (alpha + *m as f64 + 1.0) * f64::ln(tau + t)
}) })
.sum(); .sum();
//Compute the log likelihood for theta //Compute the log likelihood for theta
let log_ll_theta: f64 = M.outer_iter() let log_ll_theta: f64 = M
.map(|x| x.outer_iter() .outer_iter()
.map(|y| gamma::ln_gamma(alpha) .map(|x| {
- gamma::ln_gamma(alpha + y.sum() as f64) x.outer_iter()
+ y.iter().map(|z| .map(|y| {
gamma::ln_gamma(alpha + *z as f64) gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64)
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum(); + y.iter()
.map(|z| {
gamma::ln_gamma(alpha + *z as f64)
- gamma::ln_gamma(alpha)
})
.sum::<f64>()
})
.sum::<f64>()
})
.sum();
(log_ll_theta + log_ll_q, M) (log_ll_theta + log_ll_q, M)
} }
} }
} }
} }
impl ScoreFunction for LogLikelihood { impl ScoreFunction for LogLikelihood {
@ -102,13 +105,13 @@ impl ScoreFunction for LogLikelihood {
} }
pub struct BIC { pub struct BIC {
ll: LogLikelihood ll: LogLikelihood,
} }
impl BIC { impl BIC {
pub fn new(alpha: usize, tau: f64) -> BIC { pub fn new(alpha: usize, tau: f64) -> BIC {
BIC { BIC {
ll: LogLikelihood::new(alpha, tau) ll: LogLikelihood::new(alpha, tau),
} }
} }
} }
@ -122,14 +125,19 @@ impl ScoreFunction for BIC {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network { T: network::Network,
{
//Compute the log-likelihood //Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
//Compute the number of parameters //Compute the number of parameters
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
//TODO: Optimize this //TODO: Optimize this
//Compute the sample size //Compute the sample size
let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); let sample_size: usize = dataset
.get_trajectories()
.iter()
.map(|x| x.get_time().len() - 1)
.sum();
//Compute BIC //Compute BIC
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
} }

@ -1,10 +1,10 @@
use crate::network;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*; use ndarray::prelude::*;
use rand_chacha::rand_core::SeedableRng; use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait;
use crate::{network, params};
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,
events: Array2<usize>, events: Array2<usize>,
@ -35,7 +35,6 @@ pub struct Dataset {
impl Dataset { impl Dataset {
pub fn new(trajectories: Vec<Trajectory>) -> Dataset { pub fn new(trajectories: Vec<Trajectory>) -> Dataset {
//All the trajectories in the same dataset must represent the same process. For this reason //All the trajectories in the same dataset must represent the same process. For this reason
//each trajectory must represent the same number of variables. //each trajectory must represent the same number of variables.
if trajectories if trajectories
@ -58,7 +57,6 @@ pub fn trajectory_generator<T: network::Network>(
t_end: f64, t_end: f64,
seed: Option<u64>, seed: Option<u64>,
) -> Dataset { ) -> Dataset {
//Tmp growing vector containing generated trajectories. //Tmp growing vector containing generated trajectories.
let mut trajectories: Vec<Trajectory> = Vec::new(); let mut trajectories: Vec<Trajectory> = Vec::new();
@ -67,7 +65,7 @@ pub fn trajectory_generator<T: network::Network>(
//If a seed is present use it to initialize the random generator. //If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed), Some(seed) => SeedableRng::seed_from_u64(seed),
//Otherwise create a new random generator using the method `from_entropy` //Otherwise create a new random generator using the method `from_entropy`
None => SeedableRng::from_entropy() None => SeedableRng::from_entropy(),
}; };
//Each iteration generate one trajectory //Each iteration generate one trajectory
@ -78,7 +76,8 @@ pub fn trajectory_generator<T: network::Network>(
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut current_state: Vec<params::StateType> = net.get_node_indices() let mut current_state: Vec<params::StateType> = net
.get_node_indices()
.map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) .map(|x| net.get_node(x).get_random_state_uniform(&mut rng))
.collect(); .collect();
//History of all the configurations of the process variables. //History of all the configurations of the process variables.

@ -1,8 +1,9 @@
mod utils; mod utils;
use std::collections::BTreeSet;
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::params::{self, ParamsTrait}; use reCTBN::params::{self, ParamsTrait};
use std::collections::BTreeSet;
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
#[test] #[test]

@ -1,13 +1,13 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
mod utils; mod utils;
use utils::*;
use ndarray::arr3; use ndarray::arr3;
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::parameter_learning::*; use reCTBN::parameter_learning::*;
use reCTBN::{params, tools::*}; use reCTBN::params;
use reCTBN::tools::*;
use utils::*;
extern crate approx; extern crate approx;
@ -32,8 +32,14 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [
[[-6.0, 6.0], [2.0, -2.0]], [-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
])) ]))
); );
} }
@ -41,11 +47,20 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None) { let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), &arr3(&[
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]),
0.1 0.1
)); ));
} }
@ -76,11 +91,13 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[-3.0, 2.0, 1.0], [
[1.5, -2.0, 0.5], [-3.0, 2.0, 1.0],
[0.4, 0.6, -1.0] [1.5, -2.0, 0.5],
]])) [0.4, 0.6, -1.0]
],
]))
); );
} }
} }
@ -90,24 +107,48 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None){ let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[ &arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]), ]),
0.1 0.1
)); ));
@ -139,11 +180,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[-3.0, 2.0, 1.0], [
[1.5, -2.0, 0.5], [-3.0, 2.0, 1.0],
[0.4, 0.6, -1.0] [1.5, -2.0, 0.5],
]])) [0.4, 0.6, -1.0]
]
]))
); );
} }
} }
@ -153,21 +196,39 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 0, None){ let p = match pl.fit(&net, &data, 0, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), &arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]),
0.1 0.1
)); ));
} }
@ -204,11 +265,13 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[-3.0, 2.0, 1.0], [
[1.5, -2.0, 0.5], [-3.0, 2.0, 1.0],
[0.4, 0.6, -1.0] [1.5, -2.0, 0.5],
]])) [0.4, 0.6, -1.0]
],
]))
); );
} }
} }
@ -218,9 +281,21 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -291,8 +366,8 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 2, None){ let p = match pl.fit(&net, &data, 2, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(

@ -1,5 +1,6 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
use reCTBN::params::{ParamsTrait, *}; use reCTBN::params::{ParamsTrait, *};
mod utils; mod utils;

@ -1,17 +1,18 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
mod utils; mod utils;
use utils::*; use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::params; use reCTBN::params;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm};
use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::score_based_algorithm::*;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::StructureLearningAlgorithm;
use reCTBN::tools::*; use reCTBN::tools::*;
use std::collections::BTreeSet; use utils::*;
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
@ -69,11 +70,13 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[-3.0, 2.0, 1.0], [
[1.5, -2.0, 0.5], [-3.0, 2.0, 1.0],
[0.4, 0.6, -1.0] [1.5, -2.0, 0.5],
]])) [0.4, 0.6, -1.0]
],
]))
); );
} }
} }
@ -83,9 +86,21 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -122,11 +137,13 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[-3.0, 2.0, 1.0], [
[1.5, -2.0, 0.5], [-3.0, 2.0, 1.0],
[0.4, 0.6, -1.0] [1.5, -2.0, 0.5],
]])) [0.4, 0.6, -1.0]
],
]))
); );
} }
} }
@ -136,9 +153,21 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -185,11 +214,13 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[[ param.set_cim(arr3(&[
[-3.0, 2.0, 1.0], [
[1.5, -2.0, 0.5], [-3.0, 2.0, 1.0],
[0.4, 0.6, -1.0] [1.5, -2.0, 0.5],
]])) [0.4, 0.6, -1.0]
],
]))
); );
} }
} }
@ -199,9 +230,21 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
assert_eq!( assert_eq!(
Ok(()), Ok(()),
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [-1.0, 0.5, 0.5],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
])) ]))
); );
} }
@ -320,42 +363,56 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
} }
#[test] #[test]
pub fn chi_square_compare_matrices () { pub fn chi_square_compare_matrices() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[ 0, 2, 3], [
[ 4, 0, 6], [ 0, 2, 3],
[ 7, 8, 0]], [ 4, 0, 6],
[[0, 12, 90], [ 7, 8, 0]
[ 3, 0, 40], ],
[ 6, 40, 0]], [
[[ 0, 2, 3], [0, 12, 90],
[ 4, 0, 6], [ 3, 0, 40],
[ 44, 66, 0]] [ 6, 40, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[ let M2 = arr3(&[
[[ 0, 200, 300], [
[ 400, 0, 600], [ 0, 200, 300],
[ 700, 800, 0]] [ 400, 0, 600],
[ 700, 800, 0]
],
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices( i, &M1, j, &M2)); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
} }
#[test] #[test]
pub fn chi_square_compare_matrices_2 () { pub fn chi_square_compare_matrices_2() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[ 0, 2, 3], [
[ 4, 0, 6], [ 0, 2, 3],
[ 7, 8, 0]], [ 4, 0, 6],
[[0, 20, 30], [ 7, 8, 0]
[ 40, 0, 60], ],
[ 70, 80, 0]], [
[[ 0, 2, 3], [0, 20, 30],
[ 4, 0, 6], [ 40, 0, 60],
[ 44, 66, 0]] [ 70, 80, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[ let M2 = arr3(&[
@ -364,29 +421,37 @@ pub fn chi_square_compare_matrices_2 () {
[ 700, 800, 0]] [ 700, 800, 0]]
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
#[test] #[test]
pub fn chi_square_compare_matrices_3 () { pub fn chi_square_compare_matrices_3() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[ 0, 2, 3], [
[ 4, 0, 6], [ 0, 2, 3],
[ 7, 8, 0]], [ 4, 0, 6],
[[0, 21, 31], [ 7, 8, 0]
[ 41, 0, 59], ],
[ 71, 79, 0]], [
[[ 0, 2, 3], [0, 21, 31],
[ 4, 0, 6], [ 41, 0, 59],
[ 44, 66, 0]] [ 71, 79, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[ let M2 = arr3(&[
[[ 0, 200, 300], [
[ 400, 0, 600], [ 0, 200, 300],
[ 700, 800, 0]] [ 400, 0, 600],
[ 700, 800, 0]
],
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }

@ -29,15 +29,26 @@ fn run_sampling() {
match &mut net.get_node_mut(n1) { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); param.set_cim(arr3(&[
[
[-3.0, 3.0],
[2.0, -2.0]
],
]));
} }
} }
match &mut net.get_node_mut(n2) { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [
[[-6.0, 6.0], [2.0, -2.0]], [-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
])); ]));
} }
} }

@ -1,12 +1,19 @@
use reCTBN::params;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use reCTBN::params;
#[allow(dead_code)] #[allow(dead_code)]
pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params {
params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality)) params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(
label,
cardinality,
))
} }
pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ pub fn generate_discrete_time_continous_params(
label: String,
cardinality: usize,
) -> params::DiscreteStatesContinousTimeParams {
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::new(label, domain) params::DiscreteStatesContinousTimeParams::new(label, domain)
} }

Loading…
Cancel
Save