Refactored `src/` and `tests/` files to be compliant to `rustfmt`

pull/55/head
Meliurwen 2 years ago
parent d6e93b351c
commit 780515707c
  1. 108
      src/ctbn.rs
  2. 7
      src/lib.rs
  3. 19
      src/network.rs
  4. 61
      src/parameter_learning.rs
  5. 54
      src/params.rs
  6. 9
      src/structure_learning.rs
  7. 2
      src/structure_learning/constraint_based_algorithm.rs
  8. 24
      src/structure_learning/hypothesis_test.rs
  9. 6
      src/structure_learning/score_based_algorithm.rs
  10. 62
      src/structure_learning/score_function.rs
  11. 33
      src/tools.rs
  12. 3
      tests/ctbn.rs
  13. 20
      tests/parameter_learning.rs
  14. 3
      tests/params.rs
  15. 75
      tests/structure_learning.rs
  16. 13
      tests/utils.rs

@ -1,10 +1,9 @@
use ndarray::prelude::*;
use crate::params::{StateType, Params, ParamsTrait};
use crate::network;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::network;
use crate::params::{Params, ParamsTrait, StateType};
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is ///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
///composed by the following elements: ///composed by the following elements:
@ -22,12 +21,12 @@ use std::collections::BTreeSet;
/// use reCTBN::ctbn::*; /// use reCTBN::ctbn::*;
/// ///
/// //Create the domain for a discrete node /// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new(); /// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A")); /// domain.insert(String::from("A"));
/// domain.insert(String::from("B")); /// domain.insert(String::from("B"));
/// ///
/// //Create the parameters for a discrete node using the domain /// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); /// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
/// ///
/// //Create the node using the parameters /// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param); /// let X1 = params::Params::DiscreteStatesContinousTime(param);
@ -37,14 +36,14 @@ use std::collections::BTreeSet;
/// domain.insert(String::from("B")); /// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); /// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param); /// let X2 = params::Params::DiscreteStatesContinousTime(param);
/// ///
/// //Initialize a ctbn /// //Initialize a ctbn
/// let mut net = CtbnNetwork::new(); /// let mut net = CtbnNetwork::new();
/// ///
/// //Add nodes /// //Add nodes
/// let X1 = net.add_node(X1).unwrap(); /// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap(); /// let X2 = net.add_node(X2).unwrap();
/// ///
/// //Add an edge /// //Add an edge
/// net.add_edge(X1, X2); /// net.add_edge(X1, X2);
/// ///
@ -54,30 +53,30 @@ use std::collections::BTreeSet;
/// ``` /// ```
pub struct CtbnNetwork { pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>, adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params> nodes: Vec<Params>,
} }
impl CtbnNetwork { impl CtbnNetwork {
pub fn new() -> CtbnNetwork { pub fn new() -> CtbnNetwork {
CtbnNetwork { CtbnNetwork {
adj_matrix: None, adj_matrix: None,
nodes: Vec::new() nodes: Vec::new(),
} }
} }
} }
impl network::Network for CtbnNetwork { impl network::Network for CtbnNetwork {
fn initialize_adj_matrix(&mut self) { fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f())); self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
} }
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.reset_params(); n.reset_params();
self.adj_matrix = Option::None; self.adj_matrix = Option::None;
self.nodes.push(n); self.nodes.push(n);
Ok(self.nodes.len() -1) Ok(self.nodes.len() - 1)
} }
fn add_edge(&mut self, parent: usize, child: usize) { fn add_edge(&mut self, parent: usize, child: usize) {
@ -91,7 +90,7 @@ impl network::Network for CtbnNetwork {
} }
} }
fn get_node_indices(&self) -> std::ops::Range<usize>{ fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len() 0..self.nodes.len()
} }
@ -99,64 +98,65 @@ impl network::Network for CtbnNetwork {
self.nodes.len() self.nodes.len()
} }
fn get_node(&self, node_idx: usize) -> &Params{ fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx] &self.nodes[node_idx]
} }
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{
&mut self.nodes[node_idx] &mut self.nodes[node_idx]
} }
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize {
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{ self.adj_matrix
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { .as_ref()
if x.1 > &0 { .unwrap()
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1; .column(node)
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); .iter()
} .enumerate()
acc .fold((0, 1), |mut acc, x| {
}).0 if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
} }
fn get_param_index_from_custom_parent_set(
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize { &self,
parent_set.iter().fold((0, 1), |mut acc, x| { current_state: &Vec<StateType>,
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1; parent_set: &BTreeSet<usize>,
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); ) -> usize {
acc parent_set
}).0 .iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
} }
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref() self.adj_matrix
.as_ref()
.unwrap() .unwrap()
.column(node) .column(node)
.iter() .iter()
.enumerate() .enumerate()
.filter_map(|(idx, x)| { .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
if x > &0 { .collect()
Some(idx)
} else {
None
}
}).collect()
} }
fn get_children_set(&self, node: usize) -> BTreeSet<usize>{ fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref() self.adj_matrix
.as_ref()
.unwrap() .unwrap()
.row(node) .row(node)
.iter() .iter()
.enumerate() .enumerate()
.filter_map(|(idx, x)| { .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
if x > &0 { .collect()
Some(idx)
} else {
None
}
}).collect()
} }
} }

