From dc8013d6352b2498e7112a7ad11c2a111cd30834 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 23 Mar 2022 16:28:17 +0100 Subject: [PATCH 001/126] Added method validate_params to ParamsTrait. --- src/params.rs | 54 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/params.rs b/src/params.rs index c5a9acf..3efdc6f 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,16 +1,18 @@ +use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; -use enum_dispatch::enum_dispatch; /// Error types for trait Params -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum ParamsError { #[error("Unsupported method")] UnsupportedMethod(String), #[error("Paramiters not initialized")] ParametersNotInitialized(String), + #[error("Invalid cim for parameter")] + InvalidCIM(String), } /// Allowed type of states @@ -43,6 +45,9 @@ pub trait ParamsTrait { /// Index used by discrete node to represents their states as usize. fn state_to_index(&self, state: &StateType) -> usize; + + /// Validate parameters against domain + fn validate_params(&self) -> Result<(), ParamsError>; } /// The Params enum is the core element for building different types of nodes. The goal is to @@ -52,7 +57,6 @@ pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), } - /// DiscreteStatesContinousTime. /// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// following elements: @@ -157,5 +161,47 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(val) => val.clone() as usize, } } -} + fn validate_params(&self) -> Result<(), ParamsError> { + let domain_size = self.domain.len(); + + // Check if the cim is initialized + if let None = self.cim { + return Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))); + } + let cim = self.cim.as_ref().unwrap(); + // Check if the inner dimensions of the cim are equal to the cardinality of the variable + if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { + return Err(ParamsError::InvalidCIM(format!( + "Incompatible shape {:?} with domain {:?}", + cim.shape(), + domain_size + ))); + } + + // Check if the diagonal of each cim is non-positive + if cim + .axis_iter(Axis(0)) + .any(|x| x.diag().iter().any(|x| x >= &0.0)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))); + } + + // Check if each row sum up to 0 + let zeros = Array::zeros(domain_size); + if cim + .axis_iter(Axis(0)) + .any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0", + ))); + } + + return Ok(()); + } +} From 331c2006e9633c1b1e6630eded47ae9515f0c01e Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 23 Mar 2022 16:29:02 +0100 Subject: [PATCH 002/126] Tested validate_params implemetation for DiscreteStateContinousTimeParams --- tests/params.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/tests/params.rs b/tests/params.rs index ed601b2..6901293 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,5 +1,5 @@ -use rustyCTBN::params::*; use ndarray::prelude::*; +use rustyCTBN::params::*; use std::collections::BTreeSet; mod utils; @@ -7,11 +7,10 @@ mod utils; #[macro_use] extern crate approx; - fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { let mut params = utils::generate_discrete_time_continous_param(3); - let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; params.cim = Some(cim); params @@ -62,3 +61,63 @@ fn test_random_generation_residence_time() { assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } + +#[test] +fn test_validate_params_valid_cim() { + let param = create_ternary_discrete_time_continous_param(); + + assert_eq!(Ok(()), param.validate_params()); +} + +#[test] +fn test_validate_params_cim_not_initialized() { + let param = utils::generate_discrete_time_continous_param(3); + assert_eq!( + Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_wrong_shape() { + let mut param = utils::generate_discrete_time_continous_param(4); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "Incompatible shape [1, 3, 3] with domain 4" + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_positive_diag() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))), + param.validate_params() + ); +} + + +#[test] +fn test_validate_params_row_not_sum_to_zero() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; + param.cim = Some(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0" + ))), + param.validate_params() + ); +} From c178862664e0aa216f28da010573b31ca617893d Mon Sep 17 00:00:00 2001 From: Alessandro Bregoli Date: Sat, 26 Mar 2022 10:33:30 +0100 Subject: [PATCH 003/126] Enforced correct set of parameters (cim) when inserted manually --- src/params.rs | 50 ++++++++++++++++++++++++++++++------- tests/parameter_learning.rs | 38 ++++++++++++++-------------- tests/params.rs | 14 +++++------ tests/tools.rs | 4 +-- tests/utils.rs | 2 +- 5 files changed, 70 insertions(+), 38 deletions(-) diff --git a/src/params.rs b/src/params.rs index 3efdc6f..154fd12 100644 --- a/src/params.rs +++ b/src/params.rs @@ -69,21 +69,55 @@ pub enum Params { /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set pub struct DiscreteStatesContinousTimeParams { - pub domain: BTreeSet, - pub cim: Option>, - pub transitions: Option>, - pub residence_time: Option>, + domain: BTreeSet, + cim: Option>, + transitions: Option>, + residence_time: Option>, } impl DiscreteStatesContinousTimeParams { pub fn init(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams { - domain: domain, + domain, cim: Option::None, transitions: Option::None, residence_time: Option::None, } } + + pub fn get_cim(&self) -> &Option> { + &self.cim + } + + pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError>{ + self.cim = Some(cim); + match self.validate_params() { + Ok(()) => Ok(()), + Err(e) => { + self.cim = None; + Err(e) + } + } + } + + pub fn get_transitions(&self) -> &Option> { + &self.transitions + } + + + pub fn set_transitions(&mut self, transitions: Array3) { + self.transitions = Some(transitions); + } + + pub fn get_residence_time(&self) -> &Option> { + &self.residence_time + } + + + pub fn set_residence_time(&mut self, residence_time: Array2) { + self.residence_time = Some(residence_time); + } + } impl ParamsTrait for DiscreteStatesContinousTimeParams { @@ -192,10 +226,8 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } // Check if each row sum up to 0 - let zeros = Array::zeros(domain_size); - if cim - .axis_iter(Axis(0)) - .any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE)) + if cim.sum_axis(Axis(2)).iter() + .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index a5cca51..d6b8fd2 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -27,16 +27,16 @@ fn learn_binary_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ])); + ]))); } } @@ -77,19 +77,19 @@ fn learn_ternary_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } @@ -132,19 +132,19 @@ fn learn_ternary_cim_no_parents (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } @@ -192,28 +192,28 @@ fn learn_mixed_discrete_cim (pl: T) { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); + [0.4, 0.6, -1.0]]]))); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); + ]))); } } match &mut net.get_node_mut(n3).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ + assert_eq!(Ok(()), param.set_cim(arr3(&[ [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], + [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], @@ -223,12 +223,12 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ])); + ]))); } } - let data = trajectory_generator(&net, 300, 200.0); + let data = trajectory_generator(&net, 300, 300.0); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/params.rs b/tests/params.rs index 6901293..9bc9b49 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -12,7 +12,7 @@ fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTime let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; - params.cim = Some(cim); + params.set_cim(cim); params } @@ -85,12 +85,12 @@ fn test_validate_params_cim_not_initialized() { fn test_validate_params_wrong_shape() { let mut param = utils::generate_discrete_time_continous_param(4); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; - param.cim = Some(cim); + let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( "Incompatible shape [1, 3, 3] with domain 4" ))), - param.validate_params() + result ); } @@ -99,12 +99,12 @@ fn test_validate_params_wrong_shape() { fn test_validate_params_positive_diag() { let mut param = utils::generate_discrete_time_continous_param(3); let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; - param.cim = Some(cim); + let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( "The diagonal of each cim must be non-positive", ))), - param.validate_params() + result ); } @@ -113,11 +113,11 @@ fn test_validate_params_positive_diag() { fn test_validate_params_row_not_sum_to_zero() { let mut param = utils::generate_discrete_time_continous_param(3); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; - param.cim = Some(cim); + let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0" ))), - param.validate_params() + result ); } diff --git a/tests/tools.rs b/tests/tools.rs index 802e2fe..257c957 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -23,14 +23,14 @@ fn run_sampling() { match &mut net.get_node_mut(n1).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); + param.set_cim(arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); } } match &mut net.get_node_mut(n2).params { params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[ + param.set_cim(arr3(&[ [[-1.0,1.0],[4.0,-4.0]], [[-6.0,6.0],[2.0,-2.0]]])); } diff --git a/tests/utils.rs b/tests/utils.rs index a973926..290fa2e 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -8,7 +8,7 @@ pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) - pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ - let mut domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); + let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); params::DiscreteStatesContinousTimeParams::init(domain) } From 490fe4e010ea6d1cc7841778bc30a65cbd00ac4b Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Mar 2022 15:53:50 +0200 Subject: [PATCH 004/126] Comments --- src/params.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/params.rs b/src/params.rs index 154fd12..5c4347e 100644 --- a/src/params.rs +++ b/src/params.rs @@ -84,11 +84,16 @@ impl DiscreteStatesContinousTimeParams { residence_time: Option::None, } } - + + ///Getter function for CIM pub fn get_cim(&self) -> &Option> { &self.cim } + ///Setter function for CIM + ///This function check if the cim is valid using the validate_params method. + ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) + ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError>{ self.cim = Some(cim); match self.validate_params() { From 16714b48b5a759ba44c14a2522270327624e0fbe Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Mar 2022 15:56:34 +0200 Subject: [PATCH 005/126] Comments --- src/params.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/params.rs b/src/params.rs index 5c4347e..019e281 100644 --- a/src/params.rs +++ b/src/params.rs @@ -90,7 +90,7 @@ impl DiscreteStatesContinousTimeParams { &self.cim } - ///Setter function for CIM + ///Setter function for CIM.\\ ///This function check if the cim is valid using the validate_params method. ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError @@ -105,20 +105,25 @@ impl DiscreteStatesContinousTimeParams { } } + + ///Getter function for transitions pub fn get_transitions(&self) -> &Option> { &self.transitions } + ///Setter function for transitions pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); } + ///Getter function for residence_time pub fn get_residence_time(&self) -> &Option> { &self.residence_time } + ///Setter function for residence_time pub fn set_residence_time(&mut self, residence_time: Array2) { self.residence_time = Some(residence_time); } From 82e3c779a04a3a6cfd8472bc29fc26b558ad75d6 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Mar 2022 16:03:27 +0200 Subject: [PATCH 006/126] Added test: test_validate_params_valid_cim_with_huge_values --- tests/params.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/params.rs b/tests/params.rs index 9bc9b49..cbc7636 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -69,6 +69,14 @@ fn test_validate_params_valid_cim() { assert_eq!(Ok(()), param.validate_params()); } +#[test] +fn test_validate_params_valid_cim_with_huge_values() { + let mut param = utils::generate_discrete_time_continous_param(3); + let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]]; + let result = param.set_cim(cim); + assert_eq!(Ok(()), result); +} + #[test] fn test_validate_params_cim_not_initialized() { let param = utils::generate_discrete_time_continous_param(3); @@ -80,7 +88,6 @@ fn test_validate_params_cim_not_initialized() { ); } - #[test] fn test_validate_params_wrong_shape() { let mut param = utils::generate_discrete_time_continous_param(4); @@ -88,13 +95,12 @@ fn test_validate_params_wrong_shape() { let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( - "Incompatible shape [1, 3, 3] with domain 4" - ))), + "Incompatible shape [1, 3, 3] with domain 4" + ))), result ); } - #[test] fn test_validate_params_positive_diag() { let mut param = utils::generate_discrete_time_continous_param(3); @@ -102,13 +108,12 @@ fn test_validate_params_positive_diag() { let result = param.set_cim(cim); assert_eq!( Err(ParamsError::InvalidCIM(String::from( - "The diagonal of each cim must be non-positive", - ))), + "The diagonal of each cim must be non-positive", + ))), result ); } - #[test] fn test_validate_params_row_not_sum_to_zero() { let mut param = utils::generate_discrete_time_continous_param(3); @@ -117,7 +122,7 @@ fn test_validate_params_row_not_sum_to_zero() { assert_eq!( Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0" - ))), + ))), result ); } From 6104dcc329481a591c62d9a91735fc65ea593bd1 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 5 Apr 2022 16:38:43 +0200 Subject: [PATCH 007/126] In the Bayesian approach alpha and tau are now divided by the number of possible configurations in its parent set --- src/parameter_learning.rs | 12 +++++++----- tests/parameter_learning.rs | 16 ++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 67ea07f..4fe3bdd 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -114,8 +114,8 @@ impl ParameterLearning for MLE { } pub struct BayesianApproach { - pub default_alpha: usize, - pub default_tau: f64 + pub alpha: usize, + pub tau: f64 } impl ParameterLearning for BayesianApproach { @@ -135,13 +135,15 @@ impl ParameterLearning for BayesianApproach { }; let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - M.mapv_inplace(|x|{x + self.default_alpha}); - T.mapv_inplace(|x|{x + self.default_tau}); + + let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; + let tau: f64 = self.tau as f64 / M.shape()[0] as f64; + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); + .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau)))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index d6b8fd2..345b8d1 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -60,8 +60,8 @@ fn learn_binary_cim_MLE() { #[test] fn learn_binary_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_binary_cim(ba); } @@ -115,8 +115,8 @@ fn learn_ternary_cim_MLE() { #[test] fn learn_ternary_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_ternary_cim(ba); } @@ -168,8 +168,8 @@ fn learn_ternary_cim_no_parents_MLE() { #[test] fn learn_ternary_cim_no_parents_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_ternary_cim_no_parents(ba); } @@ -257,7 +257,7 @@ fn learn_mixed_discrete_cim_MLE() { #[test] fn learn_mixed_discrete_cim_BA() { let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; + alpha: 1, + tau: 1.0}; learn_mixed_discrete_cim(ba); } From 86d2a0b7672a4c0aadbfbdae7c371bc611d5667b Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 6 Apr 2022 07:38:14 +0200 Subject: [PATCH 008/126] Implemented LL for ctbn --- Cargo.toml | 1 + src/lib.rs | 1 + src/structure_learning.rs | 82 +++++++++++++++++++++++++++++++++++++ tests/structure_learning.rs | 37 +++++++++++++++++ 4 files changed, 121 insertions(+) create mode 100644 src/structure_learning.rs create mode 100644 tests/structure_learning.rs diff --git a/Cargo.toml b/Cargo.toml index 3aa7c53..37f87e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" +statrs = "*" [dev-dependencies] approx = "*" diff --git a/src/lib.rs b/src/lib.rs index 65e4b11..ec12261 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,4 +8,5 @@ pub mod network; pub mod ctbn; pub mod tools; pub mod parameter_learning; +pub mod structure_learning; diff --git a/src/structure_learning.rs b/src/structure_learning.rs new file mode 100644 index 0000000..f4c369b --- /dev/null +++ b/src/structure_learning.rs @@ -0,0 +1,82 @@ +use crate::network; +use crate::parameter_learning; +use crate::params; +use crate::tools; +use ndarray::prelude::*; +use statrs::function::gamma; +use std::collections::BTreeSet; + +pub trait StructureLearning { + fn fit(&self, net: T, dataset: &tools::Dataset) -> T + where + T: network::Network; +} + +pub trait ScoreFunction { + fn compute_score( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network; +} + +pub struct LogLikelihood { + alpha: usize, + tau: f64, +} + +impl LogLikelihood { + pub fn init(alpha: usize, tau: f64) -> LogLikelihood { + if tau < 0.0 { + panic!("tau must be >=0.0"); + } + LogLikelihood { alpha, tau } + } +} + +impl ScoreFunction for LogLikelihood { + fn compute_score( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network, + { + match &net.get_node(node).params { + params::Params::DiscreteStatesContinousTime(params) => { + let (M, T) = + parameter_learning::sufficient_statistics(net, dataset, node, parent_set); + let alpha = self.alpha as f64 / M.shape()[0] as f64; + let tau = self.tau / M.shape()[0] as f64; + + let log_ll_q:f64 = M + .sum_axis(Axis(2)) + .iter() + .zip(T.iter()) + .map(|(m, t)| { + gamma::ln_gamma(alpha + *m as f64 + 1.0) + + (alpha + 1.0) * f64::ln(tau) + - gamma::ln_gamma(alpha + 1.0) + - (alpha + *m as f64 + 1.0) * f64::ln(tau + t) + }) + .sum(); + + let log_ll_theta: f64 = M.outer_iter() + .map(|x| x.outer_iter() + .map(|y| gamma::ln_gamma(alpha) + - gamma::ln_gamma(alpha + y.sum() as f64) + + y.iter().map(|z| + gamma::ln_gamma(alpha + *z as f64) + - gamma::ln_gamma(alpha)).sum::()).sum::()).sum(); + log_ll_theta + log_ll_q + } + } + } +} diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs new file mode 100644 index 0000000..95b95fa --- /dev/null +++ b/tests/structure_learning.rs @@ -0,0 +1,37 @@ + +mod utils; +use utils::*; + +use rustyCTBN::ctbn::*; +use rustyCTBN::network::Network; +use rustyCTBN::tools::*; +use rustyCTBN::structure_learning::*; +use ndarray::{arr1, arr2}; +use std::collections::BTreeSet; + + +#[macro_use] +extern crate approx; + +#[test] +fn simple_log_likelihood() { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .unwrap(); + + let trj = Trajectory{ + time: arr1(&[0.0,0.1,0.3]), + events: arr2(&[[0],[1],[1]])}; + + let dataset = Dataset{ + trajectories: vec![trj]}; + + let ll = LogLikelihood::init(1, 1.0); + + assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); + + + + +} From bb42365fb81cfad449609b76575b10122d76568a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 6 Apr 2022 11:18:29 +0200 Subject: [PATCH 009/126] Added meta and refactor issue templates --- .github/ISSUE_TEMPLATE/meta_request.md | 26 ++++++++++++++++++++++ .github/ISSUE_TEMPLATE/refactor_request.md | 26 ++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/meta_request.md create mode 100644 .github/ISSUE_TEMPLATE/refactor_request.md diff --git a/.github/ISSUE_TEMPLATE/meta_request.md b/.github/ISSUE_TEMPLATE/meta_request.md new file mode 100644 index 0000000..d80ccde --- /dev/null +++ b/.github/ISSUE_TEMPLATE/meta_request.md @@ -0,0 +1,26 @@ +--- +name: 📑 Meta request +about: Suggest an idea or a change for this same repository +title: '[Meta] ' +labels: 'meta' +assignees: '' + +--- + +## Description + +As a X, I want to Y, so Z. + +## Acceptance Criteria + +* Criteria 1 +* Criteria 2 + +## Checklist + +* [ ] Element 1 +* [ ] Element 2 + +## (Optional) Extra info + +None diff --git a/.github/ISSUE_TEMPLATE/refactor_request.md b/.github/ISSUE_TEMPLATE/refactor_request.md new file mode 100644 index 0000000..503e3f3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/refactor_request.md @@ -0,0 +1,26 @@ +--- +name: ⚙️ Refactor request +about: Suggest a refactor for this project +title: '[Refactor] ' +labels: 'enhancement' +assignees: '' + +--- + +## Description + +As a X, I want to Y, so Z. + +## Acceptance Criteria + +* Criteria 1 +* Criteria 2 + +## Checklist + +* [ ] Element 1 +* [ ] Element 2 + +## (Optional) Extra info + +None From 651148fffdd6c6d1c4d9c2095e912269e881f89d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 13:51:15 +0200 Subject: [PATCH 010/126] Replaced the current RNG with a seedable one (`rand_chacha`) --- Cargo.toml | 2 ++ src/params.rs | 7 ++++--- src/tools.rs | 7 ++++++- tests/parameter_learning.rs | 8 ++++---- tests/params.rs | 6 +++++- tests/tools.rs | 2 +- 6 files changed, 22 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3aa7c53..4cb6c06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" +rand_core = "*" +rand_chacha = "*" [dev-dependencies] approx = "*" diff --git a/src/params.rs b/src/params.rs index 019e281..b418df6 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,8 +1,10 @@ use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; +use rand::rngs::ThreadRng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; +use rand_chacha::ChaCha8Rng; /// Error types for trait Params #[derive(Error, Debug, PartialEq)] @@ -30,7 +32,7 @@ pub trait ParamsTrait { /// Randomly generate a possible state of the node disregarding the state of the node and it's /// parents. - fn get_random_state_uniform(&self) -> StateType; + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType; /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. @@ -137,8 +139,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { self.residence_time = Option::None; } - fn get_random_state_uniform(&self) -> StateType { - let mut rng = rand::thread_rng(); + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } diff --git a/src/tools.rs b/src/tools.rs index 27438f9..4efe085 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -3,6 +3,8 @@ use crate::node; use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_core::SeedableRng; pub struct Trajectory { pub time: Array1, @@ -17,11 +19,14 @@ pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, + seed: u64, ) -> Dataset { let mut dataset = Dataset { trajectories: Vec::new(), }; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let node_idx: Vec<_> = net.get_node_indices().collect(); for _ in 0..n_trajectories { let mut t = 0.0; @@ -29,7 +34,7 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx .iter() - .map(|x| net.get_node(*x).params.get_random_state_uniform()) + .map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng)) .collect(); let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 345b8d1..96b6ce1 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,7 +40,7 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0); + let data = trajectory_generator(&net, 100, 100.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -148,7 +148,7 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0); + let data = trajectory_generator(&net, 100, 200.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0); + let data = trajectory_generator(&net, 300, 300.0, 1234,); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/params.rs b/tests/params.rs index cbc7636..23c99fa 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,6 +1,8 @@ use ndarray::prelude::*; use rustyCTBN::params::*; use std::collections::BTreeSet; +use rand_chacha::ChaCha8Rng; +use rand_core::SeedableRng; mod utils; @@ -21,8 +23,10 @@ fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); + let mut rng = ChaCha8Rng::seed_from_u64(123456); + states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state_uniform() { + if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { val } else { panic!() diff --git a/tests/tools.rs b/tests/tools.rs index 257c957..f831ec4 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0); + let data = trajectory_generator(&net, 4, 1.0, 1234,); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); From 185e1756cacc11476cdc11e0d5c6a5740f2c1d2b Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 14:36:38 +0200 Subject: [PATCH 011/126] The residence time generation is now seedable --- src/params.rs | 4 ++-- src/tools.rs | 1 + tests/params.rs | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/params.rs b/src/params.rs index b418df6..963ff8c 100644 --- a/src/params.rs +++ b/src/params.rs @@ -36,7 +36,7 @@ pub trait ParamsTrait { /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. - fn get_random_residence_time(&self, state: usize, u: usize) -> Result; + fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. @@ -143,7 +143,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } - fn get_random_residence_time(&self, state: usize, u: usize) -> Result { + fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { // Generate a random residence time given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates diff --git a/src/tools.rs b/src/tools.rs index 4efe085..acee937 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -56,6 +56,7 @@ pub fn trajectory_generator( .get_random_residence_time( net.get_node(idx).params.state_to_index(¤t_state[idx]), net.get_param_index_network(idx, ¤t_state), + &mut rng, ) .unwrap() + t, diff --git a/tests/params.rs b/tests/params.rs index 23c99fa..f8b1154 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -61,7 +61,9 @@ fn test_random_generation_residence_time() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); + let mut rng = ChaCha8Rng::seed_from_u64(123456); + + states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); } From 9316fcee30b68021b26e3ef4385fba143cccb6be Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 14:41:18 +0200 Subject: [PATCH 012/126] The state generation is now seedable --- src/params.rs | 4 ++-- src/tools.rs | 1 + tests/params.rs | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/params.rs b/src/params.rs index 963ff8c..6173d75 100644 --- a/src/params.rs +++ b/src/params.rs @@ -40,7 +40,7 @@ pub trait ParamsTrait { /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. - fn get_random_state(&self, state: usize, u: usize) -> Result; + fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. fn get_reserved_space_as_parent(&self) -> usize; @@ -160,7 +160,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } } - fn get_random_state(&self, state: usize, u: usize) -> Result { + fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { // Generate a random transition given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution diff --git a/src/tools.rs b/src/tools.rs index acee937..858923e 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -84,6 +84,7 @@ pub fn trajectory_generator( .params .state_to_index(¤t_state[next_node_transition]), net.get_param_index_network(next_node_transition, ¤t_state), + &mut rng, ) .unwrap(); diff --git a/tests/params.rs b/tests/params.rs index f8b1154..8ab81c1 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -42,8 +42,10 @@ fn test_random_generation_state() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); + let mut rng = ChaCha8Rng::seed_from_u64(123456); + states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { + if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { val } else { panic!() From 05af0f37c4fdd61f20e753a7a1fe4e092a57c790 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 14:57:18 +0200 Subject: [PATCH 013/126] Added `.vscode` folder to `.gitignore` --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 96ef6c0..c640ca5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.vscode From 79dbd885296c602370a57592847ba08b2b3b7ed8 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 15:35:23 +0200 Subject: [PATCH 014/126] Get rid of the `rand_core` rependency --- Cargo.toml | 1 - src/tools.rs | 2 +- tests/params.rs | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4cb6c06..9941ed6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,6 @@ thiserror = "*" rand = "*" bimap = "*" enum_dispatch = "*" -rand_core = "*" rand_chacha = "*" [dev-dependencies] diff --git a/src/tools.rs b/src/tools.rs index 858923e..8cec2a2 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -4,7 +4,7 @@ use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; use rand_chacha::ChaCha8Rng; -use rand_core::SeedableRng; +use rand_chacha::rand_core::SeedableRng; pub struct Trajectory { pub time: Array1, diff --git a/tests/params.rs b/tests/params.rs index 8ab81c1..255aba6 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -2,7 +2,7 @@ use ndarray::prelude::*; use rustyCTBN::params::*; use std::collections::BTreeSet; use rand_chacha::ChaCha8Rng; -use rand_core::SeedableRng; +use rand_chacha::rand_core::SeedableRng; mod utils; From 79ec08b29af34a99388f398865ecddb183ac00c2 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 11 Apr 2022 16:33:39 +0200 Subject: [PATCH 015/126] Made seed optional in `trajectory_generator` --- src/tools.rs | 4 +++- tests/parameter_learning.rs | 8 ++++---- tests/tools.rs | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/tools.rs b/src/tools.rs index 8cec2a2..2a38d34 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -19,12 +19,14 @@ pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, - seed: u64, + seed: Option, ) -> Dataset { let mut dataset = Dataset { trajectories: Vec::new(), }; + let seed = seed.unwrap_or_else(rand::random); + let mut rng = ChaCha8Rng::seed_from_u64(seed); let node_idx: Vec<_> = net.get_node_indices().collect(); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 96b6ce1..af57291 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,7 +40,7 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0, 1234,); + let data = trajectory_generator(&net, 100, 100.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, 1234,); + let data = trajectory_generator(&net, 100, 200.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -148,7 +148,7 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, 1234,); + let data = trajectory_generator(&net, 100, 200.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0, 1234,); + let data = trajectory_generator(&net, 300, 300.0, Some(1234),); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); diff --git a/tests/tools.rs b/tests/tools.rs index f831ec4..fc9b930 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0, 1234,); + let data = trajectory_generator(&net, 4, 1.0, Some(1234),); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); From 62fcbd466a08d7dead12c48a40c8200f3fa57698 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 12 Apr 2022 09:33:07 +0200 Subject: [PATCH 016/126] Removed `rand::thread_rng` overriding the ChaCha's `rng`, increased the epsilon from 0.2 to 0.3 in the tests --- src/params.rs | 3 --- tests/parameter_learning.rs | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/params.rs b/src/params.rs index 6173d75..f0e5efa 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,7 +1,6 @@ use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; -use rand::rngs::ThreadRng; use std::collections::{BTreeSet, HashMap}; use thiserror::Error; use rand_chacha::ChaCha8Rng; @@ -149,7 +148,6 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let x: f64 = rng.gen_range(0.0..=1.0); Ok(-x.ln() / lambda) @@ -166,7 +164,6 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution match &self.cim { Option::Some(cim) => { - let mut rng = rand::thread_rng(); let lambda = cim[[u, state, state]] * -1.0; let urand: f64 = rng.gen_range(0.0..=1.0); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index af57291..adff6e8 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -47,7 +47,7 @@ fn learn_binary_cim (pl: T) { assert!(CIM.abs_diff_eq(&arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.2)); + ]), 0.3)); } #[test] @@ -101,7 +101,7 @@ fn learn_ternary_cim (pl: T) { [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.2)); + ]), 0.3)); } @@ -154,7 +154,7 @@ fn learn_ternary_cim_no_parents (pl: T) { assert_eq!(CIM.shape(), [1, 3, 3]); assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.2)); + [0.4, 0.6, -1.0]]]), 0.3)); } @@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.2)); + ]), 0.3)); } #[test] From a350ddc980204218c72879bf28d12b2dddecf825 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 12 Apr 2022 13:40:00 +0200 Subject: [PATCH 017/126] Decreased epsilon to `0.1` with a new seed --- tests/parameter_learning.rs | 16 ++++++++-------- tests/params.rs | 6 +++--- tests/tools.rs | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index adff6e8..15245fd 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,14 +40,14 @@ fn learn_binary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 100.0, Some(1234),); + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); assert!(CIM.abs_diff_eq(&arr3(&[ [[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.3)); + ]), 0.1)); } #[test] @@ -93,7 +93,7 @@ fn learn_ternary_cim (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, Some(1234),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); @@ -101,7 +101,7 @@ fn learn_ternary_cim (pl: T) { [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.3)); + ]), 0.1)); } @@ -148,13 +148,13 @@ fn learn_ternary_cim_no_parents (pl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, Some(1234),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.3)); + [0.4, 0.6, -1.0]]]), 0.1)); } @@ -228,7 +228,7 @@ fn learn_mixed_discrete_cim (pl: T) { } - let data = trajectory_generator(&net, 300, 300.0, Some(1234),); + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); @@ -244,7 +244,7 @@ fn learn_mixed_discrete_cim (pl: T) { [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.3)); + ]), 0.1)); } #[test] diff --git a/tests/params.rs b/tests/params.rs index 255aba6..b049d4e 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -23,7 +23,7 @@ fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - let mut rng = ChaCha8Rng::seed_from_u64(123456); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); states.mapv_inplace(|_| { if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { @@ -42,7 +42,7 @@ fn test_random_generation_state() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - let mut rng = ChaCha8Rng::seed_from_u64(123456); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); states.mapv_inplace(|_| { if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { @@ -63,7 +63,7 @@ fn test_random_generation_residence_time() { let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); - let mut rng = ChaCha8Rng::seed_from_u64(123456); + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); diff --git a/tests/tools.rs b/tests/tools.rs index fc9b930..76847ef 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -36,7 +36,7 @@ fn run_sampling() { } } - let data = trajectory_generator(&net, 4, 1.0, Some(1234),); + let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); assert_eq!(4, data.trajectories.len()); assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); From 4a7c34af1793b33913eec7f92a52930927e83a83 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 13 Apr 2022 11:49:04 +0200 Subject: [PATCH 018/126] BIC --- src/structure_learning.rs | 55 +++++++++++++++++++++++++++++++++---- tests/structure_learning.rs | 22 +++++++++++++-- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/src/structure_learning.rs b/src/structure_learning.rs index f4c369b..4de13ae 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -13,7 +13,7 @@ pub trait StructureLearning { } pub trait ScoreFunction { - fn compute_score( + fn call( &self, net: &T, node: usize, @@ -36,16 +36,14 @@ impl LogLikelihood { } LogLikelihood { alpha, tau } } -} -impl ScoreFunction for LogLikelihood { fn compute_score( &self, net: &T, node: usize, parent_set: &BTreeSet, dataset: &tools::Dataset, - ) -> f64 + ) -> (f64, Array3) where T: network::Network, { @@ -75,8 +73,55 @@ impl ScoreFunction for LogLikelihood { + y.iter().map(|z| gamma::ln_gamma(alpha + *z as f64) - gamma::ln_gamma(alpha)).sum::()).sum::()).sum(); - log_ll_theta + log_ll_q + (log_ll_theta + log_ll_q, M) } } } + + + +} + +impl ScoreFunction for LogLikelihood { + fn call( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network, + { + self.compute_score(net, node, parent_set, dataset).0 + } +} + +pub struct BIC { + ll: LogLikelihood +} + +impl BIC { + pub fn init(alpha: usize, tau: f64) -> BIC { + BIC { + ll: LogLikelihood::init(alpha, tau) + } + } +} + +impl ScoreFunction for BIC { + fn call( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: network::Network { + let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); + let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); + let sample_size = M.sum(); + ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 + } } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 95b95fa..f3633b5 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -29,9 +29,27 @@ fn simple_log_likelihood() { let ll = LogLikelihood::init(1, 1.0); - assert_abs_diff_eq!(0.04257, ll.compute_score(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); + assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); - +} + + +#[test] +fn simple_bic() { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .unwrap(); + + let trj = Trajectory{ + time: arr1(&[0.0,0.1,0.3]), + events: arr2(&[[0],[1],[1]])}; + + let dataset = Dataset{ + trajectories: vec![trj]}; + + let ll = BIC::init(1, 1.0); + assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); } From 394970adca780712cd9ab599c292110381318db5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 13 Apr 2022 14:05:36 +0200 Subject: [PATCH 019/126] Test for BIC --- src/structure_learning.rs | 3 ++- tests/structure_learning.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/structure_learning.rs b/src/structure_learning.rs index 4de13ae..ba76b7a 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -121,7 +121,8 @@ impl ScoreFunction for BIC { T: network::Network { let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); - let sample_size = M.sum(); + //TODO: Optimize this + let sample_size: usize = dataset.trajectories.iter().map(|x| x.time.len() -1).sum(); ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 } } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index f3633b5..a9feea9 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -50,6 +50,6 @@ fn simple_bic() { let ll = BIC::init(1, 1.0); - assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); + assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); } From 0b019de45b47da5992601ac826e700a7e58ceca3 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 13 Apr 2022 14:38:25 +0200 Subject: [PATCH 020/126] Added labels to `labels.yml` and automated the implementation of the edits via GH Actions --- .github/ISSUE_TEMPLATE/refactor_request.md | 2 +- .github/labels.yml | 36 ++++++++++++++++++++++ .github/pull_request_template.md | 2 +- .github/workflows/labels.yml | 23 ++++++++++++++ 4 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 .github/labels.yml create mode 100644 .github/workflows/labels.yml diff --git a/.github/ISSUE_TEMPLATE/refactor_request.md b/.github/ISSUE_TEMPLATE/refactor_request.md index 503e3f3..9a4d090 100644 --- a/.github/ISSUE_TEMPLATE/refactor_request.md +++ b/.github/ISSUE_TEMPLATE/refactor_request.md @@ -2,7 +2,7 @@ name: ⚙️ Refactor request about: Suggest a refactor for this project title: '[Refactor] ' -labels: 'enhancement' +labels: enhancement, refactor assignees: '' --- diff --git a/.github/labels.yml b/.github/labels.yml new file mode 100644 index 0000000..0129d00 --- /dev/null +++ b/.github/labels.yml @@ -0,0 +1,36 @@ +- name: "bug" + color: "d73a4a" + description: "Something isn't working" +- name: "enhancement" + color: "a2eeef" + description: "New feature or request" +- name: "refactor" + color: "B06E16" + description: "Change in the structure" +- name: "documentation" + color: "0075ca" + description: "Improvements or additions to documentation" +- name: "meta" + color: "1D76DB" + description: "Something related to the project itself" + +- name: "duplicate" + color: "cfd3d7" + description: "This issue or pull request already exists" + +- name: "help wanted" + color: "008672" + description: "Extra help is needed" +- name: "urgent" + color: "D93F0B" + description: "" +- name: "wontfix" + color: "ffffff" + description: "This will not be worked on" +- name: "invalid" + color: "e4e669" + description: "This doesn't seem right" + +- name: "question" + color: "d876e3" + description: "Further information is requested" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7e51286..063a3e0 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,4 +1,4 @@ -# Pull/Merge Request into master dev + ## Description diff --git a/.github/workflows/labels.yml b/.github/workflows/labels.yml new file mode 100644 index 0000000..2d5bc59 --- /dev/null +++ b/.github/workflows/labels.yml @@ -0,0 +1,23 @@ +name: meta-github + +on: + push: + branches: + - dev + +jobs: + labeler: + runs-on: ubuntu-latest + steps: + - + name: Checkout + uses: actions/checkout@v2 + - + name: Run Labeler + if: success() + uses: crazy-max/ghaction-github-labeler@v3 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + yaml-file: .github/labels.yml + skip-delete: false + dry-run: false From a4b0a406f4d65f83d7a638f990f84a229a68fa54 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 13 Apr 2022 19:42:24 +0200 Subject: [PATCH 021/126] Hill Climbing + Simple test --- src/structure_learning/mod.rs | 10 +++ .../score_based_algorithm.rs | 61 +++++++++++++++++++ .../score_function.rs} | 6 -- tests/structure_learning.rs | 49 ++++++++++++++- 4 files changed, 118 insertions(+), 8 deletions(-) create mode 100644 src/structure_learning/mod.rs create mode 100644 src/structure_learning/score_based_algorithm.rs rename src/{structure_learning.rs => structure_learning/score_function.rs} (96%) diff --git a/src/structure_learning/mod.rs b/src/structure_learning/mod.rs new file mode 100644 index 0000000..d72862d --- /dev/null +++ b/src/structure_learning/mod.rs @@ -0,0 +1,10 @@ +pub mod score_function; +pub mod score_based_algorithm; +use crate::network; +use crate::tools; + +pub trait StructureLearningAlgorithm { + fn call(&self, net: T, dataset: &tools::Dataset) -> T + where + T: network::Network; +} diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs new file mode 100644 index 0000000..ed54092 --- /dev/null +++ b/src/structure_learning/score_based_algorithm.rs @@ -0,0 +1,61 @@ +use crate::params; +use crate::structure_learning::score_function::ScoreFunction; +use crate::structure_learning::StructureLearningAlgorithm; +use crate::tools; +use crate::{network, parameter_learning}; +use ndarray::prelude::*; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use std::collections::BTreeSet; + +pub struct HillClimbing { + score_function: S, +} + +impl HillClimbing { + pub fn init(score_function: S) -> HillClimbing { + HillClimbing { score_function } + } +} + +impl StructureLearningAlgorithm for HillClimbing { + fn call(&self, net: T, dataset: &tools::Dataset) -> T + where + T: network::Network, + { + let mut net = net; + net.initialize_adj_matrix(); + for node in net.get_node_indices() { + let mut parent_set: BTreeSet = BTreeSet::new(); + let mut current_ll = self.score_function.call(&net, node, &parent_set, dataset); + let mut old_ll = f64::NEG_INFINITY; + while current_ll > old_ll { + old_ll = current_ll; + for parent in net.get_node_indices() { + if parent == node { + continue; + } + let is_removed = parent_set.remove(&parent); + if !is_removed { + parent_set.insert(parent); + } + + let tmp_ll = self.score_function.call(&net, node, &parent_set, dataset); + + if tmp_ll < current_ll { + if is_removed { + parent_set.insert(parent); + } else { + parent_set.remove(&parent); + } + } else { + current_ll = tmp_ll; + } + } + } + parent_set.iter().for_each(|p| net.add_edge(*p, node)); + } + + return net; + } +} diff --git a/src/structure_learning.rs b/src/structure_learning/score_function.rs similarity index 96% rename from src/structure_learning.rs rename to src/structure_learning/score_function.rs index ba76b7a..06f9fb9 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning/score_function.rs @@ -6,12 +6,6 @@ use ndarray::prelude::*; use statrs::function::gamma; use std::collections::BTreeSet; -pub trait StructureLearning { - fn fit(&self, net: T, dataset: &tools::Dataset) -> T - where - T: network::Network; -} - pub trait ScoreFunction { fn call( &self, diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index a9feea9..e3a43e4 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -5,9 +5,12 @@ use utils::*; use rustyCTBN::ctbn::*; use rustyCTBN::network::Network; use rustyCTBN::tools::*; -use rustyCTBN::structure_learning::*; -use ndarray::{arr1, arr2}; +use rustyCTBN::structure_learning::score_function::*; +use rustyCTBN::structure_learning::score_based_algorithm::*; +use rustyCTBN::structure_learning::StructureLearningAlgorithm; +use ndarray::{arr1, arr2, arr3}; use std::collections::BTreeSet; +use rustyCTBN::params; #[macro_use] @@ -53,3 +56,45 @@ fn simple_bic() { assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); } + +fn learn_ternary_net_2_nodes (sl: T) { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]))); + } + } + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); + + let net = sl.call(net, &data); + assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); + assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); +} + +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing() { + let bic = BIC::init(1, 1.0); + let hl = HillClimbing::init(bic); + learn_ternary_net_2_nodes(hl); +} From c7857391007363d738829398104d71250528207c Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 13 Apr 2022 19:44:19 +0200 Subject: [PATCH 022/126] LL test for hill climbing --- tests/structure_learning.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index e3a43e4..2ddd513 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -92,8 +92,16 @@ fn learn_ternary_net_2_nodes (sl: T) { assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } + +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { + let ll = LogLikelihood::init(1, 1.0); + let hl = HillClimbing::init(ll); + learn_ternary_net_2_nodes(hl); +} + #[test] -pub fn learn_ternary_net_2_nodes_hill_climbing() { +pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let bic = BIC::init(1, 1.0); let hl = HillClimbing::init(bic); learn_ternary_net_2_nodes(hl); From 808bc0098c5069d01d9f5eb2969beb4261bb6099 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 09:06:50 +0200 Subject: [PATCH 023/126] Added learn_mixed_discrete_net_3 test --- tests/structure_learning.rs | 81 ++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 2ddd513..e0e6d9b 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -85,7 +85,7 @@ fn learn_ternary_net_2_nodes (sl: T) { } } - let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); let net = sl.call(net, &data); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); @@ -106,3 +106,82 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let hl = HillClimbing::init(bic); learn_ternary_net_2_nodes(hl); } + + + +fn learn_mixed_discrete_net_3_nodes (sl: T) { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .unwrap(); + + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"),4)) + .unwrap(); + net.add_edge(n1, n2); + net.add_edge(n1, n3); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]))); + } + } + + + match &mut net.get_node_mut(n3).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], + [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], + [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], + + [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], + [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], + [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], + + [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], + [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], + [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], + ]))); + } + } + + + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); + let net = sl.call(net, &data); + + assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); + assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); + assert_eq!(BTreeSet::from_iter(vec![n1, n2]), net.get_parent_set(n3)); +} + + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { + let ll = LogLikelihood::init(1, 1.0); + let hl = HillClimbing::init(ll); + learn_mixed_discrete_net_3_nodes(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { + let bic = BIC::init(1, 1.0); + let hl = HillClimbing::init(bic); + learn_mixed_discrete_net_3_nodes(hl); +} From 5044a88b6d993f9101532b670bc7548459ce0a39 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 09:09:11 +0200 Subject: [PATCH 024/126] refactor --- src/structure_learning/mod.rs | 2 +- src/structure_learning/score_based_algorithm.rs | 2 +- tests/structure_learning.rs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/structure_learning/mod.rs b/src/structure_learning/mod.rs index d72862d..a335101 100644 --- a/src/structure_learning/mod.rs +++ b/src/structure_learning/mod.rs @@ -4,7 +4,7 @@ use crate::network; use crate::tools; pub trait StructureLearningAlgorithm { - fn call(&self, net: T, dataset: &tools::Dataset) -> T + fn fit(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network; } diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index ed54092..63620fe 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -19,7 +19,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn call(&self, net: T, dataset: &tools::Dataset) -> T + fn fit(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network, { diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index e0e6d9b..4ce89a3 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -87,7 +87,7 @@ fn learn_ternary_net_2_nodes (sl: T) { let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); - let net = sl.call(net, &data); + let net = sl.fit(net, &data); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } @@ -164,7 +164,7 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); - let net = sl.call(net, &data); + let net = sl.fit(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); From 8ca93c931b14a90718312ba383403fed97cd0c2f Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 10:04:09 +0200 Subject: [PATCH 025/126] Refactor --- src/{structure_learning/mod.rs => structure_learning.rs} | 2 +- src/structure_learning/score_based_algorithm.rs | 8 ++------ tests/structure_learning.rs | 4 ++-- 3 files changed, 5 insertions(+), 9 deletions(-) rename src/{structure_learning/mod.rs => structure_learning.rs} (70%) diff --git a/src/structure_learning/mod.rs b/src/structure_learning.rs similarity index 70% rename from src/structure_learning/mod.rs rename to src/structure_learning.rs index a335101..8ba91df 100644 --- a/src/structure_learning/mod.rs +++ b/src/structure_learning.rs @@ -4,7 +4,7 @@ use crate::network; use crate::tools; pub trait StructureLearningAlgorithm { - fn fit(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network; } diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index 63620fe..0a23c36 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -1,11 +1,7 @@ -use crate::params; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; use crate::tools; -use crate::{network, parameter_learning}; -use ndarray::prelude::*; -use rand::prelude::*; -use rand_chacha::ChaCha8Rng; +use crate::network; use std::collections::BTreeSet; pub struct HillClimbing { @@ -19,7 +15,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network, { diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 4ce89a3..25ce1e8 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -87,7 +87,7 @@ fn learn_ternary_net_2_nodes (sl: T) { let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); - let net = sl.fit(net, &data); + let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } @@ -164,7 +164,7 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); - let net = sl.fit(net, &data); + let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); From df12b93d559b1a74ab7f056a33e73023f1aa8e93 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 10:56:27 +0200 Subject: [PATCH 026/126] Lmit parent set --- src/structure_learning/score_based_algorithm.rs | 8 +++++--- tests/structure_learning.rs | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index 0a23c36..b5590ed 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -6,11 +6,12 @@ use std::collections::BTreeSet; pub struct HillClimbing { score_function: S, + max_parent_set: Option } impl HillClimbing { - pub fn init(score_function: S) -> HillClimbing { - HillClimbing { score_function } + pub fn init(score_function: S, max_parent_set: Option) -> HillClimbing { + HillClimbing { score_function, max_parent_set } } } @@ -20,6 +21,7 @@ impl StructureLearningAlgorithm for HillClimbing { T: network::Network, { let mut net = net; + let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); net.initialize_adj_matrix(); for node in net.get_node_indices() { let mut parent_set: BTreeSet = BTreeSet::new(); @@ -32,7 +34,7 @@ impl StructureLearningAlgorithm for HillClimbing { continue; } let is_removed = parent_set.remove(&parent); - if !is_removed { + if !is_removed && parent_set.len() < max_parent_set { parent_set.insert(parent); } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 25ce1e8..f9c0034 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -96,14 +96,14 @@ fn learn_ternary_net_2_nodes (sl: T) { #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::init(1, 1.0); - let hl = HillClimbing::init(ll); + let hl = HillClimbing::init(ll, None); learn_ternary_net_2_nodes(hl); } #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let bic = BIC::init(1, 1.0); - let hl = HillClimbing::init(bic); + let hl = HillClimbing::init(bic, None); learn_ternary_net_2_nodes(hl); } @@ -175,13 +175,13 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::init(1, 1.0); - let hl = HillClimbing::init(ll); + let hl = HillClimbing::init(ll, None); learn_mixed_discrete_net_3_nodes(hl); } #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { let bic = BIC::init(1, 1.0); - let hl = HillClimbing::init(bic); + let hl = HillClimbing::init(bic, None); learn_mixed_discrete_net_3_nodes(hl); } From 9a4914f5d024dc931b188ec8cea65c0a6c80436c Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 14 Apr 2022 12:05:56 +0200 Subject: [PATCH 027/126] Added basic GH workflow with lint and test (build included in the test) --- .github/workflows/build.yml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/build.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..a548a6f --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,30 @@ +name: build + +on: + push: + branches: [ main, dev ] + pull_request: + branches: [ dev ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Setup Rust + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + components: clippy + - uses: actions-rs/clippy-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + args: --all-features + - name: Run tests + run: cargo test --verbose From a6654385fc3755d3791371faa494f3cc9b307b5a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 14 Apr 2022 13:55:01 +0200 Subject: [PATCH 028/126] Workflow more compliant to the GH style --- .github/workflows/build.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a548a6f..1f73a77 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,15 +16,19 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Setup Rust + - name: Setup (rust) uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable components: clippy - - uses: actions-rs/clippy-check@v1 + - name: Linting (clippy) + uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} args: --all-features - - name: Run tests - run: cargo test --verbose + - name: Tests (test) + uses: actions-rs/cargo@v1 + with: + command: test + args: --tests From f49523f35a54bb2fff18d6163bab568f5378cffe Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 14:06:20 +0200 Subject: [PATCH 029/126] Refactor of Dataset and Trajectory to ensure some basic properties. --- src/parameter_learning.rs | 12 +++--- src/structure_learning/score_function.rs | 2 +- src/tools.rs | 55 ++++++++++++++++++------ tests/structure_learning.rs | 25 +++++------ tests/tools.rs | 4 +- 5 files changed, 63 insertions(+), 35 deletions(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 4fe3bdd..c4221cb 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -57,12 +57,12 @@ pub fn sufficient_statistics( let mut T: Array2 = Array::zeros((parentset_domain.iter().product(), node_domain)); //Compute the sufficient statistics - for trj in dataset.trajectories.iter() { - for idx in 0..(trj.time.len() - 1) { - let t1 = trj.time[idx]; - let t2 = trj.time[idx + 1]; - let ev1 = trj.events.row(idx); - let ev2 = trj.events.row(idx + 1); + for trj in dataset.get_trajectories().iter() { + for idx in 0..(trj.get_time().len() - 1) { + let t1 = trj.get_time()[idx]; + let t2 = trj.get_time()[idx + 1]; + let ev1 = trj.get_events().row(idx); + let ev2 = trj.get_events().row(idx + 1); let idx1 = vector_to_idx.dot(&ev1); T[[idx1, ev1[node]]] += t2 - t1; diff --git a/src/structure_learning/score_function.rs b/src/structure_learning/score_function.rs index 06f9fb9..323bbef 100644 --- a/src/structure_learning/score_function.rs +++ b/src/structure_learning/score_function.rs @@ -116,7 +116,7 @@ impl ScoreFunction for BIC { let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); //TODO: Optimize this - let sample_size: usize = dataset.trajectories.iter().map(|x| x.time.len() -1).sum(); + let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 } } diff --git a/src/tools.rs b/src/tools.rs index 2a38d34..922bb2b 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -3,16 +3,49 @@ use crate::node; use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; -use rand_chacha::ChaCha8Rng; use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; pub struct Trajectory { - pub time: Array1, - pub events: Array2, + time: Array1, + events: Array2, +} + +impl Trajectory { + pub fn init(time: Array1, events: Array2) -> Trajectory { + if time.shape()[0] != events.shape()[0] { + panic!("time.shape[0] must be equal to events.shape[0]"); + } + Trajectory { time, events } + } + + pub fn get_time(&self) -> &Array1 { + &self.time + } + + pub fn get_events(&self) -> &Array2 { + &self.events + } } pub struct Dataset { - pub trajectories: Vec, + trajectories: Vec, +} + +impl Dataset { + pub fn init(trajectories: Vec) -> Dataset { + if trajectories + .iter() + .any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1]) + { + panic!("All the trajectories mus represents the same number of variables"); + } + Dataset { trajectories } + } + + pub fn get_trajectories(&self) -> &Vec { + &self.trajectories + } } pub fn trajectory_generator( @@ -21,10 +54,8 @@ pub fn trajectory_generator( t_end: f64, seed: Option, ) -> Dataset { - let mut dataset = Dataset { - trajectories: Vec::new(), - }; + let mut trajectories: Vec = Vec::new(); let seed = seed.unwrap_or_else(rand::random); let mut rng = ChaCha8Rng::seed_from_u64(seed); @@ -115,14 +146,14 @@ pub fn trajectory_generator( ); time.push(t_end.clone()); - dataset.trajectories.push(Trajectory { - time: Array::from_vec(time), - events: Array2::from_shape_vec( + trajectories.push(Trajectory::init( + Array::from_vec(time), + Array2::from_shape_vec( (events.len(), current_state.len()), events.iter().flatten().cloned().collect(), ) .unwrap(), - }); + )); } - dataset + Dataset::init(trajectories) } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index f9c0034..ad18c18 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -17,18 +17,17 @@ use rustyCTBN::params; extern crate approx; #[test] -fn simple_log_likelihood() { +fn simple_score_test() { let mut net = CtbnNetwork::init(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) .unwrap(); - let trj = Trajectory{ - time: arr1(&[0.0,0.1,0.3]), - events: arr2(&[[0],[1],[1]])}; + let trj = Trajectory::init( + arr1(&[0.0,0.1,0.3]), + arr2(&[[0],[1],[1]])); - let dataset = Dataset{ - trajectories: vec![trj]}; + let dataset = Dataset::init(vec![trj]); let ll = LogLikelihood::init(1, 1.0); @@ -44,16 +43,14 @@ fn simple_bic() { .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) .unwrap(); - let trj = Trajectory{ - time: arr1(&[0.0,0.1,0.3]), - events: arr2(&[[0],[1],[1]])}; + let trj = Trajectory::init( + arr1(&[0.0,0.1,0.3]), + arr2(&[[0],[1],[1]])); - let dataset = Dataset{ - trajectories: vec![trj]}; - - let ll = BIC::init(1, 1.0); + let dataset = Dataset::init(vec![trj]); + let bic = BIC::init(1, 1.0); - assert_abs_diff_eq!(-0.65058, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); + assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); } diff --git a/tests/tools.rs b/tests/tools.rs index 76847ef..28b3e0d 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -38,8 +38,8 @@ fn run_sampling() { let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); - assert_eq!(4, data.trajectories.len()); - assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); + assert_eq!(4, data.get_trajectories().len()); + assert_relative_eq!(1.0, data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len()-1]); } From 5444c619a2ef402af095cca376bd47c1db8f8120 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 14:22:53 +0200 Subject: [PATCH 030/126] Added tests --- tests/structure_learning.rs | 40 +++++++++++++++++++++++++++++++------ tests/tools.rs | 24 +++++++++++++++++++++- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index ad18c18..d24170a 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -105,8 +105,7 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { } - -fn learn_mixed_discrete_net_3_nodes (sl: T) { +fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { let mut net = CtbnNetwork::init(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) @@ -161,11 +160,15 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); - let net = sl.fit_transform(net, &data); + return (net, data); +} - assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); - assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); - assert_eq!(BTreeSet::from_iter(vec![n1, n2]), net.get_parent_set(n3)); +fn learn_mixed_discrete_net_3_nodes (sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); } @@ -182,3 +185,28 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { let hl = HillClimbing::init(bic, None); learn_mixed_discrete_net_3_nodes(hl); } + + + +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint (sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); +} + + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { + let ll = LogLikelihood::init(1, 1.0); + let hl = HillClimbing::init(ll, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { + let bic = BIC::init(1, 1.0); + let hl = HillClimbing::init(bic, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); +} diff --git a/tests/tools.rs b/tests/tools.rs index 28b3e0d..c341b8c 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -5,7 +5,7 @@ use rustyCTBN::ctbn::*; use rustyCTBN::node; use rustyCTBN::params; use std::collections::BTreeSet; -use ndarray::arr3; +use ndarray::{arr1, arr2, arr3}; @@ -43,3 +43,25 @@ fn run_sampling() { } +#[test] +#[should_panic] + fn trajectory_wrong_shape() { + let time = arr1(&[0.0, 0.2]); + let events = arr2(&[[0,3]]); + Trajectory::init(time, events); +} + + +#[test] +#[should_panic] +fn dataset_wrong_shape() { + let time = arr1(&[0.0, 0.2]); + let events = arr2(&[[0,3], [1,2]]); + let t1 = Trajectory::init(time, events); + + + let time = arr1(&[0.0, 0.2]); + let events = arr2(&[[0,3,3], [1,2,3]]); + let t2 = Trajectory::init(time, events); + Dataset::init(vec![t1, t2]); +} From b357c9efa0a43c7cc4fa744c62eb73e933cef703 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 14:38:37 +0200 Subject: [PATCH 031/126] HillClimbing now panic when the dataset is incompatible with the network. --- .../score_based_algorithm.rs | 4 ++ tests/structure_learning.rs | 48 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index b5590ed..7537483 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -20,6 +20,10 @@ impl StructureLearningAlgorithm for HillClimbing { where T: network::Network, { + if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { + panic!("Dataset and Network must have the same number of variables.") + } + let mut net = net; let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); net.initialize_adj_matrix(); diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index d24170a..c3482cc 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -54,6 +54,54 @@ fn simple_bic() { } + + +fn check_compatibility_between_dataset_and_network (sl: T) { + let mut net = CtbnNetwork::init(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]))); + } + } + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); + + let mut net = CtbnNetwork::init(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .unwrap(); + let net = sl.fit_transform(net, &data); +} + + +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing() { + let ll = LogLikelihood::init(1, 1.0); + let hl = HillClimbing::init(ll, None); + check_compatibility_between_dataset_and_network(hl); +} + fn learn_ternary_net_2_nodes (sl: T) { let mut net = CtbnNetwork::init(); let n1 = net From 43d01d2bf84e18199a94dcc37cc5c8073e5c65b6 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 14 Apr 2022 17:22:27 +0200 Subject: [PATCH 032/126] Added coments. --- .../score_based_algorithm.rs | 46 +++++++++---- src/structure_learning/score_function.rs | 22 ++++-- src/tools.rs | 69 ++++++++++++++----- 3 files changed, 102 insertions(+), 35 deletions(-) diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index 7537483..e57c4c1 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -1,17 +1,20 @@ +use crate::network; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; use crate::tools; -use crate::network; use std::collections::BTreeSet; pub struct HillClimbing { score_function: S, - max_parent_set: Option + max_parent_set: Option, } impl HillClimbing { pub fn init(score_function: S, max_parent_set: Option) -> HillClimbing { - HillClimbing { score_function, max_parent_set } + HillClimbing { + score_function, + max_parent_set, + } } } @@ -20,41 +23,58 @@ impl StructureLearningAlgorithm for HillClimbing { where T: network::Network, { + //Check the coherence between dataset and network if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { panic!("Dataset and Network must have the same number of variables.") } + //Make the network mutable. let mut net = net; + //Check if the max_parent_set constraint is present. let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); + //Reset the adj matrix net.initialize_adj_matrix(); + //Iterate over each node to learn their parent set. for node in net.get_node_indices() { + //Initialize an empty parent set. let mut parent_set: BTreeSet = BTreeSet::new(); - let mut current_ll = self.score_function.call(&net, node, &parent_set, dataset); - let mut old_ll = f64::NEG_INFINITY; - while current_ll > old_ll { - old_ll = current_ll; + //Compute the score for the empty parent set + let mut current_score = self.score_function.call(&net, node, &parent_set, dataset); + //Set the old score to -\infty. + let mut old_score = f64::NEG_INFINITY; + //Iterate until convergence + while current_score > old_score { + //Save the current_score. + old_score = current_score; + //Iterate over each node. for parent in net.get_node_indices() { + //Continue if the parent and the node are the same. if parent == node { continue; } + //Try to remove parent from the parent_set. let is_removed = parent_set.remove(&parent); + //If parent was not in the parent_set add it. if !is_removed && parent_set.len() < max_parent_set { parent_set.insert(parent); } - - let tmp_ll = self.score_function.call(&net, node, &parent_set, dataset); - - if tmp_ll < current_ll { + //Compute the score with the modified parent_set. + let tmp_score = self.score_function.call(&net, node, &parent_set, dataset); + //If tmp_score is worst than current_score revert the change to the parent set + if tmp_score < current_score { if is_removed { parent_set.insert(parent); } else { parent_set.remove(&parent); } - } else { - current_ll = tmp_ll; + } + //Otherwise save the computed score as current_score + else { + current_score = tmp_score; } } } + //Apply the learned parent_set to the network struct. parent_set.iter().for_each(|p| net.add_edge(*p, node)); } diff --git a/src/structure_learning/score_function.rs b/src/structure_learning/score_function.rs index 323bbef..dba40e2 100644 --- a/src/structure_learning/score_function.rs +++ b/src/structure_learning/score_function.rs @@ -25,6 +25,8 @@ pub struct LogLikelihood { impl LogLikelihood { pub fn init(alpha: usize, tau: f64) -> LogLikelihood { + + //Tau must be >=0.0 if tau < 0.0 { panic!("tau must be >=0.0"); } @@ -40,14 +42,21 @@ impl LogLikelihood { ) -> (f64, Array3) where T: network::Network, - { + { + //Identify the type of node used match &net.get_node(node).params { - params::Params::DiscreteStatesContinousTime(params) => { + params::Params::DiscreteStatesContinousTime(_params) => { + //Compute the sufficient statistics M (number of transistions) and T (residence + //time) let (M, T) = parameter_learning::sufficient_statistics(net, dataset, node, parent_set); + + //Scale alpha accordingly to the size of the parent set let alpha = self.alpha as f64 / M.shape()[0] as f64; + //Scale tau accordingly to the size of the parent set let tau = self.tau / M.shape()[0] as f64; - + + //Compute the log likelihood for q let log_ll_q:f64 = M .sum_axis(Axis(2)) .iter() @@ -59,7 +68,8 @@ impl LogLikelihood { - (alpha + *m as f64 + 1.0) * f64::ln(tau + t) }) .sum(); - + + //Compute the log likelihood for theta let log_ll_theta: f64 = M.outer_iter() .map(|x| x.outer_iter() .map(|y| gamma::ln_gamma(alpha) @@ -113,10 +123,14 @@ impl ScoreFunction for BIC { ) -> f64 where T: network::Network { + //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); + //Compute the number of parameters let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); //TODO: Optimize this + //Compute the sample size let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); + //Compute BIC ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 } } diff --git a/src/tools.rs b/src/tools.rs index 922bb2b..7cf205b 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -13,12 +13,14 @@ pub struct Trajectory { impl Trajectory { pub fn init(time: Array1, events: Array2) -> Trajectory { + //Events and time are two part of the same trajectory. For this reason they must have the + //same number of sample. if time.shape()[0] != events.shape()[0] { panic!("time.shape[0] must be equal to events.shape[0]"); } Trajectory { time, events } } - + pub fn get_time(&self) -> &Array1 { &self.time } @@ -34,6 +36,9 @@ pub struct Dataset { impl Dataset { pub fn init(trajectories: Vec) -> Dataset { + + //All the trajectories in the same dataset must represent the same process. For this reason + //each trajectory must represent the same number of variables. if trajectories .iter() .any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1]) @@ -54,23 +59,38 @@ pub fn trajectory_generator( t_end: f64, seed: Option, ) -> Dataset { - + + //Tmp growing vector containing generated trajectories. let mut trajectories: Vec = Vec::new(); - let seed = seed.unwrap_or_else(rand::random); - - let mut rng = ChaCha8Rng::seed_from_u64(seed); - - let node_idx: Vec<_> = net.get_node_indices().collect(); + + //Random Generator object + let mut rng: ChaCha8Rng = match seed { + //If a seed is present use it to initialize the random generator. + Some(seed) => SeedableRng::seed_from_u64(seed), + //Otherwise create a new random generator using the method `from_entropy` + None => SeedableRng::from_entropy() + }; + + //Each iteration generate one trajectory for _ in 0..n_trajectories { + //Current time of the sampling process let mut t = 0.0; + //History of all the moments in which something changed let mut time: Vec = Vec::new(); - let mut events: Vec> = Vec::new(); - let mut current_state: Vec = node_idx - .iter() - .map(|x| net.get_node(*x).params.get_random_state_uniform(&mut rng)) + //Configuration of the process variables at time t initialized with an uniform + //distribution. + let mut current_state: Vec = net.get_node_indices() + .map(|x| net.get_node(x).params.get_random_state_uniform(&mut rng)) .collect(); + //History of all the configurations of the process variables. + let mut events: Vec> = Vec::new(); + //Vector containing to time to the next transition for each variable. let mut next_transitions: Vec> = - (0..node_idx.len()).map(|_| Option::None).collect(); + net.get_node_indices().map(|_| Option::None).collect(); + + //Add the starting time for the trajectory. + time.push(t.clone()); + //Add the starting configuration of the trajectory. events.push( current_state .iter() @@ -79,8 +99,9 @@ pub fn trajectory_generator( }) .collect(), ); - time.push(t.clone()); + //Generate new samples until ending time is reached. while t < t_end { + //Generate the next transition time for each uninitialized variable. for (idx, val) in next_transitions.iter_mut().enumerate() { if let None = val { *val = Some( @@ -96,19 +117,24 @@ pub fn trajectory_generator( ); } } - + + //Get the variable with the smallest transition time. let next_node_transition = next_transitions .iter() .enumerate() .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) .unwrap() .0; + //Check if the next transition take place after the ending time. if next_transitions[next_node_transition].unwrap() > t_end { break; } + //Get the time in which the next transition occurs. t = next_transitions[next_node_transition].unwrap().clone(); + //Add the transition time to next time.push(t.clone()); - + + //Compute the new state of the transitioning variable. current_state[next_node_transition] = net .get_node(next_node_transition) .params @@ -120,7 +146,8 @@ pub fn trajectory_generator( &mut rng, ) .unwrap(); - + + //Add the new state to events events.push(Array::from_vec( current_state .iter() @@ -129,13 +156,16 @@ pub fn trajectory_generator( }) .collect(), )); + //Reset the next transition time for the transitioning node. next_transitions[next_node_transition] = None; + //Reset the next transition time for each child of the transitioning node. for child in net.get_children_set(next_node_transition) { next_transitions[child] = None } } - + + //Add current_state as last state. events.push( current_state .iter() @@ -144,8 +174,10 @@ pub fn trajectory_generator( }) .collect(), ); + //Add t_end as last time. time.push(t_end.clone()); - + + //Add the sampled trajectory to trajectories. trajectories.push(Trajectory::init( Array::from_vec(time), Array2::from_shape_vec( @@ -155,5 +187,6 @@ pub fn trajectory_generator( .unwrap(), )); } + //Return a dataset object with the sampled trajectories. Dataset::init(trajectories) } From ab2162b5f162e8dccfd4d025bb8ac1850eae0dd9 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 15 Apr 2022 09:14:15 +0200 Subject: [PATCH 033/126] Constructors renamed from `init` to `new` --- src/ctbn.rs | 12 ++--- src/node.rs | 2 +- src/params.rs | 2 +- .../score_based_algorithm.rs | 2 +- src/structure_learning/score_function.rs | 6 +-- src/tools.rs | 8 +-- tests/ctbn.rs | 12 ++--- tests/parameter_learning.rs | 8 +-- tests/structure_learning.rs | 52 +++++++++---------- tests/tools.rs | 10 ++-- tests/utils.rs | 4 +- 11 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/ctbn.rs b/src/ctbn.rs index 9cabe20..620e9e9 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -29,19 +29,19 @@ use std::collections::BTreeSet; /// domain.insert(String::from("B")); /// /// //Create the parameters for a discrete node using the domain -/// let param = params::DiscreteStatesContinousTimeParams::init(domain); +/// let param = params::DiscreteStatesContinousTimeParams::new(domain); /// /// //Create the node using the parameters -/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1")); +/// let X1 = node::Node::new(params::Params::DiscreteStatesContinousTime(param),String::from("X1")); /// /// let mut domain = BTreeSet::new(); /// domain.insert(String::from("A")); /// domain.insert(String::from("B")); -/// let param = params::DiscreteStatesContinousTimeParams::init(domain); -/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2")); +/// let param = params::DiscreteStatesContinousTimeParams::new(domain); +/// let X2 = node::Node::new(params::Params::DiscreteStatesContinousTime(param), String::from("X2")); /// /// //Initialize a ctbn -/// let mut net = CtbnNetwork::init(); +/// let mut net = CtbnNetwork::new(); /// /// //Add nodes /// let X1 = net.add_node(X1).unwrap(); @@ -61,7 +61,7 @@ pub struct CtbnNetwork { impl CtbnNetwork { - pub fn init() -> CtbnNetwork { + pub fn new() -> CtbnNetwork { CtbnNetwork { adj_matrix: None, nodes: Vec::new() diff --git a/src/node.rs b/src/node.rs index 7ed21ba..3d8815f 100644 --- a/src/node.rs +++ b/src/node.rs @@ -7,7 +7,7 @@ pub struct Node { } impl Node { - pub fn init(params: Params, label: String) -> Node { + pub fn new(params: Params, label: String) -> Node { Node{ params: params, label:label diff --git a/src/params.rs b/src/params.rs index f0e5efa..d80fb43 100644 --- a/src/params.rs +++ b/src/params.rs @@ -77,7 +77,7 @@ pub struct DiscreteStatesContinousTimeParams { } impl DiscreteStatesContinousTimeParams { - pub fn init(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { + pub fn new(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams { domain, cim: Option::None, diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index e57c4c1..fe4e4ff 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -10,7 +10,7 @@ pub struct HillClimbing { } impl HillClimbing { - pub fn init(score_function: S, max_parent_set: Option) -> HillClimbing { + pub fn new(score_function: S, max_parent_set: Option) -> HillClimbing { HillClimbing { score_function, max_parent_set, diff --git a/src/structure_learning/score_function.rs b/src/structure_learning/score_function.rs index dba40e2..ad66b08 100644 --- a/src/structure_learning/score_function.rs +++ b/src/structure_learning/score_function.rs @@ -24,7 +24,7 @@ pub struct LogLikelihood { } impl LogLikelihood { - pub fn init(alpha: usize, tau: f64) -> LogLikelihood { + pub fn new(alpha: usize, tau: f64) -> LogLikelihood { //Tau must be >=0.0 if tau < 0.0 { @@ -106,9 +106,9 @@ pub struct BIC { } impl BIC { - pub fn init(alpha: usize, tau: f64) -> BIC { + pub fn new(alpha: usize, tau: f64) -> BIC { BIC { - ll: LogLikelihood::init(alpha, tau) + ll: LogLikelihood::new(alpha, tau) } } } diff --git a/src/tools.rs b/src/tools.rs index 7cf205b..b981f69 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -12,7 +12,7 @@ pub struct Trajectory { } impl Trajectory { - pub fn init(time: Array1, events: Array2) -> Trajectory { + pub fn new(time: Array1, events: Array2) -> Trajectory { //Events and time are two part of the same trajectory. For this reason they must have the //same number of sample. if time.shape()[0] != events.shape()[0] { @@ -35,7 +35,7 @@ pub struct Dataset { } impl Dataset { - pub fn init(trajectories: Vec) -> Dataset { + pub fn new(trajectories: Vec) -> Dataset { //All the trajectories in the same dataset must represent the same process. For this reason //each trajectory must represent the same number of variables. @@ -178,7 +178,7 @@ pub fn trajectory_generator( time.push(t_end.clone()); //Add the sampled trajectory to trajectories. - trajectories.push(Trajectory::init( + trajectories.push(Trajectory::new( Array::from_vec(time), Array2::from_shape_vec( (events.len(), current_state.len()), @@ -188,5 +188,5 @@ pub fn trajectory_generator( )); } //Return a dataset object with the sampled trajectories. - Dataset::init(trajectories) + Dataset::new(trajectories) } diff --git a/tests/ctbn.rs b/tests/ctbn.rs index 2d54f5f..c7d33ec 100644 --- a/tests/ctbn.rs +++ b/tests/ctbn.rs @@ -8,20 +8,20 @@ use rustyCTBN::ctbn::*; #[test] fn define_simpe_ctbn() { - let _ = CtbnNetwork::init(); + let _ = CtbnNetwork::new(); assert!(true); } #[test] fn add_node_to_ctbn() { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); assert_eq!(String::from("n1"), net.get_node(n1).label); } #[test] fn add_edge_to_ctbn() { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); net.add_edge(n1, n2); @@ -31,7 +31,7 @@ fn add_edge_to_ctbn() { #[test] fn children_and_parents() { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); net.add_edge(n1, n2); @@ -44,7 +44,7 @@ fn children_and_parents() { #[test] fn compute_index_ctbn() { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); @@ -76,7 +76,7 @@ fn compute_index_ctbn() { #[test] fn compute_index_from_custom_parent_set() { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 15245fd..a17c925 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -16,7 +16,7 @@ extern crate approx; fn learn_binary_cim (pl: T) { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) .unwrap(); @@ -66,7 +66,7 @@ fn learn_binary_cim_BA() { } fn learn_ternary_cim (pl: T) { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) .unwrap(); @@ -121,7 +121,7 @@ fn learn_ternary_cim_BA() { } fn learn_ternary_cim_no_parents (pl: T) { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) .unwrap(); @@ -175,7 +175,7 @@ fn learn_ternary_cim_no_parents_BA() { fn learn_mixed_discrete_cim (pl: T) { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) .unwrap(); diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index c3482cc..5c1ed84 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -18,18 +18,18 @@ extern crate approx; #[test] fn simple_score_test() { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) .unwrap(); - let trj = Trajectory::init( + let trj = Trajectory::new( arr1(&[0.0,0.1,0.3]), arr2(&[[0],[1],[1]])); - let dataset = Dataset::init(vec![trj]); + let dataset = Dataset::new(vec![trj]); - let ll = LogLikelihood::init(1, 1.0); + let ll = LogLikelihood::new(1, 1.0); assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); @@ -38,17 +38,17 @@ fn simple_score_test() { #[test] fn simple_bic() { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) .unwrap(); - let trj = Trajectory::init( + let trj = Trajectory::new( arr1(&[0.0,0.1,0.3]), arr2(&[[0],[1],[1]])); - let dataset = Dataset::init(vec![trj]); - let bic = BIC::init(1, 1.0); + let dataset = Dataset::new(vec![trj]); + let bic = BIC::new(1, 1.0); assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); @@ -57,7 +57,7 @@ fn simple_bic() { fn check_compatibility_between_dataset_and_network (sl: T) { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) .unwrap(); @@ -86,7 +86,7 @@ fn check_compatibility_between_dataset_and_network (sl: T) { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) .unwrap(); @@ -140,21 +140,21 @@ fn learn_ternary_net_2_nodes (sl: T) { #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { - let ll = LogLikelihood::init(1, 1.0); - let hl = HillClimbing::init(ll, None); + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); learn_ternary_net_2_nodes(hl); } #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { - let bic = BIC::init(1, 1.0); - let hl = HillClimbing::init(bic, None); + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); learn_ternary_net_2_nodes(hl); } fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { - let mut net = CtbnNetwork::init(); + let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) .unwrap(); @@ -222,15 +222,15 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { - let ll = LogLikelihood::init(1, 1.0); - let hl = HillClimbing::init(ll, None); + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); learn_mixed_discrete_net_3_nodes(hl); } #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { - let bic = BIC::init(1, 1.0); - let hl = HillClimbing::init(bic, None); + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); learn_mixed_discrete_net_3_nodes(hl); } @@ -247,14 +247,14 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint node::Node { - node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) + node::Node::new(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) } pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); - params::DiscreteStatesContinousTimeParams::init(domain) + params::DiscreteStatesContinousTimeParams::new(domain) } From a6c8d3e16d18d6383139acf0884ea467c791c837 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 15 Apr 2022 09:32:21 +0200 Subject: [PATCH 034/126] Changed the name from `rustyCTBN` to `reCTBN` (Rust Engine for Continuous Time Bayesian Networks) --- Cargo.toml | 2 +- README.md | 2 +- src/ctbn.rs | 8 ++++---- tests/ctbn.rs | 8 ++++---- tests/parameter_learning.rs | 12 ++++++------ tests/params.rs | 2 +- tests/structure_learning.rs | 14 +++++++------- tests/tools.rs | 10 +++++----- tests/utils.rs | 4 ++-- 9 files changed, 31 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7214634..4779b47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rustyCTBN" +name = "reCTBN" version = "0.1.0" edition = "2021" diff --git a/README.md b/README.md index be62df2..36009cf 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@
-# rustyCTBN +# reCTBN
diff --git a/src/ctbn.rs b/src/ctbn.rs index 620e9e9..69196f8 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -18,10 +18,10 @@ use std::collections::BTreeSet; ///``` /// /// use std::collections::BTreeSet; -/// use rustyCTBN::network::Network; -/// use rustyCTBN::node; -/// use rustyCTBN::params; -/// use rustyCTBN::ctbn::*; +/// use reCTBN::network::Network; +/// use reCTBN::node; +/// use reCTBN::params; +/// use reCTBN::ctbn::*; /// /// //Create the domain for a discrete node /// let mut domain = BTreeSet::new(); diff --git a/tests/ctbn.rs b/tests/ctbn.rs index c7d33ec..1458637 100644 --- a/tests/ctbn.rs +++ b/tests/ctbn.rs @@ -1,10 +1,10 @@ mod utils; use utils::generate_discrete_time_continous_node; -use rustyCTBN::network::Network; -use rustyCTBN::node; -use rustyCTBN::params; +use reCTBN::network::Network; +use reCTBN::node; +use reCTBN::params; use std::collections::BTreeSet; -use rustyCTBN::ctbn::*; +use reCTBN::ctbn::*; #[test] fn define_simpe_ctbn() { diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index a17c925..4e22c14 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -1,12 +1,12 @@ mod utils; use utils::*; -use rustyCTBN::parameter_learning::*; -use rustyCTBN::ctbn::*; -use rustyCTBN::network::Network; -use rustyCTBN::node; -use rustyCTBN::params; -use rustyCTBN::tools::*; +use reCTBN::parameter_learning::*; +use reCTBN::ctbn::*; +use reCTBN::network::Network; +use reCTBN::node; +use reCTBN::params; +use reCTBN::tools::*; use ndarray::arr3; use std::collections::BTreeSet; diff --git a/tests/params.rs b/tests/params.rs index b049d4e..fab150b 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,5 +1,5 @@ use ndarray::prelude::*; -use rustyCTBN::params::*; +use reCTBN::params::*; use std::collections::BTreeSet; use rand_chacha::ChaCha8Rng; use rand_chacha::rand_core::SeedableRng; diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 5c1ed84..42f948a 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -2,15 +2,15 @@ mod utils; use utils::*; -use rustyCTBN::ctbn::*; -use rustyCTBN::network::Network; -use rustyCTBN::tools::*; -use rustyCTBN::structure_learning::score_function::*; -use rustyCTBN::structure_learning::score_based_algorithm::*; -use rustyCTBN::structure_learning::StructureLearningAlgorithm; +use reCTBN::ctbn::*; +use reCTBN::network::Network; +use reCTBN::tools::*; +use reCTBN::structure_learning::score_function::*; +use reCTBN::structure_learning::score_based_algorithm::*; +use reCTBN::structure_learning::StructureLearningAlgorithm; use ndarray::{arr1, arr2, arr3}; use std::collections::BTreeSet; -use rustyCTBN::params; +use reCTBN::params; #[macro_use] diff --git a/tests/tools.rs b/tests/tools.rs index e5ba576..ac62288 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -1,9 +1,9 @@ -use rustyCTBN::tools::*; -use rustyCTBN::network::Network; -use rustyCTBN::ctbn::*; -use rustyCTBN::node; -use rustyCTBN::params; +use reCTBN::tools::*; +use reCTBN::network::Network; +use reCTBN::ctbn::*; +use reCTBN::node; +use reCTBN::params; use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; diff --git a/tests/utils.rs b/tests/utils.rs index be9748c..e9e5176 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -1,5 +1,5 @@ -use rustyCTBN::params; -use rustyCTBN::node; +use reCTBN::params; +use reCTBN::node; use std::collections::BTreeSet; pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { From 9b7e6836303a803ac0a3e6ba38bc009e253db953 Mon Sep 17 00:00:00 2001 From: Alessandro Bregoli Date: Fri, 15 Apr 2022 11:38:44 +0200 Subject: [PATCH 035/126] Removed `node.rs` --- src/ctbn.rs | 32 +-- src/lib.rs | 1 - src/network.rs | 7 +- src/node.rs | 25 -- src/parameter_learning.rs | 2 - src/params.rs | 12 +- src/structure_learning/score_function.rs | 2 +- src/tools.rs | 8 +- tests/ctbn.rs | 125 +++++--- tests/parameter_learning.rs | 350 +++++++++++++++-------- tests/params.rs | 31 +- tests/structure_learning.rs | 237 +++++++++------ tests/tools.rs | 54 ++-- tests/utils.rs | 11 +- 14 files changed, 539 insertions(+), 358 deletions(-) delete mode 100644 src/node.rs diff --git a/src/ctbn.rs b/src/ctbn.rs index 69196f8..2cede4a 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -1,6 +1,5 @@ use ndarray::prelude::*; -use crate::node; -use crate::params::{StateType, ParamsTrait}; +use crate::params::{StateType, Params, ParamsTrait}; use crate::network; use std::collections::BTreeSet; @@ -19,7 +18,6 @@ use std::collections::BTreeSet; /// /// use std::collections::BTreeSet; /// use reCTBN::network::Network; -/// use reCTBN::node; /// use reCTBN::params; /// use reCTBN::ctbn::*; /// @@ -29,16 +27,16 @@ use std::collections::BTreeSet; /// domain.insert(String::from("B")); /// /// //Create the parameters for a discrete node using the domain -/// let param = params::DiscreteStatesContinousTimeParams::new(domain); +/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); /// /// //Create the node using the parameters -/// let X1 = node::Node::new(params::Params::DiscreteStatesContinousTime(param),String::from("X1")); +/// let X1 = params::Params::DiscreteStatesContinousTime(param); /// /// let mut domain = BTreeSet::new(); /// domain.insert(String::from("A")); /// domain.insert(String::from("B")); -/// let param = params::DiscreteStatesContinousTimeParams::new(domain); -/// let X2 = node::Node::new(params::Params::DiscreteStatesContinousTime(param), String::from("X2")); +/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); +/// let X2 = params::Params::DiscreteStatesContinousTime(param); /// /// //Initialize a ctbn /// let mut net = CtbnNetwork::new(); @@ -56,7 +54,7 @@ use std::collections::BTreeSet; /// ``` pub struct CtbnNetwork { adj_matrix: Option>, - nodes: Vec + nodes: Vec } @@ -75,8 +73,8 @@ impl network::Network for CtbnNetwork { } - fn add_node(&mut self, mut n: node::Node) -> Result { - n.params.reset_params(); + fn add_node(&mut self, mut n: Params) -> Result { + n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); Ok(self.nodes.len() -1) @@ -89,7 +87,7 @@ impl network::Network for CtbnNetwork { if let Some(network) = &mut self.adj_matrix { network[[parent, child]] = 1; - self.nodes[child].params.reset_params(); + self.nodes[child].reset_params(); } } @@ -101,12 +99,12 @@ impl network::Network for CtbnNetwork { self.nodes.len() } - fn get_node(&self, node_idx: usize) -> &node::Node{ + fn get_node(&self, node_idx: usize) -> &Params{ &self.nodes[node_idx] } - fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{ + fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{ &mut self.nodes[node_idx] } @@ -114,8 +112,8 @@ impl network::Network for CtbnNetwork { fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { if x.1 > &0 { - acc.0 += self.nodes[x.0].params.state_to_index(¤t_state[x.0]) * acc.1; - acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent(); + acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; + acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); } acc }).0 @@ -124,8 +122,8 @@ impl network::Network for CtbnNetwork { fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize { parent_set.iter().fold((0, 1), |mut acc, x| { - acc.0 += self.nodes[*x].params.state_to_index(¤t_state[*x]) * acc.1; - acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent(); + acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; + acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); acc }).0 } diff --git a/src/lib.rs b/src/lib.rs index ec12261..bcbde3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ #[macro_use] extern crate approx; -pub mod node; pub mod params; pub mod network; pub mod ctbn; diff --git a/src/network.rs b/src/network.rs index 3b6ce06..1c962b0 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,6 +1,5 @@ use thiserror::Error; use crate::params; -use crate::node; use std::collections::BTreeSet; /// Error types for trait Network @@ -15,14 +14,14 @@ pub enum NetworkError { ///The Network trait define the required methods for a structure used as pgm (such as ctbn). pub trait Network { fn initialize_adj_matrix(&mut self); - fn add_node(&mut self, n: node::Node) -> Result; + fn add_node(&mut self, n: params::Params) -> Result; fn add_edge(&mut self, parent: usize, child: usize); ///Get all the indices of the nodes contained inside the network fn get_node_indices(&self) -> std::ops::Range; fn get_number_of_nodes(&self) -> usize; - fn get_node(&self, node_idx: usize) -> &node::Node; - fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node; + fn get_node(&self, node_idx: usize) -> ¶ms::Params; + fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; ///Compute the index that must be used to access the parameters of a node given a specific ///configuration of the network. Usually, the only values really used in *current_state* are diff --git a/src/node.rs b/src/node.rs deleted file mode 100644 index 3d8815f..0000000 --- a/src/node.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::params::*; - - -pub struct Node { - pub params: Params, - pub label: String -} - -impl Node { - pub fn new(params: Params, label: String) -> Node { - Node{ - params: params, - label:label - } - } - -} - -impl PartialEq for Node { - fn eq(&self, other: &Node) -> bool{ - self.label == other.label - } -} - - diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index c4221cb..5270d9e 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -24,7 +24,6 @@ pub fn sufficient_statistics( //Get the number of values assumable by the node let node_domain = net .get_node(node.clone()) - .params .get_reserved_space_as_parent(); //Get the number of values assumable by each parent of the node @@ -32,7 +31,6 @@ pub fn sufficient_statistics( .iter() .map(|x| { net.get_node(x.clone()) - .params .get_reserved_space_as_parent() }) .collect(); diff --git a/src/params.rs b/src/params.rs index d80fb43..e632b1b 100644 --- a/src/params.rs +++ b/src/params.rs @@ -49,6 +49,9 @@ pub trait ParamsTrait { /// Validate parameters against domain fn validate_params(&self) -> Result<(), ParamsError>; + + /// Return a reference to the associated label + fn get_label(&self) -> &String; } /// The Params enum is the core element for building different types of nodes. The goal is to @@ -70,6 +73,7 @@ pub enum Params { /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set pub struct DiscreteStatesContinousTimeParams { + label: String, domain: BTreeSet, cim: Option>, transitions: Option>, @@ -77,8 +81,9 @@ pub struct DiscreteStatesContinousTimeParams { } impl DiscreteStatesContinousTimeParams { - pub fn new(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { + pub fn new(label: String, domain: BTreeSet) -> DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams { + label, domain, cim: Option::None, transitions: Option::None, @@ -244,4 +249,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { return Ok(()); } + + fn get_label(&self) -> &String { + &self.label + } + } diff --git a/src/structure_learning/score_function.rs b/src/structure_learning/score_function.rs index ad66b08..ea53db5 100644 --- a/src/structure_learning/score_function.rs +++ b/src/structure_learning/score_function.rs @@ -44,7 +44,7 @@ impl LogLikelihood { T: network::Network, { //Identify the type of node used - match &net.get_node(node).params { + match &net.get_node(node){ params::Params::DiscreteStatesContinousTime(_params) => { //Compute the sufficient statistics M (number of transistions) and T (residence //time) diff --git a/src/tools.rs b/src/tools.rs index b981f69..115fd67 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,5 +1,4 @@ use crate::network; -use crate::node; use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; @@ -80,7 +79,7 @@ pub fn trajectory_generator( //Configuration of the process variables at time t initialized with an uniform //distribution. let mut current_state: Vec = net.get_node_indices() - .map(|x| net.get_node(x).params.get_random_state_uniform(&mut rng)) + .map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) .collect(); //History of all the configurations of the process variables. let mut events: Vec> = Vec::new(); @@ -106,9 +105,8 @@ pub fn trajectory_generator( if let None = val { *val = Some( net.get_node(idx) - .params .get_random_residence_time( - net.get_node(idx).params.state_to_index(¤t_state[idx]), + net.get_node(idx).state_to_index(¤t_state[idx]), net.get_param_index_network(idx, ¤t_state), &mut rng, ) @@ -137,10 +135,8 @@ pub fn trajectory_generator( //Compute the new state of the transitioning variable. current_state[next_node_transition] = net .get_node(next_node_transition) - .params .get_random_state( net.get_node(next_node_transition) - .params .state_to_index(¤t_state[next_node_transition]), net.get_param_index_network(next_node_transition, ¤t_state), &mut rng, diff --git a/tests/ctbn.rs b/tests/ctbn.rs index 1458637..e5cad1e 100644 --- a/tests/ctbn.rs +++ b/tests/ctbn.rs @@ -1,10 +1,9 @@ mod utils; -use utils::generate_discrete_time_continous_node; +use reCTBN::ctbn::*; use reCTBN::network::Network; -use reCTBN::node; -use reCTBN::params; +use reCTBN::params::{self, ParamsTrait}; use std::collections::BTreeSet; -use reCTBN::ctbn::*; +use utils::generate_discrete_time_continous_node; #[test] fn define_simpe_ctbn() { @@ -15,15 +14,21 @@ fn define_simpe_ctbn() { #[test] fn add_node_to_ctbn() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - assert_eq!(String::from("n1"), net.get_node(n1).label); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); } #[test] fn add_edge_to_ctbn() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); net.add_edge(n1, n2); let cs = net.get_children_set(n1); assert_eq!(&n2, cs.iter().next().unwrap()); @@ -32,8 +37,12 @@ fn add_edge_to_ctbn() { #[test] fn children_and_parents() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); net.add_edge(n1, n2); let cs = net.get_children_set(n1); assert_eq!(&n2, cs.iter().next().unwrap()); @@ -41,59 +50,81 @@ fn children_and_parents() { assert_eq!(&n1, ps.iter().next().unwrap()); } - #[test] fn compute_index_ctbn() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); net.add_edge(n1, n2); net.add_edge(n3, n2); - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(1), - params::StateType::Discrete(1), - params::StateType::Discrete(1)]); + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(1), + params::StateType::Discrete(1), + ], + ); assert_eq!(3, idx); - - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(1), - params::StateType::Discrete(1)]); + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(1), + params::StateType::Discrete(1), + ], + ); assert_eq!(2, idx); - - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(1), - params::StateType::Discrete(1), - params::StateType::Discrete(0)]); + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(1), + params::StateType::Discrete(0), + ], + ); assert_eq!(1, idx); - } - - #[test] fn compute_index_from_custom_parent_set() { let mut net = CtbnNetwork::new(); - let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); - - - let idx = net.get_param_index_from_custom_parent_set(&vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(0), - params::StateType::Discrete(1)], - &BTreeSet::from([1])); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let _n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let _n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + let idx = net.get_param_index_from_custom_parent_set( + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1), + ], + &BTreeSet::from([1]), + ); assert_eq!(0, idx); - - let idx = net.get_param_index_from_custom_parent_set(&vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(0), - params::StateType::Discrete(1)], - &BTreeSet::from([1,2])); + let idx = net.get_param_index_from_custom_parent_set( + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1), + ], + &BTreeSet::from([1, 2]), + ); assert_eq!(2, idx); } diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 4e22c14..cd980d0 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -1,263 +1,365 @@ mod utils; use utils::*; -use reCTBN::parameter_learning::*; +use ndarray::arr3; use reCTBN::ctbn::*; use reCTBN::network::Network; -use reCTBN::node; -use reCTBN::params; -use reCTBN::tools::*; -use ndarray::arr3; +use reCTBN::parameter_learning::*; +use reCTBN::{params, tools::*}; use std::collections::BTreeSet; - #[macro_use] extern crate approx; - -fn learn_binary_cim (pl: T) { +fn learn_binary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 1.0], [4.0, -4.0]], + [[-6.0, 6.0], [2.0, -2.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), + 0.1 + )); } #[test] fn learn_binary_cim_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_binary_cim(mle); } - #[test] fn learn_binary_cim_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_binary_cim(ba); } -fn learn_ternary_cim (pl: T) { +fn learn_ternary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]), + 0.1 + )); } - #[test] fn learn_ternary_cim_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_ternary_cim(mle); } - #[test] fn learn_ternary_cim_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim(ba); } -fn learn_ternary_cim_no_parents (pl: T) { +fn learn_ternary_cim_no_parents(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); - assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), + 0.1 + )); } - #[test] fn learn_ternary_cim_no_parents_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_ternary_cim_no_parents(mle); } - #[test] fn learn_ternary_cim_no_parents_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim_no_parents(ba); } - -fn learn_mixed_discrete_cim (pl: T) { +fn learn_mixed_discrete_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); let n3 = net - .add_node(generate_discrete_time_continous_node(String::from("n3"),4)) + .add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) .unwrap(); net.add_edge(n1, n2); net.add_edge(n1, n3); net.add_edge(n2, n3); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - - match &mut net.get_node_mut(n3).params { + match &mut net.get_node_mut(n3) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ])) + ); } } - - let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ]), + 0.1 + )); } #[test] fn learn_mixed_discrete_cim_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_mixed_discrete_cim(mle); } - #[test] fn learn_mixed_discrete_cim_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_mixed_discrete_cim(ba); } diff --git a/tests/params.rs b/tests/params.rs index fab150b..c002d7b 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,16 +1,15 @@ use ndarray::prelude::*; -use reCTBN::params::*; -use std::collections::BTreeSet; -use rand_chacha::ChaCha8Rng; -use rand_chacha::rand_core::SeedableRng; +use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; +use reCTBN::params::{ParamsTrait, *}; mod utils; #[macro_use] extern crate approx; + fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { - let mut params = utils::generate_discrete_time_continous_param(3); + let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; @@ -18,6 +17,12 @@ fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTime params } +#[test] +fn test_get_label() { + let param = create_ternary_discrete_time_continous_param(); + assert_eq!(&String::from("A"), param.get_label()) +} + #[test] fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); @@ -79,15 +84,19 @@ fn test_validate_params_valid_cim() { #[test] fn test_validate_params_valid_cim_with_huge_values() { - let mut param = utils::generate_discrete_time_continous_param(3); - let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]]; + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); + let cim = array![[ + [-2e10, 1e10, 1e10], + [1.5e10, -3e10, 1.5e10], + [1e10, 1e10, -2e10] + ]]; let result = param.set_cim(cim); assert_eq!(Ok(()), result); } #[test] fn test_validate_params_cim_not_initialized() { - let param = utils::generate_discrete_time_continous_param(3); + let param = utils::generate_discrete_time_continous_params("A".to_string(), 3); assert_eq!( Err(ParamsError::ParametersNotInitialized(String::from( "CIM not initialized", @@ -98,7 +107,7 @@ fn test_validate_params_cim_not_initialized() { #[test] fn test_validate_params_wrong_shape() { - let mut param = utils::generate_discrete_time_continous_param(4); + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let result = param.set_cim(cim); assert_eq!( @@ -111,7 +120,7 @@ fn test_validate_params_wrong_shape() { #[test] fn test_validate_params_positive_diag() { - let mut param = utils::generate_discrete_time_continous_param(3); + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let result = param.set_cim(cim); assert_eq!( @@ -124,7 +133,7 @@ fn test_validate_params_positive_diag() { #[test] fn test_validate_params_row_not_sum_to_zero() { - let mut param = utils::generate_discrete_time_continous_param(3); + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; let result = param.set_cim(cim); assert_eq!( diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 42f948a..c91f508 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -1,17 +1,14 @@ - mod utils; use utils::*; +use ndarray::{arr1, arr2, arr3}; use reCTBN::ctbn::*; use reCTBN::network::Network; -use reCTBN::tools::*; +use reCTBN::params; use reCTBN::structure_learning::score_function::*; -use reCTBN::structure_learning::score_based_algorithm::*; -use reCTBN::structure_learning::StructureLearningAlgorithm; -use ndarray::{arr1, arr2, arr3}; +use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm}; +use reCTBN::tools::*; use std::collections::BTreeSet; -use reCTBN::params; - #[macro_use] extern crate approx; @@ -20,80 +17,86 @@ extern crate approx; fn simple_score_test() { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); - let trj = Trajectory::new( - arr1(&[0.0,0.1,0.3]), - arr2(&[[0],[1],[1]])); + let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); let dataset = Dataset::new(vec![trj]); let ll = LogLikelihood::new(1, 1.0); - assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); - + assert_abs_diff_eq!( + 0.04257, + ll.call(&net, n1, &BTreeSet::new(), &dataset), + epsilon = 1e-3 + ); } - #[test] fn simple_bic() { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); - let trj = Trajectory::new( - arr1(&[0.0,0.1,0.3]), - arr2(&[[0],[1],[1]])); + let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); let dataset = Dataset::new(vec![trj]); let bic = BIC::new(1, 1.0); - assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); - + assert_abs_diff_eq!( + -0.65058, + bic.call(&net, n1, &BTreeSet::new(), &dataset), + epsilon = 1e-3 + ); } - - -fn check_compatibility_between_dataset_and_network (sl: T) { +fn check_compatibility_between_dataset_and_network(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); let mut net = CtbnNetwork::new(); let _n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let net = sl.fit_transform(net, &data); } - #[test] #[should_panic] pub fn check_compatibility_between_dataset_and_network_hill_climbing() { @@ -102,42 +105,49 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes (sl: T) { +fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } - #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -152,66 +162,117 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { learn_ternary_net_2_nodes(hl); } - fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); let n3 = net - .add_node(generate_discrete_time_continous_node(String::from("n3"),4)) + .add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) .unwrap(); net.add_edge(n1, n2); net.add_edge(n1, n3); net.add_edge(n2, n3); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - - match &mut net.get_node_mut(n3).params { + match &mut net.get_node_mut(n3) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ])) + ); } } - - let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); return (net, data); } -fn learn_mixed_discrete_net_3_nodes (sl: T) { +fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -219,7 +280,6 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); } - #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -234,9 +294,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } - - -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint (sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -244,7 +302,6 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint { - param.set_cim(arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); + param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); } } - - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { param.set_cim(arr3(&[ - [[-1.0,1.0],[4.0,-4.0]], - [[-6.0,6.0],[2.0,-2.0]]])); + [[-1.0, 1.0], [4.0, -4.0]], + [[-6.0, 6.0], [2.0, -2.0]], + ])); } } - let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259)); assert_eq!(4, data.get_trajectories().len()); - assert_relative_eq!(1.0, data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len()-1]); + assert_relative_eq!( + 1.0, + data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1] + ); } - #[test] #[should_panic] - fn trajectory_wrong_shape() { +fn trajectory_wrong_shape() { let time = arr1(&[0.0, 0.2]); - let events = arr2(&[[0,3]]); + let events = arr2(&[[0, 3]]); Trajectory::new(time, events); } - #[test] #[should_panic] fn dataset_wrong_shape() { let time = arr1(&[0.0, 0.2]); - let events = arr2(&[[0,3], [1,2]]); + let events = arr2(&[[0, 3], [1, 2]]); let t1 = Trajectory::new(time, events); - let time = arr1(&[0.0, 0.2]); - let events = arr2(&[[0,3,3], [1,2,3]]); + let events = arr2(&[[0, 3, 3], [1, 2, 3]]); let t2 = Trajectory::new(time, events); Dataset::new(vec![t1, t2]); } diff --git a/tests/utils.rs b/tests/utils.rs index e9e5176..8648c46 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -1,16 +1,17 @@ use reCTBN::params; -use reCTBN::node; use std::collections::BTreeSet; -pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { - node::Node::new(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) +pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { + params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality)) } -pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ +pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); - params::DiscreteStatesContinousTimeParams::new(domain) + params::DiscreteStatesContinousTimeParams::new(label, domain) } + + From 2605bf38168b75f87d125091356af9673ad27fb1 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 11 May 2022 22:24:17 +0200 Subject: [PATCH 036/126] Implemented part of matrices comparison in chi square --- Cargo.toml | 3 +- src/parameter_learning.rs | 17 +++ src/structure_learning.rs | 2 + .../constraint_based_algorithm.rs | 5 + src/structure_learning/hypothesis_test.rs | 101 ++++++++++++++++++ tests/structure_learning.rs | 23 ++++ 6 files changed, 149 insertions(+), 2 deletions(-) create mode 100644 src/structure_learning/constraint_based_algorithm.rs create mode 100644 src/structure_learning/hypothesis_test.rs diff --git a/Cargo.toml b/Cargo.toml index 4779b47..56d0452 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] - -ndarray = {version="*", features=["approx"]} +ndarray = {version="*", features=["approx-0_5"]} thiserror = "*" rand = "*" bimap = "*" diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 5270d9e..19c0e4c 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -153,3 +153,20 @@ impl ParameterLearning for BayesianApproach { return (CIM, M, T); } } + + +pub struct Cache { + parameter_learning: P, +} + +impl Cache

{ + pub fn fit( + &mut self, + net: &T, + dataset: &tools::Dataset, + node: usize, + parent_set: Option>, + ) -> (Array3, Array3, Array2) { + self.parameter_learning.fit(net, dataset, node, parent_set) + } +} diff --git a/src/structure_learning.rs b/src/structure_learning.rs index 8ba91df..b7db7ed 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -1,5 +1,7 @@ pub mod score_function; pub mod score_based_algorithm; +pub mod constraint_based_algorithm; +pub mod hypothesis_test; use crate::network; use crate::tools; diff --git a/src/structure_learning/constraint_based_algorithm.rs b/src/structure_learning/constraint_based_algorithm.rs new file mode 100644 index 0000000..0d8b655 --- /dev/null +++ b/src/structure_learning/constraint_based_algorithm.rs @@ -0,0 +1,5 @@ + +//pub struct CTPC { +// +//} + diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs new file mode 100644 index 0000000..fc5c86f --- /dev/null +++ b/src/structure_learning/hypothesis_test.rs @@ -0,0 +1,101 @@ +use ndarray::Array2; +use ndarray::Array3; +use ndarray::Axis; + +use crate::network; +use crate::parameter_learning; +use std::collections::BTreeSet; + +pub trait HypothesisTest { + + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: parameter_learning::Cache

+ ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning; + +} + + +pub struct ChiSquare { + pub alpha: f64, +} + +pub struct F { + +} + +impl ChiSquare { + pub fn compare_matrices( + &self, i: usize, + M1: &Array3, + j: usize, + M2: &Array3 + ) -> bool { + // Bregoli, A., Scutari, M. and Stella, F., 2021. + // A constraint-based algorithm for the structural learning of + // continuous-time Bayesian networks. + // International Journal of Approximate Reasoning, 138, pp.105-122. + // + // M = M M = M + // 1 xx'|s 2 xx'|y,s + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + // __________________ + // / === + // / \ M + // / / xx'|s + // / === + // / x'ϵVal /X \ + // / \ i/ 1 + //K = / ------------------ L = - + // / === K + // / \ M + // / / xx'|y,s + // / === + // / x'ϵVal /X \ + // \ / \ i/ + // \/ + let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); + let K = K.mapv(f64::sqrt); + // Reshape to column vector. + let K = { + let n = K.len(); + K.into_shape((n, 1)).unwrap() + }; + let L = 1.0 / &K; + // ===== + // \ K . M - L . M + // \ 2 1 + // / --------------- + // / M + M + // ===== 2 1 + // x'ϵVal /X \ + // \ i/ + let X_2 = (( K * &M2 - L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1)).sum_axis(Axis(1)); + println!("X_2: {:?}", X_2); + true + } +} + +impl HypothesisTest for ChiSquare { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: parameter_learning::Cache

+ ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning { + todo!() + } +} diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index c91f508..be9c8d5 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -7,6 +7,7 @@ use reCTBN::network::Network; use reCTBN::params; use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm}; +use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::tools::*; use std::collections::BTreeSet; @@ -315,3 +316,25 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() let hl = HillClimbing::new(bic, Some(1)); learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } + +#[test] +pub fn chi_square_compare_matrices () { + let i: usize = 1; + let M1 = arr3(&[ + [[ 1, 2, 3], + [ 4, 5, 6]], + [[ 22, 12, 90], + [3, 20, 40]], + [[ 1, 2, 3], + [ 4, 5, 6]], + [[ 7, 8, 9], + [10, 11, 12]] + ]); + let j: usize = 1; + let M2 = arr3(&[[[ 1, 2, 3], // -- 2 rows \_ + [ 4, 5, 6]], + [[ 7, 8, 9], + [10, 11, 12]]]); + let chi_sq = ChiSquare {alpha: 0.5}; + chi_sq.compare_matrices( i, &M1, j, &M2); +} From 4b35ae63101394a47afbd22e397830f6043b7027 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 19 May 2022 13:47:39 +0200 Subject: [PATCH 037/126] Implemented matrices comparison function in chi square --- Cargo.toml | 2 +- src/structure_learning/hypothesis_test.rs | 28 ++++++-- tests/structure_learning.rs | 80 ++++++++++++++++++----- 3 files changed, 88 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 56d0452..553e294 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ndarray = {version="*", features=["approx-0_5"]} +ndarray = {version="*", features=["approx"]} thiserror = "*" rand = "*" bimap = "*" diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index fc5c86f..6e06721 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -1,6 +1,7 @@ use ndarray::Array2; use ndarray::Array3; use ndarray::Axis; +use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::network; use crate::parameter_learning; @@ -24,7 +25,7 @@ pub trait HypothesisTest { pub struct ChiSquare { - pub alpha: f64, + alpha: f64, } pub struct F { @@ -32,6 +33,11 @@ pub struct F { } impl ChiSquare { + pub fn new( alpha: f64) -> ChiSquare { + ChiSquare { + alpha + } + } pub fn compare_matrices( &self, i: usize, M1: &Array3, @@ -42,6 +48,7 @@ impl ChiSquare { // A constraint-based algorithm for the structural learning of // continuous-time Bayesian networks. // International Journal of Approximate Reasoning, 138, pp.105-122. + // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm // // M = M M = M // 1 xx'|s 2 xx'|y,s @@ -70,17 +77,26 @@ impl ChiSquare { K.into_shape((n, 1)).unwrap() }; let L = 1.0 / &K; - // ===== - // \ K . M - L . M + // ===== 2 + // \ (K . M - L . M) // \ 2 1 // / --------------- // / M + M // ===== 2 1 // x'ϵVal /X \ // \ i/ - let X_2 = (( K * &M2 - L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1)).sum_axis(Axis(1)); - println!("X_2: {:?}", X_2); - true + let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1); + println!("M1: {:?}", M1); + println!("M2: {:?}", M2); + println!("L*M1: {:?}", (L * &M1)); + println!("K*M2: {:?}", (K * &M2)); + println!("X_2: {:?}", X_2); + X_2.diag_mut().fill(0.0); + let X_2 = X_2.sum_axis(Axis(1)); + let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); + println!("CHI^2: {:?}", n); + println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); + X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)) } } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index be9c8d5..2c9645b 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -321,20 +321,70 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() pub fn chi_square_compare_matrices () { let i: usize = 1; let M1 = arr3(&[ - [[ 1, 2, 3], - [ 4, 5, 6]], - [[ 22, 12, 90], - [3, 20, 40]], - [[ 1, 2, 3], - [ 4, 5, 6]], - [[ 7, 8, 9], - [10, 11, 12]] + [[ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0]], + [[0, 12, 90], + [ 3, 0, 40], + [ 6, 40, 0]], + [[ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0]] ]); - let j: usize = 1; - let M2 = arr3(&[[[ 1, 2, 3], // -- 2 rows \_ - [ 4, 5, 6]], - [[ 7, 8, 9], - [10, 11, 12]]]); - let chi_sq = ChiSquare {alpha: 0.5}; - chi_sq.compare_matrices( i, &M1, j, &M2); + let j: usize = 0; + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); + let chi_sq = ChiSquare::new(0.1); + assert!(!chi_sq.compare_matrices( i, &M1, j, &M2)); +} + +#[test] +pub fn chi_square_compare_matrices_2 () { + let i: usize = 1; + let M1 = arr3(&[ + [[ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0]], + [[0, 20, 30], + [ 40, 0, 60], + [ 70, 80, 0]], + [[ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0]] + ]); + let j: usize = 0; + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); + let chi_sq = ChiSquare::new(0.1); + assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); +} + +#[test] +pub fn chi_square_compare_matrices_3 () { + let i: usize = 1; + let M1 = arr3(&[ + [[ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0]], + [[0, 21, 31], + [ 41, 0, 59], + [ 71, 79, 0]], + [[ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0]] + ]); + let j: usize = 0; + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); + let chi_sq = ChiSquare::new(0.1); + assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); } From 68ada89c0419af88eeca13541664b7799f76141f Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 13 Jun 2022 13:27:42 +0200 Subject: [PATCH 038/126] Expanded Hypothesis test --- src/parameter_learning.rs | 4 ++-- src/structure_learning/hypothesis_test.rs | 26 +++++++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 19c0e4c..6fff9d1 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -157,16 +157,16 @@ impl ParameterLearning for BayesianApproach { pub struct Cache { parameter_learning: P, + dataset: tools::Dataset, } impl Cache

{ pub fn fit( &mut self, net: &T, - dataset: &tools::Dataset, node: usize, parent_set: Option>, ) -> (Array3, Array3, Array2) { - self.parameter_learning.fit(net, dataset, node, parent_set) + self.parameter_learning.fit(net, &self.dataset, node, parent_set) } } diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index 6e06721..86500e5 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -5,6 +5,7 @@ use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::network; use crate::parameter_learning; +use crate::params::ParamsTrait; use std::collections::BTreeSet; pub trait HypothesisTest { @@ -15,7 +16,7 @@ pub trait HypothesisTest { child_node: usize, parent_node: usize, separation_set: &BTreeSet, - cache: parameter_learning::Cache

+ cache: &mut parameter_learning::Cache

) -> bool where T: network::Network, @@ -39,7 +40,8 @@ impl ChiSquare { } } pub fn compare_matrices( - &self, i: usize, + &self, + i: usize, M1: &Array3, j: usize, M2: &Array3 @@ -107,11 +109,27 @@ impl HypothesisTest for ChiSquare { child_node: usize, parent_node: usize, separation_set: &BTreeSet, - cache: parameter_learning::Cache

+ cache: &mut parameter_learning::Cache

) -> bool where T: network::Network, P: parameter_learning::ParameterLearning { - todo!() + // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM + // di dimensione nxn + // (CIM, M, T) + let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone())); + // + let mut extended_separation_set = separation_set.clone(); + extended_separation_set.insert(parent_node); + let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone())); + // Commentare qui + let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product(); + for idx_M_big in 0..M_big.shape()[0] { + let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product; + if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) { + return false; + } + } + return true; } } From f5756f71d31e86aeddaa179991b7d615f84b5ebf Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Fri, 29 Jul 2022 10:36:53 +0200 Subject: [PATCH 039/126] Changed the return type for ParameterLearning::fit from tuple to Param --- src/parameter_learning.rs | 41 +++++++++++++++--- src/params.rs | 15 +++++-- src/structure_learning/hypothesis_test.rs | 53 +++++++++++++---------- tests/parameter_learning.rs | 36 ++++++++------- 4 files changed, 95 insertions(+), 50 deletions(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 6fff9d1..bf5b96a 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -12,7 +12,7 @@ pub trait ParameterLearning{ dataset: &tools::Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2); + ) -> Params; } pub fn sufficient_statistics( @@ -84,8 +84,7 @@ impl ParameterLearning for MLE { dataset: &tools::Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { - //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes + ) -> Params { //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { @@ -107,7 +106,21 @@ impl ParameterLearning for MLE { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - return (CIM, M, T); + + + + let mut n: Params = net.get_node(node).clone(); + + match n { + Params::DiscreteStatesContinousTime(ref mut dsct) => { + dsct.set_cim_unchecked(CIM); + dsct.set_transitions(M); + dsct.set_residence_time(T); + + + } + }; + return n; } } @@ -123,7 +136,7 @@ impl ParameterLearning for BayesianApproach { dataset: &tools::Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { + ) -> Params { //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes //Use parent_set from parameter if present. Otherwise use parent_set from network. @@ -150,7 +163,21 @@ impl ParameterLearning for BayesianApproach { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - return (CIM, M, T); + + + + let mut n: Params = net.get_node(node).clone(); + + match n { + Params::DiscreteStatesContinousTime(ref mut dsct) => { + dsct.set_cim_unchecked(CIM); + dsct.set_transitions(M); + dsct.set_residence_time(T); + + + } + }; + return n; } } @@ -166,7 +193,7 @@ impl Cache

{ net: &T, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { + ) -> Params { self.parameter_learning.fit(net, &self.dataset, node, parent_set) } } diff --git a/src/params.rs b/src/params.rs index e632b1b..d9f307f 100644 --- a/src/params.rs +++ b/src/params.rs @@ -55,7 +55,8 @@ pub trait ParamsTrait { } /// The Params enum is the core element for building different types of nodes. The goal is to -/// define all the supported type of parameters. +/// define all the supported type of Parameters +#[derive(Clone)] #[enum_dispatch] pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), @@ -72,11 +73,12 @@ pub enum Params { /// realization of the parent set /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set +#[derive(Clone)] pub struct DiscreteStatesContinousTimeParams { label: String, domain: BTreeSet, cim: Option>, - transitions: Option>, + transitions: Option>, residence_time: Option>, } @@ -112,14 +114,19 @@ impl DiscreteStatesContinousTimeParams { } + ///Unchecked version of the setter function for CIM. + pub fn set_cim_unchecked(&mut self, cim: Array3) { + self.cim = Some(cim); + } + ///Getter function for transitions - pub fn get_transitions(&self) -> &Option> { + pub fn get_transitions(&self) -> &Option> { &self.transitions } ///Setter function for transitions - pub fn set_transitions(&mut self, transitions: Array3) { + pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); } diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index 86500e5..eb6b570 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -5,46 +5,39 @@ use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::network; use crate::parameter_learning; -use crate::params::ParamsTrait; +use crate::params::*; use std::collections::BTreeSet; pub trait HypothesisTest { - fn call( &self, net: &T, child_node: usize, parent_node: usize, separation_set: &BTreeSet, - cache: &mut parameter_learning::Cache

+ cache: &mut parameter_learning::Cache

, ) -> bool where T: network::Network, P: parameter_learning::ParameterLearning; - } - pub struct ChiSquare { alpha: f64, } -pub struct F { - -} +pub struct F {} impl ChiSquare { - pub fn new( alpha: f64) -> ChiSquare { - ChiSquare { - alpha - } + pub fn new(alpha: f64) -> ChiSquare { + ChiSquare { alpha } } pub fn compare_matrices( &self, i: usize, M1: &Array3, j: usize, - M2: &Array3 + M2: &Array3, ) -> bool { // Bregoli, A., Scutari, M. and Stella, F., 2021. // A constraint-based algorithm for the structural learning of @@ -87,7 +80,7 @@ impl ChiSquare { // ===== 2 1 // x'ϵVal /X \ // \ i/ - let mut X_2 = ( &K * &M2 - &L * &M1 ).mapv(|a| a.powi(2)) / (&M2 + &M1); + let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); println!("M1: {:?}", M1); println!("M2: {:?}", M2); println!("L*M1: {:?}", (L * &M1)); @@ -109,24 +102,38 @@ impl HypothesisTest for ChiSquare { child_node: usize, parent_node: usize, separation_set: &BTreeSet, - cache: &mut parameter_learning::Cache

+ cache: &mut parameter_learning::Cache

, ) -> bool where T: network::Network, - P: parameter_learning::ParameterLearning { + P: parameter_learning::ParameterLearning, + { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // di dimensione nxn // (CIM, M, T) - let ( _, M_small, _) = cache.fit(net, child_node, Some(separation_set.clone())); - // + let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){ + Params::DiscreteStatesContinousTime(node) => node + }; + // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let ( _, M_big, _) = cache.fit(net, child_node, Some(extended_separation_set.clone())); + + let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){ + Params::DiscreteStatesContinousTime(node) => node + }; // Commentare qui - let partial_cardinality_product:usize = extended_separation_set.iter().take_while(|x| **x != parent_node).map(|x| net.get_node(*x).get_reserved_space_as_parent()).product(); - for idx_M_big in 0..M_big.shape()[0] { - let idx_M_small: usize = idx_M_big%partial_cardinality_product + (idx_M_big/(partial_cardinality_product*net.get_node(parent_node).get_reserved_space_as_parent()))*partial_cardinality_product; - if ! self.compare_matrices(idx_M_small, &M_small, idx_M_big, &M_big) { + let partial_cardinality_product: usize = extended_separation_set + .iter() + .take_while(|x| **x != parent_node) + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { + let idx_M_small: usize = idx_M_big % partial_cardinality_product + + (idx_M_big + / (partial_cardinality_product + * net.get_node(parent_node).get_reserved_space_as_parent())) + * partial_cardinality_product; + if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) { return false; } } diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index cd980d0..1ce5d51 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -40,10 +40,11 @@ fn learn_binary_cim(pl: T) { } let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 1, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [2, 2, 2]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 1, None) { + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), 0.1 )); @@ -98,10 +99,11 @@ fn learn_ternary_cim(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 1, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [3, 3, 3]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 1, None){ + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[ [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], @@ -160,10 +162,11 @@ fn learn_ternary_cim_no_parents(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 0, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [1, 3, 3]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 0, None){ + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), 0.1 )); @@ -288,10 +291,11 @@ fn learn_mixed_discrete_cim(pl: T) { } let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); - let (CIM, M, T) = pl.fit(&net, &data, 2, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [9, 4, 4]); - assert!(CIM.abs_diff_eq( + let p = match pl.fit(&net, &data, 2, None){ + params::Params::DiscreteStatesContinousTime(p) => p + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[ [ [-1.0, 0.5, 0.3, 0.2], From c4f44f63ea862ccf43596f9b8b03cd8721c6398b Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 29 Jul 2022 12:53:38 +0200 Subject: [PATCH 040/126] Silenced `non_snake_case` warning globally in `src/`, solved some unused imports and other negligible leftovers --- src/lib.rs | 2 +- src/parameter_learning.rs | 5 +---- src/params.rs | 2 +- src/structure_learning/hypothesis_test.rs | 1 - 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index bcbde3f..1dcc637 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ +#![allow(non_snake_case)] #[cfg(test)] -#[macro_use] extern crate approx; pub mod params; diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index bf5b96a..ffd1db8 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -2,7 +2,6 @@ use crate::network; use crate::params::*; use crate::tools; use ndarray::prelude::*; -use ndarray::{concatenate, Slice}; use std::collections::BTreeSet; pub trait ParameterLearning{ @@ -137,15 +136,13 @@ impl ParameterLearning for BayesianApproach { node: usize, parent_set: Option>, ) -> Params { - //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes - //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { Some(p) => p, None => net.get_parent_set(node), }; - let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; let tau: f64 = self.tau as f64 / M.shape()[0] as f64; diff --git a/src/params.rs b/src/params.rs index d9f307f..c2768b1 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,7 +1,7 @@ use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet}; use thiserror::Error; use rand_chacha::ChaCha8Rng; diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index eb6b570..f8eeb30 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -1,4 +1,3 @@ -use ndarray::Array2; use ndarray::Array3; use ndarray::Axis; use statrs::distribution::{ChiSquared, ContinuousCDF}; From d0515a3f2627f289bbe27f142068aad1f5a5cd9f Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 29 Jul 2022 13:27:38 +0200 Subject: [PATCH 041/126] Solved all warnings in `tests/` --- tests/parameter_learning.rs | 4 ++-- tests/params.rs | 4 +++- tests/structure_learning.rs | 4 +++- tests/tools.rs | 2 +- tests/utils.rs | 7 +------ 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 1ce5d51..b624e94 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -1,3 +1,5 @@ +#![allow(non_snake_case)] + mod utils; use utils::*; @@ -6,9 +8,7 @@ use reCTBN::ctbn::*; use reCTBN::network::Network; use reCTBN::parameter_learning::*; use reCTBN::{params, tools::*}; -use std::collections::BTreeSet; -#[macro_use] extern crate approx; fn learn_binary_cim(pl: T) { diff --git a/tests/params.rs b/tests/params.rs index c002d7b..e07121c 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -7,8 +7,8 @@ mod utils; #[macro_use] extern crate approx; - fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { + #![allow(unused_must_use)] let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; @@ -25,6 +25,7 @@ fn test_get_label() { #[test] fn test_uniform_generation() { + #![allow(irrefutable_let_patterns)] let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); @@ -44,6 +45,7 @@ fn test_uniform_generation() { #[test] fn test_random_generation_state() { + #![allow(irrefutable_let_patterns)] let param = create_ternary_discrete_time_continous_param(); let mut states = Array1::::zeros(10000); diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 2c9645b..790a4b6 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -1,3 +1,5 @@ +#![allow(non_snake_case)] + mod utils; use utils::*; @@ -95,7 +97,7 @@ fn check_compatibility_between_dataset_and_network params::Params { params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality)) } - pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); params::DiscreteStatesContinousTimeParams::new(label, domain) } - - - - - From fd335d73f6ffaa9272ba47f3ad642e59bcc4b4b0 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 29 Jul 2022 16:26:57 +0200 Subject: [PATCH 042/126] Nerfed `clippy` in GitHub Workflows, now the linting must pass otherwise the pipeline fails --- .github/workflows/build.yml | 2 +- README.md | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1f73a77..25ecf35 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -26,7 +26,7 @@ jobs: uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} - args: --all-features + args: --all-targets -- -D warnings -A clippy::all -W clippy::correctness - name: Tests (test) uses: actions-rs/cargo@v1 with: diff --git a/README.md b/README.md index 36009cf..b114188 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,14 @@ To launch **tests**: cargo test ``` -To **lint**: +To **lint** with `cargo check`: ```sh -cargo check +cargo check --all-targets +``` + +Or with `clippy`: + +```sh +cargo clippy --all-targets -- -A clippy::all -W clippy::correctness ``` From 58adca25033cde749d4a4614035b5b45c581fa6c Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sat, 30 Jul 2022 16:27:55 +0200 Subject: [PATCH 043/126] Added `rustfmt` in GH Workflow, added `rustfmt.toml` and updated `README.md` --- .github/workflows/build.yml | 7 ++++++- README.md | 8 ++++++++ rustfmt.toml | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 rustfmt.toml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 25ecf35..b94d480 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,12 +21,17 @@ jobs: with: profile: minimal toolchain: stable - components: clippy + components: clippy, fmt - name: Linting (clippy) uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} args: --all-targets -- -D warnings -A clippy::all -W clippy::correctness + - name: Formatting (rustfmt) + uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check --verbose - name: Tests (test) uses: actions-rs/cargo@v1 with: diff --git a/README.md b/README.md index b114188..f928955 100644 --- a/README.md +++ b/README.md @@ -48,3 +48,11 @@ Or with `clippy`: ```sh cargo clippy --all-targets -- -A clippy::all -W clippy::correctness ``` + +To check the **formatting**: + +> **NOTE:** remove `--check` to apply the changes to the file(s). + +```sh +cargo fmt --all -- --check +``` diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..3e7fb50 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,36 @@ +# This file defines the Rust style for automatic reformatting. +# See also https://rust-lang.github.io/rustfmt + +# NOTE: the unstable options will be uncommented when stabilized. + +# Version of the formatting rules to use. +#version = "One" + +# Number of spaces per tab. +tab_spaces = 4 + +max_width = 100 +#comment_width = 80 + +# Prevent carriage returns, admitted only \n. +newline_style = "Unix" + +# The "Default" setting has a heuristic which can split lines too aggresively. +#use_small_heuristics = "Max" + +# How imports should be grouped into `use` statements. +#imports_granularity = "Module" + +# How consecutive imports are grouped together. +#group_imports = "StdExternalCrate" + +# Error if unable to get all lines within max_width, except for comments and +# string literals. +#error_on_line_overflow = true + +# Error if unable to get comments or string literals within max_width, or they +# are left with trailing whitespaces. +#error_on_unformatted = true + +# Files to ignore like third party code which is formatted upstream. +#ignore = [] From d5058e9ed255934279a90066cde8572c9d1bbec5 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sat, 30 Jul 2022 16:36:27 +0200 Subject: [PATCH 044/126] Fixed malformed `rustfmt` component name in GH Workflows --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b94d480..bc5c583 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,7 +21,7 @@ jobs: with: profile: minimal toolchain: stable - components: clippy, fmt + components: clippy, rustfmt - name: Linting (clippy) uses: actions-rs/clippy-check@v1 with: From 780515707cd134ca11d53eb0712bba8a819e1b3c Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sat, 30 Jul 2022 17:14:54 +0200 Subject: [PATCH 045/126] Refactored `src/` and `tests/` files to be compliant to `rustfmt` --- src/ctbn.rs | 108 +++++++++--------- src/lib.rs | 7 +- src/network.rs | 19 +-- src/parameter_learning.rs | 61 ++++------ src/params.rs | 54 ++++++--- src/structure_learning.rs | 9 +- .../constraint_based_algorithm.rs | 2 - src/structure_learning/hypothesis_test.rs | 24 ++-- .../score_based_algorithm.rs | 6 +- src/structure_learning/score_function.rs | 62 +++++----- src/tools.rs | 33 +++--- tests/ctbn.rs | 3 +- tests/parameter_learning.rs | 20 ++-- tests/params.rs | 3 +- tests/structure_learning.rs | 75 ++++-------- tests/utils.rs | 13 ++- 16 files changed, 247 insertions(+), 252 deletions(-) diff --git a/src/ctbn.rs b/src/ctbn.rs index 2cede4a..e2f5dd7 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -1,10 +1,9 @@ -use ndarray::prelude::*; -use crate::params::{StateType, Params, ParamsTrait}; -use crate::network; use std::collections::BTreeSet; +use ndarray::prelude::*; - +use crate::network; +use crate::params::{Params, ParamsTrait, StateType}; ///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is ///composed by the following elements: @@ -22,12 +21,12 @@ use std::collections::BTreeSet; /// use reCTBN::ctbn::*; /// /// //Create the domain for a discrete node -/// let mut domain = BTreeSet::new(); +/// let mut domain = BTreeSet::new(); /// domain.insert(String::from("A")); /// domain.insert(String::from("B")); /// /// //Create the parameters for a discrete node using the domain -/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); +/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); /// /// //Create the node using the parameters /// let X1 = params::Params::DiscreteStatesContinousTime(param); @@ -37,14 +36,14 @@ use std::collections::BTreeSet; /// domain.insert(String::from("B")); /// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); /// let X2 = params::Params::DiscreteStatesContinousTime(param); -/// +/// /// //Initialize a ctbn /// let mut net = CtbnNetwork::new(); /// /// //Add nodes /// let X1 = net.add_node(X1).unwrap(); /// let X2 = net.add_node(X2).unwrap(); -/// +/// /// //Add an edge /// net.add_edge(X1, X2); /// @@ -54,30 +53,30 @@ use std::collections::BTreeSet; /// ``` pub struct CtbnNetwork { adj_matrix: Option>, - nodes: Vec + nodes: Vec, } - impl CtbnNetwork { pub fn new() -> CtbnNetwork { CtbnNetwork { adj_matrix: None, - nodes: Vec::new() + nodes: Vec::new(), } } } impl network::Network for CtbnNetwork { fn initialize_adj_matrix(&mut self) { - self.adj_matrix = Some(Array2::::zeros((self.nodes.len(), self.nodes.len()).f())); - + self.adj_matrix = Some(Array2::::zeros( + (self.nodes.len(), self.nodes.len()).f(), + )); } - fn add_node(&mut self, mut n: Params) -> Result { + fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); - Ok(self.nodes.len() -1) + Ok(self.nodes.len() - 1) } fn add_edge(&mut self, parent: usize, child: usize) { @@ -91,7 +90,7 @@ impl network::Network for CtbnNetwork { } } - fn get_node_indices(&self) -> std::ops::Range{ + fn get_node_indices(&self) -> std::ops::Range { 0..self.nodes.len() } @@ -99,64 +98,65 @@ impl network::Network for CtbnNetwork { self.nodes.len() } - fn get_node(&self, node_idx: usize) -> &Params{ + fn get_node(&self, node_idx: usize) -> &Params { &self.nodes[node_idx] } - - fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{ + fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { &mut self.nodes[node_idx] } - - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ - self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { - if x.1 > &0 { - acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; - acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); - } - acc - }).0 + fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize { + self.adj_matrix + .as_ref() + .unwrap() + .column(node) + .iter() + .enumerate() + .fold((0, 1), |mut acc, x| { + if x.1 > &0 { + acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; + acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); + } + acc + }) + .0 } - - fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize { - parent_set.iter().fold((0, 1), |mut acc, x| { - acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; - acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); - acc - }).0 + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &BTreeSet, + ) -> usize { + parent_set + .iter() + .fold((0, 1), |mut acc, x| { + acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; + acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); + acc + }) + .0 } fn get_parent_set(&self, node: usize) -> BTreeSet { - self.adj_matrix.as_ref() + self.adj_matrix + .as_ref() .unwrap() .column(node) .iter() .enumerate() - .filter_map(|(idx, x)| { - if x > &0 { - Some(idx) - } else { - None - } - }).collect() + .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) + .collect() } - fn get_children_set(&self, node: usize) -> BTreeSet{ - self.adj_matrix.as_ref() + fn get_children_set(&self, node: usize) -> BTreeSet { + self.adj_matrix + .as_ref() .unwrap() .row(node) .iter() .enumerate() - .filter_map(|(idx, x)| { - if x > &0 { - Some(idx) - } else { - None - } - }).collect() + .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) + .collect() } - } - diff --git a/src/lib.rs b/src/lib.rs index 1dcc637..8c57af2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,9 @@ #[cfg(test)] extern crate approx; -pub mod params; -pub mod network; pub mod ctbn; -pub mod tools; +pub mod network; pub mod parameter_learning; +pub mod params; pub mod structure_learning; - +pub mod tools; diff --git a/src/network.rs b/src/network.rs index 1c962b0..cbae339 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,20 +1,21 @@ +use std::collections::BTreeSet; + use thiserror::Error; + use crate::params; -use std::collections::BTreeSet; /// Error types for trait Network #[derive(Error, Debug)] pub enum NetworkError { #[error("Error during node insertion")] - NodeInsertionError(String) + NodeInsertionError(String), } - ///Network ///The Network trait define the required methods for a structure used as pgm (such as ctbn). pub trait Network { fn initialize_adj_matrix(&mut self); - fn add_node(&mut self, n: params::Params) -> Result; + fn add_node(&mut self, n: params::Params) -> Result; fn add_edge(&mut self, parent: usize, child: usize); ///Get all the indices of the nodes contained inside the network @@ -26,13 +27,17 @@ pub trait Network { ///Compute the index that must be used to access the parameters of a node given a specific ///configuration of the network. Usually, the only values really used in *current_state* are ///the ones in the parent set of the *node*. - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; - + fn get_param_index_network(&self, node: usize, current_state: &Vec) + -> usize; ///Compute the index that must be used to access the parameters of a node given a specific ///configuration of the network and a generic parent_set. Usually, the only values really used ///in *current_state* are the ones in the parent set of the *node*. - fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize; + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &BTreeSet, + ) -> usize; fn get_parent_set(&self, node: usize) -> BTreeSet; fn get_children_set(&self, node: usize) -> BTreeSet; } diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index ffd1db8..10f0257 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -1,11 +1,12 @@ -use crate::network; -use crate::params::*; -use crate::tools; -use ndarray::prelude::*; use std::collections::BTreeSet; -pub trait ParameterLearning{ - fn fit( +use ndarray::prelude::*; + +use crate::params::*; +use crate::{network, tools}; + +pub trait ParameterLearning { + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -14,24 +15,19 @@ pub trait ParameterLearning{ ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, dataset: &tools::Dataset, node: usize, - parent_set: &BTreeSet - ) -> (Array3, Array2) { + parent_set: &BTreeSet, +) -> (Array3, Array2) { //Get the number of values assumable by the node - let node_domain = net - .get_node(node.clone()) - .get_reserved_space_as_parent(); + let node_domain = net.get_node(node.clone()).get_reserved_space_as_parent(); //Get the number of values assumable by each parent of the node let parentset_domain: Vec = parent_set .iter() - .map(|x| { - net.get_node(x.clone()) - .get_reserved_space_as_parent() - }) + .map(|x| net.get_node(x.clone()).get_reserved_space_as_parent()) .collect(); //Vector used to convert a specific configuration of the parent_set to the corresponding index @@ -45,7 +41,7 @@ pub fn sufficient_statistics( vector_to_idx[*idx] = acc; acc * x }); - + //Number of transition given a specific configuration of the parent set let mut M: Array3 = Array::zeros((parentset_domain.iter().product(), node_domain, node_domain)); @@ -70,13 +66,11 @@ pub fn sufficient_statistics( } return (M, T); - } pub struct MLE {} impl ParameterLearning for MLE { - fn fit( &self, net: &T, @@ -84,19 +78,18 @@ impl ParameterLearning for MLE { node: usize, parent_set: Option>, ) -> Params { - //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { Some(p) => p, None => net.get_parent_set(node), }; - + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - //Compute the CIM as M[i,x,y]/T[i,x] + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); + .for_each(|(mut C, m)| C.assign(&(&m / &T))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); @@ -105,8 +98,6 @@ impl ParameterLearning for MLE { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - - let mut n: Params = net.get_node(node).clone(); @@ -115,8 +106,6 @@ impl ParameterLearning for MLE { dsct.set_cim_unchecked(CIM); dsct.set_transitions(M); dsct.set_residence_time(T); - - } }; return n; @@ -125,7 +114,7 @@ impl ParameterLearning for MLE { pub struct BayesianApproach { pub alpha: usize, - pub tau: f64 + pub tau: f64, } impl ParameterLearning for BayesianApproach { @@ -141,17 +130,17 @@ impl ParameterLearning for BayesianApproach { Some(p) => p, None => net.get_parent_set(node), }; - + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; let tau: f64 = self.tau as f64 / M.shape()[0] as f64; - //Compute the CIM as M[i,x,y]/T[i,x] + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha)/&T.mapv(|y| y + tau)))); + .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); @@ -161,8 +150,6 @@ impl ParameterLearning for BayesianApproach { C.diag_mut().assign(&diag); }); - - let mut n: Params = net.get_node(node).clone(); match n { @@ -170,27 +157,25 @@ impl ParameterLearning for BayesianApproach { dsct.set_cim_unchecked(CIM); dsct.set_transitions(M); dsct.set_residence_time(T); - - } }; return n; } } - pub struct Cache { parameter_learning: P, dataset: tools::Dataset, } impl Cache

{ - pub fn fit( + pub fn fit( &mut self, net: &T, node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning.fit(net, &self.dataset, node, parent_set) + self.parameter_learning + .fit(net, &self.dataset, node, parent_set) } } diff --git a/src/params.rs b/src/params.rs index c2768b1..f994b99 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,9 +1,10 @@ +use std::collections::BTreeSet; + use enum_dispatch::enum_dispatch; use ndarray::prelude::*; use rand::Rng; -use std::collections::{BTreeSet}; -use thiserror::Error; use rand_chacha::ChaCha8Rng; +use thiserror::Error; /// Error types for trait Params #[derive(Error, Debug, PartialEq)] @@ -35,11 +36,21 @@ pub trait ParamsTrait { /// Randomly generate a residence time for the given node taking into account the node state /// and its parent set. - fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; + fn get_random_residence_time( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result; /// Randomly generate a possible state for the given node taking into account the node state /// and its parent set. - fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result; + fn get_random_state( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result; /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. fn get_reserved_space_as_parent(&self) -> usize; @@ -49,7 +60,7 @@ pub trait ParamsTrait { /// Validate parameters against domain fn validate_params(&self) -> Result<(), ParamsError>; - + /// Return a reference to the associated label fn get_label(&self) -> &String; } @@ -92,17 +103,17 @@ impl DiscreteStatesContinousTimeParams { residence_time: Option::None, } } - + ///Getter function for CIM pub fn get_cim(&self) -> &Option> { &self.cim - } + } ///Setter function for CIM.\\ - ///This function check if the cim is valid using the validate_params method. + ///This function check if the cim is valid using the validate_params method. ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError - pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError>{ + pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError> { self.cim = Some(cim); match self.validate_params() { Ok(()) => Ok(()), @@ -113,7 +124,6 @@ impl DiscreteStatesContinousTimeParams { } } - ///Unchecked version of the setter function for CIM. pub fn set_cim_unchecked(&mut self, cim: Array3) { self.cim = Some(cim); @@ -124,7 +134,6 @@ impl DiscreteStatesContinousTimeParams { &self.transitions } - ///Setter function for transitions pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); @@ -135,12 +144,10 @@ impl DiscreteStatesContinousTimeParams { &self.residence_time } - ///Setter function for residence_time pub fn set_residence_time(&mut self, residence_time: Array2) { self.residence_time = Some(residence_time); } - } impl ParamsTrait for DiscreteStatesContinousTimeParams { @@ -154,7 +161,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { StateType::Discrete(rng.gen_range(0..(self.domain.len()))) } - fn get_random_residence_time(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { + fn get_random_residence_time( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result { // Generate a random residence time given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates @@ -170,7 +182,12 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } } - fn get_random_state(&self, state: usize, u: usize, rng: &mut ChaCha8Rng) -> Result { + fn get_random_state( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result { // Generate a random transition given the current state of the node and its parent set. // The method used is described in: // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution @@ -246,7 +263,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { } // Check if each row sum up to 0 - if cim.sum_axis(Axis(2)).iter() + if cim + .sum_axis(Axis(2)) + .iter() .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) { return Err(ParamsError::InvalidCIM(String::from( @@ -257,8 +276,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { return Ok(()); } - fn get_label(&self) -> &String { + fn get_label(&self) -> &String { &self.label } - } diff --git a/src/structure_learning.rs b/src/structure_learning.rs index b7db7ed..8b90cdf 100644 --- a/src/structure_learning.rs +++ b/src/structure_learning.rs @@ -1,12 +1,11 @@ -pub mod score_function; -pub mod score_based_algorithm; pub mod constraint_based_algorithm; pub mod hypothesis_test; -use crate::network; -use crate::tools; +pub mod score_based_algorithm; +pub mod score_function; +use crate::{network, tools}; pub trait StructureLearningAlgorithm { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where T: network::Network; } diff --git a/src/structure_learning/constraint_based_algorithm.rs b/src/structure_learning/constraint_based_algorithm.rs index 0d8b655..b3fc3e1 100644 --- a/src/structure_learning/constraint_based_algorithm.rs +++ b/src/structure_learning/constraint_based_algorithm.rs @@ -1,5 +1,3 @@ - //pub struct CTPC { // //} - diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index f8eeb30..5ddcc51 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -1,11 +1,10 @@ -use ndarray::Array3; -use ndarray::Axis; +use std::collections::BTreeSet; + +use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; -use crate::network; -use crate::parameter_learning; use crate::params::*; -use std::collections::BTreeSet; +use crate::{network, parameter_learning}; pub trait HypothesisTest { fn call( @@ -110,15 +109,15 @@ impl HypothesisTest for ChiSquare { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // di dimensione nxn // (CIM, M, T) - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())){ - Params::DiscreteStatesContinousTime(node) => node + let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, }; // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())){ - Params::DiscreteStatesContinousTime(node) => node + let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, }; // Commentare qui let partial_cardinality_product: usize = extended_separation_set @@ -132,7 +131,12 @@ impl HypothesisTest for ChiSquare { / (partial_cardinality_product * net.get_node(parent_node).get_reserved_space_as_parent())) * partial_cardinality_product; - if !self.compare_matrices(idx_M_small, P_small.get_transitions().as_ref().unwrap(), idx_M_big, P_big.get_transitions().as_ref().unwrap()) { + if !self.compare_matrices( + idx_M_small, + P_small.get_transitions().as_ref().unwrap(), + idx_M_big, + P_big.get_transitions().as_ref().unwrap(), + ) { return false; } } diff --git a/src/structure_learning/score_based_algorithm.rs b/src/structure_learning/score_based_algorithm.rs index fe4e4ff..cc8541a 100644 --- a/src/structure_learning/score_based_algorithm.rs +++ b/src/structure_learning/score_based_algorithm.rs @@ -1,8 +1,8 @@ -use crate::network; +use std::collections::BTreeSet; + use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::tools; -use std::collections::BTreeSet; +use crate::{network, tools}; pub struct HillClimbing { score_function: S, diff --git a/src/structure_learning/score_function.rs b/src/structure_learning/score_function.rs index ea53db5..b3b1597 100644 --- a/src/structure_learning/score_function.rs +++ b/src/structure_learning/score_function.rs @@ -1,10 +1,9 @@ -use crate::network; -use crate::parameter_learning; -use crate::params; -use crate::tools; +use std::collections::BTreeSet; + use ndarray::prelude::*; use statrs::function::gamma; -use std::collections::BTreeSet; + +use crate::{network, parameter_learning, params, tools}; pub trait ScoreFunction { fn call( @@ -25,7 +24,6 @@ pub struct LogLikelihood { impl LogLikelihood { pub fn new(alpha: usize, tau: f64) -> LogLikelihood { - //Tau must be >=0.0 if tau < 0.0 { panic!("tau must be >=0.0"); @@ -42,9 +40,9 @@ impl LogLikelihood { ) -> (f64, Array3) where T: network::Network, - { + { //Identify the type of node used - match &net.get_node(node){ + match &net.get_node(node) { params::Params::DiscreteStatesContinousTime(_params) => { //Compute the sufficient statistics M (number of transistions) and T (residence //time) @@ -55,35 +53,40 @@ impl LogLikelihood { let alpha = self.alpha as f64 / M.shape()[0] as f64; //Scale tau accordingly to the size of the parent set let tau = self.tau / M.shape()[0] as f64; - + //Compute the log likelihood for q - let log_ll_q:f64 = M + let log_ll_q: f64 = M .sum_axis(Axis(2)) .iter() .zip(T.iter()) .map(|(m, t)| { - gamma::ln_gamma(alpha + *m as f64 + 1.0) - + (alpha + 1.0) * f64::ln(tau) + gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau) - gamma::ln_gamma(alpha + 1.0) - (alpha + *m as f64 + 1.0) * f64::ln(tau + t) }) .sum(); - + //Compute the log likelihood for theta - let log_ll_theta: f64 = M.outer_iter() - .map(|x| x.outer_iter() - .map(|y| gamma::ln_gamma(alpha) - - gamma::ln_gamma(alpha + y.sum() as f64) - + y.iter().map(|z| - gamma::ln_gamma(alpha + *z as f64) - - gamma::ln_gamma(alpha)).sum::()).sum::()).sum(); + let log_ll_theta: f64 = M + .outer_iter() + .map(|x| { + x.outer_iter() + .map(|y| { + gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64) + + y.iter() + .map(|z| { + gamma::ln_gamma(alpha + *z as f64) + - gamma::ln_gamma(alpha) + }) + .sum::() + }) + .sum::() + }) + .sum(); (log_ll_theta + log_ll_q, M) } } } - - - } impl ScoreFunction for LogLikelihood { @@ -102,13 +105,13 @@ impl ScoreFunction for LogLikelihood { } pub struct BIC { - ll: LogLikelihood + ll: LogLikelihood, } impl BIC { pub fn new(alpha: usize, tau: f64) -> BIC { BIC { - ll: LogLikelihood::new(alpha, tau) + ll: LogLikelihood::new(alpha, tau), } } } @@ -122,14 +125,19 @@ impl ScoreFunction for BIC { dataset: &tools::Dataset, ) -> f64 where - T: network::Network { + T: network::Network, + { //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); //Compute the number of parameters let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); //TODO: Optimize this //Compute the sample size - let sample_size: usize = dataset.get_trajectories().iter().map(|x| x.get_time().len() - 1).sum(); + let sample_size: usize = dataset + .get_trajectories() + .iter() + .map(|x| x.get_time().len() - 1) + .sum(); //Compute BIC ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 } diff --git a/src/tools.rs b/src/tools.rs index 115fd67..448b26f 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,10 +1,10 @@ -use crate::network; -use crate::params; -use crate::params::ParamsTrait; use ndarray::prelude::*; use rand_chacha::rand_core::SeedableRng; use rand_chacha::ChaCha8Rng; +use crate::params::ParamsTrait; +use crate::{network, params}; + pub struct Trajectory { time: Array1, events: Array2, @@ -19,7 +19,7 @@ impl Trajectory { } Trajectory { time, events } } - + pub fn get_time(&self) -> &Array1 { &self.time } @@ -35,7 +35,6 @@ pub struct Dataset { impl Dataset { pub fn new(trajectories: Vec) -> Dataset { - //All the trajectories in the same dataset must represent the same process. For this reason //each trajectory must represent the same number of variables. if trajectories @@ -58,18 +57,17 @@ pub fn trajectory_generator( t_end: f64, seed: Option, ) -> Dataset { - //Tmp growing vector containing generated trajectories. let mut trajectories: Vec = Vec::new(); - + //Random Generator object let mut rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. Some(seed) => SeedableRng::seed_from_u64(seed), //Otherwise create a new random generator using the method `from_entropy` - None => SeedableRng::from_entropy() + None => SeedableRng::from_entropy(), }; - + //Each iteration generate one trajectory for _ in 0..n_trajectories { //Current time of the sampling process @@ -78,15 +76,16 @@ pub fn trajectory_generator( let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut current_state: Vec = net.get_node_indices() + let mut current_state: Vec = net + .get_node_indices() .map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) .collect(); - //History of all the configurations of the process variables. + //History of all the configurations of the process variables. let mut events: Vec> = Vec::new(); //Vector containing to time to the next transition for each variable. let mut next_transitions: Vec> = net.get_node_indices().map(|_| Option::None).collect(); - + //Add the starting time for the trajectory. time.push(t.clone()); //Add the starting configuration of the trajectory. @@ -115,7 +114,7 @@ pub fn trajectory_generator( ); } } - + //Get the variable with the smallest transition time. let next_node_transition = next_transitions .iter() @@ -131,7 +130,7 @@ pub fn trajectory_generator( t = next_transitions[next_node_transition].unwrap().clone(); //Add the transition time to next time.push(t.clone()); - + //Compute the new state of the transitioning variable. current_state[next_node_transition] = net .get_node(next_node_transition) @@ -142,7 +141,7 @@ pub fn trajectory_generator( &mut rng, ) .unwrap(); - + //Add the new state to events events.push(Array::from_vec( current_state @@ -160,7 +159,7 @@ pub fn trajectory_generator( next_transitions[child] = None } } - + //Add current_state as last state. events.push( current_state @@ -172,7 +171,7 @@ pub fn trajectory_generator( ); //Add t_end as last time. time.push(t_end.clone()); - + //Add the sampled trajectory to trajectories. trajectories.push(Trajectory::new( Array::from_vec(time), diff --git a/tests/ctbn.rs b/tests/ctbn.rs index e5cad1e..63c9621 100644 --- a/tests/ctbn.rs +++ b/tests/ctbn.rs @@ -1,8 +1,9 @@ mod utils; +use std::collections::BTreeSet; + use reCTBN::ctbn::*; use reCTBN::network::Network; use reCTBN::params::{self, ParamsTrait}; -use std::collections::BTreeSet; use utils::generate_discrete_time_continous_node; #[test] diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index b624e94..0409402 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -1,13 +1,13 @@ #![allow(non_snake_case)] mod utils; -use utils::*; - use ndarray::arr3; use reCTBN::ctbn::*; use reCTBN::network::Network; use reCTBN::parameter_learning::*; -use reCTBN::{params, tools::*}; +use reCTBN::params; +use reCTBN::tools::*; +use utils::*; extern crate approx; @@ -41,7 +41,7 @@ fn learn_binary_cim(pl: T) { let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); let p = match pl.fit(&net, &data, 1, None) { - params::Params::DiscreteStatesContinousTime(p) => p + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( @@ -99,8 +99,8 @@ fn learn_ternary_cim(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let p = match pl.fit(&net, &data, 1, None){ - params::Params::DiscreteStatesContinousTime(p) => p + let p = match pl.fit(&net, &data, 1, None) { + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( @@ -162,8 +162,8 @@ fn learn_ternary_cim_no_parents(pl: T) { } let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); - let p = match pl.fit(&net, &data, 0, None){ - params::Params::DiscreteStatesContinousTime(p) => p + let p = match pl.fit(&net, &data, 0, None) { + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( @@ -291,8 +291,8 @@ fn learn_mixed_discrete_cim(pl: T) { } let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); - let p = match pl.fit(&net, &data, 2, None){ - params::Params::DiscreteStatesContinousTime(p) => p + let p = match pl.fit(&net, &data, 2, None) { + params::Params::DiscreteStatesContinousTime(p) => p, }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( diff --git a/tests/params.rs b/tests/params.rs index e07121c..7f16f12 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,5 +1,6 @@ use ndarray::prelude::*; -use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; use reCTBN::params::{ParamsTrait, *}; mod utils; diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 790a4b6..ee5109e 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -1,17 +1,18 @@ #![allow(non_snake_case)] mod utils; -use utils::*; +use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; use reCTBN::ctbn::*; use reCTBN::network::Network; use reCTBN::params; -use reCTBN::structure_learning::score_function::*; -use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm}; use reCTBN::structure_learning::hypothesis_test::*; +use reCTBN::structure_learning::score_based_algorithm::*; +use reCTBN::structure_learning::score_function::*; +use reCTBN::structure_learning::StructureLearningAlgorithm; use reCTBN::tools::*; -use std::collections::BTreeSet; +use utils::*; #[macro_use] extern crate approx; @@ -320,73 +321,43 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() } #[test] -pub fn chi_square_compare_matrices () { +pub fn chi_square_compare_matrices() { let i: usize = 1; let M1 = arr3(&[ - [[ 0, 2, 3], - [ 4, 0, 6], - [ 7, 8, 0]], - [[0, 12, 90], - [ 3, 0, 40], - [ 6, 40, 0]], - [[ 0, 2, 3], - [ 4, 0, 6], - [ 44, 66, 0]] + [[0, 2, 3], [4, 0, 6], [7, 8, 0]], + [[0, 12, 90], [3, 0, 40], [6, 40, 0]], + [[0, 2, 3], [4, 0, 6], [44, 66, 0]], ]); let j: usize = 0; - let M2 = arr3(&[ - [[ 0, 200, 300], - [ 400, 0, 600], - [ 700, 800, 0]] - ]); + let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let chi_sq = ChiSquare::new(0.1); - assert!(!chi_sq.compare_matrices( i, &M1, j, &M2)); + assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); } #[test] -pub fn chi_square_compare_matrices_2 () { +pub fn chi_square_compare_matrices_2() { let i: usize = 1; let M1 = arr3(&[ - [[ 0, 2, 3], - [ 4, 0, 6], - [ 7, 8, 0]], - [[0, 20, 30], - [ 40, 0, 60], - [ 70, 80, 0]], - [[ 0, 2, 3], - [ 4, 0, 6], - [ 44, 66, 0]] + [[0, 2, 3], [4, 0, 6], [7, 8, 0]], + [[0, 20, 30], [40, 0, 60], [70, 80, 0]], + [[0, 2, 3], [4, 0, 6], [44, 66, 0]], ]); let j: usize = 0; - let M2 = arr3(&[ - [[ 0, 200, 300], - [ 400, 0, 600], - [ 700, 800, 0]] - ]); + let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let chi_sq = ChiSquare::new(0.1); - assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); + assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } #[test] -pub fn chi_square_compare_matrices_3 () { +pub fn chi_square_compare_matrices_3() { let i: usize = 1; let M1 = arr3(&[ - [[ 0, 2, 3], - [ 4, 0, 6], - [ 7, 8, 0]], - [[0, 21, 31], - [ 41, 0, 59], - [ 71, 79, 0]], - [[ 0, 2, 3], - [ 4, 0, 6], - [ 44, 66, 0]] + [[0, 2, 3], [4, 0, 6], [7, 8, 0]], + [[0, 21, 31], [41, 0, 59], [71, 79, 0]], + [[0, 2, 3], [4, 0, 6], [44, 66, 0]], ]); let j: usize = 0; - let M2 = arr3(&[ - [[ 0, 200, 300], - [ 400, 0, 600], - [ 700, 800, 0]] - ]); + let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); let chi_sq = ChiSquare::new(0.1); - assert!(chi_sq.compare_matrices( i, &M1, j, &M2)); + assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } diff --git a/tests/utils.rs b/tests/utils.rs index 1449b1d..ed43215 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -1,12 +1,19 @@ -use reCTBN::params; use std::collections::BTreeSet; +use reCTBN::params; + #[allow(dead_code)] pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { - params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality)) + params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params( + label, + cardinality, + )) } -pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ +pub fn generate_discrete_time_continous_params( + label: String, + cardinality: usize, +) -> params::DiscreteStatesContinousTimeParams { let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); params::DiscreteStatesContinousTimeParams::new(label, domain) } From d8b94940a8f942253e1f00bbcf8f06ef45eb4425 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sat, 30 Jul 2022 17:58:17 +0200 Subject: [PATCH 046/126] Added `rust-toolchain.toml` --- rust-toolchain.toml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 rust-toolchain.toml diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..367bc0b --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,7 @@ +# This file defines the Rust toolchain to use when a command is executed. +# See also https://rust-lang.github.io/rustup/overrides.html + +[toolchain] +channel = "stable" +components = [ "clippy", "rustfmt" ] +profile = "minimal" From a1c1448da7ccd0c1517514cc7b1d2078d6c17621 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sun, 31 Jul 2022 13:39:54 +0200 Subject: [PATCH 047/126] Conformed all rank-3 tensors to the same notation and now `rustfmt` ignores `tests/` --- rustfmt.toml | 5 +- tests/parameter_learning.rs | 137 ++++++++++++++++++++++------- tests/structure_learning.rs | 166 ++++++++++++++++++++++++++++-------- tests/tools.rs | 17 +++- 4 files changed, 254 insertions(+), 71 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index 3e7fb50..b6f1257 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -33,4 +33,7 @@ newline_style = "Unix" #error_on_unformatted = true # Files to ignore like third party code which is formatted upstream. -#ignore = [] +# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors +ignore = [ + "tests/" +] diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 0409402..5de02d7 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -32,8 +32,14 @@ fn learn_binary_cim(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], ])) ); } @@ -45,7 +51,16 @@ fn learn_binary_cim(pl: T) { }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( - &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), + &arr3(&[ + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], + ]), 0.1 )); } @@ -76,11 +91,13 @@ fn learn_ternary_cim(pl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -90,9 +107,21 @@ fn learn_ternary_cim(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -105,9 +134,21 @@ fn learn_ternary_cim(pl: T) { assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( &arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ]), 0.1 )); @@ -139,11 +180,13 @@ fn learn_ternary_cim_no_parents(pl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ] + ])) ); } } @@ -153,9 +196,21 @@ fn learn_ternary_cim_no_parents(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -167,7 +222,13 @@ fn learn_ternary_cim_no_parents(pl: T) { }; assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( - &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), + &arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ]), 0.1 )); } @@ -204,11 +265,13 @@ fn learn_mixed_discrete_cim(pl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -218,9 +281,21 @@ fn learn_mixed_discrete_cim(pl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index ee5109e..81a4ed3 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -70,11 +70,13 @@ fn check_compatibility_between_dataset_and_network { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -84,9 +86,21 @@ fn check_compatibility_between_dataset_and_network(sl: T) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -137,9 +153,21 @@ fn learn_ternary_net_2_nodes(sl: T) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -186,11 +214,13 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!( Ok(()), - param.set_cim(arr3(&[[ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ]])) + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) ); } } @@ -200,9 +230,21 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { assert_eq!( Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], ])) ); } @@ -324,12 +366,30 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() pub fn chi_square_compare_matrices() { let i: usize = 1; let M1 = arr3(&[ - [[0, 2, 3], [4, 0, 6], [7, 8, 0]], - [[0, 12, 90], [3, 0, 40], [6, 40, 0]], - [[0, 2, 3], [4, 0, 6], [44, 66, 0]], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 12, 90], + [ 3, 0, 40], + [ 6, 40, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], ]); let j: usize = 0; - let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); + let M2 = arr3(&[ + [ + [ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0] + ], + ]); let chi_sq = ChiSquare::new(0.1); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -338,12 +398,28 @@ pub fn chi_square_compare_matrices() { pub fn chi_square_compare_matrices_2() { let i: usize = 1; let M1 = arr3(&[ - [[0, 2, 3], [4, 0, 6], [7, 8, 0]], - [[0, 20, 30], [40, 0, 60], [70, 80, 0]], - [[0, 2, 3], [4, 0, 6], [44, 66, 0]], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 20, 30], + [ 40, 0, 60], + [ 70, 80, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], ]); let j: usize = 0; - let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); let chi_sq = ChiSquare::new(0.1); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -352,12 +428,30 @@ pub fn chi_square_compare_matrices_2() { pub fn chi_square_compare_matrices_3() { let i: usize = 1; let M1 = arr3(&[ - [[0, 2, 3], [4, 0, 6], [7, 8, 0]], - [[0, 21, 31], [41, 0, 59], [71, 79, 0]], - [[0, 2, 3], [4, 0, 6], [44, 66, 0]], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 21, 31], + [ 41, 0, 59], + [ 71, 79, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], ]); let j: usize = 0; - let M2 = arr3(&[[[0, 200, 300], [400, 0, 600], [700, 800, 0]]]); + let M2 = arr3(&[ + [ + [ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0] + ], + ]); let chi_sq = ChiSquare::new(0.1); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } diff --git a/tests/tools.rs b/tests/tools.rs index f7435f7..589b04e 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -29,15 +29,26 @@ fn run_sampling() { match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); + param.set_cim(arr3(&[ + [ + [-3.0, 3.0], + [2.0, -2.0] + ], + ])); } } match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { param.set_cim(arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], ])); } } From 1f59e5e3b7f6b809897b7f6977ee7cd11baf0aa7 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 1 Aug 2022 09:16:37 +0200 Subject: [PATCH 048/126] Added Rust nightly toolchain to GitHub Workflows --- .github/workflows/build.yml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bc5c583..fe5bd8b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,12 +16,20 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Setup (rust) + - name: Setup Rust stable (default) uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable + default: true components: clippy, rustfmt + - name: Setup Rust nightly + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + default: false + components: rustfmt - name: Linting (clippy) uses: actions-rs/clippy-check@v1 with: @@ -31,7 +39,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: fmt - args: --all -- --check --verbose + args: +nightly --all -- --check --verbose - name: Tests (test) uses: actions-rs/cargo@v1 with: From 6f2d83cf6be8c1c0845285a5d3615d7beee4dc71 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 1 Aug 2022 09:23:32 +0200 Subject: [PATCH 049/126] Added fmt to nightly toolchain in GitHub Workflows --- .github/workflows/build.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe5bd8b..0e9a88c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,8 +38,9 @@ jobs: - name: Formatting (rustfmt) uses: actions-rs/cargo@v1 with: + toolchain: nightly command: fmt - args: +nightly --all -- --check --verbose + args: --all -- --check --verbose - name: Tests (test) uses: actions-rs/cargo@v1 with: From bb2fe52a39b1f556526665399a05d5f1acadca8c Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Fri, 5 Aug 2022 17:08:24 +0200 Subject: [PATCH 050/126] Implemented ForwardSampler --- src/lib.rs | 1 + src/sampling.rs | 102 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 src/sampling.rs diff --git a/src/lib.rs b/src/lib.rs index 8c57af2..d40776a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,3 +8,4 @@ pub mod parameter_learning; pub mod params; pub mod structure_learning; pub mod tools; +pub mod sampling; diff --git a/src/sampling.rs b/src/sampling.rs new file mode 100644 index 0000000..a0bfaaf --- /dev/null +++ b/src/sampling.rs @@ -0,0 +1,102 @@ +use crate::{ + network::{self, Network}, + params::{self, ParamsTrait}, +}; +use rand::SeedableRng; +use rand_chacha::ChaCha8Rng; + +trait Sampler: Iterator { + fn reset(&mut self); +} + +pub struct ForwardSampler<'a, T> +where + T: Network, +{ + net: &'a T, + rng: ChaCha8Rng, + current_time: f64, + current_state: Vec, + next_transitions: Vec>, +} + +impl<'a, T: Network> ForwardSampler<'a, T> { + pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { + let mut rng: ChaCha8Rng = match seed { + //If a seed is present use it to initialize the random generator. + Some(seed) => SeedableRng::seed_from_u64(seed), + //Otherwise create a new random generator using the method `from_entropy` + None => SeedableRng::from_entropy(), + }; + let mut fs = ForwardSampler { + net: net, + rng: rng, + current_time: 0.0, + current_state: vec![], + next_transitions: vec![], + }; + fs.reset(); + return fs; + } +} + +impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { + type Item = (f64, Vec); + + fn next(&mut self) -> Option { + let ret_time = self.current_time.clone(); + let ret_state = self.current_state.clone(); + + for (idx, val) in self.next_transitions.iter_mut().enumerate() { + if let None = val { + *val = Some( + self.net + .get_node(idx) + .get_random_residence_time( + self.net + .get_node(idx) + .state_to_index(&self.current_state[idx]), + self.net.get_param_index_network(idx, &self.current_state), + &mut self.rng, + ) + .unwrap() + + self.current_time, + ); + } + } + + let next_node_transition = self.next_transitions + .iter() + .enumerate() + .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) + .unwrap() + .0; + + self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); + + self.current_state[next_node_transition] = self.net + .get_node(next_node_transition) + .get_random_state( + self.net.get_node(next_node_transition) + .state_to_index(&self.current_state[next_node_transition]), + self.net.get_param_index_network(next_node_transition, &self.current_state), + &mut self.rng, + ) + .unwrap(); + + + Some((ret_time, ret_state)) + } +} + +impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { + fn reset(&mut self) { + self.current_time = 0.0; + self.current_state = self + .net + .get_node_indices() + .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) + .collect(); + self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); + } +} From 8953471570a37695915e4d7a58ac97723198b953 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 5 Aug 2022 21:21:45 +0200 Subject: [PATCH 051/126] Added tests to chi square call function and added a constructor to cache --- src/parameter_learning.rs | 6 ++++++ src/structure_learning/hypothesis_test.rs | 8 +++++++- tests/structure_learning.rs | 21 +++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index 10f0257..bdb5d4a 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -169,6 +169,12 @@ pub struct Cache { } impl Cache

{ + pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache

{ + Cache { + parameter_learning, + dataset, + } + } pub fn fit( &mut self, net: &T, diff --git a/src/structure_learning/hypothesis_test.rs b/src/structure_learning/hypothesis_test.rs index 5ddcc51..4f2ce18 100644 --- a/src/structure_learning/hypothesis_test.rs +++ b/src/structure_learning/hypothesis_test.rs @@ -30,6 +30,8 @@ impl ChiSquare { pub fn new(alpha: f64) -> ChiSquare { ChiSquare { alpha } } + // Restituisce true quando le matrici sono molto simili, quindi indipendenti + // false quando sono diverse, quindi dipendenti pub fn compare_matrices( &self, i: usize, @@ -69,6 +71,7 @@ impl ChiSquare { let n = K.len(); K.into_shape((n, 1)).unwrap() }; + println!("K: {:?}", K); let L = 1.0 / &K; // ===== 2 // \ (K . M - L . M) @@ -89,10 +92,13 @@ impl ChiSquare { let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); println!("CHI^2: {:?}", n); println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); - X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)) + let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); + println!("test: {:?}", ret); + ret } } +// ritorna false quando sono dipendenti e false quando sono indipendenti impl HypothesisTest for ChiSquare { fn call( &self, diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 81a4ed3..a1667c2 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -6,6 +6,8 @@ use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; use reCTBN::ctbn::*; use reCTBN::network::Network; +use reCTBN::parameter_learning::BayesianApproach; +use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::score_based_algorithm::*; @@ -455,3 +457,22 @@ pub fn chi_square_compare_matrices_3() { let chi_sq = ChiSquare::new(0.1); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } + + +#[test] +pub fn chi_square_call() { + + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let N3: usize = 2; + let N2: usize = 1; + let N1: usize = 0; + let separation_set = BTreeSet::new(); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let mut cache = Cache::new(parameter_learning, data); + let chi_sq = ChiSquare::new(0.0001); + + assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache)); + assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache)); + assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); +} From 23d25be4bd6e5dcc0550c4f59a0055ba84ce3025 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 20 Sep 2022 11:16:40 +0200 Subject: [PATCH 052/126] Replaced `approx` with `approx-0_5` --- Cargo.toml | 4 ++-- tests/parameter_learning.rs | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 553e294..547a8b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ndarray = {version="*", features=["approx"]} +ndarray = {version="*", features=["approx-0_5"]} thiserror = "*" rand = "*" bimap = "*" @@ -15,4 +15,4 @@ statrs = "*" rand_chacha = "*" [dev-dependencies] -approx = "*" +approx = { package = "approx", version = "0.5" } diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 5de02d7..70f998c 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -10,6 +10,7 @@ use reCTBN::tools::*; use utils::*; extern crate approx; +use crate::approx::AbsDiffEq; fn learn_binary_cim(pl: T) { let mut net = CtbnNetwork::new(); From 8163bfb2b0317b1f0fcd5b0a1b63588a8d7fd7c7 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 21 Sep 2022 10:39:01 +0200 Subject: [PATCH 053/126] Update trajectory generator --- src/sampling.rs | 8 ++- src/tools.rs | 118 ++++++------------------------------ tests/parameter_learning.rs | 2 +- 3 files changed, 28 insertions(+), 100 deletions(-) diff --git a/src/sampling.rs b/src/sampling.rs index a0bfaaf..9bbf569 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -5,7 +5,7 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -trait Sampler: Iterator { +pub trait Sampler: Iterator { fn reset(&mut self); } @@ -84,6 +84,12 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { ) .unwrap(); + self.next_transitions[next_node_transition] = None; + + for child in self.net.get_children_set(next_node_transition) { + self.next_transitions[child] = None; + } + Some((ret_time, ret_state)) } diff --git a/src/tools.rs b/src/tools.rs index 448b26f..0dfea9e 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -3,6 +3,7 @@ use rand_chacha::rand_core::SeedableRng; use rand_chacha::ChaCha8Rng; use crate::params::ParamsTrait; +use crate::sampling::{Sampler, ForwardSampler}; use crate::{network, params}; pub struct Trajectory { @@ -61,114 +62,28 @@ pub fn trajectory_generator( let mut trajectories: Vec = Vec::new(); //Random Generator object - let mut rng: ChaCha8Rng = match seed { - //If a seed is present use it to initialize the random generator. - Some(seed) => SeedableRng::seed_from_u64(seed), - //Otherwise create a new random generator using the method `from_entropy` - None => SeedableRng::from_entropy(), - }; + let mut sampler = ForwardSampler::new(net, seed); //Each iteration generate one trajectory for _ in 0..n_trajectories { - //Current time of the sampling process - let mut t = 0.0; //History of all the moments in which something changed let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut current_state: Vec = net - .get_node_indices() - .map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) - .collect(); - //History of all the configurations of the process variables. - let mut events: Vec> = Vec::new(); - //Vector containing to time to the next transition for each variable. - let mut next_transitions: Vec> = - net.get_node_indices().map(|_| Option::None).collect(); + let mut events: Vec> = Vec::new(); - //Add the starting time for the trajectory. - time.push(t.clone()); - //Add the starting configuration of the trajectory. - events.push( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - ); + //Current Time and Current State + let (mut t, mut current_state) = sampler.next().unwrap(); //Generate new samples until ending time is reached. while t < t_end { - //Generate the next transition time for each uninitialized variable. - for (idx, val) in next_transitions.iter_mut().enumerate() { - if let None = val { - *val = Some( - net.get_node(idx) - .get_random_residence_time( - net.get_node(idx).state_to_index(¤t_state[idx]), - net.get_param_index_network(idx, ¤t_state), - &mut rng, - ) - .unwrap() - + t, - ); - } - } - - //Get the variable with the smallest transition time. - let next_node_transition = next_transitions - .iter() - .enumerate() - .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) - .unwrap() - .0; - //Check if the next transition take place after the ending time. - if next_transitions[next_node_transition].unwrap() > t_end { - break; - } - //Get the time in which the next transition occurs. - t = next_transitions[next_node_transition].unwrap().clone(); - //Add the transition time to next - time.push(t.clone()); - - //Compute the new state of the transitioning variable. - current_state[next_node_transition] = net - .get_node(next_node_transition) - .get_random_state( - net.get_node(next_node_transition) - .state_to_index(¤t_state[next_node_transition]), - net.get_param_index_network(next_node_transition, ¤t_state), - &mut rng, - ) - .unwrap(); - - //Add the new state to events - events.push(Array::from_vec( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - )); - //Reset the next transition time for the transitioning node. - next_transitions[next_node_transition] = None; - - //Reset the next transition time for each child of the transitioning node. - for child in net.get_children_set(next_node_transition) { - next_transitions[child] = None - } + time.push(t); + events.push(current_state); + (t, current_state) = sampler.next().unwrap(); } - //Add current_state as last state. - events.push( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - ); + current_state = events.last().unwrap().clone(); + events.push(current_state); + //Add t_end as last time. time.push(t_end.clone()); @@ -176,11 +91,18 @@ pub fn trajectory_generator( trajectories.push(Trajectory::new( Array::from_vec(time), Array2::from_shape_vec( - (events.len(), current_state.len()), - events.iter().flatten().cloned().collect(), + (events.len(), events.last().unwrap().len()), + events + .iter() + .flatten() + .map(|x| match x { + params::StateType::Discrete(x) => x.clone(), + }) + .collect(), ) .unwrap(), )); + sampler.reset(); } //Return a dataset object with the sampled trajectories. Dataset::new(trajectories) diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 5de02d7..1e19ce7 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -427,7 +427,7 @@ fn learn_mixed_discrete_cim(pl: T) { [0.8, 0.6, 0.2, -1.6] ], ]), - 0.1 + 0.2 )); } From a07dc214ce72f25538eb0b86ab16b44f0422cae1 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 21 Sep 2022 10:50:12 +0200 Subject: [PATCH 054/126] Clippy errors solved --- src/sampling.rs | 4 ++-- src/tools.rs | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/sampling.rs b/src/sampling.rs index 9bbf569..19b2409 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -1,5 +1,5 @@ use crate::{ - network::{self, Network}, + network::{Network}, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -22,7 +22,7 @@ where impl<'a, T: Network> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { - let mut rng: ChaCha8Rng = match seed { + let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. Some(seed) => SeedableRng::seed_from_u64(seed), //Otherwise create a new random generator using the method `from_entropy` diff --git a/src/tools.rs b/src/tools.rs index 0dfea9e..671a55a 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,8 +1,5 @@ use ndarray::prelude::*; -use rand_chacha::rand_core::SeedableRng; -use rand_chacha::ChaCha8Rng; -use crate::params::ParamsTrait; use crate::sampling::{Sampler, ForwardSampler}; use crate::{network, params}; From 622cd305d0b2ac2d376d74318f0c3aa6a5c1e487 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 21 Sep 2022 11:24:26 +0200 Subject: [PATCH 055/126] Formatting --- src/lib.rs | 2 +- src/sampling.rs | 21 ++++++++++++--------- src/tools.rs | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d40776a..280bd21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,6 @@ pub mod ctbn; pub mod network; pub mod parameter_learning; pub mod params; +pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod sampling; diff --git a/src/sampling.rs b/src/sampling.rs index 19b2409..0660939 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -1,5 +1,5 @@ use crate::{ - network::{Network}, + network::Network, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -65,32 +65,35 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { } } - let next_node_transition = self.next_transitions + let next_node_transition = self + .next_transitions .iter() .enumerate() .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) .unwrap() .0; - + self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); - - self.current_state[next_node_transition] = self.net + + self.current_state[next_node_transition] = self + .net .get_node(next_node_transition) .get_random_state( - self.net.get_node(next_node_transition) + self.net + .get_node(next_node_transition) .state_to_index(&self.current_state[next_node_transition]), - self.net.get_param_index_network(next_node_transition, &self.current_state), + self.net + .get_param_index_network(next_node_transition, &self.current_state), &mut self.rng, ) .unwrap(); self.next_transitions[next_node_transition] = None; - + for child in self.net.get_children_set(next_node_transition) { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) } } diff --git a/src/tools.rs b/src/tools.rs index 671a55a..70bbf76 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,6 +1,6 @@ use ndarray::prelude::*; -use crate::sampling::{Sampler, ForwardSampler}; +use crate::sampling::{ForwardSampler, Sampler}; use crate::{network, params}; pub struct Trajectory { From df99a2cf3ebbe8e9bed3c4c7de5e4d4696de3618 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 22 Sep 2022 09:35:04 +0200 Subject: [PATCH 056/126] cargo workspace --- Cargo.toml | 21 ++++--------------- reCTBN/Cargo.toml | 18 ++++++++++++++++ {src => reCTBN/src}/ctbn.rs | 0 {src => reCTBN/src}/lib.rs | 0 {src => reCTBN/src}/network.rs | 0 {src => reCTBN/src}/parameter_learning.rs | 0 {src => reCTBN/src}/params.rs | 0 {src => reCTBN/src}/sampling.rs | 0 {src => reCTBN/src}/structure_learning.rs | 0 .../constraint_based_algorithm.rs | 0 .../structure_learning/hypothesis_test.rs | 0 .../score_based_algorithm.rs | 0 .../src}/structure_learning/score_function.rs | 0 {src => reCTBN/src}/tools.rs | 0 {tests => reCTBN/tests}/ctbn.rs | 0 {tests => reCTBN/tests}/parameter_learning.rs | 0 {tests => reCTBN/tests}/params.rs | 0 {tests => reCTBN/tests}/structure_learning.rs | 0 {tests => reCTBN/tests}/tools.rs | 0 {tests => reCTBN/tests}/utils.rs | 0 20 files changed, 22 insertions(+), 17 deletions(-) create mode 100644 reCTBN/Cargo.toml rename {src => reCTBN/src}/ctbn.rs (100%) rename {src => reCTBN/src}/lib.rs (100%) rename {src => reCTBN/src}/network.rs (100%) rename {src => reCTBN/src}/parameter_learning.rs (100%) rename {src => reCTBN/src}/params.rs (100%) rename {src => reCTBN/src}/sampling.rs (100%) rename {src => reCTBN/src}/structure_learning.rs (100%) rename {src => reCTBN/src}/structure_learning/constraint_based_algorithm.rs (100%) rename {src => reCTBN/src}/structure_learning/hypothesis_test.rs (100%) rename {src => reCTBN/src}/structure_learning/score_based_algorithm.rs (100%) rename {src => reCTBN/src}/structure_learning/score_function.rs (100%) rename {src => reCTBN/src}/tools.rs (100%) rename {tests => reCTBN/tests}/ctbn.rs (100%) rename {tests => reCTBN/tests}/parameter_learning.rs (100%) rename {tests => reCTBN/tests}/params.rs (100%) rename {tests => reCTBN/tests}/structure_learning.rs (100%) rename {tests => reCTBN/tests}/tools.rs (100%) rename {tests => reCTBN/tests}/utils.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 547a8b8..53c74f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,18 +1,5 @@ -[package] -name = "reCTBN" -version = "0.1.0" -edition = "2021" +[workspace] -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -ndarray = {version="*", features=["approx-0_5"]} -thiserror = "*" -rand = "*" -bimap = "*" -enum_dispatch = "*" -statrs = "*" -rand_chacha = "*" - -[dev-dependencies] -approx = { package = "approx", version = "0.5" } +members = [ + "reCTBN", +] diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml new file mode 100644 index 0000000..547a8b8 --- /dev/null +++ b/reCTBN/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "reCTBN" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ndarray = {version="*", features=["approx-0_5"]} +thiserror = "*" +rand = "*" +bimap = "*" +enum_dispatch = "*" +statrs = "*" +rand_chacha = "*" + +[dev-dependencies] +approx = { package = "approx", version = "0.5" } diff --git a/src/ctbn.rs b/reCTBN/src/ctbn.rs similarity index 100% rename from src/ctbn.rs rename to reCTBN/src/ctbn.rs diff --git a/src/lib.rs b/reCTBN/src/lib.rs similarity index 100% rename from src/lib.rs rename to reCTBN/src/lib.rs diff --git a/src/network.rs b/reCTBN/src/network.rs similarity index 100% rename from src/network.rs rename to reCTBN/src/network.rs diff --git a/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs similarity index 100% rename from src/parameter_learning.rs rename to reCTBN/src/parameter_learning.rs diff --git a/src/params.rs b/reCTBN/src/params.rs similarity index 100% rename from src/params.rs rename to reCTBN/src/params.rs diff --git a/src/sampling.rs b/reCTBN/src/sampling.rs similarity index 100% rename from src/sampling.rs rename to reCTBN/src/sampling.rs diff --git a/src/structure_learning.rs b/reCTBN/src/structure_learning.rs similarity index 100% rename from src/structure_learning.rs rename to reCTBN/src/structure_learning.rs diff --git a/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs similarity index 100% rename from src/structure_learning/constraint_based_algorithm.rs rename to reCTBN/src/structure_learning/constraint_based_algorithm.rs diff --git a/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs similarity index 100% rename from src/structure_learning/hypothesis_test.rs rename to reCTBN/src/structure_learning/hypothesis_test.rs diff --git a/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs similarity index 100% rename from src/structure_learning/score_based_algorithm.rs rename to reCTBN/src/structure_learning/score_based_algorithm.rs diff --git a/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs similarity index 100% rename from src/structure_learning/score_function.rs rename to reCTBN/src/structure_learning/score_function.rs diff --git a/src/tools.rs b/reCTBN/src/tools.rs similarity index 100% rename from src/tools.rs rename to reCTBN/src/tools.rs diff --git a/tests/ctbn.rs b/reCTBN/tests/ctbn.rs similarity index 100% rename from tests/ctbn.rs rename to reCTBN/tests/ctbn.rs diff --git a/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs similarity index 100% rename from tests/parameter_learning.rs rename to reCTBN/tests/parameter_learning.rs diff --git a/tests/params.rs b/reCTBN/tests/params.rs similarity index 100% rename from tests/params.rs rename to reCTBN/tests/params.rs diff --git a/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs similarity index 100% rename from tests/structure_learning.rs rename to reCTBN/tests/structure_learning.rs diff --git a/tests/tools.rs b/reCTBN/tests/tools.rs similarity index 100% rename from tests/tools.rs rename to reCTBN/tests/tools.rs diff --git a/tests/utils.rs b/reCTBN/tests/utils.rs similarity index 100% rename from tests/utils.rs rename to reCTBN/tests/utils.rs From 174e85734ec567f23127f0bd3cce3dcc2154e265 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 12 Oct 2022 14:47:09 +0200 Subject: [PATCH 057/126] Added docs generation in GitHub Workflows --- .github/workflows/build.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0e9a88c..7cc300c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,7 +22,7 @@ jobs: profile: minimal toolchain: stable default: true - components: clippy, rustfmt + components: clippy, rustfmt, rust-docs - name: Setup Rust nightly uses: actions-rs/toolchain@v1 with: @@ -30,6 +30,11 @@ jobs: toolchain: nightly default: false components: rustfmt + - name: Docs (doc) + uses: actions-rs/cargo@v1 + with: + command: rustdoc + args: --package reCTBN -- --default-theme=ayu - name: Linting (clippy) uses: actions-rs/clippy-check@v1 with: From 672de56c3119e700583b9fcb44d3feed76ef81bb Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 12 Oct 2022 15:51:52 +0200 Subject: [PATCH 058/126] Added a brief description of the project in form of docstrings at the crate level --- reCTBN/src/lib.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 280bd21..da1aa06 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -1,3 +1,9 @@ +//! # reCTBN +//! +//! > **Note:** At the moment it's in pre-alpha state. 🧪⚗️💥 +//! +//! `reCTBN` is a Continuous Time Bayesian Networks Library written in Rust. 🦀 + #![allow(non_snake_case)] #[cfg(test)] extern crate approx; From 616d5ec3d551269f184141f48a7a9d638fc49767 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 12 Oct 2022 16:33:47 +0200 Subject: [PATCH 059/126] Fixed and prettified some imprecisions to old docstrings and added new ones in `ctbn.rs` --- reCTBN/src/ctbn.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/ctbn.rs index e2f5dd7..fae7f4d 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/ctbn.rs @@ -5,16 +5,18 @@ use ndarray::prelude::*; use crate::network; use crate::params::{Params, ParamsTrait, StateType}; -///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is -///composed by the following elements: -///- **adj_metrix**: a 2d ndarray representing the adjacency matrix -///- **nodes**: a vector containing all the nodes and their parameters. -///The index of a node inside the vector is also used as index for the adj_matrix. +/// It represents both the structure and the parameters of a CTBN. /// -///# Examples +/// # Arguments /// -///``` -/// +/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix +/// * `nodes` - A vector containing all the nodes and their parameters. +/// +/// The index of a node inside the vector is also used as index for the `adj_matrix`. +/// +/// # Example +/// +/// ```rust /// use std::collections::BTreeSet; /// use reCTBN::network::Network; /// use reCTBN::params; @@ -66,12 +68,14 @@ impl CtbnNetwork { } impl network::Network for CtbnNetwork { + /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( (self.nodes.len(), self.nodes.len()).f(), )); } + /// Add a new node. fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; @@ -79,6 +83,7 @@ impl network::Network for CtbnNetwork { Ok(self.nodes.len() - 1) } + /// Connect two nodes with a new edge. fn add_edge(&mut self, parent: usize, child: usize) { if let None = self.adj_matrix { self.initialize_adj_matrix(); @@ -94,6 +99,7 @@ impl network::Network for CtbnNetwork { 0..self.nodes.len() } + /// Get the number of nodes of the network. fn get_number_of_nodes(&self) -> usize { self.nodes.len() } @@ -138,6 +144,7 @@ impl network::Network for CtbnNetwork { .0 } + /// Get all the parents of the given node. fn get_parent_set(&self, node: usize) -> BTreeSet { self.adj_matrix .as_ref() @@ -149,6 +156,7 @@ impl network::Network for CtbnNetwork { .collect() } + /// Get all the children of the given node. fn get_children_set(&self, node: usize) -> BTreeSet { self.adj_matrix .as_ref() From ccced921495558a26703bf351310d5c8429ce7d0 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 17 Oct 2022 11:48:08 +0200 Subject: [PATCH 060/126] Fixed some docstrings in `params.rs` --- reCTBN/src/params.rs | 47 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index f994b99..d7bf8f8 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -23,9 +23,8 @@ pub enum StateType { Discrete(usize), } -/// Parameters -/// The Params trait is the core element for building different types of nodes. The goal is to -/// define the set of method required to describes a generic node. +/// This is a core element for building different types of nodes; the goal is to define the set of +/// methods required to describes a generic node. #[enum_dispatch(Params)] pub trait ParamsTrait { fn reset_params(&mut self); @@ -65,8 +64,8 @@ pub trait ParamsTrait { fn get_label(&self) -> &String; } -/// The Params enum is the core element for building different types of nodes. The goal is to -/// define all the supported type of Parameters +/// Is a core element for building different types of nodes; the goal is to define all the +/// supported type of Parameters #[derive(Clone)] #[enum_dispatch] pub enum Params { @@ -76,14 +75,14 @@ pub enum Params { /// DiscreteStatesContinousTime. /// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// following elements: -/// - **domain**: an ordered and exhaustive set of possible states -/// - **cim**: Conditional Intensity Matrix -/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter -/// learning task and are composed by: -/// - **transitions**: number of transitions from one state to another given a specific -/// realization of the parent set -/// - **residence_time**: permanence time in each possible states given a specific -/// realization of the parent set +/// - `label` +/// - `domain`: an ordered and exhaustive set of possible states. +/// - `cim`: Conditional Intensity Matrix. +/// - `transitions`: number of transitions from one state to another given a specific realization +/// of the parent set; is a sufficient statistics are mainly used during the parameter learning +/// task. +/// - `residence_time`: permanence time in each possible state, given a specific realization of the +/// parent set; is a sufficient statistics are mainly used during the parameter learning task. #[derive(Clone)] pub struct DiscreteStatesContinousTimeParams { label: String, @@ -104,15 +103,17 @@ impl DiscreteStatesContinousTimeParams { } } - ///Getter function for CIM + /// Getter function for CIM pub fn get_cim(&self) -> &Option> { &self.cim } - ///Setter function for CIM.\\ - ///This function check if the cim is valid using the validate_params method. - ///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) - ///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError + /// Setter function for CIM. + /// + /// This function check if the CIM is valid using the validate_params method: + /// - **Valid CIM inserted**: it substitute the CIM in self.cim and return `Ok(())`. + /// - **Invalid CIM inserted**: it replace the self.cim value with None and it returns + /// `ParamsError`. pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError> { self.cim = Some(cim); match self.validate_params() { @@ -124,27 +125,27 @@ impl DiscreteStatesContinousTimeParams { } } - ///Unchecked version of the setter function for CIM. + /// Unchecked version of the setter function for CIM. pub fn set_cim_unchecked(&mut self, cim: Array3) { self.cim = Some(cim); } - ///Getter function for transitions + /// Getter function for transitions. pub fn get_transitions(&self) -> &Option> { &self.transitions } - ///Setter function for transitions + /// Setter function for transitions. pub fn set_transitions(&mut self, transitions: Array3) { self.transitions = Some(transitions); } - ///Getter function for residence_time + /// Getter function for residence_time. pub fn get_residence_time(&self) -> &Option> { &self.residence_time } - ///Setter function for residence_time + ///Setter function for residence_time. pub fn set_residence_time(&mut self, residence_time: Array2) { self.residence_time = Some(residence_time); } From 2153f46758b367f9803bc520fcf0d9ee7be2abbb Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 17 Oct 2022 12:03:14 +0200 Subject: [PATCH 061/126] Fixed small lint problem and added a paragraph to the README --- README.md | 8 ++++++++ reCTBN/src/params.rs | 3 +-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f928955..6a60dff 100644 --- a/README.md +++ b/README.md @@ -56,3 +56,11 @@ To check the **formatting**: ```sh cargo fmt --all -- --check ``` + +## Documentation + +To generate the **documentation**: + +```sh +cargo rustdoc --package reCTBN --open -- --default-theme=ayu +``` diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index d7bf8f8..65db06c 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -72,7 +72,6 @@ pub enum Params { DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), } -/// DiscreteStatesContinousTime. /// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// following elements: /// - `label` @@ -109,7 +108,7 @@ impl DiscreteStatesContinousTimeParams { } /// Setter function for CIM. - /// + /// /// This function check if the CIM is valid using the validate_params method: /// - **Valid CIM inserted**: it substitute the CIM in self.cim and return `Ok(())`. /// - **Invalid CIM inserted**: it replace the self.cim value with None and it returns From 064b582833e25b903ebee32bc062967870272c78 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 17 Oct 2022 12:59:39 +0200 Subject: [PATCH 062/126] Added docstrings for the chi-squared test --- .../src/structure_learning/hypothesis_test.rs | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4f2ce18..7083d38 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -20,6 +20,17 @@ pub trait HypothesisTest { P: parameter_learning::ParameterLearning; } +/// Does the chi-squared test (χ2 test). +/// +/// Used to determine if a difference between two sets of data is due to chance, or if it is due to +/// a relationship (dependence) between the variables. +/// +/// # Arguments +/// +/// * `alpha` - is the significance level, the probability to reject a true null hypothesis; +/// in other words is the risk of concluding that an association between the variables exists +/// when there is no actual association. + pub struct ChiSquare { alpha: f64, } @@ -30,8 +41,21 @@ impl ChiSquare { pub fn new(alpha: f64) -> ChiSquare { ChiSquare { alpha } } - // Restituisce true quando le matrici sono molto simili, quindi indipendenti - // false quando sono diverse, quindi dipendenti + + /// Compare two matrices extracted from two 3rd-orer tensors. + /// + /// # Arguments + /// + /// * `i` - Position of the matrix of `M1` to compare with `M2`. + /// * `M1` - 3rd-order tensor 1. + /// * `j` - Position of the matrix of `M2` to compare with `M1`. + /// * `M2` - 3rd-order tensor 2. + /// + /// # Returns + /// + /// * `true` - when the matrices `M1` and `M2` are very similar, then **dependendent**. + /// * `false` - when the matrices `M1` and `M2` are too different, then **independent**. + pub fn compare_matrices( &self, i: usize, From 9ca8973550010a503b0c74cb6b1672886401ee22 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 17 Oct 2022 13:17:34 +0200 Subject: [PATCH 063/126] Harmonized some docstrings in `params.rs` --- reCTBN/src/params.rs | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 65db06c..e533f21 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -73,14 +73,17 @@ pub enum Params { } /// This represents the parameters of a classical discrete node for ctbn and it's composed by the -/// following elements: -/// - `label` -/// - `domain`: an ordered and exhaustive set of possible states. -/// - `cim`: Conditional Intensity Matrix. -/// - `transitions`: number of transitions from one state to another given a specific realization +/// following elements. +/// +/// # Arguments +/// +/// * `label` - node's variable name. +/// * `domain` - an ordered and exhaustive set of possible states. +/// * `cim` - Conditional Intensity Matrix. +/// * `transitions` - number of transitions from one state to another given a specific realization /// of the parent set; is a sufficient statistics are mainly used during the parameter learning /// task. -/// - `residence_time`: permanence time in each possible state, given a specific realization of the +/// * `residence_time` - residence time in each possible state, given a specific realization of the /// parent set; is a sufficient statistics are mainly used during the parameter learning task. #[derive(Clone)] pub struct DiscreteStatesContinousTimeParams { @@ -109,9 +112,9 @@ impl DiscreteStatesContinousTimeParams { /// Setter function for CIM. /// - /// This function check if the CIM is valid using the validate_params method: - /// - **Valid CIM inserted**: it substitute the CIM in self.cim and return `Ok(())`. - /// - **Invalid CIM inserted**: it replace the self.cim value with None and it returns + /// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method: + /// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`. + /// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns /// `ParamsError`. pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError> { self.cim = Some(cim); @@ -144,7 +147,7 @@ impl DiscreteStatesContinousTimeParams { &self.residence_time } - ///Setter function for residence_time. + /// Setter function for residence_time. pub fn set_residence_time(&mut self, residence_time: Array2) { self.residence_time = Some(residence_time); } From f7165d034598a14cc3f0460ee09ddbd31a831e95 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 17 Oct 2022 14:19:26 +0200 Subject: [PATCH 064/126] Added docstrings to `networks.rs` --- reCTBN/src/network.rs | 87 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 9 deletions(-) diff --git a/reCTBN/src/network.rs b/reCTBN/src/network.rs index cbae339..8fc8271 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/network.rs @@ -11,33 +11,102 @@ pub enum NetworkError { NodeInsertionError(String), } -///Network -///The Network trait define the required methods for a structure used as pgm (such as ctbn). +/// It defines the required methods for a structure used as a PGM (such as a CTBN). pub trait Network { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; + /// Add an **directed edge** between a two nodes of the network. + /// + /// # Arguments + /// + /// * `parent` - parent node. + /// * `child` - child node. fn add_edge(&mut self, parent: usize, child: usize); - ///Get all the indices of the nodes contained inside the network + /// Get all the indices of the nodes contained inside the network. fn get_node_indices(&self) -> std::ops::Range; + + /// Get the numbers of nodes contained in the network. fn get_number_of_nodes(&self) -> usize; + + /// Get the **node param**. + /// + /// # Arguments + /// + /// * `node_idx` - node index value. + /// + /// # Return + /// + /// * The selected **node param**. fn get_node(&self, node_idx: usize) -> ¶ms::Params; + + /// Get the **node param**. + /// + /// # Arguments + /// + /// * `node_idx` - node index value. + /// + /// # Return + /// + /// * The selected **node mutable param**. fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; - ///Compute the index that must be used to access the parameters of a node given a specific - ///configuration of the network. Usually, the only values really used in *current_state* are - ///the ones in the parent set of the *node*. + /// Compute the index that must be used to access the parameters of a `node`, given a specific + /// configuration of the network. + /// + /// Usually, the only values really used in `current_state` are the ones in the parent set of + /// the `node`. + /// + /// # Arguments + /// + /// * `node` - selected node. + /// * `current_state` - current configuration of the network. + /// + /// # Return + /// + /// * Index of the `node` relative to the network. fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; - ///Compute the index that must be used to access the parameters of a node given a specific - ///configuration of the network and a generic parent_set. Usually, the only values really used - ///in *current_state* are the ones in the parent set of the *node*. + /// Compute the index that must be used to access the parameters of a `node`, given a specific + /// configuration of the network and a generic `parent_set`. + /// + /// Usually, the only values really used in `current_state` are the ones in the parent set of + /// the `node`. + /// + /// # Arguments + /// + /// * `current_state` - current configuration of the network. + /// * `parent_set` - parent set of the selected `node`. + /// + /// # Return + /// + /// * Index of the `node` relative to the network. fn get_param_index_from_custom_parent_set( &self, current_state: &Vec, parent_set: &BTreeSet, ) -> usize; + + /// Get the **parent set** of a given **node**. + /// + /// # Arguments + /// + /// * `node` - node index value. + /// + /// # Return + /// + /// * The **parent set** of the selected node. fn get_parent_set(&self, node: usize) -> BTreeSet; + + /// Get the **children set** of a given **node**. + /// + /// # Arguments + /// + /// * `node` - node index value. + /// + /// # Return + /// + /// * The **children set** of the selected node. fn get_children_set(&self, node: usize) -> BTreeSet; } From 08623e28d4f479cfcacc3746c19e1ff903a0e5ef Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 18 Oct 2022 15:11:15 +0200 Subject: [PATCH 065/126] Added various docstrings notes to all rust modules --- reCTBN/src/ctbn.rs | 2 ++ reCTBN/src/network.rs | 5 ++++- reCTBN/src/parameter_learning.rs | 2 ++ reCTBN/src/params.rs | 2 ++ reCTBN/src/sampling.rs | 2 ++ reCTBN/src/structure_learning.rs | 2 ++ reCTBN/src/structure_learning/constraint_based_algorithm.rs | 2 ++ reCTBN/src/structure_learning/hypothesis_test.rs | 2 ++ reCTBN/src/structure_learning/score_based_algorithm.rs | 2 ++ reCTBN/src/structure_learning/score_function.rs | 2 ++ reCTBN/src/tools.rs | 2 ++ 11 files changed, 24 insertions(+), 1 deletion(-) diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/ctbn.rs index fae7f4d..2b01d14 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/ctbn.rs @@ -1,3 +1,5 @@ +//! Continuous Time Bayesian Network + use std::collections::BTreeSet; use ndarray::prelude::*; diff --git a/reCTBN/src/network.rs b/reCTBN/src/network.rs index 8fc8271..fbdd2e6 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/network.rs @@ -1,3 +1,5 @@ +//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs + use std::collections::BTreeSet; use thiserror::Error; @@ -11,7 +13,8 @@ pub enum NetworkError { NodeInsertionError(String), } -/// It defines the required methods for a structure used as a PGM (such as a CTBN). +/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such +/// as a CTBN). pub trait Network { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index bdb5d4a..61d4dca 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,3 +1,5 @@ +//! Module containing methods used to learn the parameters. + use std::collections::BTreeSet; use ndarray::prelude::*; diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index e533f21..070c997 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -1,3 +1,5 @@ +//! Module containing methods to define different types of nodes. + use std::collections::BTreeSet; use enum_dispatch::enum_dispatch; diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 0660939..d435634 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,3 +1,5 @@ +//! Module containing methods for the sampling. + use crate::{ network::Network, params::{self, ParamsTrait}, diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index 8b90cdf..57fed1e 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -1,3 +1,5 @@ +//! Learn the structure of the network. + pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index b3fc3e1..670c8ed 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,3 +1,5 @@ +//! Module containing constraint based algorithms like CTPC and Hiton. + //pub struct CTPC { // //} diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 7083d38..1404b8e 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -1,3 +1,5 @@ +//! Module containing an hypothesis test for constraint based algorithms like chi-squared test, F test, etc... + use std::collections::BTreeSet; use ndarray::{Array3, Axis}; diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index cc8541a..9e329eb 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -1,3 +1,5 @@ +//! Module containing score based algorithms like Hill Climbing and Tabu Search. + use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index b3b1597..cb6ad7b 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -1,3 +1,5 @@ +//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc... + use std::collections::BTreeSet; use ndarray::prelude::*; diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 70bbf76..aa48883 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,3 +1,5 @@ +//! Contains commonly used methods used across the crate. + use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; From 245b3b5d4598b8fa14c44ce85a830c336966d888 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 18 Oct 2022 15:34:37 +0200 Subject: [PATCH 066/126] Replaced the static docstrings at the crate level in `lib.rs` withthe external file `README.md` --- reCTBN/src/lib.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index da1aa06..b1cf773 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -1,8 +1,4 @@ -//! # reCTBN -//! -//! > **Note:** At the moment it's in pre-alpha state. 🧪⚗️💥 -//! -//! `reCTBN` is a Continuous Time Bayesian Networks Library written in Rust. 🦀 +#![doc = include_str!("../../README.md")] #![allow(non_snake_case)] #[cfg(test)] From 3522e1b6f633fd4ed0c9bee167ef604d869502ea Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 18 Oct 2022 15:40:59 +0200 Subject: [PATCH 067/126] Small formatting fix and rewording of a docstring --- reCTBN/src/lib.rs | 1 - reCTBN/src/structure_learning/hypothesis_test.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index b1cf773..db33ae4 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -1,5 +1,4 @@ #![doc = include_str!("../../README.md")] - #![allow(non_snake_case)] #[cfg(test)] extern crate approx; diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 1404b8e..f931e77 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -1,4 +1,4 @@ -//! Module containing an hypothesis test for constraint based algorithms like chi-squared test, F test, etc... +//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc... use std::collections::BTreeSet; From 832922922ade46f91904842c75841ceebd67e429 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 20 Oct 2022 08:58:02 +0200 Subject: [PATCH 068/126] Fixed error in chi2 compare matrices docstring --- reCTBN/src/structure_learning/hypothesis_test.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index f931e77..4e1cb24 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -55,8 +55,8 @@ impl ChiSquare { /// /// # Returns /// - /// * `true` - when the matrices `M1` and `M2` are very similar, then **dependendent**. - /// * `false` - when the matrices `M1` and `M2` are too different, then **independent**. + /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**. + /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**. pub fn compare_matrices( &self, From 0ae2168a9439635b91ffaa680db79569b3604c29 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 20 Oct 2022 09:36:02 +0200 Subject: [PATCH 069/126] Commented out some prints in chi2 compare matrices --- .../src/structure_learning/hypothesis_test.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4e1cb24..6474155 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -97,7 +97,7 @@ impl ChiSquare { let n = K.len(); K.into_shape((n, 1)).unwrap() }; - println!("K: {:?}", K); + //println!("K: {:?}", K); let L = 1.0 / &K; // ===== 2 // \ (K . M - L . M) @@ -108,18 +108,18 @@ impl ChiSquare { // x'ϵVal /X \ // \ i/ let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); - println!("M1: {:?}", M1); - println!("M2: {:?}", M2); - println!("L*M1: {:?}", (L * &M1)); - println!("K*M2: {:?}", (K * &M2)); - println!("X_2: {:?}", X_2); + //println!("M1: {:?}", M1); + //println!("M2: {:?}", M2); + //println!("L*M1: {:?}", (L * &M1)); + //println!("K*M2: {:?}", (K * &M2)); + //println!("X_2: {:?}", X_2); X_2.diag_mut().fill(0.0); let X_2 = X_2.sum_axis(Axis(1)); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); - println!("CHI^2: {:?}", n); - println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); + //println!("CHI^2: {:?}", n); + //println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); - println!("test: {:?}", ret); + //println!("test: {:?}", ret); ret } } From a92b605daaba623632c0d9f15ecbede72aa76f2b Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 20 Oct 2022 11:48:50 +0200 Subject: [PATCH 070/126] In `Cargo.toml` set the exact version for `thiserror` crate and the tilde requirement for all the others --- reCTBN/Cargo.toml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index 547a8b8..b0a691b 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -6,13 +6,13 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ndarray = {version="*", features=["approx-0_5"]} -thiserror = "*" -rand = "*" -bimap = "*" -enum_dispatch = "*" -statrs = "*" -rand_chacha = "*" +ndarray = {version="~0.15", features=["approx-0_5"]} +thiserror = "1.0.37" +rand = "~0.8" +bimap = "~0.6" +enum_dispatch = "~0.3" +statrs = "~0.16" +rand_chacha = "~0.3" [dev-dependencies] -approx = { package = "approx", version = "0.5" } +approx = { package = "approx", version = "~0.5" } From ec72a6a2f9e6981da7407a2e7e4a66b2345d003c Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 28 Oct 2022 16:02:12 +0200 Subject: [PATCH 071/126] Defined the `compare_matrices` function for the F-test --- .../src/structure_learning/hypothesis_test.rs | 58 ++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 6474155..7534eaf 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -3,7 +3,7 @@ use std::collections::BTreeSet; use ndarray::{Array3, Axis}; -use statrs::distribution::{ChiSquared, ContinuousCDF}; +use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; use crate::{network, parameter_learning}; @@ -37,7 +37,61 @@ pub struct ChiSquare { alpha: f64, } -pub struct F {} +pub struct F { + alpha: f64, +} + +impl F { + pub fn new(alpha: f64) -> F { + F { alpha } + } + + pub fn compare_matrices( + &self, + i: usize, + M1: &Array3, + cim_1: &Array3, + j: usize, + M2: &Array3, + cim_2: &Array3, + ) -> bool { + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + let cim_1 = cim_1.index_axis(Axis(0), i); + let cim_2 = cim_2.index_axis(Axis(0), j); + let r1 = M1.sum_axis(Axis(1)); + let r2 = M2.sum_axis(Axis(1)); + let q1 = cim_1.diag(); + let q2 = cim_2.diag(); + for idx in 0..r1.shape()[0] { + let s = q2[idx] / q1[idx]; + let F = FisherSnedecor::new(r1[idx], r2[idx]); + let lim_sx = F.as_ref().expect("REASON").cdf(self.alpha / 2.0); + let lim_dx = F.as_ref().expect("REASON").cdf(1.0 - (self.alpha / 2.0)); + if s < lim_sx || s > lim_dx { + return false; + } + } + true + } +} + +impl HypothesisTest for F { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + cache: &mut parameter_learning::Cache

, + ) -> bool + where + T: network::Network, + P: parameter_learning::ParameterLearning, + { + true + } +} impl ChiSquare { pub fn new(alpha: f64) -> ChiSquare { From c08f4e1985edf949049b09eceba059e7de0cb1f1 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 10 Nov 2022 14:10:17 +0100 Subject: [PATCH 072/126] Added F call function --- .../src/structure_learning/hypothesis_test.rs | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 7534eaf..75c0eac 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -89,7 +89,38 @@ impl HypothesisTest for F { T: network::Network, P: parameter_learning::ParameterLearning, { - true + let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let mut extended_separation_set = separation_set.clone(); + extended_separation_set.insert(parent_node); + + let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let partial_cardinality_product: usize = extended_separation_set + .iter() + .take_while(|x| **x != parent_node) + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { + let idx_M_small: usize = idx_M_big % partial_cardinality_product + + (idx_M_big + / (partial_cardinality_product + * net.get_node(parent_node).get_reserved_space_as_parent())) + * partial_cardinality_product; + if !self.compare_matrices( + idx_M_small, + P_small.get_transitions().as_ref().unwrap(), + P_small.get_cim().as_ref().unwrap(), + idx_M_big, + P_big.get_transitions().as_ref().unwrap(), + P_big.get_cim().as_ref().unwrap(), + ) { + return false; + } + } + return true; } } From ed5471c7cf6d4a28e7486c3e8be0dd9e63cb79b5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 14 Nov 2022 16:07:04 +0100 Subject: [PATCH 073/126] Added ctmp --- reCTBN/src/lib.rs | 3 +- reCTBN/src/parameter_learning.rs | 12 +- reCTBN/src/{network.rs => process.rs} | 5 +- reCTBN/src/{ => process}/ctbn.rs | 10 +- reCTBN/src/process/ctmp.rs | 106 +++++++++++++++ reCTBN/src/sampling.rs | 10 +- reCTBN/src/structure_learning.rs | 4 +- .../src/structure_learning/hypothesis_test.rs | 6 +- .../score_based_algorithm.rs | 4 +- .../src/structure_learning/score_function.rs | 10 +- reCTBN/src/tools.rs | 4 +- reCTBN/tests/ctbn.rs | 4 +- reCTBN/tests/ctmp.rs | 127 ++++++++++++++++++ reCTBN/tests/parameter_learning.rs | 4 +- reCTBN/tests/structure_learning.rs | 4 +- reCTBN/tests/tools.rs | 4 +- 16 files changed, 276 insertions(+), 41 deletions(-) rename reCTBN/src/{network.rs => process.rs} (98%) rename reCTBN/src/{ => process}/ctbn.rs (95%) create mode 100644 reCTBN/src/process/ctmp.rs create mode 100644 reCTBN/tests/ctmp.rs diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index db33ae4..6ab59cb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -3,10 +3,9 @@ #[cfg(test)] extern crate approx; -pub mod ctbn; -pub mod network; pub mod parameter_learning; pub mod params; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod process; diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 61d4dca..2aa518c 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,10 +5,10 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{network, tools}; +use crate::{process, tools}; pub trait ParameterLearning { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -17,7 +17,7 @@ pub trait ParameterLearning { ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, dataset: &tools::Dataset, node: usize, @@ -73,7 +73,7 @@ pub fn sufficient_statistics( pub struct MLE {} impl ParameterLearning for MLE { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -120,7 +120,7 @@ pub struct BayesianApproach { } impl ParameterLearning for BayesianApproach { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -177,7 +177,7 @@ impl Cache

{ dataset, } } - pub fn fit( + pub fn fit( &mut self, net: &T, node: usize, diff --git a/reCTBN/src/network.rs b/reCTBN/src/process.rs similarity index 98% rename from reCTBN/src/network.rs rename to reCTBN/src/process.rs index fbdd2e6..2b70b59 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/process.rs @@ -1,5 +1,8 @@ //! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs +pub mod ctbn; +pub mod ctmp; + use std::collections::BTreeSet; use thiserror::Error; @@ -15,7 +18,7 @@ pub enum NetworkError { /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait Network { +pub trait NetworkProcess { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/process/ctbn.rs similarity index 95% rename from reCTBN/src/ctbn.rs rename to reCTBN/src/process/ctbn.rs index 2b01d14..c59d99d 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; -use crate::network; +use crate::process; use crate::params::{Params, ParamsTrait, StateType}; /// It represents both the structure and the parameters of a CTBN. @@ -20,9 +20,9 @@ use crate::params::{Params, ParamsTrait, StateType}; /// /// ```rust /// use std::collections::BTreeSet; -/// use reCTBN::network::Network; +/// use reCTBN::process::NetworkProcess; /// use reCTBN::params; -/// use reCTBN::ctbn::*; +/// use reCTBN::process::ctbn::*; /// /// //Create the domain for a discrete node /// let mut domain = BTreeSet::new(); @@ -69,7 +69,7 @@ impl CtbnNetwork { } } -impl network::Network for CtbnNetwork { +impl process::NetworkProcess for CtbnNetwork { /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( @@ -78,7 +78,7 @@ impl network::Network for CtbnNetwork { } /// Add a new node. - fn add_node(&mut self, mut n: Params) -> Result { + fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs new file mode 100644 index 0000000..b0b042a --- /dev/null +++ b/reCTBN/src/process/ctmp.rs @@ -0,0 +1,106 @@ +use std::collections::BTreeSet; + +use crate::{process, params::{Params, StateType}}; + +use super::NetworkProcess; + +pub struct CtmpProcess { + param: Option +} + +impl CtmpProcess { + pub fn new() -> CtmpProcess { + CtmpProcess { param: None } + } +} + +impl NetworkProcess for CtmpProcess { + fn initialize_adj_matrix(&mut self) { + unimplemented!("CtmpProcess has only one node") + } + + fn add_node(&mut self, n: crate::params::Params) -> Result { + match self.param { + None => { + self.param = Some(n); + Ok(0) + }, + Some(_) => Err(process::NetworkError::NodeInsertionError("CtmpProcess has only one node".to_string())) + } + } + + fn add_edge(&mut self, parent: usize, child: usize) { + unimplemented!("CtmpProcess has only one node") + } + + fn get_node_indices(&self) -> std::ops::Range { + match self.param { + None => 0..0, + Some(_) => 0..1 + } + } + + fn get_number_of_nodes(&self) -> usize { + match self.param { + None => 0, + Some(_) => 1 + } + } + + fn get_node(&self, node_idx: usize) -> &crate::params::Params { + if node_idx == 0 { + self.param.as_ref().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { + if node_idx == 0 { + self.param.as_mut().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_network(&self, node: usize, current_state: &Vec) + -> usize { + if node == 0 { + match current_state[0] { + StateType::Discrete(x) => x + } + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &std::collections::BTreeSet, + ) -> usize { + unimplemented!("CtmpProcess has only one node") + } + + fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } + + fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } +} diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d435634..050daeb 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,7 +1,7 @@ //! Module containing methods for the sampling. use crate::{ - network::Network, + process::NetworkProcess, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -13,7 +13,7 @@ pub trait Sampler: Iterator { pub struct ForwardSampler<'a, T> where - T: Network, + T: NetworkProcess, { net: &'a T, rng: ChaCha8Rng, @@ -22,7 +22,7 @@ where next_transitions: Vec>, } -impl<'a, T: Network> ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. @@ -42,7 +42,7 @@ impl<'a, T: Network> ForwardSampler<'a, T> { } } -impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { type Item = (f64, Vec); fn next(&mut self) -> Option { @@ -100,7 +100,7 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { } } -impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; self.current_state = self diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index 57fed1e..b272e22 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{network, tools}; +use crate::{process, tools}; pub trait StructureLearningAlgorithm { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network; + T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 6474155..4ec3377 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::params::*; -use crate::{network, parameter_learning}; +use crate::{process, parameter_learning}; pub trait HypothesisTest { fn call( @@ -18,7 +18,7 @@ pub trait HypothesisTest { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning; } @@ -135,7 +135,7 @@ impl HypothesisTest for ChiSquare { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 9e329eb..16e9056 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{network, tools}; +use crate::{process, tools}; pub struct HillClimbing { score_function: S, @@ -23,7 +23,7 @@ impl HillClimbing { impl StructureLearningAlgorithm for HillClimbing { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network, + T: process::NetworkProcess, { //Check the coherence between dataset and network if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index cb6ad7b..8943478 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{network, parameter_learning, params, tools}; +use crate::{process, parameter_learning, params, tools}; pub trait ScoreFunction { fn call( @@ -16,7 +16,7 @@ pub trait ScoreFunction { dataset: &tools::Dataset, ) -> f64 where - T: network::Network; + T: process::NetworkProcess; } pub struct LogLikelihood { @@ -41,7 +41,7 @@ impl LogLikelihood { dataset: &tools::Dataset, ) -> (f64, Array3) where - T: network::Network, + T: process::NetworkProcess, { //Identify the type of node used match &net.get_node(node) { @@ -100,7 +100,7 @@ impl ScoreFunction for LogLikelihood { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { self.compute_score(net, node, parent_set, dataset).0 } @@ -127,7 +127,7 @@ impl ScoreFunction for BIC { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index aa48883..6f2f648 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{network, params}; +use crate::{process, params}; pub struct Trajectory { time: Array1, @@ -51,7 +51,7 @@ impl Dataset { } } -pub fn trajectory_generator( +pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 63c9621..0ad0fc4 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,8 +1,8 @@ mod utils; use std::collections::BTreeSet; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs new file mode 100644 index 0000000..31bc6df --- /dev/null +++ b/reCTBN/tests/ctmp.rs @@ -0,0 +1,127 @@ +mod utils; + +use std::collections::BTreeSet; + +use reCTBN::{ + params, + params::ParamsTrait, + process::{ctmp::*, NetworkProcess}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn define_simple_ctmp() { + let _ = CtmpProcess::new(); + assert!(true); +} + +#[test] +fn add_node_to_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); +} + +#[test] +fn add_two_nodes_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + match n2 { + Ok(_) => assert!(false), + Err(_) => assert!(true), + }; +} + +#[test] +#[should_panic] +fn add_edge_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + net.add_edge(0, 1) +} + +#[test] +fn childen_and_parents() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + assert_eq!(0, net.get_parent_set(0).len()); + assert_eq!(0, net.get_children_set(0).len()); +} + +#[test] +#[should_panic] +fn get_childen_panic() { + let mut net = CtmpProcess::new(); + net.get_children_set(0); +} + +#[test] +#[should_panic] +fn get_childen_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_children_set(1); +} + +#[test] +#[should_panic] +fn get_parent_panic() { + let mut net = CtmpProcess::new(); + net.get_parent_set(0); +} + +#[test] +#[should_panic] +fn get_parent_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_parent_set(1); +} + +#[test] +fn compute_index_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); + assert_eq!(6, idx); +} + +#[test] +#[should_panic] +fn compute_index_from_custom_parent_set_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let _idx = net.get_param_index_from_custom_parent_set( + &vec![params::StateType::Discrete(6)], + &BTreeSet::from([0]) + ); +} diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 7d09b07..2cbc185 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -2,8 +2,8 @@ mod utils; use ndarray::arr3; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; use reCTBN::tools::*; diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a1667c2..2ec64b2 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -4,8 +4,8 @@ mod utils; use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 589b04e..806faef 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,6 +1,6 @@ use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*; From 9fbdf25149f6ebe7fcaded6e609e711783053736 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 16 Nov 2022 10:40:19 +0100 Subject: [PATCH 074/126] Fixed `chi_square_call` test, the test was passing, but only for pure chance --- reCTBN/tests/structure_learning.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a1667c2..a8cf3c6 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -108,7 +108,7 @@ fn check_compatibility_between_dataset_and_network Date: Wed, 16 Nov 2022 10:46:55 +0100 Subject: [PATCH 075/126] Slight optimization of `F::compare_matrices` --- reCTBN/src/structure_learning/hypothesis_test.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 75c0eac..9f7a518 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -65,9 +65,10 @@ impl F { let q2 = cim_2.diag(); for idx in 0..r1.shape()[0] { let s = q2[idx] / q1[idx]; - let F = FisherSnedecor::new(r1[idx], r2[idx]); - let lim_sx = F.as_ref().expect("REASON").cdf(self.alpha / 2.0); - let lim_dx = F.as_ref().expect("REASON").cdf(1.0 - (self.alpha / 2.0)); + let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap(); + let s = F.cdf(s); + let lim_sx = self.alpha / 2.0; + let lim_dx = 1.0 - (self.alpha / 2.0); if s < lim_sx || s > lim_dx { return false; } From 7c3cba50d4afb08c1087711ef8fba12a2351ad54 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 11:14:41 +0100 Subject: [PATCH 076/126] Implemented amalgamation --- reCTBN/src/process/ctbn.rs | 78 +++++++++++++++++++++++++++++++++++++- reCTBN/tests/ctbn.rs | 36 +++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c59d99d..3852c50 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,8 +4,11 @@ use std::collections::BTreeSet; use ndarray::prelude::*; +use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; use crate::process; -use crate::params::{Params, ParamsTrait, StateType}; + +use super::ctmp::CtmpProcess; +use super::NetworkProcess; /// It represents both the structure and the parameters of a CTBN. /// @@ -67,6 +70,79 @@ impl CtbnNetwork { nodes: Vec::new(), } } + + pub fn amalgamation(&self) -> CtmpProcess { + for v in self.nodes.iter() { + match v { + Params::DiscreteStatesContinousTime(_) => {} + _ => panic!("Unsupported node"), + } + } + + let variables_domain = + Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); + + let state_space = variables_domain.product(); + let variables_set = BTreeSet::from_iter(self.get_node_indices()); + let mut amalgamated_cim: Array3 = Array::zeros((1, state_space, state_space)); + + for idx_current_state in 0..state_space { + let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); + let current_state_statetype: Vec = current_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + for idx_node in 0..self.nodes.len() { + let p = match self.get_node(idx_node) { + Params::DiscreteStatesContinousTime(p) => p, + }; + for next_node_state in 0..variables_domain[idx_node] { + let mut next_state = current_state.clone(); + next_state[idx_node] = next_node_state; + + let next_state_statetype: Vec = next_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + let idx_next_state = self.get_param_index_from_custom_parent_set( + &next_state_statetype, + &variables_set, + ); + amalgamated_cim[[0, idx_current_state, idx_next_state]] += + p.get_cim().as_ref().unwrap()[[ + self.get_param_index_network(idx_node, ¤t_state_statetype), + current_state[idx_node], + next_node_state, + ]]; + } + } + } + + let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( + "ctmp".to_string(), + BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), + ); + + println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); + + amalgamated_param.set_cim(amalgamated_cim).unwrap(); + + let mut ctmp = CtmpProcess::new(); + + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap(); + return ctmp; + } + + pub fn idx_to_state(variables_domain: &Array1, state: usize) -> Array1 { + let mut state = state; + let mut array_state = Array1::zeros(variables_domain.shape()[0]); + for (idx, var) in variables_domain.indexed_iter() { + array_state[idx] = state % var; + state = state / var; + } + + return array_state; + } } impl process::NetworkProcess for CtbnNetwork { diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 0ad0fc4..fc17a94 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,7 +1,10 @@ mod utils; use std::collections::BTreeSet; +use std::f64::EPSILON; -use reCTBN::process::ctbn::*; +use approx::AbsDiffEq; +use ndarray::arr3; +use reCTBN::process::{ctbn::*, ctmp::*}; use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; @@ -129,3 +132,34 @@ fn compute_index_from_custom_parent_set() { ); assert_eq!(2, idx); } + +#[test] +fn simple_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + net.initialize_adj_matrix(); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); + } + } + + let ctmp = net.amalgamation(); + let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){ + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + + assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); +} From 3a0151a9f62a85da46d916af0f35c5309b5be9e5 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 16 Nov 2022 12:44:06 +0100 Subject: [PATCH 077/126] Added test for F-test call function --- reCTBN/tests/structure_learning.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a8cf3c6..5c1db80 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -477,3 +477,23 @@ pub fn chi_square_call() { separation_set.insert(N1); assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); } + +#[test] +pub fn f_call() { + + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let N3: usize = 2; + let N2: usize = 1; + let N1: usize = 0; + let mut separation_set = BTreeSet::new(); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let mut cache = Cache::new(parameter_learning, data); + let f = F::new(0.000001); + + + assert!(f.call(&net, N1, N3, &separation_set, &mut cache)); + assert!(!f.call(&net, N3, N1, &separation_set, &mut cache)); + assert!(!f.call(&net, N3, N2, &separation_set, &mut cache)); + separation_set.insert(N1); + assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); +} From 4a7a8c5fbab5c0addcd2785a2f585c6c32eb4637 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 14:06:30 +0100 Subject: [PATCH 078/126] Added more tests --- reCTBN/src/params.rs | 4 +- reCTBN/src/process/ctbn.rs | 10 +- reCTBN/tests/ctbn.rs | 228 ++++++++++++++++++++++++++++++++++++- 3 files changed, 234 insertions(+), 8 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 070c997..9f63860 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,11 +267,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } + let domain_size = domain_size as f64; + // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) + .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 3852c50..a6be923 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -70,7 +70,12 @@ impl CtbnNetwork { nodes: Vec::new(), } } - + + ///Transform the **CTBN** into a **CTMP** + /// + /// # Return + /// + /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { for v in self.nodes.iter() { match v { @@ -123,8 +128,7 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); - + println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index fc17a94..a7752f2 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -4,9 +4,9 @@ use std::f64::EPSILON; use approx::AbsDiffEq; use ndarray::arr3; -use reCTBN::process::{ctbn::*, ctmp::*}; -use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; +use reCTBN::process::NetworkProcess; +use reCTBN::process::{ctbn::*, ctmp::*}; use utils::generate_discrete_time_continous_node; #[test] @@ -149,7 +149,7 @@ fn simple_amalgamation() { } let ctmp = net.amalgamation(); - let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){ + let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0) { p.get_cim().as_ref().unwrap() } else { unreachable!(); @@ -160,6 +160,226 @@ fn simple_amalgamation() { unreachable!(); }; - assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); } + +#[test] +fn chain_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + let ctmp = net.amalgamation(); + + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + let p_ctmp_handmade = arr3(&[[ + [ + -1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} + +#[test] +fn chainfork_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + let n4 = net + .add_node(generate_discrete_time_continous_node(String::from("n4"), 2)) + .unwrap(); + + net.add_edge(n1, n3); + net.add_edge(n2, n3); + net.add_edge(n3, n4); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n4) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + + let ctmp = net.amalgamation(); + + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + let p_ctmp_handmade = arr3(&[[ + [ + -2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} From 28ed1a40b32bb1a6629e5a67b379ed0b56c89861 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 14:52:39 +0100 Subject: [PATCH 079/126] Fix for clippy --- reCTBN/src/lib.rs | 2 +- reCTBN/src/process/ctbn.rs | 13 ++-- reCTBN/src/process/ctmp.rs | 60 +++++++++++-------- reCTBN/src/sampling.rs | 2 +- .../src/structure_learning/hypothesis_test.rs | 2 +- .../src/structure_learning/score_function.rs | 2 +- reCTBN/src/tools.rs | 2 +- reCTBN/tests/ctbn.rs | 4 +- reCTBN/tests/ctmp.rs | 6 +- 9 files changed, 52 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 6ab59cb..c62c42e 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -5,7 +5,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; +pub mod process; pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod process; diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index a6be923..7cb327d 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -70,7 +70,7 @@ impl CtbnNetwork { nodes: Vec::new(), } } - + ///Transform the **CTBN** into a **CTMP** /// /// # Return @@ -105,10 +105,8 @@ impl CtbnNetwork { let mut next_state = current_state.clone(); next_state[idx_node] = next_node_state; - let next_state_statetype: Vec = next_state - .iter() - .map(|x| StateType::Discrete(*x)) - .collect(); + let next_state_statetype: Vec = + next_state.iter().map(|x| StateType::Discrete(*x)).collect(); let idx_next_state = self.get_param_index_from_custom_parent_set( &next_state_statetype, &variables_set, @@ -127,13 +125,14 @@ impl CtbnNetwork { "ctmp".to_string(), BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - + println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); - ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap(); + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) + .unwrap(); return ctmp; } diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index b0b042a..81509fa 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -1,11 +1,14 @@ use std::collections::BTreeSet; -use crate::{process, params::{Params, StateType}}; +use crate::{ + params::{Params, StateType}, + process, +}; use super::NetworkProcess; pub struct CtmpProcess { - param: Option + param: Option, } impl CtmpProcess { @@ -24,26 +27,28 @@ impl NetworkProcess for CtmpProcess { None => { self.param = Some(n); Ok(0) - }, - Some(_) => Err(process::NetworkError::NodeInsertionError("CtmpProcess has only one node".to_string())) + } + Some(_) => Err(process::NetworkError::NodeInsertionError( + "CtmpProcess has only one node".to_string(), + )), } } - fn add_edge(&mut self, parent: usize, child: usize) { + fn add_edge(&mut self, _parent: usize, _child: usize) { unimplemented!("CtmpProcess has only one node") } fn get_node_indices(&self) -> std::ops::Range { match self.param { None => 0..0, - Some(_) => 0..1 + Some(_) => 0..1, } } fn get_number_of_nodes(&self) -> usize { match self.param { None => 0, - Some(_) => 1 + Some(_) => 1, } } @@ -63,11 +68,14 @@ impl NetworkProcess for CtmpProcess { } } - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize { + fn get_param_index_network( + &self, + node: usize, + current_state: &Vec, + ) -> usize { if node == 0 { match current_state[0] { - StateType::Discrete(x) => x + StateType::Discrete(x) => x, } } else { unimplemented!("CtmpProcess has only one node") @@ -76,31 +84,35 @@ impl NetworkProcess for CtmpProcess { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, - parent_set: &std::collections::BTreeSet, + _current_state: &Vec, + _parent_set: &std::collections::BTreeSet, ) -> usize { unimplemented!("CtmpProcess has only one node") } fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { match self.param { - Some(_) => if node == 0 { - BTreeSet::new() - } else { - unimplemented!("CtmpProcess has only one node") - }, - None => panic!("Uninitialized CtmpProcess") + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), } } fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { match self.param { - Some(_) => if node == 0 { - BTreeSet::new() - } else { - unimplemented!("CtmpProcess has only one node") - }, - None => panic!("Uninitialized CtmpProcess") + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), } } } diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 050daeb..0662994 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,8 +1,8 @@ //! Module containing methods for the sampling. use crate::{ - process::NetworkProcess, params::{self, ParamsTrait}, + process::NetworkProcess, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4ec3377..344c995 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::params::*; -use crate::{process, parameter_learning}; +use crate::{parameter_learning, process}; pub trait HypothesisTest { fn call( diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index 8943478..f8b38b5 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{process, parameter_learning, params, tools}; +use crate::{parameter_learning, params, process, tools}; pub trait ScoreFunction { fn call( diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 6f2f648..2e727e8 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{process, params}; +use crate::{params, process}; pub struct Trajectory { time: Array1, diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index a7752f2..7db2bae 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,12 +1,12 @@ mod utils; use std::collections::BTreeSet; -use std::f64::EPSILON; + use approx::AbsDiffEq; use ndarray::arr3; use reCTBN::params::{self, ParamsTrait}; use reCTBN::process::NetworkProcess; -use reCTBN::process::{ctbn::*, ctmp::*}; +use reCTBN::process::{ctbn::*}; use utils::generate_discrete_time_continous_node; #[test] diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs index 31bc6df..830bfe0 100644 --- a/reCTBN/tests/ctmp.rs +++ b/reCTBN/tests/ctmp.rs @@ -45,7 +45,7 @@ fn add_edge_to_ctmp() { let _n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); net.add_edge(0, 1) } @@ -64,7 +64,7 @@ fn childen_and_parents() { #[test] #[should_panic] fn get_childen_panic() { - let mut net = CtmpProcess::new(); + let net = CtmpProcess::new(); net.get_children_set(0); } @@ -81,7 +81,7 @@ fn get_childen_panic2() { #[test] #[should_panic] fn get_parent_panic() { - let mut net = CtmpProcess::new(); + let net = CtmpProcess::new(); net.get_parent_set(0); } From 44eaf8713fb9ce8f7bb05cc63af4bc625438e983 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 15:16:33 +0100 Subject: [PATCH 080/126] Fix for clippy --- reCTBN/src/process/ctbn.rs | 6 ------ reCTBN/tests/ctbn.rs | 31 +++++++++++-------------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 7cb327d..7473d4c 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -77,12 +77,6 @@ impl CtbnNetwork { /// /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { - for v in self.nodes.iter() { - match v { - Params::DiscreteStatesContinousTime(_) => {} - _ => panic!("Unsupported node"), - } - } let variables_domain = Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 7db2bae..3eb40d7 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -149,16 +149,10 @@ fn simple_amalgamation() { } let ctmp = net.amalgamation(); - let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0); + let p_ctbn = p_ctbn.get_cim().as_ref().unwrap(); + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); } @@ -211,11 +205,10 @@ fn chain_amalgamation() { let ctmp = net.amalgamation(); - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + + + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); let p_ctmp_handmade = arr3(&[[ [ @@ -308,11 +301,9 @@ fn chainfork_amalgamation() { let ctmp = net.amalgamation(); - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); let p_ctmp_handmade = arr3(&[[ [ From 38e744e034e52e277bab2c3c7052f9e796862d81 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 15:25:27 +0100 Subject: [PATCH 081/126] Fix fmt --- reCTBN/src/process/ctbn.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 7473d4c..c949afe 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -77,7 +77,6 @@ impl CtbnNetwork { /// /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { - let variables_domain = Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); From 1878f687d6198a16b56f618ea6e3945ef1703ee5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 21 Nov 2022 16:34:39 +0100 Subject: [PATCH 082/126] Refactor of sampling --- reCTBN/src/sampling.rs | 13 ++++++++++--- reCTBN/src/tools.rs | 12 ++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 0662994..3bc0c6f 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,10 +7,17 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -pub trait Sampler: Iterator { +pub struct Sample { + pub t: f64, + pub state: Vec +} + +pub trait Sampler: Iterator { fn reset(&mut self); } + + pub struct ForwardSampler<'a, T> where T: NetworkProcess, @@ -43,7 +50,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { } impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { - type Item = (f64, Vec); + type Item = Sample; fn next(&mut self) -> Option { let ret_time = self.current_time.clone(); @@ -96,7 +103,7 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) + Some(Sample{t: ret_time, state: ret_state}) } } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 2e727e8..e749d69 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -72,15 +72,15 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); //Current Time and Current State - let (mut t, mut current_state) = sampler.next().unwrap(); + let mut sample = sampler.next().unwrap(); //Generate new samples until ending time is reached. - while t < t_end { - time.push(t); - events.push(current_state); - (t, current_state) = sampler.next().unwrap(); + while sample.t < t_end { + time.push(sample.t); + events.push(sample.state); + sample = sampler.next().unwrap(); } - current_state = events.last().unwrap().clone(); + let current_state = events.last().unwrap().clone(); events.push(current_state); //Add t_end as last time. From 055eb7088e8bf5d139312e55806101a2738cac73 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 21 Nov 2022 17:34:32 +0100 Subject: [PATCH 083/126] Implemented FactoredRewardFunction --- reCTBN/src/lib.rs | 1 + reCTBN/src/process/ctbn.rs | 1 - reCTBN/src/reward_function.rs | 80 +++++++++++++++++++++++++++++++++ reCTBN/src/sampling.rs | 1 + reCTBN/tests/reward_function.rs | 30 +++++++++++++ 5 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 reCTBN/src/reward_function.rs create mode 100644 reCTBN/tests/reward_function.rs diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index c62c42e..1d25552 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -9,3 +9,4 @@ pub mod process; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod reward_function; diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c949afe..0b6161c 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -119,7 +119,6 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs new file mode 100644 index 0000000..9ff09cc --- /dev/null +++ b/reCTBN/src/reward_function.rs @@ -0,0 +1,80 @@ +use crate::{process, sampling, params::{ParamsTrait, self}}; +use ndarray; + + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64 +} + +pub trait RewardFunction { + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward; + fn initialize_from_network_process(p: &T) -> Self; +} + + +pub struct FactoredRewardFunction { + transition_reward: Vec>, + instantaneous_reward: Vec> +} + +impl FactoredRewardFunction { + pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2 { + &self.transition_reward[node_idx] + } + + pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2 { + &mut self.transition_reward[node_idx] + } + + pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1 { + &self.instantaneous_reward[node_idx] + } + + pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { + &mut self.instantaneous_reward[node_idx] + } + + +} + +impl RewardFunction for FactoredRewardFunction { + + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward { + let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { + let x = match x {params::StateType::Discrete(x) => x}; + self.instantaneous_reward[idx][*x] + }).sum(); + if let Some(previous_state) = previous_state { + let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option { + let p = match p {params::StateType::Discrete(p) => p}; + let c = match c {params::StateType::Discrete(c) => c}; + if p != c { + Some(self.transition_reward[idx][[*p,*c]]) + } else { + None + } + }).unwrap_or(0.0); + Reward {transition_reward, instantaneous_reward} + } else { + Reward { transition_reward: 0.0, instantaneous_reward} + } + } + + fn initialize_from_network_process(p: &T) -> Self { + let mut transition_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; + for i in p.get_node_indices() { + //This works only for discrete nodes! + let size: usize = p.get_node(i).get_reserved_space_as_parent(); + instantaneous_reward.push(ndarray::Array1::zeros(size)); + transition_reward.push(ndarray::Array2::zeros((size, size))); + } + + FactoredRewardFunction { transition_reward, instantaneous_reward } + + } + +} + diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 3bc0c6f..d5a1dbe 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,6 +7,7 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; +#[derive(Clone)] pub struct Sample { pub t: f64, pub state: Vec diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs new file mode 100644 index 0000000..7f73e6c --- /dev/null +++ b/reCTBN/tests/reward_function.rs @@ -0,0 +1,30 @@ +mod utils; + +use ndarray::*; +use utils::generate_discrete_time_continous_node; +use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; + + +#[test] +fn simple_factored_reward_function() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); + + let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; + let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + + assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); +} From f6015acce99e41582d3902dc7342556e3fe4a115 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 08:53:29 +0100 Subject: [PATCH 084/126] Added tests --- reCTBN/tests/reward_function.rs | 90 ++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 7f73e6c..0c7fd9b 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -6,7 +6,7 @@ use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; #[test] -fn simple_factored_reward_function() { +fn simple_factored_reward_function_binary_node() { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) @@ -28,3 +28,91 @@ fn simple_factored_reward_function() { assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); } + + +#[test] +fn simple_factored_reward_function_ternary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; + let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + let s2 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2)]}; + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + + + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); +} + +#[test] +fn factored_reward_function_two_nodes() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + + rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); + + let s00 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]}; + let s01 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]}; + let s02 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]}; + + + let s10 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]}; + let s11 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]}; + let s12 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]}; + + assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + + + assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + + + assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + + + assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); +} From 68ef7ea7c3ad4f33849f1cdf84349939e2e4a6b7 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 09:30:59 +0100 Subject: [PATCH 085/126] Added comments --- reCTBN/src/lib.rs | 2 +- reCTBN/src/reward_function.rs | 120 ++++++++++++++++++++++++++-------- reCTBN/src/sampling.rs | 9 +-- 3 files changed, 98 insertions(+), 33 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 1d25552..8feddfb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -6,7 +6,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; pub mod process; +pub mod reward_function; pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod reward_function; diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs index 9ff09cc..eeddd85 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward_function.rs @@ -1,22 +1,62 @@ -use crate::{process, sampling, params::{ParamsTrait, self}}; +//! Module for dealing with reward functions + +use crate::{ + params::{self, ParamsTrait}, + process, sampling, +}; use ndarray; +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state #[derive(Debug, PartialEq)] pub struct Reward { pub transition_reward: f64, - pub instantaneous_reward: f64 + pub instantaneous_reward: f64, } +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + pub trait RewardFunction { - fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward; + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `sampling::Sample` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: sampling::Sample, + previous_state: Option, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` fn initialize_from_network_process(p: &T) -> Self; } +/// Reward function over a factored state space +/// +/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node +/// of the underling `NetworkProcess` +/// +/// # Arguments +/// +/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition +/// reward of a node pub struct FactoredRewardFunction { transition_reward: Vec>, - instantaneous_reward: Vec> + instantaneous_reward: Vec>, } impl FactoredRewardFunction { @@ -35,36 +75,60 @@ impl FactoredRewardFunction { pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { &mut self.instantaneous_reward[node_idx] } - - } impl RewardFunction for FactoredRewardFunction { - - fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward { - let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { - let x = match x {params::StateType::Discrete(x) => x}; - self.instantaneous_reward[idx][*x] - }).sum(); + fn call( + &self, + current_state: sampling::Sample, + previous_state: Option, + ) -> Reward { + let instantaneous_reward: f64 = current_state + .state + .iter() + .enumerate() + .map(|(idx, x)| { + let x = match x { + params::StateType::Discrete(x) => x, + }; + self.instantaneous_reward[idx][*x] + }) + .sum(); if let Some(previous_state) = previous_state { - let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option { - let p = match p {params::StateType::Discrete(p) => p}; - let c = match c {params::StateType::Discrete(c) => c}; - if p != c { - Some(self.transition_reward[idx][[*p,*c]]) - } else { - None - } - }).unwrap_or(0.0); - Reward {transition_reward, instantaneous_reward} + let transition_reward = previous_state + .state + .iter() + .zip(current_state.state.iter()) + .enumerate() + .find_map(|(idx, (p, c))| -> Option { + let p = match p { + params::StateType::Discrete(p) => p, + }; + let c = match c { + params::StateType::Discrete(c) => c, + }; + if p != c { + Some(self.transition_reward[idx][[*p, *c]]) + } else { + None + } + }) + .unwrap_or(0.0); + Reward { + transition_reward, + instantaneous_reward, + } } else { - Reward { transition_reward: 0.0, instantaneous_reward} + Reward { + transition_reward: 0.0, + instantaneous_reward, + } } } fn initialize_from_network_process(p: &T) -> Self { let mut transition_reward: Vec> = vec![]; - let mut instantaneous_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; for i in p.get_node_indices() { //This works only for discrete nodes! let size: usize = p.get_node(i).get_reserved_space_as_parent(); @@ -72,9 +136,9 @@ impl RewardFunction for FactoredRewardFunction { transition_reward.push(ndarray::Array2::zeros((size, size))); } - FactoredRewardFunction { transition_reward, instantaneous_reward } - + FactoredRewardFunction { + transition_reward, + instantaneous_reward, + } } - } - diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d5a1dbe..a0a9fcb 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -10,15 +10,13 @@ use rand_chacha::ChaCha8Rng; #[derive(Clone)] pub struct Sample { pub t: f64, - pub state: Vec + pub state: Vec, } pub trait Sampler: Iterator { fn reset(&mut self); } - - pub struct ForwardSampler<'a, T> where T: NetworkProcess, @@ -104,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some(Sample{t: ret_time, state: ret_state}) + Some(Sample { + t: ret_time, + state: ret_state, + }) } } From bcb64a161ad49204eea20142b7f803e06e72becb Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 10:02:21 +0100 Subject: [PATCH 086/126] Mini refactor. Introduced the type alias NetworkProcessState. --- reCTBN/src/process.rs | 6 ++++-- reCTBN/src/process/ctbn.rs | 10 +++++----- reCTBN/src/process/ctmp.rs | 12 ++++-------- reCTBN/src/reward_function.rs | 16 +++++++--------- reCTBN/src/sampling.rs | 8 ++++---- reCTBN/src/tools.rs | 2 +- reCTBN/tests/reward_function.rs | 24 ++++++++++++------------ 7 files changed, 37 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index 2b70b59..dc297bc 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -16,6 +16,9 @@ pub enum NetworkError { NodeInsertionError(String), } +/// This type is used to represent a specific realization of a generic NetworkProcess +pub type NetworkProcessState = Vec; + /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). pub trait NetworkProcess { @@ -71,8 +74,7 @@ pub trait NetworkProcess { /// # Return /// /// * Index of the `node` relative to the network. - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize; + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; /// Compute the index that must be used to access the parameters of a `node`, given a specific /// configuration of the network and a generic `parent_set`. diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 0b6161c..162345e 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -8,7 +8,7 @@ use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, Stat use crate::process; use super::ctmp::CtmpProcess; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; /// It represents both the structure and the parameters of a CTBN. /// @@ -86,7 +86,7 @@ impl CtbnNetwork { for idx_current_state in 0..state_space { let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); - let current_state_statetype: Vec = current_state + let current_state_statetype: NetworkProcessState = current_state .iter() .map(|x| StateType::Discrete(*x)) .collect(); @@ -98,7 +98,7 @@ impl CtbnNetwork { let mut next_state = current_state.clone(); next_state[idx_node] = next_node_state; - let next_state_statetype: Vec = + let next_state_statetype: NetworkProcessState = next_state.iter().map(|x| StateType::Discrete(*x)).collect(); let idx_next_state = self.get_param_index_from_custom_parent_set( &next_state_statetype, @@ -185,7 +185,7 @@ impl process::NetworkProcess for CtbnNetwork { &mut self.nodes[node_idx] } - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { self.adj_matrix .as_ref() .unwrap() @@ -204,7 +204,7 @@ impl process::NetworkProcess for CtbnNetwork { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, + current_state: &NetworkProcessState, parent_set: &BTreeSet, ) -> usize { parent_set diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 81509fa..41b8db6 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -5,7 +5,7 @@ use crate::{ process, }; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; pub struct CtmpProcess { param: Option, @@ -68,11 +68,7 @@ impl NetworkProcess for CtmpProcess { } } - fn get_param_index_network( - &self, - node: usize, - current_state: &Vec, - ) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { if node == 0 { match current_state[0] { StateType::Discrete(x) => x, @@ -84,8 +80,8 @@ impl NetworkProcess for CtmpProcess { fn get_param_index_from_custom_parent_set( &self, - _current_state: &Vec, - _parent_set: &std::collections::BTreeSet, + _current_state: &NetworkProcessState, + _parent_set: &BTreeSet, ) -> usize { unimplemented!("CtmpProcess has only one node") } diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs index eeddd85..35e15c8 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward_function.rs @@ -2,7 +2,7 @@ use crate::{ params::{self, ParamsTrait}, - process, sampling, + process, }; use ndarray; @@ -27,13 +27,13 @@ pub trait RewardFunction { /// /// # Arguments /// - /// * `current_state`: the current state of the network represented as a `sampling::Sample` + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` /// * `previous_state`: an optional argument representing the previous state of the network fn call( &self, - current_state: sampling::Sample, - previous_state: Option, + current_state: process::NetworkProcessState, + previous_state: Option, ) -> Reward; /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess @@ -80,11 +80,10 @@ impl FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction { fn call( &self, - current_state: sampling::Sample, - previous_state: Option, + current_state: process::NetworkProcessState, + previous_state: Option, ) -> Reward { let instantaneous_reward: f64 = current_state - .state .iter() .enumerate() .map(|(idx, x)| { @@ -96,9 +95,8 @@ impl RewardFunction for FactoredRewardFunction { .sum(); if let Some(previous_state) = previous_state { let transition_reward = previous_state - .state .iter() - .zip(current_state.state.iter()) + .zip(current_state.iter()) .enumerate() .find_map(|(idx, (p, c))| -> Option { let p = match p { diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index a0a9fcb..1384872 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,8 +1,8 @@ //! Module containing methods for the sampling. use crate::{ - params::{self, ParamsTrait}, - process::NetworkProcess, + params::ParamsTrait, + process::{NetworkProcess, NetworkProcessState}, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; @@ -10,7 +10,7 @@ use rand_chacha::ChaCha8Rng; #[derive(Clone)] pub struct Sample { pub t: f64, - pub state: Vec, + pub state: NetworkProcessState, } pub trait Sampler: Iterator { @@ -24,7 +24,7 @@ where net: &'a T, rng: ChaCha8Rng, current_time: f64, - current_state: Vec, + current_state: NetworkProcessState, next_transitions: Vec>, } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index e749d69..ecfeff9 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -69,7 +69,7 @@ pub fn trajectory_generator( let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut events: Vec> = Vec::new(); + let mut events: Vec = Vec::new(); //Current Time and Current State let mut sample = sampler.next().unwrap(); diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 0c7fd9b..dcc5e69 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -2,7 +2,7 @@ mod utils; use ndarray::*; use utils::generate_discrete_time_continous_node; -use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; #[test] @@ -16,8 +16,8 @@ fn simple_factored_reward_function_binary_node() { rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); - let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; - let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); @@ -41,9 +41,9 @@ fn simple_factored_reward_function_ternary_node() { rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); - let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; - let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; - let s2 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2)]}; + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); @@ -78,14 +78,14 @@ fn factored_reward_function_two_nodes() { rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); - let s00 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]}; - let s01 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]}; - let s02 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]}; + let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; + let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; + let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; - let s10 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]}; - let s11 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]}; - let s12 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]}; + let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; + let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; + let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); From cac19b17565e06bb192298d7fc715c8c7701897b Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 25 Nov 2022 10:19:05 +0100 Subject: [PATCH 087/126] Aligned F-test to the new changes --- reCTBN/src/structure_learning/hypothesis_test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index aa37cfa..dd3bbf7 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -87,7 +87,7 @@ impl HypothesisTest for F { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { From 6e90458418c8d3d6b9c9107faa8ec4ef77a238e6 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 28 Nov 2022 11:08:20 +0100 Subject: [PATCH 088/126] Laying grounds for CTPC --- .../constraint_based_algorithm.rs | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 670c8ed..d931f78 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,5 +1,39 @@ //! Module containing constraint based algorithms like CTPC and Hiton. -//pub struct CTPC { -// -//} +use super::hypothesis_test::*; +use crate::structure_learning::StructureLearningAlgorithm; +use crate::{process, tools}; +use crate::parameter_learning::{Cache, ParameterLearning}; + +pub struct CTPC { + Ftest: F, + Chi2test: ChiSquare, + cache: Cache

, + +} + +impl CTPC

{ + pub fn new(Ftest: F, Chi2test: ChiSquare, cache: Cache

) -> CTPC

{ + CTPC { + Chi2test, + Ftest, + cache, + } + } +} + +impl StructureLearningAlgorithm for CTPC

{ + fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + where + T: process::NetworkProcess, + { + //Check the coherence between dataset and network + if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { + panic!("Dataset and Network must have the same number of variables.") + } + + //Make the network mutable. + let mut net = net; + net + } +} From 4fc5c1d4b56a0ecd643daa0c76080dc3354e88ce Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 28 Nov 2022 13:31:37 +0100 Subject: [PATCH 089/126] Refactored reward module --- reCTBN/src/lib.rs | 2 +- reCTBN/src/reward.rs | 41 ++++++++++++++++++++++ reCTBN/src/{ => reward}/reward_function.rs | 40 ++------------------- reCTBN/tests/reward_function.rs | 2 +- 4 files changed, 45 insertions(+), 40 deletions(-) create mode 100644 reCTBN/src/reward.rs rename reCTBN/src/{ => reward}/reward_function.rs (72%) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 8feddfb..1997fa6 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -6,7 +6,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; pub mod process; -pub mod reward_function; +pub mod reward; pub mod sampling; pub mod structure_learning; pub mod tools; diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs new file mode 100644 index 0000000..114ba03 --- /dev/null +++ b/reCTBN/src/reward.rs @@ -0,0 +1,41 @@ +pub mod reward_function; + +use crate::process; + +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64, +} + +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + +pub trait RewardFunction { + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: process::NetworkProcessState, + previous_state: Option, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` + fn initialize_from_network_process(p: &T) -> Self; +} diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward/reward_function.rs similarity index 72% rename from reCTBN/src/reward_function.rs rename to reCTBN/src/reward/reward_function.rs index 35e15c8..ae94ff1 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward/reward_function.rs @@ -3,46 +3,10 @@ use crate::{ params::{self, ParamsTrait}, process, + reward::{Reward, RewardFunction}, }; -use ndarray; - -/// Instantiation of reward function and instantaneous reward -/// -/// -/// # Arguments -/// -/// * `transition_reward`: reward obtained transitioning from one state to another -/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state - -#[derive(Debug, PartialEq)] -pub struct Reward { - pub transition_reward: f64, - pub instantaneous_reward: f64, -} - -/// The trait RewardFunction describe the methods that all the reward functions must satisfy -pub trait RewardFunction { - /// Given the current state and the previous state, it compute the reward. - /// - /// # Arguments - /// - /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` - /// * `previous_state`: an optional argument representing the previous state of the network - - fn call( - &self, - current_state: process::NetworkProcessState, - previous_state: Option, - ) -> Reward; - - /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess - /// - /// # Arguments - /// - /// * `p`: any structure that implements the trait `process::NetworkProcess` - fn initialize_from_network_process(p: &T) -> Self; -} +use ndarray; /// Reward function over a factored state space /// diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index dcc5e69..03f2ab7 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -2,7 +2,7 @@ mod utils; use ndarray::*; use utils::generate_discrete_time_continous_node; -use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params}; #[test] From cecf16a771d6fd53ec7ad1b909c7310c9a8368d8 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Nov 2022 09:43:12 +0100 Subject: [PATCH 090/126] Added sigle state evaluation --- reCTBN/src/reward.rs | 20 ++++- reCTBN/src/reward/reward_evaluation.rs | 71 ++++++++++++++++ reCTBN/src/reward/reward_function.rs | 4 +- reCTBN/src/sampling.rs | 27 +++++-- reCTBN/src/tools.rs | 3 +- reCTBN/tests/reward_evaluation.rs | 107 +++++++++++++++++++++++++ reCTBN/tests/reward_function.rs | 61 +++++++------- 7 files changed, 248 insertions(+), 45 deletions(-) create mode 100644 reCTBN/src/reward/reward_evaluation.rs create mode 100644 reCTBN/tests/reward_evaluation.rs diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs index 114ba03..1ea575c 100644 --- a/reCTBN/src/reward.rs +++ b/reCTBN/src/reward.rs @@ -1,6 +1,8 @@ pub mod reward_function; +pub mod reward_evaluation; use crate::process; +use ndarray; /// Instantiation of reward function and instantaneous reward /// @@ -28,8 +30,8 @@ pub trait RewardFunction { fn call( &self, - current_state: process::NetworkProcessState, - previous_state: Option, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, ) -> Reward; /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess @@ -39,3 +41,17 @@ pub trait RewardFunction { /// * `p`: any structure that implements the trait `process::NetworkProcess` fn initialize_from_network_process(p: &T) -> Self; } + +pub trait RewardEvaluation { + fn call( + &self, + network_process: &N, + reward_function: &R, + ) -> ndarray::Array1; + fn call_state( + &self, + network_process: &N, + reward_function: &R, + state: &process::NetworkProcessState, + ) -> f64; +} diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs new file mode 100644 index 0000000..fca7c1a --- /dev/null +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -0,0 +1,71 @@ +use crate::{ + reward::RewardEvaluation, + sampling::{ForwardSampler, Sampler}, + process::NetworkProcessState +}; + +pub struct MonteCarloDiscountedRward { + n_iterations: usize, + end_time: f64, + discount_factor: f64, + seed: Option, +} + +impl MonteCarloDiscountedRward { + pub fn new( + n_iterations: usize, + end_time: f64, + discount_factor: f64, + seed: Option, + ) -> MonteCarloDiscountedRward { + MonteCarloDiscountedRward { + n_iterations, + end_time, + discount_factor, + seed, + } + } +} + +impl RewardEvaluation for MonteCarloDiscountedRward { + fn call( + &self, + network_process: &N, + reward_function: &R, + ) -> ndarray::Array1 { + todo!() + } + + fn call_state( + &self, + network_process: &N, + reward_function: &R, + state: &NetworkProcessState, + ) -> f64 { + let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); + let mut ret = 0.0; + + for _i in 0..self.n_iterations { + sampler.reset(); + let mut previous = sampler.next().unwrap(); + while previous.t < self.end_time { + let current = sampler.next().unwrap(); + if current.t > self.end_time { + let r = reward_function.call(&previous.state, None); + let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) + - std::f64::consts::E.powf(-self.discount_factor * self.end_time); + ret += discount * r.instantaneous_reward; + } else { + let r = reward_function.call(&previous.state, Some(¤t.state)); + let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) + - std::f64::consts::E.powf(-self.discount_factor * current.t); + ret += discount * r.instantaneous_reward; + ret += std::f64::consts::E.powf(-self.discount_factor * current.t) * r.transition_reward; + } + previous = current; + } + } + + ret / self.n_iterations as f64 + } +} diff --git a/reCTBN/src/reward/reward_function.rs b/reCTBN/src/reward/reward_function.rs index ae94ff1..216df6a 100644 --- a/reCTBN/src/reward/reward_function.rs +++ b/reCTBN/src/reward/reward_function.rs @@ -44,8 +44,8 @@ impl FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction { fn call( &self, - current_state: process::NetworkProcessState, - previous_state: Option, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, ) -> Reward { let instantaneous_reward: f64 = current_state .iter() diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 1384872..73c6d78 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -26,10 +26,15 @@ where current_time: f64, current_state: NetworkProcessState, next_transitions: Vec>, + initial_state: Option, } impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { - pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { + pub fn new( + net: &'a T, + seed: Option, + initial_state: Option, + ) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. Some(seed) => SeedableRng::seed_from_u64(seed), @@ -37,11 +42,12 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { None => SeedableRng::from_entropy(), }; let mut fs = ForwardSampler { - net: net, - rng: rng, + net, + rng, current_time: 0.0, current_state: vec![], next_transitions: vec![], + initial_state, }; fs.reset(); return fs; @@ -112,11 +118,16 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; - self.current_state = self - .net - .get_node_indices() - .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) - .collect(); + match &self.initial_state { + None => { + self.current_state = self + .net + .get_node_indices() + .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) + .collect() + } + Some(is) => self.current_state = is.clone(), + }; self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); } } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index ecfeff9..38ebd49 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -61,8 +61,7 @@ pub fn trajectory_generator( let mut trajectories: Vec = Vec::new(); //Random Generator object - - let mut sampler = ForwardSampler::new(net, seed); + let mut sampler = ForwardSampler::new(net, seed, None); //Each iteration generate one trajectory for _ in 0..n_trajectories { //History of all the moments in which something changed diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs new file mode 100644 index 0000000..f3938e7 --- /dev/null +++ b/reCTBN/tests/reward_evaluation.rs @@ -0,0 +1,107 @@ +mod utils; + +use approx::{abs_diff_eq, assert_abs_diff_eq}; +use ndarray::*; +use reCTBN::{ + params, + process::{ctbn::*, NetworkProcess, NetworkProcessState}, + reward::{reward_evaluation::*, reward_function::*, *}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn simple_factored_reward_function_binary_node_MC() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1) + .assign(&arr1(&[3.0, 3.0])); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); + } + } + + net.initialize_adj_matrix(); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + + let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); + assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s1), epsilon = 1e-2); +} + +#[test] +fn simple_factored_reward_function_chain_MC() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n2) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n3) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + let s000: NetworkProcessState = vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(0), + params::StateType::Discrete(0), + ]; + + let mc = MonteCarloDiscountedRward::new(10000, 100.0, 1.0, Some(215)); + assert_abs_diff_eq!(2.447, mc.call_state(&net, &rf, &s000), epsilon = 1e-1); +} diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 03f2ab7..853efc9 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -18,15 +18,15 @@ fn simple_factored_reward_function_binary_node() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); } @@ -46,16 +46,16 @@ fn simple_factored_reward_function_ternary_node() { let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; - assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); - assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); } #[test] @@ -77,7 +77,6 @@ fn factored_reward_function_two_nodes() { rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); - let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; @@ -87,32 +86,32 @@ fn factored_reward_function_two_nodes() { let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; - assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); - assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); - assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); } From bb239aaa0c9b873ac6172a5b92fa85e159b98e16 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 1 Dec 2022 08:15:30 +0100 Subject: [PATCH 091/126] Implemented reward_evaluation for an entire process. --- reCTBN/src/params.rs | 2 +- reCTBN/src/reward.rs | 9 ++++-- reCTBN/src/reward/reward_evaluation.rs | 41 +++++++++++++++++++++----- reCTBN/tests/reward_evaluation.rs | 17 ++++++++--- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 9f63860..3d08273 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -20,7 +20,7 @@ pub enum ParamsError { } /// Allowed type of states -#[derive(Clone)] +#[derive(Clone, Hash, PartialEq, Eq, Debug)] pub enum StateType { Discrete(usize), } diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs index 1ea575c..b34db7f 100644 --- a/reCTBN/src/reward.rs +++ b/reCTBN/src/reward.rs @@ -1,6 +1,8 @@ pub mod reward_function; pub mod reward_evaluation; +use std::collections::HashMap; + use crate::process; use ndarray; @@ -43,12 +45,13 @@ pub trait RewardFunction { } pub trait RewardEvaluation { - fn call( + fn evaluate_state_space( &self, network_process: &N, reward_function: &R, - ) -> ndarray::Array1; - fn call_state( + ) -> HashMap; + + fn evaluate_state( &self, network_process: &N, reward_function: &R, diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index fca7c1a..67baa5e 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -1,7 +1,12 @@ +use std::collections::HashMap; + +use crate::params::{self, ParamsTrait}; +use crate::process; + use crate::{ + process::NetworkProcessState, reward::RewardEvaluation, sampling::{ForwardSampler, Sampler}, - process::NetworkProcessState }; pub struct MonteCarloDiscountedRward { @@ -28,21 +33,42 @@ impl MonteCarloDiscountedRward { } impl RewardEvaluation for MonteCarloDiscountedRward { - fn call( + fn evaluate_state_space( &self, network_process: &N, reward_function: &R, - ) -> ndarray::Array1 { - todo!() + ) -> HashMap { + let variables_domain: Vec> = network_process + .get_node_indices() + .map(|x| match network_process.get_node(x) { + params::Params::DiscreteStatesContinousTime(x) => + (0..x.get_reserved_space_as_parent()).map(|s| params::StateType::Discrete(s)).collect() + }).collect(); + + let n_states:usize = variables_domain.iter().map(|x| x.len()).product(); + + (0..n_states).map(|s| { + let state: process::NetworkProcessState = variables_domain.iter().fold((s, vec![]), |acc, x| { + let mut acc = acc; + let idx_s = acc.0%x.len(); + acc.1.push(x[idx_s].clone()); + acc.0 = acc.0 / x.len(); + acc + }).1; + + let r = self.evaluate_state(network_process, reward_function, &state); + (state, r) + }).collect() } - fn call_state( + fn evaluate_state( &self, network_process: &N, reward_function: &R, state: &NetworkProcessState, ) -> f64 { - let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); + let mut sampler = + ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); let mut ret = 0.0; for _i in 0..self.n_iterations { @@ -60,7 +86,8 @@ impl RewardEvaluation for MonteCarloDiscountedRward { let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) - std::f64::consts::E.powf(-self.discount_factor * current.t); ret += discount * r.instantaneous_reward; - ret += std::f64::consts::E.powf(-self.discount_factor * current.t) * r.transition_reward; + ret += std::f64::consts::E.powf(-self.discount_factor * current.t) + * r.transition_reward; } previous = current; } diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index f3938e7..1650507 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -34,8 +34,13 @@ fn simple_factored_reward_function_binary_node_MC() { let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); - assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s0), epsilon = 1e-2); - assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s1), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); + assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); + } #[test] @@ -102,6 +107,10 @@ fn simple_factored_reward_function_chain_MC() { params::StateType::Discrete(0), ]; - let mc = MonteCarloDiscountedRward::new(10000, 100.0, 1.0, Some(215)); - assert_abs_diff_eq!(2.447, mc.call_state(&net, &rf, &s000), epsilon = 1e-1); + let mc = MonteCarloDiscountedRward::new(1000, 10.0, 1.0, Some(215)); + assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); + } From 687f19ff1f7a3cec6abe2c8bd62429aab73ce284 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 1 Dec 2022 08:45:10 +0100 Subject: [PATCH 092/126] Added FiniteHorizon --- reCTBN/src/reward/reward_evaluation.rs | 87 +++++++++++++++++--------- reCTBN/tests/reward_evaluation.rs | 10 ++- 2 files changed, 65 insertions(+), 32 deletions(-) diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index 67baa5e..9dcb6d2 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -9,30 +9,35 @@ use crate::{ sampling::{ForwardSampler, Sampler}, }; -pub struct MonteCarloDiscountedRward { +pub enum RewardCriteria { + FiniteHorizon, + InfiniteHorizon {discount_factor: f64}, +} + +pub struct MonteCarloRward { n_iterations: usize, end_time: f64, - discount_factor: f64, + reward_criteria: RewardCriteria, seed: Option, } -impl MonteCarloDiscountedRward { +impl MonteCarloRward { pub fn new( n_iterations: usize, end_time: f64, - discount_factor: f64, + reward_criteria: RewardCriteria, seed: Option, - ) -> MonteCarloDiscountedRward { - MonteCarloDiscountedRward { + ) -> MonteCarloRward { + MonteCarloRward { n_iterations, end_time, - discount_factor, + reward_criteria, seed, } } } -impl RewardEvaluation for MonteCarloDiscountedRward { +impl RewardEvaluation for MonteCarloRward { fn evaluate_state_space( &self, network_process: &N, @@ -41,24 +46,32 @@ impl RewardEvaluation for MonteCarloDiscountedRward { let variables_domain: Vec> = network_process .get_node_indices() .map(|x| match network_process.get_node(x) { - params::Params::DiscreteStatesContinousTime(x) => - (0..x.get_reserved_space_as_parent()).map(|s| params::StateType::Discrete(s)).collect() - }).collect(); + params::Params::DiscreteStatesContinousTime(x) => (0..x + .get_reserved_space_as_parent()) + .map(|s| params::StateType::Discrete(s)) + .collect(), + }) + .collect(); + + let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); - let n_states:usize = variables_domain.iter().map(|x| x.len()).product(); - - (0..n_states).map(|s| { - let state: process::NetworkProcessState = variables_domain.iter().fold((s, vec![]), |acc, x| { - let mut acc = acc; - let idx_s = acc.0%x.len(); - acc.1.push(x[idx_s].clone()); - acc.0 = acc.0 / x.len(); - acc - }).1; + (0..n_states) + .map(|s| { + let state: process::NetworkProcessState = variables_domain + .iter() + .fold((s, vec![]), |acc, x| { + let mut acc = acc; + let idx_s = acc.0 % x.len(); + acc.1.push(x[idx_s].clone()); + acc.0 = acc.0 / x.len(); + acc + }) + .1; - let r = self.evaluate_state(network_process, reward_function, &state); - (state, r) - }).collect() + let r = self.evaluate_state(network_process, reward_function, &state); + (state, r) + }) + .collect() } fn evaluate_state( @@ -78,16 +91,30 @@ impl RewardEvaluation for MonteCarloDiscountedRward { let current = sampler.next().unwrap(); if current.t > self.end_time { let r = reward_function.call(&previous.state, None); - let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) - - std::f64::consts::E.powf(-self.discount_factor * self.end_time); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => self.end_time - previous.t, + RewardCriteria::InfiniteHorizon {discount_factor} => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * self.end_time) + } + }; ret += discount * r.instantaneous_reward; } else { let r = reward_function.call(&previous.state, Some(¤t.state)); - let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) - - std::f64::consts::E.powf(-self.discount_factor * current.t); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => current.t-previous.t, + RewardCriteria::InfiniteHorizon {discount_factor} => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * current.t) + } + }; ret += discount * r.instantaneous_reward; - ret += std::f64::consts::E.powf(-self.discount_factor * current.t) - * r.transition_reward; + ret += match self.reward_criteria { + RewardCriteria::FiniteHorizon => 1.0, + RewardCriteria::InfiniteHorizon {discount_factor} => { + std::f64::consts::E.powf(-discount_factor * current.t) + } + } * r.transition_reward; } previous = current; } diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index 1650507..b2cfd29 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -33,7 +33,7 @@ fn simple_factored_reward_function_binary_node_MC() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); + let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -41,6 +41,12 @@ fn simple_factored_reward_function_binary_node_MC() { assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); + + let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215)); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + } #[test] @@ -107,7 +113,7 @@ fn simple_factored_reward_function_chain_MC() { params::StateType::Discrete(0), ]; - let mc = MonteCarloDiscountedRward::new(1000, 10.0, 1.0, Some(215)); + let mc = MonteCarloRward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); let rst = mc.evaluate_state_space(&net, &rf); From 414aa3186711b36613bd8980c6860435fac12257 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 5 Dec 2022 09:21:37 +0100 Subject: [PATCH 093/126] Bugfix --- reCTBN/src/parameter_learning.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 2aa518c..3f505f9 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -144,6 +144,12 @@ impl ParameterLearning for BayesianApproach { .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); + + CIM.outer_iter_mut() + .for_each(|mut C| { + C.diag_mut().fill(0.0); + }); + //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); CIM.outer_iter_mut() From 9284ca5dd2facaa30a7439ca9e9f00d2778479a4 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 5 Dec 2022 15:24:32 +0100 Subject: [PATCH 094/126] Implemanted NeighborhoodRelativeReward --- reCTBN/src/reward/reward_evaluation.rs | 54 +++++++++++++++++++++++--- reCTBN/tests/reward_evaluation.rs | 6 +-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index 9dcb6d2..cb7b8f1 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -14,21 +14,21 @@ pub enum RewardCriteria { InfiniteHorizon {discount_factor: f64}, } -pub struct MonteCarloRward { +pub struct MonteCarloReward { n_iterations: usize, end_time: f64, reward_criteria: RewardCriteria, seed: Option, } -impl MonteCarloRward { +impl MonteCarloReward { pub fn new( n_iterations: usize, end_time: f64, reward_criteria: RewardCriteria, seed: Option, - ) -> MonteCarloRward { - MonteCarloRward { + ) -> MonteCarloReward { + MonteCarloReward { n_iterations, end_time, reward_criteria, @@ -37,7 +37,7 @@ impl MonteCarloRward { } } -impl RewardEvaluation for MonteCarloRward { +impl RewardEvaluation for MonteCarloReward { fn evaluate_state_space( &self, network_process: &N, @@ -123,3 +123,47 @@ impl RewardEvaluation for MonteCarloRward { ret / self.n_iterations as f64 } } + +pub struct NeighborhoodRelativeReward { + inner_reward: RE +} + +impl NeighborhoodRelativeReward{ + pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward{ + NeighborhoodRelativeReward {inner_reward} + } +} + +impl RewardEvaluation for NeighborhoodRelativeReward { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap { + + let absolute_reward = self.inner_reward.evaluate_state_space(network_process, reward_function); + + //This approach optimize memory. Maybe optimizing execution time can be better. + absolute_reward.iter().map(|(k1, v1)| { + let mut max_val:f64 = 1.0; + absolute_reward.iter().for_each(|(k2,v2)| { + let count_diff:usize = k1.iter().zip(k2.iter()).map(|(s1, s2)| if s1 == s2 {0} else {1}).sum(); + if count_diff < 2 { + max_val = max_val.max(v1/v2); + } + + }); + (k1.clone(), max_val) + }).collect() + + } + + fn evaluate_state( + &self, + _network_process: &N, + _reward_function: &R, + _state: &process::NetworkProcessState, + ) -> f64 { + unimplemented!(); + } +} diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index b2cfd29..63e9c98 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -33,7 +33,7 @@ fn simple_factored_reward_function_binary_node_MC() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + let mc = MonteCarloReward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -42,7 +42,7 @@ fn simple_factored_reward_function_binary_node_MC() { assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); - let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215)); + let mc = MonteCarloReward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215)); assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -113,7 +113,7 @@ fn simple_factored_reward_function_chain_MC() { params::StateType::Discrete(0), ]; - let mc = MonteCarloRward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + let mc = MonteCarloReward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); let rst = mc.evaluate_state_space(&net, &rf); From 6d952f8c0741faf827ac725e328f1b6e8b4d5b8a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 08:52:40 +0100 Subject: [PATCH 095/126] Added `itertools` a WIP version of CTPC and some hacky and temporary modifications --- reCTBN/Cargo.toml | 1 + reCTBN/src/structure_learning.rs | 2 +- .../constraint_based_algorithm.rs | 44 +++++++++++++++++-- .../score_based_algorithm.rs | 2 +- reCTBN/tests/structure_learning.rs | 8 ++-- 5 files changed, 48 insertions(+), 9 deletions(-) diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index b0a691b..fdac697 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -13,6 +13,7 @@ bimap = "~0.6" enum_dispatch = "~0.3" statrs = "~0.16" rand_chacha = "~0.3" +itertools = "~0.10" [dev-dependencies] approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index b272e22..d119ab2 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -7,7 +7,7 @@ pub mod score_function; use crate::{process, tools}; pub trait StructureLearningAlgorithm { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index d931f78..6fd5b79 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,15 +1,18 @@ //! Module containing constraint based algorithms like CTPC and Hiton. +use itertools::Itertools; +use std::collections::BTreeSet; +use std::usize; + use super::hypothesis_test::*; +use crate::parameter_learning::{Cache, ParameterLearning}; use crate::structure_learning::StructureLearningAlgorithm; use crate::{process, tools}; -use crate::parameter_learning::{Cache, ParameterLearning}; pub struct CTPC { Ftest: F, Chi2test: ChiSquare, cache: Cache

, - } impl CTPC

{ @@ -23,7 +26,7 @@ impl CTPC

{ } impl StructureLearningAlgorithm for CTPC

{ - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess, { @@ -34,6 +37,41 @@ impl StructureLearningAlgorithm for CTPC

{ //Make the network mutable. let mut net = net; + + net.initialize_adj_matrix(); + + for child_node in net.get_node_indices() { + let mut candidate_parent_set: BTreeSet = net + .get_node_indices() + .into_iter() + .filter(|x| x != &child_node) + .collect(); + let mut b = 0; + while b < candidate_parent_set.len() { + for parent_node in candidate_parent_set.iter() { + for separation_set in candidate_parent_set + .iter() + .filter(|x| x != &parent_node) + .map(|x| *x) + .combinations(b) + { + let separation_set = separation_set.into_iter().collect(); + if self.Ftest.call( + &net, + child_node, + *parent_node, + &separation_set, + &mut self.cache, + ) && self.Chi2test.call(&net, child_node, *parent_node, &separation_set, &mut self.cache) { + candidate_parent_set.remove(&parent_node); + break; + } + } + } + b = b + 1; + } + } + net } } diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 16e9056..d59f0c1 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -21,7 +21,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T where T: process::NetworkProcess, { diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 6134510..4bf9027 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -58,7 +58,7 @@ fn simple_bic() { ); } -fn check_compatibility_between_dataset_and_network(sl: T) { +fn check_compatibility_between_dataset_and_network(mut sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -125,7 +125,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes(sl: T) { +fn learn_ternary_net_2_nodes(mut sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -320,7 +320,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } -fn learn_mixed_discrete_net_3_nodes(sl: T) { +fn learn_mixed_discrete_net_3_nodes(mut sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -342,7 +342,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(mut sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); From 6d42d8a805c493aad5dcbdd63a4fad4bb638646c Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 12:54:08 +0100 Subject: [PATCH 096/126] Solved issue with `candidate_parent_set` variable in CTPC and added loop to fill the adjacency matrix --- .../constraint_based_algorithm.rs | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 6fd5b79..d94c793 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -48,6 +48,7 @@ impl StructureLearningAlgorithm for CTPC

{ .collect(); let mut b = 0; while b < candidate_parent_set.len() { + let mut not_parent_node: usize = child_node; for parent_node in candidate_parent_set.iter() { for separation_set in candidate_parent_set .iter() @@ -62,16 +63,30 @@ impl StructureLearningAlgorithm for CTPC

{ *parent_node, &separation_set, &mut self.cache, - ) && self.Chi2test.call(&net, child_node, *parent_node, &separation_set, &mut self.cache) { - candidate_parent_set.remove(&parent_node); + ) && self.Chi2test.call( + &net, + child_node, + *parent_node, + &separation_set, + &mut self.cache, + ) { + not_parent_node = parent_node.clone(); break; } } + if not_parent_node != child_node { + break; + } + } + if not_parent_node != child_node { + candidate_parent_set.remove(¬_parent_node); } b = b + 1; } + for parent_node in candidate_parent_set.iter() { + net.add_edge(*parent_node, child_node); + } } - net } } From 8d0f9db289b8453bed413d2ba70ec1693b0c376d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 17:21:53 +0100 Subject: [PATCH 097/126] WIP: Added tests for CTPC --- reCTBN/tests/structure_learning.rs | 79 ++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 4bf9027..c0deffd 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -10,6 +10,7 @@ use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; +use reCTBN::structure_learning::constraint_based_algorithm::*; use reCTBN::structure_learning::score_based_algorithm::*; use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::StructureLearningAlgorithm; @@ -497,3 +498,81 @@ pub fn f_call() { separation_set.insert(N1); assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); } + +#[test] +pub fn learn_ternary_net_2_nodes_ctpc() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let f = F::new(0.000001); + let chi_sq = ChiSquare::new(0.0001); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let cache = Cache::new(parameter_learning, data.clone()); + let mut ctpc = CTPC::new(f, chi_sq, cache); + + + let net = ctpc.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); + assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); +} + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc() { + let (_, data) = get_mixed_discrete_net_3_nodes_with_data(); + + let f = F::new(1e-24); + let chi_sq = ChiSquare::new(1e-24); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let cache = Cache::new(parameter_learning, data); + let ctpc = CTPC::new(f, chi_sq, cache); + + learn_mixed_discrete_net_3_nodes(ctpc); +} From 468ebf09cc330c2448b90ed27e090e1613336045 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 19 Dec 2022 17:24:23 +0100 Subject: [PATCH 098/126] WIP: Added `#[derive(Clone)]` to `Dataset` and `Trajectory` --- reCTBN/src/tools.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index ecfeff9..47a067d 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -5,6 +5,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; +#[derive(Clone)] pub struct Trajectory { time: Array1, events: Array2, @@ -29,6 +30,7 @@ impl Trajectory { } } +#[derive(Clone)] pub struct Dataset { trajectories: Vec, } From ea5df7cad6485742905f2a0297a95cfc6cf2f801 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 20 Dec 2022 12:36:28 +0100 Subject: [PATCH 099/126] Solved another issue with `candidate_parent_set` variable in CTPC --- .../constraint_based_algorithm.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index d94c793..8949aa5 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -46,15 +46,15 @@ impl StructureLearningAlgorithm for CTPC

{ .into_iter() .filter(|x| x != &child_node) .collect(); - let mut b = 0; - while b < candidate_parent_set.len() { - let mut not_parent_node: usize = child_node; + let mut separation_set_size = 0; + while separation_set_size < candidate_parent_set.len() { + let mut candidate_parent_set_TMP = candidate_parent_set.clone(); for parent_node in candidate_parent_set.iter() { for separation_set in candidate_parent_set .iter() .filter(|x| x != &parent_node) .map(|x| *x) - .combinations(b) + .combinations(separation_set_size) { let separation_set = separation_set.into_iter().collect(); if self.Ftest.call( @@ -70,18 +70,13 @@ impl StructureLearningAlgorithm for CTPC

{ &separation_set, &mut self.cache, ) { - not_parent_node = parent_node.clone(); + candidate_parent_set_TMP.remove(parent_node); break; } } - if not_parent_node != child_node { - break; - } - } - if not_parent_node != child_node { - candidate_parent_set.remove(¬_parent_node); } - b = b + 1; + candidate_parent_set = candidate_parent_set_TMP; + separation_set_size += 1; } for parent_node in candidate_parent_set.iter() { net.add_edge(*parent_node, child_node); From 19856195c39e5428e906fbb0b7ab0ecb8e9e6394 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 21 Dec 2022 11:41:26 +0100 Subject: [PATCH 100/126] Refactored cache laying grounds for its node-centered implementation changing also its signature, propagated this change and refactored CTPC tests --- reCTBN/src/parameter_learning.rs | 28 ++--- reCTBN/src/structure_learning.rs | 4 +- .../constraint_based_algorithm.rs | 20 ++-- .../src/structure_learning/hypothesis_test.rs | 23 +++- .../score_based_algorithm.rs | 4 +- reCTBN/tests/structure_learning.rs | 112 ++++-------------- 6 files changed, 72 insertions(+), 119 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 2aa518c..f8a7664 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,13 +5,13 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub trait ParameterLearning { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params; @@ -19,7 +19,7 @@ pub trait ParameterLearning { pub fn sufficient_statistics( net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: &BTreeSet, ) -> (Array3, Array2) { @@ -76,7 +76,7 @@ impl ParameterLearning for MLE { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { @@ -123,7 +123,7 @@ impl ParameterLearning for BayesianApproach { fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { @@ -165,25 +165,21 @@ impl ParameterLearning for BayesianApproach { } } -pub struct Cache { - parameter_learning: P, - dataset: tools::Dataset, +pub struct Cache<'a, P: ParameterLearning> { + parameter_learning: &'a P, } -impl Cache

{ - pub fn new(parameter_learning: P, dataset: tools::Dataset) -> Cache

{ - Cache { - parameter_learning, - dataset, - } +impl<'a, P: ParameterLearning> Cache<'a, P> { + pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { + Cache { parameter_learning } } pub fn fit( &mut self, net: &T, + dataset: &Dataset, node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning - .fit(net, &self.dataset, node, parent_set) + self.parameter_learning.fit(net, dataset, node, parent_set) } } diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index d119ab2..a4c6ea1 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub trait StructureLearningAlgorithm { - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 8949aa5..6d54fe7 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -6,27 +6,28 @@ use std::usize; use super::hypothesis_test::*; use crate::parameter_learning::{Cache, ParameterLearning}; +use crate::process; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{process, tools}; +use crate::tools::Dataset; pub struct CTPC { + parameter_learning: P, Ftest: F, Chi2test: ChiSquare, - cache: Cache

, } impl CTPC

{ - pub fn new(Ftest: F, Chi2test: ChiSquare, cache: Cache

) -> CTPC

{ + pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC

{ CTPC { - Chi2test, + parameter_learning, Ftest, - cache, + Chi2test, } } } impl StructureLearningAlgorithm for CTPC

{ - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess, { @@ -41,6 +42,7 @@ impl StructureLearningAlgorithm for CTPC

{ net.initialize_adj_matrix(); for child_node in net.get_node_indices() { + let mut cache = Cache::new(&self.parameter_learning); let mut candidate_parent_set: BTreeSet = net .get_node_indices() .into_iter() @@ -62,13 +64,15 @@ impl StructureLearningAlgorithm for CTPC

{ child_node, *parent_node, &separation_set, - &mut self.cache, + dataset, + &mut cache, ) && self.Chi2test.call( &net, child_node, *parent_node, &separation_set, - &mut self.cache, + dataset, + &mut cache, ) { candidate_parent_set_TMP.remove(parent_node); break; diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index dd3bbf7..dd683ab 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; -use crate::{parameter_learning, process}; +use crate::{parameter_learning, process, tools::Dataset}; pub trait HypothesisTest { fn call( @@ -15,6 +15,7 @@ pub trait HypothesisTest { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where @@ -84,19 +85,25 @@ impl HypothesisTest for F { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { Params::DiscreteStatesContinousTime(node) => node, }; let partial_cardinality_product: usize = extended_separation_set @@ -218,6 +225,7 @@ impl HypothesisTest for ChiSquare { child_node: usize, parent_node: usize, separation_set: &BTreeSet, + dataset: &Dataset, cache: &mut parameter_learning::Cache

, ) -> bool where @@ -227,14 +235,19 @@ impl HypothesisTest for ChiSquare { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // di dimensione nxn // (CIM, M, T) - let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); - let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { Params::DiscreteStatesContinousTime(node) => node, }; // Commentare qui diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index d59f0c1..d65ea88 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{process, tools}; +use crate::{process, tools::Dataset}; pub struct HillClimbing { score_function: S, @@ -21,7 +21,7 @@ impl HillClimbing { } impl StructureLearningAlgorithm for HillClimbing { - fn fit_transform(&mut self, net: T, dataset: &tools::Dataset) -> T + fn fit_transform(&self, net: T, dataset: &Dataset) -> T where T: process::NetworkProcess, { diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index c0deffd..6f97c9d 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -59,7 +59,7 @@ fn simple_bic() { ); } -fn check_compatibility_between_dataset_and_network(mut sl: T) { +fn check_compatibility_between_dataset_and_network(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -126,7 +126,7 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes(mut sl: T) { +fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) @@ -321,7 +321,7 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } -fn learn_mixed_discrete_net_3_nodes(mut sl: T) { +fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -343,7 +343,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(mut sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -393,7 +393,7 @@ pub fn chi_square_compare_matrices() { [ 700, 800, 0] ], ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -423,7 +423,7 @@ pub fn chi_square_compare_matrices_2() { [ 400, 0, 600], [ 700, 800, 0]] ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -455,7 +455,7 @@ pub fn chi_square_compare_matrices_3() { [ 700, 800, 0] ], ]); - let chi_sq = ChiSquare::new(0.1); + let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); } @@ -469,14 +469,14 @@ pub fn chi_square_call() { let N1: usize = 0; let mut separation_set = BTreeSet::new(); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let mut cache = Cache::new(parameter_learning, data); - let chi_sq = ChiSquare::new(0.0001); + let mut cache = Cache::new(¶meter_learning); + let chi_sq = ChiSquare::new(1e-4); - assert!(chi_sq.call(&net, N1, N3, &separation_set, &mut cache)); - assert!(!chi_sq.call(&net, N3, N1, &separation_set, &mut cache)); - assert!(!chi_sq.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); + assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); - assert!(chi_sq.call(&net, N2, N3, &separation_set, &mut cache)); + assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); } #[test] @@ -488,91 +488,31 @@ pub fn f_call() { let N1: usize = 0; let mut separation_set = BTreeSet::new(); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let mut cache = Cache::new(parameter_learning, data); - let f = F::new(0.000001); + let mut cache = Cache::new(¶meter_learning); + let f = F::new(1e-6); - assert!(f.call(&net, N1, N3, &separation_set, &mut cache)); - assert!(!f.call(&net, N3, N1, &separation_set, &mut cache)); - assert!(!f.call(&net, N3, N2, &separation_set, &mut cache)); + assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); + assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); - assert!(f.call(&net, N2, N3, &separation_set, &mut cache)); + assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); } #[test] pub fn learn_ternary_net_2_nodes_ctpc() { - let mut net = CtbnNetwork::new(); - let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) - .unwrap(); - let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) - .unwrap(); - net.add_edge(n1, n2); - - match &mut net.get_node_mut(n1) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(arr3(&[ - [ - [-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0] - ], - ])) - ); - } - } - - match &mut net.get_node_mut(n2) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(arr3(&[ - [ - [-1.0, 0.5, 0.5], - [3.0, -4.0, 1.0], - [0.9, 0.1, -1.0] - ], - [ - [-6.0, 2.0, 4.0], - [1.5, -2.0, 0.5], - [3.0, 1.0, -4.0] - ], - [ - [-1.0, 0.1, 0.9], - [2.0, -2.5, 0.5], - [0.9, 0.1, -1.0] - ], - ])) - ); - } - } - - let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); - - let f = F::new(0.000001); - let chi_sq = ChiSquare::new(0.0001); + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let cache = Cache::new(parameter_learning, data.clone()); - let mut ctpc = CTPC::new(f, chi_sq, cache); - - - let net = ctpc.fit_transform(net, &data); - assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); - assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes(ctpc); } #[test] fn learn_mixed_discrete_net_3_nodes_ctpc() { - let (_, data) = get_mixed_discrete_net_3_nodes_with_data(); - - let f = F::new(1e-24); - let chi_sq = ChiSquare::new(1e-24); + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; - let cache = Cache::new(parameter_learning, data); - let ctpc = CTPC::new(f, chi_sq, cache); - + let ctpc = CTPC::new(parameter_learning, f, chi_sq); learn_mixed_discrete_net_3_nodes(ctpc); } From ea3e406bf14e40c622ffc40e02abb18e55d1e30a Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 29 Dec 2022 23:03:03 +0100 Subject: [PATCH 101/126] Implemented basic cache --- reCTBN/src/parameter_learning.rs | 17 ++++++++++++++--- reCTBN/tests/structure_learning.rs | 4 ++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index f8a7664..021b100 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::BTreeSet; +use std::collections::{BTreeSet,HashMap}; use ndarray::prelude::*; @@ -165,13 +165,15 @@ impl ParameterLearning for BayesianApproach { } } +// TODO: Move to constraint_based_algorithm.rs pub struct Cache<'a, P: ParameterLearning> { parameter_learning: &'a P, + cache_persistent: HashMap>, Params>, } impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { parameter_learning } + Cache { parameter_learning, cache_persistent: HashMap::new() } } pub fn fit( &mut self, @@ -180,6 +182,15 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { node: usize, parent_set: Option>, ) -> Params { - self.parameter_learning.fit(net, dataset, node, parent_set) + match self.cache_persistent.get(&parent_set) { + // TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone()); + self.cache_persistent.insert(parent_set, params.clone()); + params + } + } } } diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 6f97c9d..a37f2b3 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -473,9 +473,11 @@ pub fn chi_square_call() { let chi_sq = ChiSquare::new(1e-4); assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); } @@ -493,9 +495,11 @@ pub fn f_call() { assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); } From a0da3e2fe8fb8e3dde02ca2c375e2826623d4ee0 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 30 Dec 2022 17:47:32 +0100 Subject: [PATCH 102/126] Fixed formatting issue --- reCTBN/src/parameter_learning.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 021b100..73193ca 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::{BTreeSet,HashMap}; +use std::collections::{BTreeSet, HashMap}; use ndarray::prelude::*; @@ -173,7 +173,10 @@ pub struct Cache<'a, P: ParameterLearning> { impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { parameter_learning, cache_persistent: HashMap::new() } + Cache { + parameter_learning, + cache_persistent: HashMap::new(), + } } pub fn fit( &mut self, @@ -187,7 +190,9 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { // not cloning requires a minor and reasoned refactoring across the library Some(params) => params.clone(), None => { - let params = self.parameter_learning.fit(net, dataset, node, parent_set.clone()); + let params = self + .parameter_learning + .fit(net, dataset, node, parent_set.clone()); self.cache_persistent.insert(parent_set, params.clone()); params } From 867bf029345855637ea6a608cf9b3ae58d0937eb Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 30 Dec 2022 17:55:26 +0100 Subject: [PATCH 103/126] Greatly improved memory consumption in cache, stale data no longer pile up --- reCTBN/src/parameter_learning.rs | 56 +++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 73193ca..fcc47c4 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,6 +1,7 @@ //! Module containing methods used to learn the parameters. use std::collections::{BTreeSet, HashMap}; +use std::mem; use ndarray::prelude::*; @@ -168,14 +169,18 @@ impl ParameterLearning for BayesianApproach { // TODO: Move to constraint_based_algorithm.rs pub struct Cache<'a, P: ParameterLearning> { parameter_learning: &'a P, - cache_persistent: HashMap>, Params>, + cache_persistent_small: HashMap>, Params>, + cache_persistent_big: HashMap>, Params>, + parent_set_size_small: usize, } impl<'a, P: ParameterLearning> Cache<'a, P> { pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { Cache { parameter_learning, - cache_persistent: HashMap::new(), + cache_persistent_small: HashMap::new(), + cache_persistent_big: HashMap::new(), + parent_set_size_small: 0, } } pub fn fit( @@ -185,16 +190,43 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { node: usize, parent_set: Option>, ) -> Params { - match self.cache_persistent.get(&parent_set) { - // TODO: Bettern not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = self - .parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent.insert(parent_set, params.clone()); - params + let parent_set_len = parent_set.as_ref().unwrap().len(); + if parent_set_len > self.parent_set_size_small + 1 { + //self.cache_persistent_small = self.cache_persistent_big; + mem::swap( + &mut self.cache_persistent_small, + &mut self.cache_persistent_big, + ); + self.cache_persistent_big = HashMap::new(); + self.parent_set_size_small += 1; + } + + if parent_set_len > self.parent_set_size_small { + match self.cache_persistent_big.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_big.insert(parent_set, params.clone()); + params + } + } + } else { + match self.cache_persistent_small.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_small + .insert(parent_set, params.clone()); + params + } } } } From 4d3f9518e4137e911d55cc0f723e839e8f391752 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 4 Jan 2023 12:14:36 +0100 Subject: [PATCH 104/126] CTPC parallelization at the nodes level with `rayon` --- reCTBN/Cargo.toml | 1 + reCTBN/src/parameter_learning.rs | 2 +- reCTBN/src/process.rs | 2 +- .../src/structure_learning/constraint_based_algorithm.rs | 8 +++++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml index fdac697..4749b23 100644 --- a/reCTBN/Cargo.toml +++ b/reCTBN/Cargo.toml @@ -14,6 +14,7 @@ enum_dispatch = "~0.3" statrs = "~0.16" rand_chacha = "~0.3" itertools = "~0.10" +rayon = "~1.6" [dev-dependencies] approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index fcc47c4..ff6a7b9 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -8,7 +8,7 @@ use ndarray::prelude::*; use crate::params::*; use crate::{process, tools::Dataset}; -pub trait ParameterLearning { +pub trait ParameterLearning: Sync { fn fit( &self, net: &T, diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index dc297bc..45c5e0a 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -21,7 +21,7 @@ pub type NetworkProcessState = Vec; /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait NetworkProcess { +pub trait NetworkProcess: Sync { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 6d54fe7..634c144 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,6 +1,8 @@ //! Module containing constraint based algorithms like CTPC and Hiton. use itertools::Itertools; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::ParallelExtend; use std::collections::BTreeSet; use std::usize; @@ -41,7 +43,8 @@ impl StructureLearningAlgorithm for CTPC

{ net.initialize_adj_matrix(); - for child_node in net.get_node_indices() { + let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![]; + learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { let mut cache = Cache::new(&self.parameter_learning); let mut candidate_parent_set: BTreeSet = net .get_node_indices() @@ -82,6 +85,9 @@ impl StructureLearningAlgorithm for CTPC

{ candidate_parent_set = candidate_parent_set_TMP; separation_set_size += 1; } + (child_node, candidate_parent_set) + })); + for (child_node, candidate_parent_set) in learned_parent_sets { for parent_node in candidate_parent_set.iter() { net.add_edge(*parent_node, child_node); } From 5632833963ed3f27514c02bb42c711a79cc06b74 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 5 Jan 2023 10:53:59 +0100 Subject: [PATCH 105/126] Moved `Cache` to `constraint_based_algorithm.rs` --- reCTBN/src/parameter_learning.rs | 69 +----------------- .../constraint_based_algorithm.rs | 71 ++++++++++++++++++- .../src/structure_learning/hypothesis_test.rs | 7 +- reCTBN/tests/structure_learning.rs | 1 - 4 files changed, 74 insertions(+), 74 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index ff6a7b9..536a9d5 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,7 +1,6 @@ //! Module containing methods used to learn the parameters. -use std::collections::{BTreeSet, HashMap}; -use std::mem; +use std::collections::BTreeSet; use ndarray::prelude::*; @@ -165,69 +164,3 @@ impl ParameterLearning for BayesianApproach { return n; } } - -// TODO: Move to constraint_based_algorithm.rs -pub struct Cache<'a, P: ParameterLearning> { - parameter_learning: &'a P, - cache_persistent_small: HashMap>, Params>, - cache_persistent_big: HashMap>, Params>, - parent_set_size_small: usize, -} - -impl<'a, P: ParameterLearning> Cache<'a, P> { - pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { - Cache { - parameter_learning, - cache_persistent_small: HashMap::new(), - cache_persistent_big: HashMap::new(), - parent_set_size_small: 0, - } - } - pub fn fit( - &mut self, - net: &T, - dataset: &Dataset, - node: usize, - parent_set: Option>, - ) -> Params { - let parent_set_len = parent_set.as_ref().unwrap().len(); - if parent_set_len > self.parent_set_size_small + 1 { - //self.cache_persistent_small = self.cache_persistent_big; - mem::swap( - &mut self.cache_persistent_small, - &mut self.cache_persistent_big, - ); - self.cache_persistent_big = HashMap::new(); - self.parent_set_size_small += 1; - } - - if parent_set_len > self.parent_set_size_small { - match self.cache_persistent_big.get(&parent_set) { - // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = - self.parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent_big.insert(parent_set, params.clone()); - params - } - } - } else { - match self.cache_persistent_small.get(&parent_set) { - // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O - // not cloning requires a minor and reasoned refactoring across the library - Some(params) => params.clone(), - None => { - let params = - self.parameter_learning - .fit(net, dataset, node, parent_set.clone()); - self.cache_persistent_small - .insert(parent_set, params.clone()); - params - } - } - } - } -} diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index 634c144..f49b194 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -1,17 +1,84 @@ //! Module containing constraint based algorithms like CTPC and Hiton. +use crate::params::Params; use itertools::Itertools; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use rayon::prelude::ParallelExtend; -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; +use std::mem; use std::usize; use super::hypothesis_test::*; -use crate::parameter_learning::{Cache, ParameterLearning}; +use crate::parameter_learning::ParameterLearning; use crate::process; use crate::structure_learning::StructureLearningAlgorithm; use crate::tools::Dataset; +pub struct Cache<'a, P: ParameterLearning> { + parameter_learning: &'a P, + cache_persistent_small: HashMap>, Params>, + cache_persistent_big: HashMap>, Params>, + parent_set_size_small: usize, +} + +impl<'a, P: ParameterLearning> Cache<'a, P> { + pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { + Cache { + parameter_learning, + cache_persistent_small: HashMap::new(), + cache_persistent_big: HashMap::new(), + parent_set_size_small: 0, + } + } + pub fn fit( + &mut self, + net: &T, + dataset: &Dataset, + node: usize, + parent_set: Option>, + ) -> Params { + let parent_set_len = parent_set.as_ref().unwrap().len(); + if parent_set_len > self.parent_set_size_small + 1 { + //self.cache_persistent_small = self.cache_persistent_big; + mem::swap( + &mut self.cache_persistent_small, + &mut self.cache_persistent_big, + ); + self.cache_persistent_big = HashMap::new(); + self.parent_set_size_small += 1; + } + + if parent_set_len > self.parent_set_size_small { + match self.cache_persistent_big.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_big.insert(parent_set, params.clone()); + params + } + } + } else { + match self.cache_persistent_small.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_small + .insert(parent_set, params.clone()); + params + } + } + } + } +} + pub struct CTPC { parameter_learning: P, Ftest: F, diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index dd683ab..4c02929 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,6 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; +use crate::structure_learning::constraint_based_algorithm::Cache; use crate::{parameter_learning, process, tools::Dataset}; pub trait HypothesisTest { @@ -16,7 +17,7 @@ pub trait HypothesisTest { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, @@ -86,7 +87,7 @@ impl HypothesisTest for F { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, @@ -226,7 +227,7 @@ impl HypothesisTest for ChiSquare { parent_node: usize, separation_set: &BTreeSet, dataset: &Dataset, - cache: &mut parameter_learning::Cache

, + cache: &mut Cache

, ) -> bool where T: process::NetworkProcess, diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a37f2b3..9a69b45 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -7,7 +7,6 @@ use ndarray::{arr1, arr2, arr3}; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; -use reCTBN::parameter_learning::Cache; use reCTBN::params; use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::constraint_based_algorithm::*; From ff235b4b7735e3796720a018843f061bbf39bd0c Mon Sep 17 00:00:00 2001 From: Alessandro Bregoli Date: Sat, 14 Jan 2023 14:20:37 +0100 Subject: [PATCH 106/126] Parallelize score based strucutre learning --- .../structure_learning/score_based_algorithm.rs | 16 ++++++++++++---- reCTBN/src/structure_learning/score_function.rs | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index d65ea88..6850027 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -6,6 +6,9 @@ use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; use crate::{process, tools::Dataset}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::ParallelExtend; + pub struct HillClimbing { score_function: S, max_parent_set: Option, @@ -36,8 +39,9 @@ impl StructureLearningAlgorithm for HillClimbing { let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); //Reset the adj matrix net.initialize_adj_matrix(); + let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![]; //Iterate over each node to learn their parent set. - for node in net.get_node_indices() { + learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| { //Initialize an empty parent set. let mut parent_set: BTreeSet = BTreeSet::new(); //Compute the score for the empty parent set @@ -76,10 +80,14 @@ impl StructureLearningAlgorithm for HillClimbing { } } } - //Apply the learned parent_set to the network struct. - parent_set.iter().for_each(|p| net.add_edge(*p, node)); + (node, parent_set) + })); + + for (child_node, candidate_parent_set) in learned_parent_sets { + for parent_node in candidate_parent_set.iter() { + net.add_edge(*parent_node, child_node); + } } - return net; } } diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index f8b38b5..5a56594 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -7,7 +7,7 @@ use statrs::function::gamma; use crate::{parameter_learning, params, process, tools}; -pub trait ScoreFunction { +pub trait ScoreFunction: Sync { fn call( &self, net: &T, From 5d676be18033aee322142acef9e7d439e499071f Mon Sep 17 00:00:00 2001 From: Alessandro Bregoli Date: Mon, 16 Jan 2023 06:50:24 +0100 Subject: [PATCH 107/126] parallelized re --- reCTBN/src/reward.rs | 3 +-- reCTBN/src/reward/reward_evaluation.rs | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs index b34db7f..f0edf2f 100644 --- a/reCTBN/src/reward.rs +++ b/reCTBN/src/reward.rs @@ -4,7 +4,6 @@ pub mod reward_evaluation; use std::collections::HashMap; use crate::process; -use ndarray; /// Instantiation of reward function and instantaneous reward /// @@ -22,7 +21,7 @@ pub struct Reward { /// The trait RewardFunction describe the methods that all the reward functions must satisfy -pub trait RewardFunction { +pub trait RewardFunction: Sync { /// Given the current state and the previous state, it compute the reward. /// /// # Arguments diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index cb7b8f1..431efde 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + use crate::params::{self, ParamsTrait}; use crate::process; + use crate::{ process::NetworkProcessState, reward::RewardEvaluation, @@ -55,7 +58,7 @@ impl RewardEvaluation for MonteCarloReward { let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); - (0..n_states) + (0..n_states).into_par_iter() .map(|s| { let state: process::NetworkProcessState = variables_domain .iter() From 7ec56914d91c38f27862b138728d9938651188a2 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 17 Jan 2023 21:43:54 +0100 Subject: [PATCH 108/126] Added doctest for CTPC --- .../constraint_based_algorithm.rs | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs index f49b194..f9cd820 100644 --- a/reCTBN/src/structure_learning/constraint_based_algorithm.rs +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -79,6 +79,190 @@ impl<'a, P: ParameterLearning> Cache<'a, P> { } } +/// Continuous-Time Peter Clark algorithm. +/// +/// A method to learn the structure of the network. +/// +/// # Arguments +/// +/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters. +/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test. +/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test. +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::parameter_learning::BayesianApproach; +/// use reCTBN::structure_learning::StructureLearningAlgorithm; +/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare}; +/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC; +/// # +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("A")); +/// # domain.insert(String::from("B")); +/// # domain.insert(String::from("C")); +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain); +/// # //Create the node n1 using the parameters +/// # let n1 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("D")); +/// # domain.insert(String::from("E")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain); +/// # let n2 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("G")); +/// # domain.insert(String::from("H")); +/// # domain.insert(String::from("I")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain); +/// # let n3 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # // Initialize a ctbn +/// # let mut net = CtbnNetwork::new(); +/// # +/// # // Add the nodes and their edges +/// # let n1 = net.add_node(n1).unwrap(); +/// # let n2 = net.add_node(n2).unwrap(); +/// # let n3 = net.add_node(n3).unwrap(); +/// # net.add_edge(n1, n2); +/// # net.add_edge(n1, n3); +/// # net.add_edge(n2, n3); +/// # +/// # match &mut net.get_node_mut(n1) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-3.0, 2.0, 1.0], +/// # [1.5, -2.0, 0.5], +/// # [0.4, 0.6, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n2) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.5], +/// # [3.0, -4.0, 1.0], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 4.0], +/// # [1.5, -2.0, 0.5], +/// # [3.0, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.0, 0.1, 0.9], +/// # [2.0, -2.5, 0.5], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n3) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.3, 0.2], +/// # [0.5, -4.0, 2.5, 1.0], +/// # [2.5, 0.5, -4.0, 1.0], +/// # [0.7, 0.2, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 3.0, 1.0], +/// # [1.5, -3.0, 0.5, 1.0], +/// # [2.0, 1.3, -5.0, 1.7], +/// # [2.5, 0.5, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.3, 0.1, 0.9], +/// # [1.4, -4.0, 0.5, 2.1], +/// # [1.0, 1.5, -3.0, 0.5], +/// # [0.4, 0.3, 0.1, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.7, 0.3], +/// # [1.3, -5.9, 2.7, 1.9], +/// # [2.0, 1.5, -4.0, 0.5], +/// # [0.2, 0.7, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 1.0, 2.0, 3.0], +/// # [0.5, -3.0, 1.0, 1.5], +/// # [1.4, 2.1, -4.3, 0.8], +/// # [0.5, 1.0, 2.5, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.9, 0.3, 0.1], +/// # [0.1, -1.3, 0.2, 1.0], +/// # [0.5, 1.0, -3.0, 1.5], +/// # [0.1, 0.4, 0.3, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.6, 0.4], +/// # [2.6, -7.1, 1.4, 3.1], +/// # [5.0, 1.0, -8.0, 2.0], +/// # [1.4, 0.4, 0.2, -2.0] +/// # ], +/// # [ +/// # [-3.0, 1.0, 1.5, 0.5], +/// # [3.0, -6.0, 1.0, 2.0], +/// # [0.3, 0.5, -1.9, 1.1], +/// # [5.0, 1.0, 2.0, -8.0] +/// # ], +/// # [ +/// # [-2.6, 0.6, 0.2, 1.8], +/// # [2.0, -6.0, 3.0, 1.0], +/// # [0.1, 0.5, -1.3, 0.7], +/// # [0.8, 0.6, 0.2, -1.6] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # // Generate the trajectory +/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873)); +/// +/// // Initialize the hypothesis tests to pass to the CTPC with their +/// // respective significance level `alpha` +/// let f = F::new(1e-6); +/// let chi_sq = ChiSquare::new(1e-4); +/// // Use the bayesian approach to learn the parameters +/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; +/// +/// //Initialize CTPC +/// let ctpc = CTPC::new(parameter_learning, f, chi_sq); +/// +/// // Learn the structure of the network from the generated trajectory +/// let net = ctpc.fit_transform(net, &data); +/// # +/// # // Compare the generated network with the original one +/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); +/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +/// ``` pub struct CTPC { parameter_learning: P, Ftest: F, From c2df26c3e6835d87f8ea2a47fefad29cffe26fba Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 18 Jan 2023 14:18:24 +0100 Subject: [PATCH 109/126] Added docstrings for the F-test and removed some comments --- .../src/structure_learning/hypothesis_test.rs | 66 +++++++------------ 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4c02929..311ec47 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -39,6 +39,17 @@ pub struct ChiSquare { alpha: f64, } +/// Does the F-test. +/// +/// Used to determine if a difference between two sets of data is due to chance, or if it is due to +/// a relationship (dependence) between the variables. +/// +/// # Arguments +/// +/// * `alpha` - is the significance level, the probability to reject a true null hypothesis; +/// in other words is the risk of concluding that an association between the variables exists +/// when there is no actual association. + pub struct F { alpha: f64, } @@ -48,6 +59,20 @@ impl F { F { alpha } } + /// Compare two matrices extracted from two 3rd-orer tensors. + /// + /// # Arguments + /// + /// * `i` - Position of the matrix of `M1` to compare with `M2`. + /// * `M1` - 3rd-order tensor 1. + /// * `j` - Position of the matrix of `M2` to compare with `M1`. + /// * `M2` - 3rd-order tensor 2. + /// + /// # Returns + /// + /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**. + /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**. + pub fn compare_matrices( &self, i: usize, @@ -164,26 +189,8 @@ impl ChiSquare { // continuous-time Bayesian networks. // International Journal of Approximate Reasoning, 138, pp.105-122. // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm - // - // M = M M = M - // 1 xx'|s 2 xx'|y,s let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); - // __________________ - // / === - // / \ M - // / / xx'|s - // / === - // / x'ϵVal /X \ - // / \ i/ 1 - //K = / ------------------ L = - - // / === K - // / \ M - // / / xx'|y,s - // / === - // / x'ϵVal /X \ - // \ / \ i/ - // \/ let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); let K = K.mapv(f64::sqrt); // Reshape to column vector. @@ -191,34 +198,16 @@ impl ChiSquare { let n = K.len(); K.into_shape((n, 1)).unwrap() }; - //println!("K: {:?}", K); let L = 1.0 / &K; - // ===== 2 - // \ (K . M - L . M) - // \ 2 1 - // / --------------- - // / M + M - // ===== 2 1 - // x'ϵVal /X \ - // \ i/ let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); - //println!("M1: {:?}", M1); - //println!("M2: {:?}", M2); - //println!("L*M1: {:?}", (L * &M1)); - //println!("K*M2: {:?}", (K * &M2)); - //println!("X_2: {:?}", X_2); X_2.diag_mut().fill(0.0); let X_2 = X_2.sum_axis(Axis(1)); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); - //println!("CHI^2: {:?}", n); - //println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); - //println!("test: {:?}", ret); ret } } -// ritorna false quando sono dipendenti e false quando sono indipendenti impl HypothesisTest for ChiSquare { fn call( &self, @@ -233,13 +222,9 @@ impl HypothesisTest for ChiSquare { T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { - // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM - // di dimensione nxn - // (CIM, M, T) let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { Params::DiscreteStatesContinousTime(node) => node, }; - // let mut extended_separation_set = separation_set.clone(); extended_separation_set.insert(parent_node); @@ -251,7 +236,6 @@ impl HypothesisTest for ChiSquare { ) { Params::DiscreteStatesContinousTime(node) => node, }; - // Commentare qui let partial_cardinality_product: usize = extended_separation_set .iter() .take_while(|x| **x != parent_node) From a077f738eeafc606dd1f791503b0d0c629e97aaa Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 26 Jan 2023 16:16:22 +0100 Subject: [PATCH 110/126] Added `StructureGen` struct for generating the structure of a `CtbnNetwork` --- reCTBN/src/tools.rs | 39 +++++++++++++++++++++++++++++++++++++++ reCTBN/tests/tools.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 47a067d..c58403a 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,7 +1,11 @@ //! Contains commonly used methods used across the crate. use ndarray::prelude::*; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use crate::process::ctbn::CtbnNetwork; +use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -108,3 +112,38 @@ pub fn trajectory_generator( //Return a dataset object with the sampled trajectories. Dataset::new(trajectories) } + +pub struct StructureGen { + density: f64, + rng: ChaCha8Rng, +} + +impl StructureGen { + pub fn new(density: f64, seed: Option) -> StructureGen { + if density < 0.0 || density > 1.0 { + panic!( + "Density value must be between 1.0 and 0.0, got {}.", + density + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + StructureGen { density, rng } + } + + pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + let last_node_idx = net.get_node_indices().len(); + for parent in 0..last_node_idx { + for child in 0..last_node_idx { + if parent != child { + if self.rng.gen_bool(self.density) { + net.add_edge(parent, child); + } + } + } + } + net + } +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 806faef..ac64f8d 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -82,3 +82,43 @@ fn dataset_wrong_shape() { let t2 = Trajectory::new(time, events); Dataset::new(vec![t1, t2]); } + +#[test] +#[should_panic] +fn structure_gen_wrong_density() { + let density = 2.1; + StructureGen::new(density, None); +} + +#[test] +fn structure_gen_right_densities() { + for density in [1.0, 0.75, 0.5, 0.25, 0.0] { + StructureGen::new(density, None); + } +} + +#[test] +fn structure_gen_gen_structure() { + let mut net = CtbnNetwork::new(); + for node_label in 0..100 { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + 2, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator = StructureGen::new(density, Some(7641630759785120)); + structure_generator.gen_structure(&mut net); + let mut edges = 0; + for node in net.get_node_indices(){ + edges += net.get_children_set(node).len() + } + let nodes = net.get_node_indices().len() as f64; + let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; + let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance + // As the way `gen_structure()` is implemented we can only reasonably + // expect the number of edges to be somewhere around the expected value. + assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); +} From 4b994d8a19855dd387b98141cc4a154f8eb62521 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 26 Jan 2023 16:16:23 +0100 Subject: [PATCH 111/126] Renamed `StructureGen` with `UniformRandomGenerator` and defining the new trait `RandomGraphGenerator` --- reCTBN/src/tools.rs | 15 ++++++++++----- reCTBN/tests/tools.rs | 16 ++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index c58403a..599b420 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -113,13 +113,18 @@ pub fn trajectory_generator( Dataset::new(trajectories) } -pub struct StructureGen { +pub trait RandomGraphGenerator { + fn new(density: f64, seed: Option) -> Self; + fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork; +} + +pub struct UniformRandomGenerator { density: f64, rng: ChaCha8Rng, } -impl StructureGen { - pub fn new(density: f64, seed: Option) -> StructureGen { +impl RandomGraphGenerator for UniformRandomGenerator { + fn new(density: f64, seed: Option) -> UniformRandomGenerator { if density < 0.0 || density > 1.0 { panic!( "Density value must be between 1.0 and 0.0, got {}.", @@ -130,10 +135,10 @@ impl StructureGen { Some(seed) => SeedableRng::seed_from_u64(seed), None => SeedableRng::from_entropy(), }; - StructureGen { density, rng } + UniformRandomGenerator { density, rng } } - pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { let last_node_idx = net.get_node_indices().len(); for parent in 0..last_node_idx { for child in 0..last_node_idx { diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index ac64f8d..4c32de7 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -85,20 +85,20 @@ fn dataset_wrong_shape() { #[test] #[should_panic] -fn structure_gen_wrong_density() { +fn uniform_random_generator_wrong_density() { let density = 2.1; - StructureGen::new(density, None); + let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); } #[test] -fn structure_gen_right_densities() { +fn uniform_random_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - StructureGen::new(density, None); + let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); } } #[test] -fn structure_gen_gen_structure() { +fn uniform_random_generator_generate_graph() { let mut net = CtbnNetwork::new(); for node_label in 0..100 { net.add_node( @@ -109,8 +109,8 @@ fn structure_gen_gen_structure() { ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator = StructureGen::new(density, Some(7641630759785120)); - structure_generator.gen_structure(&mut net); + let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ edges += net.get_children_set(node).len() @@ -118,7 +118,7 @@ fn structure_gen_gen_structure() { let nodes = net.get_node_indices().len() as f64; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance - // As the way `gen_structure()` is implemented we can only reasonably + // As the way `generate_graph()` is implemented we can only reasonably // expect the number of edges to be somewhere around the expected value. assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); } From 434e671f0a114c95a7f18825cfe5cf42773d96a7 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 27 Jan 2023 12:43:47 +0100 Subject: [PATCH 112/126] Renamed `UniformRandomGenerator` to `UniformGraphGenerator`, replaced `CtbnNetwork` requirement with `NetworkProcess` in `RandomGraphGenerator`, some related tweaks --- reCTBN/src/tools.rs | 15 +++++++-------- reCTBN/tests/tools.rs | 10 +++++----- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 599b420..9222239 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -4,7 +4,6 @@ use ndarray::prelude::*; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use crate::process::ctbn::CtbnNetwork; use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -115,16 +114,16 @@ pub fn trajectory_generator( pub trait RandomGraphGenerator { fn new(density: f64, seed: Option) -> Self; - fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork; + fn generate_graph(&mut self, net: &mut T); } -pub struct UniformRandomGenerator { +pub struct UniformGraphGenerator { density: f64, rng: ChaCha8Rng, } -impl RandomGraphGenerator for UniformRandomGenerator { - fn new(density: f64, seed: Option) -> UniformRandomGenerator { +impl RandomGraphGenerator for UniformGraphGenerator { + fn new(density: f64, seed: Option) -> UniformGraphGenerator { if density < 0.0 || density > 1.0 { panic!( "Density value must be between 1.0 and 0.0, got {}.", @@ -135,10 +134,11 @@ impl RandomGraphGenerator for UniformRandomGenerator { Some(seed) => SeedableRng::seed_from_u64(seed), None => SeedableRng::from_entropy(), }; - UniformRandomGenerator { density, rng } + UniformGraphGenerator { density, rng } } - fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + fn generate_graph(&mut self, net: &mut T) { + net.initialize_adj_matrix(); let last_node_idx = net.get_node_indices().len(); for parent in 0..last_node_idx { for child in 0..last_node_idx { @@ -149,6 +149,5 @@ impl RandomGraphGenerator for UniformRandomGenerator { } } } - net } } diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 4c32de7..9a96959 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -87,13 +87,13 @@ fn dataset_wrong_shape() { #[should_panic] fn uniform_random_generator_wrong_density() { let density = 2.1; - let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } #[test] fn uniform_random_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } } @@ -109,7 +109,7 @@ fn uniform_random_generator_generate_graph() { ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ @@ -117,8 +117,8 @@ fn uniform_random_generator_generate_graph() { } let nodes = net.get_node_indices().len() as f64; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; - let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance + let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance // As the way `generate_graph()` is implemented we can only reasonably // expect the number of edges to be somewhere around the expected value. - assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); + assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); } From d6f0fb9623b16187bf4341707797e1b094e6eeba Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sun, 29 Jan 2023 16:44:13 +0100 Subject: [PATCH 113/126] WIP: implementing `UniformParametersGenerator` --- reCTBN/src/tools.rs | 92 ++++++++++++++++++++++++++++++++++++++++++- reCTBN/tests/tools.rs | 47 ++++++++++++++++++++-- 2 files changed, 135 insertions(+), 4 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 9222239..7c438d5 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,9 +1,12 @@ //! Contains commonly used methods used across the crate. -use ndarray::prelude::*; +use std::ops::{DivAssign, MulAssign, Range}; + +use ndarray::{Array, Array1, Array2, Array3, Axis}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; +use crate::params::ParamsTrait; use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -151,3 +154,90 @@ impl RandomGraphGenerator for UniformGraphGenerator { } } } + +pub trait RandomParametersGenerator { + fn new(interval: Range, seed: Option) -> Self; + fn generate_parameters(&mut self, net: &mut T); +} + +pub struct UniformParametersGenerator { + interval: Range, + rng: ChaCha8Rng, +} + +impl RandomParametersGenerator for UniformParametersGenerator { + fn new(interval: Range, seed: Option) -> UniformParametersGenerator { + if interval.start < 0.0 || interval.end < 0.0 { + panic!( + "Interval must be entirely less or equal than 0, got {}..{}.", + interval.start, interval.end + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + UniformParametersGenerator { interval, rng } + } + fn generate_parameters(&mut self, net: &mut T) { + for node in net.get_node_indices() { + let parent_set = net.get_parent_set(node); + let parent_set_state_space_cardinality: usize = parent_set + .iter() + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + println!( + "parent_set_state_space_cardinality = {}", + parent_set_state_space_cardinality + ); + let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); + println!("node_domain_cardinality = {}", node_domain_cardinality); + let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64) + ..=(self.interval.end / node_domain_cardinality as f64); + println!("cim_single_param_range = {:?}", cim_single_param_range); + + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + + //let diagonal = cim.axis_iter(Axis(0)); + cim.axis_iter_mut(Axis(0)) + .for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}"))); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + let sum_axis = x.sum_axis(Axis(0)); + //let division = 1.0 / &sum_axis; + x.div_assign(&sum_axis); + println!("{}", x); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag); + println!("{}", x); + x.diag_mut().assign(&-diag) + }); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.diag().iter().for_each(|x| println!("{x}"))); + + println!("Sum Axis"); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}"))); + println!("Matrices"); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.iter().for_each(|x| println!("{}", x))); + //.any(|x| x.diag().iter().any(|x| x >= &0.0)) + + //println!("{:?}", diagonal); + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(cim)); + } + } + } + } +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 9a96959..e91cf04 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,3 +1,5 @@ +use std::ops::Range; + use ndarray::{arr1, arr2, arr3}; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; @@ -85,20 +87,27 @@ fn dataset_wrong_shape() { #[test] #[should_panic] -fn uniform_random_generator_wrong_density() { +fn uniform_graph_generator_wrong_density_1() { let density = 2.1; let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } #[test] -fn uniform_random_generator_right_densities() { +#[should_panic] +fn uniform_graph_generator_wrong_density_2() { + let density = -0.5; + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); +} + +#[test] +fn uniform_graph_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } } #[test] -fn uniform_random_generator_generate_graph() { +fn uniform_graph_generator_generate_graph() { let mut net = CtbnNetwork::new(); for node_label in 0..100 { net.add_node( @@ -122,3 +131,35 @@ fn uniform_random_generator_generate_graph() { // expect the number of edges to be somewhere around the expected value. assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); } + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_1() { + let interval: Range = -2.0..-5.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); +} + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_2() { + let interval: Range = -1.0..0.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); +} + +#[test] +fn uniform_parameters_generator_right_densities() { + let mut net = CtbnNetwork::new(); + for node_label in 0..3 { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + 9, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + structure_generator.generate_graph(&mut net); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120)); + cim_generator.generate_parameters(&mut net); +} From f4e3c98c796aea5863f15a9969ca5500eb340f5d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 30 Jan 2023 10:48:17 +0100 Subject: [PATCH 114/126] Implemented `UniformParametersGenerator` and its test --- reCTBN/src/tools.rs | 40 ++++++---------------------------------- reCTBN/tests/tools.rs | 21 +++++++++++++++++---- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 7c438d5..344c66c 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -179,23 +179,15 @@ impl RandomParametersGenerator for UniformParametersGenerator { }; UniformParametersGenerator { interval, rng } } + fn generate_parameters(&mut self, net: &mut T) { for node in net.get_node_indices() { - let parent_set = net.get_parent_set(node); - let parent_set_state_space_cardinality: usize = parent_set + let parent_set_state_space_cardinality: usize = net + .get_parent_set(node) .iter() .map(|x| net.get_node(*x).get_reserved_space_as_parent()) .product(); - println!( - "parent_set_state_space_cardinality = {}", - parent_set_state_space_cardinality - ); let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); - println!("node_domain_cardinality = {}", node_domain_cardinality); - let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64) - ..=(self.interval.end / node_domain_cardinality as f64); - println!("cim_single_param_range = {:?}", cim_single_param_range); - let mut cim = Array3::::from_shape_fn( ( parent_set_state_space_cardinality, @@ -204,38 +196,18 @@ impl RandomParametersGenerator for UniformParametersGenerator { ), |_| self.rng.gen(), ); - - //let diagonal = cim.axis_iter(Axis(0)); - cim.axis_iter_mut(Axis(0)) - .for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}"))); cim.axis_iter_mut(Axis(0)).for_each(|mut x| { x.diag_mut().fill(0.0); - let sum_axis = x.sum_axis(Axis(0)); - //let division = 1.0 / &sum_axis; - x.div_assign(&sum_axis); - println!("{}", x); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { self.rng.gen_range(self.interval.clone()) }); - x.mul_assign(&diag); - println!("{}", x); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); x.diag_mut().assign(&-diag) }); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.diag().iter().for_each(|x| println!("{x}"))); - - println!("Sum Axis"); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}"))); - println!("Matrices"); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.iter().for_each(|x| println!("{}", x))); - //.any(|x| x.diag().iter().any(|x| x >= &0.0)) - - //println!("{:?}", diagonal); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(cim)); + param.set_cim_unchecked(cim); } } } diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index e91cf04..f04fb2a 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -149,17 +149,30 @@ fn uniform_parameters_generator_wrong_density_2() { #[test] fn uniform_parameters_generator_right_densities() { let mut net = CtbnNetwork::new(); - for node_label in 0..3 { + let nodes_cardinality = 0..5; + let nodes_domain_cardinality = 9; + for node_label in nodes_cardinality { net.add_node( utils::generate_discrete_time_continous_node( node_label.to_string(), - 9, + nodes_domain_cardinality, ) ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, seed); structure_generator.generate_graph(&mut net); - let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120)); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed); cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(param.get_cim().clone().unwrap())); + } + } + } } From 0f61cbee4c6fc72a159a3b3bb91ca2a7cb553ccd Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 31 Jan 2023 09:30:40 +0100 Subject: [PATCH 115/126] Refactored CIM validation for `UniformParametersGenerator` test --- reCTBN/tests/tools.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index f04fb2a..59ed71c 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,6 +1,7 @@ use std::ops::Range; use ndarray::{arr1, arr2, arr3}; +use reCTBN::params::ParamsTrait; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::params; @@ -167,12 +168,9 @@ fn uniform_parameters_generator_right_densities() { let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed); cim_generator.generate_parameters(&mut net); for node in net.get_node_indices() { - match &mut net.get_node_mut(node) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(param.get_cim().clone().unwrap())); - } - } + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); } } From 097dc25030732f7fc7d778aef6bb74d9cb8ac723 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 31 Jan 2023 11:28:26 +0100 Subject: [PATCH 116/126] Added tests for `UniformParametersGenerator` and `UniformGraphGenerator` against `CTMP`, plus some small refactoring to the other tests --- reCTBN/tests/tools.rs | 103 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 59ed71c..59d8f27 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -3,10 +3,13 @@ use std::ops::Range; use ndarray::{arr1, arr2, arr3}; use reCTBN::params::ParamsTrait; use reCTBN::process::ctbn::*; +use reCTBN::process::ctmp::*; use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*; +use utils::*; + #[macro_use] extern crate approx; @@ -90,36 +93,50 @@ fn dataset_wrong_shape() { #[should_panic] fn uniform_graph_generator_wrong_density_1() { let density = 2.1; - let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); } #[test] #[should_panic] fn uniform_graph_generator_wrong_density_2() { let density = -0.5; - let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); } #[test] fn uniform_graph_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); } } #[test] -fn uniform_graph_generator_generate_graph() { +fn uniform_graph_generator_generate_graph_ctbn() { let mut net = CtbnNetwork::new(); - for node_label in 0..100 { + let nodes_cardinality = 0..=100; + let nodes_domain_cardinality = 2; + for node_label in nodes_cardinality { net.add_node( utils::generate_discrete_time_continous_node( node_label.to_string(), - 2, + nodes_domain_cardinality, ) ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ @@ -133,28 +150,54 @@ fn uniform_graph_generator_generate_graph() { assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); } +#[test] +#[should_panic] +fn uniform_graph_generator_generate_graph_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); + structure_generator.generate_graph(&mut net); +} + #[test] #[should_panic] fn uniform_parameters_generator_wrong_density_1() { let interval: Range = -2.0..-5.0; - let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); } #[test] #[should_panic] fn uniform_parameters_generator_wrong_density_2() { let interval: Range = -1.0..0.0; - let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); } #[test] -fn uniform_parameters_generator_right_densities() { +fn uniform_parameters_generator_right_densities_ctbn() { let mut net = CtbnNetwork::new(); - let nodes_cardinality = 0..5; + let nodes_cardinality = 0..=3; let nodes_domain_cardinality = 9; for node_label in nodes_cardinality { net.add_node( - utils::generate_discrete_time_continous_node( + generate_discrete_time_continous_node( node_label.to_string(), nodes_domain_cardinality, ) @@ -163,9 +206,41 @@ fn uniform_parameters_generator_right_densities() { let density = 1.0/3.0; let seed = Some(7641630759785120); let interval = 0.0..7.0; - let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, seed); + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + seed + ); structure_generator.generate_graph(&mut net); - let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); + cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); + } +} + +#[test] +fn uniform_parameters_generator_right_densities_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); cim_generator.generate_parameters(&mut net); for node in net.get_node_indices() { assert_eq!( From a01a9ef20107983667cc2c30f627c8fcf3662df5 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 31 Jan 2023 13:16:52 +0100 Subject: [PATCH 117/126] Recomputing the diagonal when generating parameters to counter the precision loss and increase `f64::EPSILON` calculating its square root instead of multiplying it with the node's `domain_size` --- reCTBN/src/params.rs | 4 +--- reCTBN/src/tools.rs | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 9f63860..dc941e5 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,13 +267,11 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } - let domain_size = domain_size as f64; - // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) + .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 344c66c..e9b9fd8 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -203,7 +203,10 @@ impl RandomParametersGenerator for UniformParametersGenerator { self.rng.gen_range(self.interval.clone()) }); x.mul_assign(&diag.clone().insert_axis(Axis(1))); - x.diag_mut().assign(&-diag) + // Recomputing the diagonal in order to reduce the issues caused by the loss of + // precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) }); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { From e08d12ac1f243511befbc76c0c35c7dc03efd679 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 09:04:35 +0100 Subject: [PATCH 118/126] Added tests for structure learning algorithms using uniform graph and parameters generators as complementary to their handcrafted version --- reCTBN/tests/structure_learning.rs | 171 +++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 9a69b45..3d7e230 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -117,6 +117,50 @@ fn check_compatibility_between_dataset_and_network(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); + + let mut net = CtbnNetwork::new(); + let _n1 = net + .add_node( + generate_discrete_time_continous_node(String::from("0"), + 3) + ).unwrap(); + let _net = sl.fit_transform(net, &data); +} + #[test] #[should_panic] pub fn check_compatibility_between_dataset_and_network_hill_climbing() { @@ -125,6 +169,14 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + check_compatibility_between_dataset_and_network_gen(hl); +} + fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -182,6 +234,25 @@ fn learn_ternary_net_2_nodes(sl: T) { assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } +fn learn_ternary_net_2_nodes_gen(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -189,6 +260,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_ternary_net_2_nodes_gen(hl); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -196,6 +274,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_ternary_net_2_nodes_gen(hl); +} + fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { let mut net = CtbnNetwork::new(); let n1 = net @@ -320,6 +405,30 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } +fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); + return (net, data); +} + fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -328,6 +437,14 @@ fn learn_mixed_discrete_net_3_nodes(sl: T) { assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); } +fn learn_mixed_discrete_net_3_nodes_gen(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -335,6 +452,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -342,6 +466,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -350,6 +481,14 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { let ll = LogLikelihood::new(1, 1.0); @@ -357,6 +496,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { let bic = BIC::new(1, 1.0); @@ -364,6 +510,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn chi_square_compare_matrices() { let i: usize = 1; @@ -511,6 +664,15 @@ pub fn learn_ternary_net_2_nodes_ctpc() { learn_ternary_net_2_nodes(ctpc); } +#[test] +pub fn learn_ternary_net_2_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes_gen(ctpc); +} + #[test] fn learn_mixed_discrete_net_3_nodes_ctpc() { let f = F::new(1e-6); @@ -519,3 +681,12 @@ fn learn_mixed_discrete_net_3_nodes_ctpc() { let ctpc = CTPC::new(parameter_learning, f, chi_sq); learn_mixed_discrete_net_3_nodes(ctpc); } + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_mixed_discrete_net_3_nodes_gen(ctpc); +} From 430033afdb17a239ada1ccad16f3c32e3ce48234 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 11:20:13 +0100 Subject: [PATCH 119/126] Added tests for the learning of parameters using uniform graph and parameters generators as complementary to their handcrafted version --- reCTBN/tests/parameter_learning.rs | 203 +++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 2cbc185..0a09a2a 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -6,6 +6,7 @@ use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; +use reCTBN::params::Params::DiscreteStatesContinousTime; use reCTBN::tools::*; use utils::*; @@ -66,18 +67,78 @@ fn learn_binary_cim(pl: T) { )); } +fn generate_nodes( + net: &mut CtbnNetwork, + nodes_cardinality: usize, + nodes_domain_cardinality: usize +) { + for node_label in 0..nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } +} + +fn learn_binary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 2); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_binary_cim_MLE() { let mle = MLE {}; learn_binary_cim(mle); } +#[test] +fn learn_binary_cim_MLE_gen() { + let mle = MLE {}; + learn_binary_cim_gen(mle); +} + #[test] fn learn_binary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_binary_cim(ba); } +#[test] +fn learn_binary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_binary_cim_gen(ba); +} + fn learn_ternary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -155,18 +216,63 @@ fn learn_ternary_cim(pl: T) { )); } +fn learn_ternary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 4.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_MLE() { let mle = MLE {}; learn_ternary_cim(mle); } +#[test] +fn learn_ternary_cim_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_gen(mle); +} + #[test] fn learn_ternary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim(ba); } +#[test] +fn learn_ternary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_gen(ba); +} + fn learn_ternary_cim_no_parents(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -234,18 +340,63 @@ fn learn_ternary_cim_no_parents(pl: T) { )); } +fn learn_ternary_cim_no_parents_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(0) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 0, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_no_parents_MLE() { let mle = MLE {}; learn_ternary_cim_no_parents(mle); } +#[test] +fn learn_ternary_cim_no_parents_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_no_parents_gen(mle); +} + #[test] fn learn_ternary_cim_no_parents_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim_no_parents(ba); } +#[test] +fn learn_ternary_cim_no_parents_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_no_parents_gen(ba); +} + fn learn_mixed_discrete_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -432,14 +583,66 @@ fn learn_mixed_discrete_cim(pl: T) { )); } +fn learn_mixed_discrete_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..8.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(2) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 2, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.2 + ) + ); +} + #[test] fn learn_mixed_discrete_cim_MLE() { let mle = MLE {}; learn_mixed_discrete_cim(mle); } +#[test] +fn learn_mixed_discrete_cim_MLE_gen() { + let mle = MLE {}; + learn_mixed_discrete_cim_gen(mle); +} + #[test] fn learn_mixed_discrete_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_mixed_discrete_cim(ba); } + +#[test] +fn learn_mixed_discrete_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_mixed_discrete_cim_gen(ba); +} From 4884010ea97f1670f79ee2a2e9fe914b9ea65b80 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 14:36:12 +0100 Subject: [PATCH 120/126] Added doctests for `UniformParametersGenerator` and `UniformGraphGenerator` --- reCTBN/src/tools.rs | 138 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index e9b9fd8..89c19a9 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -120,6 +120,72 @@ pub trait RandomGraphGenerator { fn generate_graph(&mut self, net: &mut T); } +/// Graph Generator using an uniform distribution. +/// +/// A method to generate a random graph with edges uniformly distributed. +/// +/// # Arguments +/// +/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomGraphGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// +/// // Initialize the Graph Generator using the one with an +/// // uniform distribution +/// let density = 1.0/3.0; +/// let seed = Some(7641630759785120); +/// let mut structure_generator = UniformGraphGenerator::new( +/// density, +/// seed +/// ); +/// +/// // Generate the graph directly on the network +/// structure_generator.generate_graph(&mut net); +/// # // Count all the edges generated in the network +/// # let mut edges = 0; +/// # for node in net.get_node_indices(){ +/// # edges += net.get_children_set(node).len() +/// # } +/// # // Number of all the nodes in the network +/// # let nodes = net.get_node_indices().len() as f64; +/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; +/// # // ±10% of tolerance +/// # let tolerance = ((expected_edges as f64)*0.10) as usize; +/// # // As the way `generate_graph()` is implemented we can only reasonably +/// # // expect the number of edges to be somewhere around the expected value. +/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); +/// ``` pub struct UniformGraphGenerator { density: f64, rng: ChaCha8Rng, @@ -140,6 +206,7 @@ impl RandomGraphGenerator for UniformGraphGenerator { UniformGraphGenerator { density, rng } } + /// Generate an uniformly distributed graph. fn generate_graph(&mut self, net: &mut T) { net.initialize_adj_matrix(); let last_node_idx = net.get_node_indices().len(); @@ -160,6 +227,76 @@ pub trait RandomParametersGenerator { fn generate_parameters(&mut self, net: &mut T); } +/// Parameters Generator using an uniform distribution. +/// +/// A method to generate random parameters uniformly distributed. +/// +/// # Arguments +/// +/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::ParamsTrait; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::tools::RandomGraphGenerator; +/// # use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomParametersGenerator; +/// use reCTBN::tools::UniformParametersGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// # +/// # // Initialize the Graph Generator using the one with an +/// # // uniform distribution +/// # let mut structure_generator = UniformGraphGenerator::new( +/// # 1.0/3.0, +/// # Some(7641630759785120) +/// # ); +/// # +/// # // Generate the graph directly on the network +/// # structure_generator.generate_graph(&mut net); +/// +/// // Initialize the parameters generator with uniform distributin +/// let mut cim_generator = UniformParametersGenerator::new( +/// 0.0..7.0, +/// Some(7641630759785120) +/// ); +/// +/// // Generate CIMs with uniformly distributed parameters. +/// cim_generator.generate_parameters(&mut net); +/// # +/// # for node in net.get_node_indices() { +/// # assert_eq!( +/// # Ok(()), +/// # net.get_node(node).validate_params() +/// # ); +/// } +/// ``` pub struct UniformParametersGenerator { interval: Range, rng: ChaCha8Rng, @@ -180,6 +317,7 @@ impl RandomParametersGenerator for UniformParametersGenerator { UniformParametersGenerator { interval, rng } } + /// Generate CIMs with uniformly distributed parameters. fn generate_parameters(&mut self, net: &mut T) { for node in net.get_node_indices() { let parent_set_state_space_cardinality: usize = net From 0639a755d0e74f7563f8d254609152b8f9480167 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 15:32:13 +0100 Subject: [PATCH 121/126] Refactored `generate_parameters` moving some code inside `match` statement --- reCTBN/src/tools.rs | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 89c19a9..0a48410 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -325,29 +325,29 @@ impl RandomParametersGenerator for UniformParametersGenerator { .iter() .map(|x| net.get_node(*x).get_reserved_space_as_parent()) .product(); - let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); - let mut cim = Array3::::from_shape_fn( - ( - parent_set_state_space_cardinality, - node_domain_cardinality, - node_domain_cardinality, - ), - |_| self.rng.gen(), - ); - cim.axis_iter_mut(Axis(0)).for_each(|mut x| { - x.diag_mut().fill(0.0); - x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); - let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { - self.rng.gen_range(self.interval.clone()) - }); - x.mul_assign(&diag.clone().insert_axis(Axis(1))); - // Recomputing the diagonal in order to reduce the issues caused by the loss of - // precision when validating the parameters. - let diag_sum = -x.sum_axis(Axis(1)); - x.diag_mut().assign(&diag_sum) - }); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { + let node_domain_cardinality = param.get_reserved_space_as_parent(); + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); + // Recomputing the diagonal in order to reduce the issues caused by the + // loss of precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) + }); param.set_cim_unchecked(cim); } } From adb0f99419fb6dd2920d385f069e61f61052a8ed Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 7 Feb 2023 14:34:58 +0100 Subject: [PATCH 122/126] Added automatic stopping for MonteCarloReward --- reCTBN/src/parameter_learning.rs | 8 +- reCTBN/src/reward.rs | 2 +- reCTBN/src/reward/reward_evaluation.rs | 97 +++++++++++++------ .../score_based_algorithm.rs | 2 +- reCTBN/tests/reward_evaluation.rs | 6 +- 5 files changed, 73 insertions(+), 42 deletions(-) diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 7e45e58..3c34d06 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -144,11 +144,9 @@ impl ParameterLearning for BayesianApproach { .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); - - CIM.outer_iter_mut() - .for_each(|mut C| { - C.diag_mut().fill(0.0); - }); + CIM.outer_iter_mut().for_each(|mut C| { + C.diag_mut().fill(0.0); + }); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs index f0edf2f..910954c 100644 --- a/reCTBN/src/reward.rs +++ b/reCTBN/src/reward.rs @@ -1,5 +1,5 @@ -pub mod reward_function; pub mod reward_evaluation; +pub mod reward_function; use std::collections::HashMap; diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index 431efde..3802489 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -1,11 +1,11 @@ use std::collections::HashMap; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use statrs::distribution::ContinuousCDF; use crate::params::{self, ParamsTrait}; use crate::process; - use crate::{ process::NetworkProcessState, reward::RewardEvaluation, @@ -14,11 +14,13 @@ use crate::{ pub enum RewardCriteria { FiniteHorizon, - InfiniteHorizon {discount_factor: f64}, + InfiniteHorizon { discount_factor: f64 }, } pub struct MonteCarloReward { - n_iterations: usize, + max_iterations: usize, + max_err_stop: f64, + alpha_stop: f64, end_time: f64, reward_criteria: RewardCriteria, seed: Option, @@ -26,13 +28,17 @@ pub struct MonteCarloReward { impl MonteCarloReward { pub fn new( - n_iterations: usize, + max_iterations: usize, + max_err_stop: f64, + alpha_stop: f64, end_time: f64, reward_criteria: RewardCriteria, seed: Option, ) -> MonteCarloReward { MonteCarloReward { - n_iterations, + max_iterations, + max_err_stop, + alpha_stop, end_time, reward_criteria, seed, @@ -58,7 +64,8 @@ impl RewardEvaluation for MonteCarloReward { let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); - (0..n_states).into_par_iter() + (0..n_states) + .into_par_iter() .map(|s| { let state: process::NetworkProcessState = variables_domain .iter() @@ -85,10 +92,13 @@ impl RewardEvaluation for MonteCarloReward { ) -> f64 { let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); - let mut ret = 0.0; + let mut expected_value = 0.0; + let mut squared_expected_value = 0.0; + let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); - for _i in 0..self.n_iterations { + for i in 0..self.max_iterations { sampler.reset(); + let mut ret = 0.0; let mut previous = sampler.next().unwrap(); while previous.t < self.end_time { let current = sampler.next().unwrap(); @@ -96,7 +106,7 @@ impl RewardEvaluation for MonteCarloReward { let r = reward_function.call(&previous.state, None); let discount = match self.reward_criteria { RewardCriteria::FiniteHorizon => self.end_time - previous.t, - RewardCriteria::InfiniteHorizon {discount_factor} => { + RewardCriteria::InfiniteHorizon { discount_factor } => { std::f64::consts::E.powf(-discount_factor * previous.t) - std::f64::consts::E.powf(-discount_factor * self.end_time) } @@ -105,8 +115,8 @@ impl RewardEvaluation for MonteCarloReward { } else { let r = reward_function.call(&previous.state, Some(¤t.state)); let discount = match self.reward_criteria { - RewardCriteria::FiniteHorizon => current.t-previous.t, - RewardCriteria::InfiniteHorizon {discount_factor} => { + RewardCriteria::FiniteHorizon => current.t - previous.t, + RewardCriteria::InfiniteHorizon { discount_factor } => { std::f64::consts::E.powf(-discount_factor * previous.t) - std::f64::consts::E.powf(-discount_factor * current.t) } @@ -114,51 +124,74 @@ impl RewardEvaluation for MonteCarloReward { ret += discount * r.instantaneous_reward; ret += match self.reward_criteria { RewardCriteria::FiniteHorizon => 1.0, - RewardCriteria::InfiniteHorizon {discount_factor} => { + RewardCriteria::InfiniteHorizon { discount_factor } => { std::f64::consts::E.powf(-discount_factor * current.t) } } * r.transition_reward; } previous = current; } + + let float_i = i as f64; + expected_value = + expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); + squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) + + ret.powi(2) / (float_i + 1.0); + + if i > 2 { + let var = + (float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); + if self.alpha_stop + - 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) + > 0.0 + { + return expected_value; + } + } } - ret / self.n_iterations as f64 + expected_value } } pub struct NeighborhoodRelativeReward { - inner_reward: RE + inner_reward: RE, } -impl NeighborhoodRelativeReward{ - pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward{ - NeighborhoodRelativeReward {inner_reward} +impl NeighborhoodRelativeReward { + pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward { + NeighborhoodRelativeReward { inner_reward } } } -impl RewardEvaluation for NeighborhoodRelativeReward { +impl RewardEvaluation for NeighborhoodRelativeReward { fn evaluate_state_space( &self, network_process: &N, reward_function: &R, ) -> HashMap { + let absolute_reward = self + .inner_reward + .evaluate_state_space(network_process, reward_function); - let absolute_reward = self.inner_reward.evaluate_state_space(network_process, reward_function); - //This approach optimize memory. Maybe optimizing execution time can be better. - absolute_reward.iter().map(|(k1, v1)| { - let mut max_val:f64 = 1.0; - absolute_reward.iter().for_each(|(k2,v2)| { - let count_diff:usize = k1.iter().zip(k2.iter()).map(|(s1, s2)| if s1 == s2 {0} else {1}).sum(); - if count_diff < 2 { - max_val = max_val.max(v1/v2); - } - - }); - (k1.clone(), max_val) - }).collect() - + absolute_reward + .iter() + .map(|(k1, v1)| { + let mut max_val: f64 = 1.0; + absolute_reward.iter().for_each(|(k2, v2)| { + let count_diff: usize = k1 + .iter() + .zip(k2.iter()) + .map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) + .sum(); + if count_diff < 2 { + max_val = max_val.max(v1 / v2); + } + }); + (k1.clone(), max_val) + }) + .collect() } fn evaluate_state( diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 6850027..9173b86 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -82,7 +82,7 @@ impl StructureLearningAlgorithm for HillClimbing { } (node, parent_set) })); - + for (child_node, candidate_parent_set) in learned_parent_sets { for parent_node in candidate_parent_set.iter() { net.add_edge(*parent_node, child_node); diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index 63e9c98..c0372b1 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -33,7 +33,7 @@ fn simple_factored_reward_function_binary_node_MC() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - let mc = MonteCarloReward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -42,7 +42,7 @@ fn simple_factored_reward_function_binary_node_MC() { assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); - let mc = MonteCarloReward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215)); + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -113,7 +113,7 @@ fn simple_factored_reward_function_chain_MC() { params::StateType::Discrete(0), ]; - let mc = MonteCarloReward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); let rst = mc.evaluate_state_space(&net, &rf); From 776b9aa030fb25ecce5216ad64abad6f8762c9ba Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 7 Feb 2023 16:34:57 +0100 Subject: [PATCH 123/126] clippy error solved --- reCTBN/tests/reward_evaluation.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index c0372b1..355341c 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -1,6 +1,6 @@ mod utils; -use approx::{abs_diff_eq, assert_abs_diff_eq}; +use approx::assert_abs_diff_eq; use ndarray::*; use reCTBN::{ params, @@ -10,7 +10,7 @@ use reCTBN::{ use utils::generate_discrete_time_continous_node; #[test] -fn simple_factored_reward_function_binary_node_MC() { +fn simple_factored_reward_function_binary_node_mc() { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) @@ -50,7 +50,7 @@ fn simple_factored_reward_function_binary_node_MC() { } #[test] -fn simple_factored_reward_function_chain_MC() { +fn simple_factored_reward_function_chain_mc() { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) From c4da4ceadd4836a1947d0d100b252000018fc0d1 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 16 Feb 2023 14:56:20 +0100 Subject: [PATCH 124/126] Added `get_adj_matrix()` --- reCTBN/src/process.rs | 2 ++ reCTBN/src/process/ctbn.rs | 5 +++++ reCTBN/src/process/ctmp.rs | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index 45c5e0a..f554af4 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -5,6 +5,7 @@ pub mod ctmp; use std::collections::BTreeSet; +use ndarray::Array2; use thiserror::Error; use crate::params; @@ -117,4 +118,5 @@ pub trait NetworkProcess: Sync { /// /// * The **children set** of the selected node. fn get_children_set(&self, node: usize) -> BTreeSet; + fn get_adj_matrix(&self) -> Option>; } diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 162345e..9784776 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -240,4 +240,9 @@ impl process::NetworkProcess for CtbnNetwork { .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) .collect() } + + /// Get the Adjacency Matrix. + fn get_adj_matrix(&self) -> Option> { + self.adj_matrix.clone() + } } diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 41b8db6..4a346e7 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -1,5 +1,7 @@ use std::collections::BTreeSet; +use ndarray::Array2; + use crate::{ params::{Params, StateType}, process, @@ -111,4 +113,7 @@ impl NetworkProcess for CtmpProcess { None => panic!("Uninitialized CtmpProcess"), } } + fn get_adj_matrix(&self) -> Option> { + unimplemented!("CtmpProcess has only one node") + } } From 4fd0ee040734c3b78227f3bbc8002842083be829 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 16 Feb 2023 15:12:32 +0100 Subject: [PATCH 125/126] Function `get_adj_matrix()` is now specific to `CtbnNetwork` only --- reCTBN/src/process.rs | 2 -- reCTBN/src/process/ctbn.rs | 9 ++++----- reCTBN/src/process/ctmp.rs | 5 ----- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index f554af4..45c5e0a 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -5,7 +5,6 @@ pub mod ctmp; use std::collections::BTreeSet; -use ndarray::Array2; use thiserror::Error; use crate::params; @@ -118,5 +117,4 @@ pub trait NetworkProcess: Sync { /// /// * The **children set** of the selected node. fn get_children_set(&self, node: usize) -> BTreeSet; - fn get_adj_matrix(&self) -> Option>; } diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 9784776..6956ea0 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -138,6 +138,10 @@ impl CtbnNetwork { return array_state; } + /// Get the Adjacency Matrix. + pub fn get_adj_matrix(&self) -> Option> { + self.adj_matrix.clone() + } } impl process::NetworkProcess for CtbnNetwork { @@ -240,9 +244,4 @@ impl process::NetworkProcess for CtbnNetwork { .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) .collect() } - - /// Get the Adjacency Matrix. - fn get_adj_matrix(&self) -> Option> { - self.adj_matrix.clone() - } } diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 4a346e7..41b8db6 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -1,7 +1,5 @@ use std::collections::BTreeSet; -use ndarray::Array2; - use crate::{ params::{Params, StateType}, process, @@ -113,7 +111,4 @@ impl NetworkProcess for CtmpProcess { None => panic!("Uninitialized CtmpProcess"), } } - fn get_adj_matrix(&self) -> Option> { - unimplemented!("CtmpProcess has only one node") - } } From 7a3ac6c9abd364dfc62d2a45717c95c32f78476f Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 16 Feb 2023 16:36:07 +0100 Subject: [PATCH 126/126] Using `as_ref()` instead of `clone()` in `get_adj_matrix()` --- reCTBN/src/process/ctbn.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 6956ea0..d93400d 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -139,8 +139,8 @@ impl CtbnNetwork { return array_state; } /// Get the Adjacency Matrix. - pub fn get_adj_matrix(&self) -> Option> { - self.adj_matrix.clone() + pub fn get_adj_matrix(&self) -> Option<&Array2> { + self.adj_matrix.as_ref() } }