diff --git a/src/ctbn.rs b/src/ctbn.rs index 2cede4a..e2f5dd7 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -1,10 +1,9 @@ -use ndarray::prelude::*; -use crate::params::{StateType, Params, ParamsTrait}; -use crate::network; use std::collections::BTreeSet; +use ndarray::prelude::*; - +use crate::network; +use crate::params::{Params, ParamsTrait, StateType}; ///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is ///composed by the following elements: @@ -22,12 +21,12 @@ use std::collections::BTreeSet; /// use reCTBN::ctbn::*; /// /// //Create the domain for a discrete node -/// let mut domain = BTreeSet::new(); +/// let mut domain = BTreeSet::new(); /// domain.insert(String::from("A")); /// domain.insert(String::from("B")); /// /// //Create the parameters for a discrete node using the domain -/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); +/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); /// /// //Create the node using the parameters /// let X1 = params::Params::DiscreteStatesContinousTime(param); @@ -37,14 +36,14 @@ use std::collections::BTreeSet; /// domain.insert(String::from("B")); /// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); /// let X2 = params::Params::DiscreteStatesContinousTime(param); -/// +/// /// //Initialize a ctbn /// let mut net = CtbnNetwork::new(); /// /// //Add nodes /// let X1 = net.add_node(X1).unwrap(); /// let X2 = net.add_node(X2).unwrap(); -/// +/// /// //Add an edge /// net.add_edge(X1, X2); /// @@ -54,30 +53,30 @@ use std::collections::BTreeSet; /// ``` pub struct CtbnNetwork { adj_matrix: Option>, - nodes: Vec + nodes: Vec, } - impl CtbnNetwork { pub fn new() -> CtbnNetwork { CtbnNetwork { adj_matrix: None, - nodes: Vec::new() + nodes: Vec::new(), } } } impl network::Network for CtbnNetwork { fn initialize_adj_matrix(&mut self) { - self.adj_matrix = Some(Array2::::zeros((self.nodes.len(), self.nodes.len()).f())); - + self.adj_matrix = Some(Array2::::zeros( + (self.nodes.len(), self.nodes.len()).f(), + )); } - fn add_node(&mut self, mut n: Params) -> Result { + fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); - Ok(self.nodes.len() -1) + Ok(self.nodes.len() - 1) } fn add_edge(&mut self, parent: usize, child: usize) { @@ -91,7 +90,7 @@ impl network::Network for CtbnNetwork { } } - fn get_node_indices(&self) -> std::ops::Range{ + fn get_node_indices(&self) -> std::ops::Range { 0..self.nodes.len() } @@ -99,64 +98,65 @@ impl network::Network for CtbnNetwork { self.nodes.len() } - fn get_node(&self, node_idx: usize) -> &Params{ + fn get_node(&self, node_idx: usize) -> &Params { &self.nodes[node_idx] } - - fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{ + fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { &mut self.nodes[node_idx] } - - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ - self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { - if x.1 > &0 { - acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; - acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); - } - acc - }).0 + fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize { + self.adj_matrix + .as_ref() + .unwrap() + .column(node) + .iter() + .enumerate() + .fold((0, 1), |mut acc, x| { + if x.1 > &0 { + acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; + acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); + } + acc + }) + .0 } - - fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize { - parent_set.iter().fold((0, 1), |mut acc, x| { - acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; - acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); - acc - }).0 + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &BTreeSet, + ) -> usize { + parent_set + .iter() + .fold((0, 1), |mut acc, x| { + acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; + acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); + acc + }) + .0 } fn get_parent_set(&self, node: usize) -> BTreeSet { - self.adj_matrix.as_ref() + self.adj_matrix + .as_ref() .unwrap() .column(node) .iter() .enumerate() - .filter_map(|(idx, x)| { - if x > &0 { - Some(idx) - } else { - None - } - }).collect() + .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) + .collect() } - fn get_children_set(&self, node: usize) -> BTreeSet{ - self.adj_matrix.as_ref() + fn get_children_set(&self, node: usize) -> BTreeSet { + self.adj_matrix + .as_ref() .unwrap() .row(node) .iter() .enumerate() - .filter_map(|(idx, x)| { - if x > &0 { - Some(idx) - } else { - None - } - }).collect() + .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) + .collect() } - } - diff --git a/src/lib.rs b/src/lib.rs index 1dcc637..8c57af2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,9 @@ #[cfg(test)] extern crate approx; -pub mod params; -pub mod network; pub mod ctbn; -pub mod tools; +pub mod network; pub mod parameter_learning; +pub mod params; pub mod structure_learning; - +pub mod tools; diff --git a/src/network.rs b/src/network.rs index 1c962b0..cbae339 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,20 +1,21 @@ +use std::collections::BTreeSet; + use thiserror::Error; + use crate::params; -use std::collections::BTreeSet; /// Error types for trait Network #[derive(Error, Debug)] pub enum NetworkError { #[error("Error during node insertion")] - NodeInsertionError(String) + NodeInsertionError(String), } - ///Network ///The Network trait define the required methods for a structure used as pgm (such as ctbn). pub trait Network { fn initialize_adj_matrix(&mut self); - fn add_node(&mut self, n: params::Params) -> Result; + fn add_node(&mut self, n: params::Params) -> Result; fn add_edge(&mut self, parent: usize, child: usize); ///Get all the indices of the nodes contained inside the network @@ -26,13 +27,17 @@ pub trait Network { ///Compute the index that must be used to access the parameters of a node given a specific ///configuration of the network. Usually, the only values really used in *current_state* are ///the ones in the parent set of the *node*. - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; - + fn get_param_index_network(&self, node: usize, current_state: &Vec) + -> usize; ///Compute the index that must be used to access the parameters of a node given a specific ///configuration of the network and a generic parent_set. Usually, the only values really used ///in *current_state* are the ones in the parent set of the *node*. - fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize; + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &BTreeSet, + ) -> usize; fn get_parent_set(&self, node: usize) -> BTreeSet; fn get_children_set(&self, node: usize) -> BTreeSet; } diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index ffd1db8..10f0257 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -1,11 +1,12 @@ -use crate::network; -use crate::params::*; -use crate::tools; -use ndarray::prelude::*; use std::collections::BTreeSet; -pub trait ParameterLearning{ - fn fit( +use ndarray::prelude::*; + +use crate::params::*; +use crate::{network, tools}; + +pub trait ParameterLearning { + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -14,24 +15,19 @@ pub trait ParameterLearning{ ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, dataset: &tools::Dataset, node: usize, - parent_set: &BTreeSet - ) -> (Array3, Array2) { + parent_set: &BTreeSet, +) -> (Array3, Array2) { //Get the number of values assumable by the node - let node_domain = net - .get_node(node.clone()) - .get_reserved_space_as_parent(); + let node_domain = net.get_node(node.clone()).get_reserved_space_as_parent(); //Get the number of values assumable by each parent of the node let parentset_domain: Vec = parent_set .iter() - .map(|x| { - net.get_node(x.clone()) - .get_reserved_space_as_parent() - }) + .map(|x| net.get_node(x.clone()).get_reserved_space_as_parent()) .collect(); //Vector used to convert a specific configuration of the parent_set to the corresponding index @@ -45,7 +41,7 @@ pub fn sufficient_statistics( vector_to_idx[*idx] = acc; acc * x }); - + //Number of transition given a specific configuration of the parent set let mut M: Array3 = Array::zeros((parentset_domain.iter().product(), node_domain, node_domain)); @@ -70,13 +66,11 @@ pub fn sufficient_statistics( } return (M, T); - } pub struct MLE {} impl ParameterLearning for MLE { - fn fit( &self, net: &T, @@ -84,19 +78,18 @@ impl ParameterLearning for MLE { node: usize, parent_set: Option>, ) -> Params { - //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { Some(p) => p, None => net.get_parent_set(node), }; - + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - //Compute the CIM as M[i,x,y]/T[i,x] + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); + .for_each(|(mut C, m)| C.assign(&(&m / &T))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); @@ -105,8 +98,6 @@ impl ParameterLearning for MLE { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - - let mut n: Params = net.get_node(node).clone(); @@ -115,8 +106,6 @@ impl ParameterLearning for MLE { dsct.set_cim_unchecked(CIM); dsct.set_transitions(M); dsct.set_residence_time(T); - - } }; return n; @@ -125,7 +114,7 @@ impl ParameterLearning for MLE { pub struct BayesianApproach { pub alpha: usize, - pub tau: f64 + pub tau: f64, } impl ParameterLearning for BayesianApproach { @@ -141,17 +130,17 @@ impl ParameterLearning for BayesianApproach { Some(p) => p, None => net.get_parent_set(node), }; - + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; let tau: f64 = self.tau as f64 / M.shape()[0] as f64; - //Compute the CIM as M[i,x,y]/T[i,x] + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau)))); + .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); @@ -161,8 +150,6 @@ impl ParameterLearning for BayesianApproach { C.diag_mut().assign(&diag); }); - - let mut n: Params = net.get_node(node).clone(); match n { @@ -170,27 +157,25 @@ impl ParameterLearning for BayesianApproach { dsct.set_cim_unchecked(CIM); dsct.set_transitions(M); dsct.set_residence_time(T); - - } }; return n; } } - pub struct Cache { parameter_learning: P, dataset: tools::Dataset, } impl Cache