@ -2,10 +2,9 @@
#[cfg(test)] #[cfg(test)]
extern crate approx; extern crate approx;
pub mod params;
pub mod network;
pub mod ctbn; pub mod ctbn;
pub mod tools; pub mod network;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params;
pub mod structure_learning; pub mod structure_learning;
pub mod tools;

@ -1,20 +1,21 @@
use std::collections::BTreeSet;
use thiserror::Error; use thiserror::Error;
use crate::params; use crate::params;
use std::collections::BTreeSet;
/// Error types for trait Network /// Error types for trait Network
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum NetworkError { pub enum NetworkError {
#[error("Error during node insertion")] #[error("Error during node insertion")]
NodeInsertionError(String) NodeInsertionError(String),
} }
///Network ///Network
///The Network trait define the required methods for a structure used as pgm (such as ctbn). ///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network { pub trait Network {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize); fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network ///Get all the indices of the nodes contained inside the network
@ -26,13 +27,17 @@ pub trait Network {
///Compute the index that must be used to access the parameters of a node given a specific ///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are ///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*. ///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize; fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize;
///Compute the index that must be used to access the parameters of a node given a specific ///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used ///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*. ///in *current_state* are the ones in the parent set of the *node*.
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<params::StateType>, parent_set: &BTreeSet<usize>) -> usize; fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
} }

@ -1,11 +1,12 @@
use crate::network;
use crate::params::*;
use crate::tools;
use ndarray::prelude::*;
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait ParameterLearning{ use ndarray::prelude::*;
fn fit<T:network::Network>(
use crate::params::*;
use crate::{network, tools};
pub trait ParameterLearning {
fn fit<T: network::Network>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -14,24 +15,19 @@ pub trait ParameterLearning{
) -> Params; ) -> Params;
} }
pub fn sufficient_statistics<T:network::Network>( pub fn sufficient_statistics<T: network::Network>(
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
parent_set: &BTreeSet<usize> parent_set: &BTreeSet<usize>,
) -> (Array3<usize>, Array2<f64>) { ) -> (Array3<usize>, Array2<f64>) {
//Get the number of values assumable by the node //Get the number of values assumable by the node
let node_domain = net let node_domain = net.get_node(node.clone()).get_reserved_space_as_parent();
.get_node(node.clone())
.get_reserved_space_as_parent();
//Get the number of values assumable by each parent of the node //Get the number of values assumable by each parent of the node
let parentset_domain: Vec<usize> = parent_set let parentset_domain: Vec<usize> = parent_set
.iter() .iter()
.map(|x| { .map(|x| net.get_node(x.clone()).get_reserved_space_as_parent())
net.get_node(x.clone())
.get_reserved_space_as_parent()
})
.collect(); .collect();
//Vector used to convert a specific configuration of the parent_set to the corresponding index //Vector used to convert a specific configuration of the parent_set to the corresponding index
@ -45,7 +41,7 @@ pub fn sufficient_statistics<T:network::Network>(
vector_to_idx[*idx] = acc; vector_to_idx[*idx] = acc;
acc * x acc * x
}); });
//Number of transition given a specific configuration of the parent set //Number of transition given a specific configuration of the parent set
let mut M: Array3<usize> = let mut M: Array3<usize> =
Array::zeros((parentset_domain.iter().product(), node_domain, node_domain)); Array::zeros((parentset_domain.iter().product(), node_domain, node_domain));
@ -70,13 +66,11 @@ pub fn sufficient_statistics<T:network::Network>(
} }
return (M, T); return (M, T);
} }
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
fn fit<T: network::Network>( fn fit<T: network::Network>(
&self, &self,
net: &T, net: &T,
@ -84,19 +78,18 @@ impl ParameterLearning for MLE {
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
//Use parent_set from parameter if present. Otherwise use parent_set from network. //Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set { let parent_set = match parent_set {
Some(p) => p, Some(p) => p,
None => net.get_parent_set(node), None => net.get_parent_set(node),
}; };
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
//Compute the CIM as M[i,x,y]/T[i,x] //Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T))); .for_each(|(mut C, m)| C.assign(&(&m / &T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
@ -105,8 +98,6 @@ impl ParameterLearning for MLE {
.for_each(|(mut C, diag)| { .for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
let mut n: Params = net.get_node(node).clone(); let mut n: Params = net.get_node(node).clone();
@ -115,8 +106,6 @@ impl ParameterLearning for MLE {
dsct.set_cim_unchecked(CIM); dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M); dsct.set_transitions(M);
dsct.set_residence_time(T); dsct.set_residence_time(T);
} }
}; };
return n; return n;
@ -125,7 +114,7 @@ impl ParameterLearning for MLE {
pub struct BayesianApproach { pub struct BayesianApproach {
pub alpha: usize, pub alpha: usize,
pub tau: f64 pub tau: f64,
} }
impl ParameterLearning for BayesianApproach { impl ParameterLearning for BayesianApproach {
@ -141,17 +130,17 @@ impl ParameterLearning for BayesianApproach {
Some(p) => p, Some(p) => p,
None => net.get_parent_set(node), None => net.get_parent_set(node),
}; };
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64;
let tau: f64 = self.tau as f64 / M.shape()[0] as f64; let tau: f64 = self.tau as f64 / M.shape()[0] as f64;
//Compute the CIM as M[i,x,y]/T[i,x] //Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau)))); .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau))));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
@ -161,8 +150,6 @@ impl ParameterLearning for BayesianApproach {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
let mut n: Params = net.get_node(node).clone(); let mut n: Params = net.get_node(node).clone();
match n { match n {
@ -170,27 +157,25 @@ impl ParameterLearning for BayesianApproach {
dsct.set_cim_unchecked(CIM); dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M); dsct.set_transitions(M);
dsct.set_residence_time(T); dsct.set_residence_time(T);
} }
}; };
return n; return n;
} }
} }
pub struct Cache<P: ParameterLearning> { pub struct Cache<P: ParameterLearning> {
parameter_learning: P, parameter_learning: P,
dataset: tools::Dataset, dataset: tools::Dataset,
} }
impl<P: ParameterLearning> Cache<P> { impl<P: ParameterLearning> Cache<P> {
pub fn fit<T:network::Network>( pub fn fit<T: network::Network>(
&mut self, &mut self,
net: &T, net: &T,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
self.parameter_learning.fit(net, &self.dataset, node, parent_set) self.parameter_learning
.fit(net, &self.dataset, node, parent_set)
} }
} }