{ - pub fn fit( + pub fn fit( &mut self, net: &T, node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning.fit(net, &self.dataset, node, parent_set) + self.parameter_learning + .fit(net, &self.dataset, node, parent_set) } } diff --git a/src/params.rs b/src/params.rs index c2768b1..f994b99 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,9 +1,10 @@ +use std::collections::BTreeSet; + use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; -use std::collections::{BTreeSet}; -use thiserror::Error; use rand_chacha::ChaCha8Rng; +use thiserror::Error; /// Error types for trait Params #[derive(Error, Debug, PartialEq)] @@ -35,11 +36,21 @@ pub trait ParamsTrait { /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. - fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; + fn get_random_residence_time( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result; /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. - fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; + fn get_random_state( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result; /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. fn get_reserved_space_as_parent(&self) -> usize; @@ -49,7 +60,7 @@ pub trait ParamsTrait { /// Validate parameters against domain fn validate_params(&self) -> Result<(), ParamsError>; - + /// Return a reference to the associated label fn get_label(&self) -> &String; } @@ -92,17 +103,17 @@ impl DiscreteStatesContinousTimeParams { residence_time: Option::None, } } - + ///Getter function for CIM pub fn get_cim(&self) -> &Option> { &self.cim - } + } ///Setter function for CIM.\\ - ///This function check if the cim is valid using the validate_params method. + ///This function check if the cim is valid using the validate_params method. ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError - pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError>{ + pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError> { self.cim = Some(cim); match self.validate_params() { Ok(()) => Ok(()), @@ -113,7 +124,6 @@ impl DiscreteStatesContinousTimeParams { } } - ///Unchecked version of the setter function for CIM. pub fn set_cim_unchecked(&mut self, cim: Array3) { self.cim = Some(cim); @@ -124,7 +134,6 @@ impl DiscreteStatesContinousTimeParams { &self.transitions } - ///Setter function for transitions pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); @@ -135,12 +144,10 @@ impl DiscreteStatesContinousTimeParams { &self.residence_time } - ///Setter function for residence_time pub fn set_residence_time(&mut self, residence_time: Array2) { self.residence_time = Some(residence_time); } - } impl ParamsTrait for DiscreteStatesContinousTimeParams { @@ -154,7 +161,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } - fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { + fn get_random_residence_time( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result { // Generate a random residence time given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates @@ -170,7 +182,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } } - fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { + fn get_random_state( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result { // Generate a random transition given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution @@ -246,7 +263,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } // Check if each row sum up to 0 - if cim.sum_axis(Axis(2)).iter() + if cim + .sum_axis(Axis(2)) + .iter() .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) { return Err(ParamsError::InvalidCIM(String::from( @@ -257,8 +276,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { return Ok(()); } - fn get_label(&self) -> &String { + fn get_label(&self) -> &String { &self.label } - } diff --git a/src/structure_learning.rs b/src/structure_learning.rs index b7db7ed..8b90cdf 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -1,12 +1,11 @@ -pub mod score_function; -pub mod score_based_algorithm; pub mod constraint_based_algorithm; pub mod hypothesis_test; -use crate::network; -use crate::tools; +pub mod score_based_algorithm; +pub mod score_function; +use crate::{network, tools}; pub trait StructureLearningAlgorithm { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network; } diff --git a/src/structure_learning/constraint_based_algorithm.rs b/src/structure_learning/constraint_based_algorithm.rs index 0d8b655..b3fc3e1 100644 --- a/src/structure_learning/constraint_based_algorithm.rs +++ b/src/structure_learning/constraint_based_algorithm.rs @@ -1,5 +1,3 @@ - //pub struct CTPC { // //} - diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index f8eeb30..5ddcc51 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -1,11 +1,10 @@ -use ndarray::Array3; -use ndarray::Axis; +use std::collections::BTreeSet; + +use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; -use crate::network; -use crate::parameter_learning; use crate::params::*; -use std::collections::BTreeSet; +use crate::{network, parameter_learning}; pub trait HypothesisTest { fn call( @@ -110,15 +109,15 @@ impl HypothesisTest for ChiSquare { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // di dimensione nxn // (CIM, M, T) - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){ - Params::DiscreteStatesContinousTime(node) => node + let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, }; // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){ - Params::DiscreteStatesContinousTime(node) => node + let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, }; // Commentare qui let partial_cardinality_product: usize = extended_separation_set @@ -132,7 +131,12 @@ impl HypothesisTest for ChiSquare { / (partial_cardinality_product * net.get_node(parent_node).get_reserved_space_as_parent())) * partial_cardinality_product; - if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) { + if !self.compare_matrices( + idx_M_small, + P_small.get_transitions().as_ref().unwrap(), + idx_M_big, + P_big.get_transitions().as_ref().unwrap(), + ) { return false; } } diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index fe4e4ff..cc8541a 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -1,8 +1,8 @@ -use crate::network; +use std::collections::BTreeSet; + use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::tools; -use std::collections::BTreeSet; +use crate::{network, tools}; pub struct HillClimbing { score_function: S, diff --git a/src/structure_learning/score_function.rs b/src/structure_learning/score_function.rs index ea53db5..b3b1597 100644 --- a/src/structure_learning/score_function.rs +++ b/src/structure_learning/score_function.rs @@ -1,10 +1,9 @@ -use crate::network; -use crate::parameter_learning; -use crate::params; -use crate::tools; +use std::collections::BTreeSet; + use ndarray::prelude::*; use statrs::function::gamma; -use std::collections::BTreeSet; + +use crate::{network, parameter_learning, params, tools}; pub trait ScoreFunction { fn call( @@ -25,7 +24,6 @@ pub struct LogLikelihood { impl LogLikelihood { pub fn new(alpha: usize, tau: f64) -> LogLikelihood { - //Tau must be >=0.0 if tau < 0.0 { panic!("tau must be >=0.0"); @@ -42,9 +40,9 @@ impl LogLikelihood { ) -> (f64, Array3) where T: network::Network, - { + { //Identify the type of node used - match &net.get_node(node){ + match &net.get_node(node) { params::Params::DiscreteStatesContinousTime(_params) => { //Compute the sufficient statistics M (number of transistions) and T (residence //time) @@ -55,35 +53,40 @@ impl LogLikelihood { let alpha = self.alpha as f64 / M.shape()[0] as f64; //Scale tau accordingly to the size of the parent set let tau = self.tau / M.shape()[0] as f64; - + //Compute the log likelihood for q - let log_ll_q:f64 = M + let log_ll_q: f64 = M .sum_axis(Axis(2)) .iter() .zip(T.iter()) .map(|(m, t)| { - gamma::ln_gamma(alpha + *m as f64 + 1.0) - + (alpha + 1.0) * f64::ln(tau) + gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau) - gamma::ln_gamma(alpha + 1.0) - (alpha + *m as f64 + 1.0) * f64::ln(tau + t) }) .sum(); - + //Compute the log likelihood for theta - let log_ll_theta: f64 = M.outer_iter() - .map(|x| x.outer_iter() - .map(|y| gamma::ln_gamma(alpha) - - gamma::ln_gamma(alpha + y.sum() as f64) - + y.iter().map(|z| - gamma::ln_gamma(alpha + *z as f64) - - gamma::ln_gamma(alpha)).sum::()).sum::()).sum(); + let log_ll_theta: f64 = M + .outer_iter() + .map(|x| { + x.outer_iter() + .map(|y| { + gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64) + + y.iter() + .map(|z| { + gamma::ln_gamma(alpha + *z as f64) + - gamma::ln_gamma(alpha) + }) + .sum::() + }) + .sum::() + }) + .sum(); (log_ll_theta + log_ll_q, M) } } } - - - } impl ScoreFunction for LogLikelihood { @@ -102,13 +105,13 @@ impl ScoreFunction for LogLikelihood { } pub struct BIC { - ll: LogLikelihood + ll: LogLikelihood, } impl BIC { pub fn new(alpha: usize, tau: f64) -> BIC { BIC { - ll: LogLikelihood::new(alpha, tau) + ll: LogLikelihood::new(alpha, tau), } } } @@ -122,14 +125,19 @@ impl ScoreFunction for BIC { dataset: &tools::Dataset, ) -> f64 where - T: network::Network { + T: network::Network, + { //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); //Compute the number of parameters let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); //TODO: Optimize this //Compute the sample size - let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); + let sample_size: usize = dataset + .get_trajectories() + .iter() + .map(|x| x.get_time().len() - 1) + .sum(); //Compute BIC ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 } diff --git a/src/tools.rs b/src/tools.rs index 115fd67..448b26f 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,10 +1,10 @@ -use crate::network; -use crate::params; -use crate::params::ParamsTrait; use ndarray::prelude::*; use rand_chacha::rand_core::SeedableRng; use rand_chacha::ChaCha8Rng; +use crate::params::ParamsTrait; +use crate::{network, params}; + pub struct Trajectory { time: Array1, events: Array2, @@ -19,7 +19,7 @@ impl Trajectory { } Trajectory { time, events } } - + pub fn get_time(&self) -> &Array1 { &self.time } @@ -35,7 +35,6 @@ pub struct Dataset { impl Dataset { pub fn new(trajectories: Vec) -> Dataset { - //All the trajectories in the same dataset must represent the same process. For this reason //each trajectory must represent the same number of variables. if trajectories @@ -58,18 +57,17 @@ pub fn trajectory_generator( t_end: f64, seed: Option, ) -> Dataset { - //Tmp growing vector containing generated trajectories. let mut trajectories: Vec = Vec::new(); - + //Random Generator object let mut rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. Some(seed) => SeedableRng::seed_from_u64(seed), //Otherwise create a new random generator using the method `from_entropy` - None => SeedableRng::from_entropy() + None => SeedableRng::from_entropy(), }; - + //Each iteration generate one trajectory for _ in 0..n_trajectories { //Current time of the sampling process @@ -78,15 +76,16 @@ pub fn trajectory_generator( let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut current_state: Vec = net.get_node_indices() + let mut current_state: Vec = net + .get_node_indices() .map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) .collect(); - //History of all the configurations of the process variables. + //History of all the configurations of the process variables. let mut events: Vec> = Vec::new(); //Vector containing to time to the next transition for each variable. let mut next_transitions: Vec> = net.get_node_indices().map(|_| Option::None).collect(); - + //Add the starting time for the trajectory. time.push(t.clone()); //Add the starting configuration of the trajectory. @@ -115,7 +114,7 @@ pub fn trajectory_generator( ); } } - + //Get the variable with the smallest transition time. let next_node_transition = next_transitions .iter() @@ -131,7 +130,7 @@ pub fn trajectory_generator( t = next_transitions[next_node_transition].unwrap().clone(); //Add the transition time to next time.push(t.clone()); - + //Compute the new state of the transitioning variable. current_state[next_node_transition] = net .get_node(next_node_transition) @@ -142,7 +141,7 @@ pub fn trajectory_generator( &mut rng, ) .unwrap(); - + //Add the new state to events events.push(Array::from_vec( current_state @@ -160,7 +159,7 @@ pub fn trajectory_generator( next_transitions[child] = None } } - + //Add current_state as last state. events.push( current_state @@ -172,7 +171,7 @@ pub fn trajectory_generator( ); //Add t_end as last time. time.push(t_end.clone()); - + //Add the sampled trajectory to trajectories. trajectories.push(Trajectory::new( Array::from_vec(time), diff --git a/tests/ctbn.rs b/tests/ctbn.rs index e5cad1e..63c9621 100644 --- a/tests/ctbn.rs +++ b/tests/ctbn.rs @@ -1,8 +1,9 @@ mod utils; +use std::collections::BTreeSet; + use reCTBN::ctbn::*; use reCTBN::network::Network; use reCTBN::params::{self, ParamsTrait}; -use std::collections::BTreeSet; use utils::generate_discrete_time_continous_node; #[test] diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index b624e94..0409402 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -1,13 +1,13 @@ #![allow(non_snake_case)] mod utils; -use utils::*; - use ndarray::arr3; use reCTBN::ctbn::*; use reCTBN::network::Network; use reCTBN::parameter_learning::*; -use reCTBN::{params, tools::*}; +use reCTBN::params; +use reCTBN::tools::*; +use utils::*; extern crate approx; @@ -41,7 +41,7 @@ fn learn_binary_cim(pl: T) { let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); let p = match pl.fit(&net, &data, 1, None) { - params::Params::DiscreteStatesContinousTime(p) => p + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( @@ -99,8 +99,8 @@ fn learn_ternary_cim(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let p = match pl.fit(&net, &data, 1, None){ - params::Params::DiscreteStatesContinousTime(p) => p + let p = match pl.fit(&net, &data, 1, None) { + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( @@ -162,8 +162,8 @@ fn learn_ternary_cim_no_parents(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let p = match pl.fit(&net, &data, 0, None){ - params::Params::DiscreteStatesContinousTime(p) => p + let p = match pl.fit(&net, &data, 0, None) { + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( @@ -291,8 +291,8 @@ fn learn_mixed_discrete_cim(pl: T) { } let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); - let p = match pl.fit(&net, &data, 2, None){ - params::Params::DiscreteStatesContinousTime(p) => p + let p = match pl.fit(&net, &data, 2, None) { + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( diff --git a/tests/params.rs b/tests/params.rs index e07121c..7f16f12 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,5 +1,6 @@ use ndarray::prelude::*; -use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; use reCTBN::params::{ParamsTrait, *}; mod utils; diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 790a4b6..ee5109e 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -1,17 +1,18 @@ #![allow(non_snake_case)] mod utils; -use utils::*; +use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; use reCTBN::ctbn::*; use reCTBN::network::Network; use reCTBN::params; -use reCTBN::structure_learning::score_function::*; -use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm}; use reCTBN::structure_learning::hypothesis_test::*; +use reCTBN::structure_learning::score_based_algorithm::*; +use reCTBN::structure_learning::score_function::*; +use reCTBN::structure_learning::StructureLearningAlgorithm; use reCTBN::tools::*; -use std::collections::BTreeSet; +use utils::*; #[macro_use] extern crate approx; @@ -320,73 +321,43 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() } #[test] -pub fn chi_square_compare_matrices () { +pub fn chi_square_compare_matrices() { let i: usize = 1; let M1 = arr3(&[ - [[ 0, 2, 3], - [ 4, 0, 6], - [ 7, 8, 0]], - [[0, 12, 90], - [ 3, 0, 40], - [ 6, 40, 0]], - [[ 0, 2, 3], - [ 4, 0, 6], - [ 44, 66, 0]] + [[0, 2, 3], [4, 0, 6], [7, 8, 0]], + [[0, 12, 90], [3, 0, 40], [6, 40, 0]], + [[0, 2, 3], [4, 0, 6], [44, 66, 0]], ]); let j: usize = 0; - let M2 = arr3(&[ - [[ 0, 200, 300], - [ 400, 0, 600], - [ 700, 800, 0]] - ]); + let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let chi_sq = ChiSquare::new(0.1); - assert!(!chi_sq.compare_matrices( i, &M1, j, &M2)); + assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); } #[test] -pub fn chi_square_compare_matrices_2 () { +pub fn chi_square_compare_matrices_2() { let i: usize = 1; let M1 = arr3(&[ - [[ 0, 2, 3], - [ 4, 0, 6], - [ 7, 8, 0]], - [[0, 20, 30], - [ 40, 0, 60], - [ 70, 80, 0]], - [[ 0, 2, 3], - [ 4, 0, 6], - [ 44, 66, 0]] + [[0, 2, 3], [4, 0, 6], [7, 8, 0]], + [[0, 20, 30], [40, 0, 60], [70, 80, 0]], + [[0, 2, 3], [4, 0, 6], [44, 66, 0]], ]); let j: usize = 0; - let M2 = arr3(&[ - [[ 0, 200, 300], - [ 400, 0, 600], - [ 700, 800, 0]] - ]); + let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let chi_sq = ChiSquare::new(0.1); - assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); + assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } #[test] -pub fn chi_square_compare_matrices_3 () { +pub fn chi_square_compare_matrices_3() { let i: usize = 1; let M1 = arr3(&[ - [[ 0, 2, 3], - [ 4, 0, 6], - [ 7, 8, 0]], - [[0, 21, 31], - [ 41, 0, 59], - [ 71, 79, 0]], - [[ 0, 2, 3], - [ 4, 0, 6], - [ 44, 66, 0]] + [[0, 2, 3], [4, 0, 6], [7, 8, 0]], + [[0, 21, 31], [41, 0, 59], [71, 79, 0]], + [[0, 2, 3], [4, 0, 6], [44, 66, 0]], ]); let j: usize = 0; - let M2 = arr3(&[ - [[ 0, 200, 300], - [ 400, 0, 600], - [ 700, 800, 0]] - ]); + let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let chi_sq = ChiSquare::new(0.1); - assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); + assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } diff --git a/tests/utils.rs b/tests/utils.rs index 1449b1d..ed43215 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -1,12 +1,19 @@ -use reCTBN::params; use std::collections::BTreeSet; +use reCTBN::params; + #[allow(dead_code)] pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { - params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality)) + params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params( + label, + cardinality, + )) } -pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ +pub fn generate_discrete_time_continous_params( + label: String, + cardinality: usize, +) -> params::DiscreteStatesContinousTimeParams { let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); params::DiscreteStatesContinousTimeParams::new(label, domain) }