@ -1,9 +1,10 @@
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch; use enum_dispatch::enum_dispatch;
use ndarray::prelude::*; use ndarray::prelude::*;
use rand::Rng; use rand::Rng;
use std::collections::{BTreeSet};
use thiserror::Error;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use thiserror::Error;
/// Error types for trait Params /// Error types for trait Params
#[derive(Error, Debug, PartialEq)] #[derive(Error, Debug, PartialEq)]
@ -35,11 +36,21 @@ pub trait ParamsTrait {
/// Randomly generate a residence time for the given node taking into account the node state /// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set. /// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError>; fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state /// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set. /// and its parent set.
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError>; fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs. /// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize; fn get_reserved_space_as_parent(&self) -> usize;
@ -49,7 +60,7 @@ pub trait ParamsTrait {
/// Validate parameters against domain /// Validate parameters against domain
fn validate_params(&self) -> Result<(), ParamsError>; fn validate_params(&self) -> Result<(), ParamsError>;
/// Return a reference to the associated label /// Return a reference to the associated label
fn get_label(&self) -> &String; fn get_label(&self) -> &String;
} }
@ -92,17 +103,17 @@ impl DiscreteStatesContinousTimeParams {
residence_time: Option::None, residence_time: Option::None,
} }
} }
///Getter function for CIM ///Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> { pub fn get_cim(&self) -> &Option<Array3<f64>> {
&self.cim &self.cim
} }
///Setter function for CIM.\\ ///Setter function for CIM.\\
///This function check if the cim is valid using the validate_params method. ///This function check if the cim is valid using the validate_params method.
///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(())
///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError>{ pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
self.cim = Some(cim); self.cim = Some(cim);
match self.validate_params() { match self.validate_params() {
Ok(()) => Ok(()), Ok(()) => Ok(()),
@ -113,7 +124,6 @@ impl DiscreteStatesContinousTimeParams {
} }
} }
///Unchecked version of the setter function for CIM. ///Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) { pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim); self.cim = Some(cim);
@ -124,7 +134,6 @@ impl DiscreteStatesContinousTimeParams {
&self.transitions &self.transitions
} }
///Setter function for transitions ///Setter function for transitions
pub fn set_transitions(&mut self, transitions: Array3<usize>) { pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions); self.transitions = Some(transitions);
@ -135,12 +144,10 @@ impl DiscreteStatesContinousTimeParams {
&self.residence_time &self.residence_time
} }
///Setter function for residence_time ///Setter function for residence_time
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) { pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time); self.residence_time = Some(residence_time);
} }
} }
impl ParamsTrait for DiscreteStatesContinousTimeParams { impl ParamsTrait for DiscreteStatesContinousTimeParams {
@ -154,7 +161,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
StateType::Discrete(rng.gen_range(0..(self.domain.len()))) StateType::Discrete(rng.gen_range(0..(self.domain.len())))
} }
fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<f64, ParamsError> { fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set. // Generate a random residence time given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
@ -170,7 +182,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
} }
} }
fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result<StateType, ParamsError> { fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set. // Generate a random transition given the current state of the node and its parent set.
// The method used is described in: // The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
@ -246,7 +263,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
} }
// Check if each row sum up to 0 // Check if each row sum up to 0
if cim.sum_axis(Axis(2)).iter() if cim
.sum_axis(Axis(2))
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0)
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
@ -257,8 +276,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
return Ok(()); return Ok(());
} }
fn get_label(&self) -> &String { fn get_label(&self) -> &String {
&self.label &self.label
} }
} }

@ -1,12 +1,11 @@
pub mod score_function;
pub mod score_based_algorithm;
pub mod constraint_based_algorithm; pub mod constraint_based_algorithm;
pub mod hypothesis_test; pub mod hypothesis_test;
use crate::network; pub mod score_based_algorithm;
use crate::tools; pub mod score_function;
use crate::{network, tools};
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn fit_transform<T, >(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network; T: network::Network;
} }

@ -1,5 +1,3 @@
//pub struct CTPC { //pub struct CTPC {
// //
//} //}

@ -1,11 +1,10 @@
use ndarray::Array3; use std::collections::BTreeSet;
use ndarray::Axis;
use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF}; use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::network;
use crate::parameter_learning;
use crate::params::*; use crate::params::*;
use std::collections::BTreeSet; use crate::{network, parameter_learning};
pub trait HypothesisTest { pub trait HypothesisTest {
fn call<T, P>( fn call<T, P>(
@ -110,15 +109,15 @@ impl HypothesisTest for ChiSquare {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn // di dimensione nxn
// (CIM, M, T) // (CIM, M, T)
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){ let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node Params::DiscreteStatesContinousTime(node) => node,
}; };
// //
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){ let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node Params::DiscreteStatesContinousTime(node) => node,
}; };
// Commentare qui // Commentare qui
let partial_cardinality_product: usize = extended_separation_set let partial_cardinality_product: usize = extended_separation_set
@ -132,7 +131,12 @@ impl HypothesisTest for ChiSquare {
/ (partial_cardinality_product / (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent())) * net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product; * partial_cardinality_product;
if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) { if !self.compare_matrices(
idx_M_small,
P_small.get_transitions().as_ref().unwrap(),
idx_M_big,
P_big.get_transitions().as_ref().unwrap(),
) {
return false; return false;
} }
} }

@ -1,8 +1,8 @@
use crate::network; use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools; use crate::{network, tools};
use std::collections::BTreeSet;
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,

@ -1,10 +1,9 @@
use crate::network; use std::collections::BTreeSet;
use crate::parameter_learning;
use crate::params;
use crate::tools;
use ndarray::prelude::*; use ndarray::prelude::*;
use statrs::function::gamma; use statrs::function::gamma;
use std::collections::BTreeSet;
use crate::{network, parameter_learning, params, tools};
pub trait ScoreFunction { pub trait ScoreFunction {
fn call<T>( fn call<T>(
@ -25,7 +24,6 @@ pub struct LogLikelihood {
impl LogLikelihood { impl LogLikelihood {
pub fn new(alpha: usize, tau: f64) -> LogLikelihood { pub fn new(alpha: usize, tau: f64) -> LogLikelihood {
//Tau must be >=0.0 //Tau must be >=0.0
if tau < 0.0 { if tau < 0.0 {
panic!("tau must be >=0.0"); panic!("tau must be >=0.0");
@ -42,9 +40,9 @@ impl LogLikelihood {
) -> (f64, Array3<usize>) ) -> (f64, Array3<usize>)
where where
T: network::Network, T: network::Network,
{ {
//Identify the type of node used //Identify the type of node used
match &net.get_node(node){ match &net.get_node(node) {
params::Params::DiscreteStatesContinousTime(_params) => { params::Params::DiscreteStatesContinousTime(_params) => {
//Compute the sufficient statistics M (number of transistions) and T (residence //Compute the sufficient statistics M (number of transistions) and T (residence
//time) //time)
@ -55,35 +53,40 @@ impl LogLikelihood {
let alpha = self.alpha as f64 / M.shape()[0] as f64; let alpha = self.alpha as f64 / M.shape()[0] as f64;
//Scale tau accordingly to the size of the parent set //Scale tau accordingly to the size of the parent set
let tau = self.tau / M.shape()[0] as f64; let tau = self.tau / M.shape()[0] as f64;
//Compute the log likelihood for q //Compute the log likelihood for q
let log_ll_q:f64 = M let log_ll_q: f64 = M
.sum_axis(Axis(2)) .sum_axis(Axis(2))
.iter() .iter()
.zip(T.iter()) .zip(T.iter())
.map(|(m, t)| { .map(|(m, t)| {
gamma::ln_gamma(alpha + *m as f64 + 1.0) gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau)
+ (alpha + 1.0) * f64::ln(tau)
- gamma::ln_gamma(alpha + 1.0) - gamma::ln_gamma(alpha + 1.0)
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t) - (alpha + *m as f64 + 1.0) * f64::ln(tau + t)
}) })
.sum(); .sum();
//Compute the log likelihood for theta //Compute the log likelihood for theta
let log_ll_theta: f64 = M.outer_iter() let log_ll_theta: f64 = M
.map(|x| x.outer_iter() .outer_iter()
.map(|y| gamma::ln_gamma(alpha) .map(|x| {
- gamma::ln_gamma(alpha + y.sum() as f64) x.outer_iter()
+ y.iter().map(|z| .map(|y| {
gamma::ln_gamma(alpha + *z as f64) gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64)
- gamma::ln_gamma(alpha)).sum::<f64>()).sum::<f64>()).sum(); + y.iter()
.map(|z| {
gamma::ln_gamma(alpha + *z as f64)
- gamma::ln_gamma(alpha)
})
.sum::<f64>()
})
.sum::<f64>()
})
.sum();
(log_ll_theta + log_ll_q, M) (log_ll_theta + log_ll_q, M)
} }
} }
} }
} }
impl ScoreFunction for LogLikelihood { impl ScoreFunction for LogLikelihood {
@ -102,13 +105,13 @@ impl ScoreFunction for LogLikelihood {
} }
pub struct BIC { pub struct BIC {
ll: LogLikelihood ll: LogLikelihood,
} }
impl BIC { impl BIC {
pub fn new(alpha: usize, tau: f64) -> BIC { pub fn new(alpha: usize, tau: f64) -> BIC {
BIC { BIC {
ll: LogLikelihood::new(alpha, tau) ll: LogLikelihood::new(alpha, tau),
} }
} }
} }
@ -122,14 +125,19 @@ impl ScoreFunction for BIC {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network { T: network::Network,
{
//Compute the log-likelihood //Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
//Compute the number of parameters //Compute the number of parameters
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
//TODO: Optimize this //TODO: Optimize this
//Compute the sample size //Compute the sample size
let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); let sample_size: usize = dataset
.get_trajectories()
.iter()
.map(|x| x.get_time().len() - 1)
.sum();
//Compute BIC //Compute BIC
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
} }

@ -1,10 +1,10 @@
use crate::network;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*; use ndarray::prelude::*;
use rand_chacha::rand_core::SeedableRng; use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait;
use crate::{network, params};
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,
events: Array2<usize>, events: Array2<usize>,
@ -19,7 +19,7 @@ impl Trajectory {
} }
Trajectory { time, events } Trajectory { time, events }
} }
pub fn get_time(&self) -> &Array1<f64> { pub fn get_time(&self) -> &Array1<f64> {
&self.time &self.time
} }
@ -35,7 +35,6 @@ pub struct Dataset {
impl Dataset { impl Dataset {
pub fn new(trajectories: Vec<Trajectory>) -> Dataset { pub fn new(trajectories: Vec<Trajectory>) -> Dataset {
//All the trajectories in the same dataset must represent the same process. For this reason //All the trajectories in the same dataset must represent the same process. For this reason
//each trajectory must represent the same number of variables. //each trajectory must represent the same number of variables.
if trajectories if trajectories
@ -58,18 +57,17 @@ pub fn trajectory_generator<T: network::Network>(
t_end: f64, t_end: f64,
seed: Option<u64>, seed: Option<u64>,
) -> Dataset { ) -> Dataset {
//Tmp growing vector containing generated trajectories. //Tmp growing vector containing generated trajectories.
let mut trajectories: Vec<Trajectory> = Vec::new(); let mut trajectories: Vec<Trajectory> = Vec::new();
//Random Generator object //Random Generator object
let mut rng: ChaCha8Rng = match seed { let mut rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator. //If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed), Some(seed) => SeedableRng::seed_from_u64(seed),
//Otherwise create a new random generator using the method `from_entropy` //Otherwise create a new random generator using the method `from_entropy`
None => SeedableRng::from_entropy() None => SeedableRng::from_entropy(),
}; };
//Each iteration generate one trajectory //Each iteration generate one trajectory
for _ in 0..n_trajectories { for _ in 0..n_trajectories {
//Current time of the sampling process //Current time of the sampling process
@ -78,15 +76,16 @@ pub fn trajectory_generator<T: network::Network>(
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut current_state: Vec<params::StateType> = net.get_node_indices() let mut current_state: Vec<params::StateType> = net
.get_node_indices()
.map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) .map(|x| net.get_node(x).get_random_state_uniform(&mut rng))
.collect(); .collect();
//History of all the configurations of the process variables. //History of all the configurations of the process variables.
let mut events: Vec<Array1<usize>> = Vec::new(); let mut events: Vec<Array1<usize>> = Vec::new();
//Vector containing to time to the next transition for each variable. //Vector containing to time to the next transition for each variable.
let mut next_transitions: Vec<Option<f64>> = let mut next_transitions: Vec<Option<f64>> =
net.get_node_indices().map(|_| Option::None).collect(); net.get_node_indices().map(|_| Option::None).collect();
//Add the starting time for the trajectory. //Add the starting time for the trajectory.
time.push(t.clone()); time.push(t.clone());
//Add the starting configuration of the trajectory. //Add the starting configuration of the trajectory.
@ -115,7 +114,7 @@ pub fn trajectory_generator<T: network::Network>(
); );
} }
} }
//Get the variable with the smallest transition time. //Get the variable with the smallest transition time.
let next_node_transition = next_transitions let next_node_transition = next_transitions
.iter() .iter()
@ -131,7 +130,7 @@ pub fn trajectory_generator<T: network::Network>(
t = next_transitions[next_node_transition].unwrap().clone(); t = next_transitions[next_node_transition].unwrap().clone();
//Add the transition time to next //Add the transition time to next
time.push(t.clone()); time.push(t.clone());
//Compute the new state of the transitioning variable. //Compute the new state of the transitioning variable.
current_state[next_node_transition] = net current_state[next_node_transition] = net
.get_node(next_node_transition) .get_node(next_node_transition)
@ -142,7 +141,7 @@ pub fn trajectory_generator<T: network::Network>(
&mut rng, &mut rng,
) )
.unwrap(); .unwrap();
//Add the new state to events //Add the new state to events
events.push(Array::from_vec( events.push(Array::from_vec(
current_state current_state
@ -160,7 +159,7 @@ pub fn trajectory_generator<T: network::Network>(
next_transitions[child] = None next_transitions[child] = None
} }
} }
//Add current_state as last state. //Add current_state as last state.
events.push( events.push(
current_state current_state
@ -172,7 +171,7 @@ pub fn trajectory_generator<T: network::Network>(
); );
//Add t_end as last time. //Add t_end as last time.
time.push(t_end.clone()); time.push(t_end.clone());
//Add the sampled trajectory to trajectories. //Add the sampled trajectory to trajectories.
trajectories.push(Trajectory::new( trajectories.push(Trajectory::new(
Array::from_vec(time), Array::from_vec(time),

@ -1,8 +1,9 @@
mod utils; mod utils;
use std::collections::BTreeSet;
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::params::{self, ParamsTrait}; use reCTBN::params::{self, ParamsTrait};
use std::collections::BTreeSet;
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
#[test] #[test]

@ -1,13 +1,13 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
mod utils; mod utils;
use utils::*;
use ndarray::arr3; use ndarray::arr3;
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::parameter_learning::*; use reCTBN::parameter_learning::*;
use reCTBN::{params, tools::*}; use reCTBN::params;
use reCTBN::tools::*;
use utils::*;
extern crate approx; extern crate approx;
@ -41,7 +41,7 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None) { let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
@ -99,8 +99,8 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None){ let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
@ -162,8 +162,8 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 0, None){ let p = match pl.fit(&net, &data, 0, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
@ -291,8 +291,8 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
} }
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 2, None){ let p = match pl.fit(&net, &data, 2, None) {
params::Params::DiscreteStatesContinousTime(p) => p params::Params::DiscreteStatesContinousTime(p) => p,
}; };
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(

@ -1,5 +1,6 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
use reCTBN::params::{ParamsTrait, *}; use reCTBN::params::{ParamsTrait, *};
mod utils; mod utils;

@ -1,17 +1,18 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
mod utils; mod utils;
use utils::*; use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::params; use reCTBN::params;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm};
use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::score_based_algorithm::*;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::StructureLearningAlgorithm;
use reCTBN::tools::*; use reCTBN::tools::*;
use std::collections::BTreeSet; use utils::*;
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
@ -320,73 +321,43 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
} }
#[test] #[test]
pub fn chi_square_compare_matrices () { pub fn chi_square_compare_matrices() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[ 0, 2, 3], [[0, 2, 3], [4, 0, 6], [7, 8, 0]],
[ 4, 0, 6], [[0, 12, 90], [3, 0, 40], [6, 40, 0]],
[ 7, 8, 0]], [[0, 2, 3], [4, 0, 6], [44, 66, 0]],
[[0, 12, 90],
[ 3, 0, 40],
[ 6, 40, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[ let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]);
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices( i, &M1, j, &M2)); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
} }
#[test] #[test]
pub fn chi_square_compare_matrices_2 () { pub fn chi_square_compare_matrices_2() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[ 0, 2, 3], [[0, 2, 3], [4, 0, 6], [7, 8, 0]],
[ 4, 0, 6], [[0, 20, 30], [40, 0, 60], [70, 80, 0]],
[ 7, 8, 0]], [[0, 2, 3], [4, 0, 6], [44, 66, 0]],
[[0, 20, 30],
[ 40, 0, 60],
[ 70, 80, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[ let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]);
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
#[test] #[test]
pub fn chi_square_compare_matrices_3 () { pub fn chi_square_compare_matrices_3() {
let i: usize = 1; let i: usize = 1;
let M1 = arr3(&[ let M1 = arr3(&[
[[ 0, 2, 3], [[0, 2, 3], [4, 0, 6], [7, 8, 0]],
[ 4, 0, 6], [[0, 21, 31], [41, 0, 59], [71, 79, 0]],
[ 7, 8, 0]], [[0, 2, 3], [4, 0, 6], [44, 66, 0]],
[[0, 21, 31],
[ 41, 0, 59],
[ 71, 79, 0]],
[[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]]
]); ]);
let j: usize = 0; let j: usize = 0;
let M2 = arr3(&[ let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]);
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }

@ -1,12 +1,19 @@
use reCTBN::params;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use reCTBN::params;
#[allow(dead_code)] #[allow(dead_code)]
pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params {
params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality)) params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(
label,
cardinality,
))
} }
pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ pub fn generate_discrete_time_continous_params(
label: String,
cardinality: usize,
) -> params::DiscreteStatesContinousTimeParams {
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::new(label, domain) params::DiscreteStatesContinousTimeParams::new(label, domain)
} }

Loading…
Cancel
Save