From ed5471c7cf6d4a28e7486c3e8be0dd9e63cb79b5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 14 Nov 2022 16:07:04 +0100 Subject: [PATCH 01/11] Added ctmp --- reCTBN/src/lib.rs | 3 +- reCTBN/src/parameter_learning.rs | 12 +- reCTBN/src/{network.rs => process.rs} | 5 +- reCTBN/src/{ => process}/ctbn.rs | 10 +- reCTBN/src/process/ctmp.rs | 106 +++++++++++++++ reCTBN/src/sampling.rs | 10 +- reCTBN/src/structure_learning.rs | 4 +- .../src/structure_learning/hypothesis_test.rs | 6 +- .../score_based_algorithm.rs | 4 +- .../src/structure_learning/score_function.rs | 10 +- reCTBN/src/tools.rs | 4 +- reCTBN/tests/ctbn.rs | 4 +- reCTBN/tests/ctmp.rs | 127 ++++++++++++++++++ reCTBN/tests/parameter_learning.rs | 4 +- reCTBN/tests/structure_learning.rs | 4 +- reCTBN/tests/tools.rs | 4 +- 16 files changed, 276 insertions(+), 41 deletions(-) rename reCTBN/src/{network.rs => process.rs} (98%) rename reCTBN/src/{ => process}/ctbn.rs (95%) create mode 100644 reCTBN/src/process/ctmp.rs create mode 100644 reCTBN/tests/ctmp.rs diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index db33ae4..6ab59cb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -3,10 +3,9 @@ #[cfg(test)] extern crate approx; -pub mod ctbn; -pub mod network; pub mod parameter_learning; pub mod params; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod process; diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 61d4dca..2aa518c 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,10 +5,10 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{network, tools}; +use crate::{process, tools}; pub trait ParameterLearning { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -17,7 +17,7 @@ pub trait ParameterLearning { ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, dataset: &tools::Dataset, node: usize, @@ -73,7 +73,7 @@ pub fn sufficient_statistics( pub struct MLE {} impl ParameterLearning for MLE { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -120,7 +120,7 @@ pub struct BayesianApproach { } impl ParameterLearning for BayesianApproach { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -177,7 +177,7 @@ impl Cache

{ dataset, } } - pub fn fit( + pub fn fit( &mut self, net: &T, node: usize, diff --git a/reCTBN/src/network.rs b/reCTBN/src/process.rs similarity index 98% rename from reCTBN/src/network.rs rename to reCTBN/src/process.rs index fbdd2e6..2b70b59 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/process.rs @@ -1,5 +1,8 @@ //! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs +pub mod ctbn; +pub mod ctmp; + use std::collections::BTreeSet; use thiserror::Error; @@ -15,7 +18,7 @@ pub enum NetworkError { /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait Network { +pub trait NetworkProcess { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/process/ctbn.rs similarity index 95% rename from reCTBN/src/ctbn.rs rename to reCTBN/src/process/ctbn.rs index 2b01d14..c59d99d 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; -use crate::network; +use crate::process; use crate::params::{Params, ParamsTrait, StateType}; /// It represents both the structure and the parameters of a CTBN. @@ -20,9 +20,9 @@ use crate::params::{Params, ParamsTrait, StateType}; /// /// ```rust /// use std::collections::BTreeSet; -/// use reCTBN::network::Network; +/// use reCTBN::process::NetworkProcess; /// use reCTBN::params; -/// use reCTBN::ctbn::*; +/// use reCTBN::process::ctbn::*; /// /// //Create the domain for a discrete node /// let mut domain = BTreeSet::new(); @@ -69,7 +69,7 @@ impl CtbnNetwork { } } -impl network::Network for CtbnNetwork { +impl process::NetworkProcess for CtbnNetwork { /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( @@ -78,7 +78,7 @@ impl network::Network for CtbnNetwork { } /// Add a new node. - fn add_node(&mut self, mut n: Params) -> Result { + fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs new file mode 100644 index 0000000..b0b042a --- /dev/null +++ b/reCTBN/src/process/ctmp.rs @@ -0,0 +1,106 @@ +use std::collections::BTreeSet; + +use crate::{process, params::{Params, StateType}}; + +use super::NetworkProcess; + +pub struct CtmpProcess { + param: Option +} + +impl CtmpProcess { + pub fn new() -> CtmpProcess { + CtmpProcess { param: None } + } +} + +impl NetworkProcess for CtmpProcess { + fn initialize_adj_matrix(&mut self) { + unimplemented!("CtmpProcess has only one node") + } + + fn add_node(&mut self, n: crate::params::Params) -> Result { + match self.param { + None => { + self.param = Some(n); + Ok(0) + }, + Some(_) => Err(process::NetworkError::NodeInsertionError("CtmpProcess has only one node".to_string())) + } + } + + fn add_edge(&mut self, parent: usize, child: usize) { + unimplemented!("CtmpProcess has only one node") + } + + fn get_node_indices(&self) -> std::ops::Range { + match self.param { + None => 0..0, + Some(_) => 0..1 + } + } + + fn get_number_of_nodes(&self) -> usize { + match self.param { + None => 0, + Some(_) => 1 + } + } + + fn get_node(&self, node_idx: usize) -> &crate::params::Params { + if node_idx == 0 { + self.param.as_ref().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { + if node_idx == 0 { + self.param.as_mut().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_network(&self, node: usize, current_state: &Vec) + -> usize { + if node == 0 { + match current_state[0] { + StateType::Discrete(x) => x + } + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &std::collections::BTreeSet, + ) -> usize { + unimplemented!("CtmpProcess has only one node") + } + + fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } + + fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } +} diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d435634..050daeb 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,7 +1,7 @@ //! Module containing methods for the sampling. use crate::{ - network::Network, + process::NetworkProcess, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -13,7 +13,7 @@ pub trait Sampler: Iterator { pub struct ForwardSampler<'a, T> where - T: Network, + T: NetworkProcess, { net: &'a T, rng: ChaCha8Rng, @@ -22,7 +22,7 @@ where next_transitions: Vec>, } -impl<'a, T: Network> ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. @@ -42,7 +42,7 @@ impl<'a, T: Network> ForwardSampler<'a, T> { } } -impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { type Item = (f64, Vec); fn next(&mut self) -> Option { @@ -100,7 +100,7 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { } } -impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; self.current_state = self diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index 57fed1e..b272e22 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{network, tools}; +use crate::{process, tools}; pub trait StructureLearningAlgorithm { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network; + T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 6474155..4ec3377 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::params::*; -use crate::{network, parameter_learning}; +use crate::{process, parameter_learning}; pub trait HypothesisTest { fn call( @@ -18,7 +18,7 @@ pub trait HypothesisTest { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning; } @@ -135,7 +135,7 @@ impl HypothesisTest for ChiSquare { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 9e329eb..16e9056 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{network, tools}; +use crate::{process, tools}; pub struct HillClimbing { score_function: S, @@ -23,7 +23,7 @@ impl HillClimbing { impl StructureLearningAlgorithm for HillClimbing { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network, + T: process::NetworkProcess, { //Check the coherence between dataset and network if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index cb6ad7b..8943478 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{network, parameter_learning, params, tools}; +use crate::{process, parameter_learning, params, tools}; pub trait ScoreFunction { fn call( @@ -16,7 +16,7 @@ pub trait ScoreFunction { dataset: &tools::Dataset, ) -> f64 where - T: network::Network; + T: process::NetworkProcess; } pub struct LogLikelihood { @@ -41,7 +41,7 @@ impl LogLikelihood { dataset: &tools::Dataset, ) -> (f64, Array3) where - T: network::Network, + T: process::NetworkProcess, { //Identify the type of node used match &net.get_node(node) { @@ -100,7 +100,7 @@ impl ScoreFunction for LogLikelihood { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { self.compute_score(net, node, parent_set, dataset).0 } @@ -127,7 +127,7 @@ impl ScoreFunction for BIC { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index aa48883..6f2f648 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{network, params}; +use crate::{process, params}; pub struct Trajectory { time: Array1, @@ -51,7 +51,7 @@ impl Dataset { } } -pub fn trajectory_generator( +pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 63c9621..0ad0fc4 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,8 +1,8 @@ mod utils; use std::collections::BTreeSet; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs new file mode 100644 index 0000000..31bc6df --- /dev/null +++ b/reCTBN/tests/ctmp.rs @@ -0,0 +1,127 @@ +mod utils; + +use std::collections::BTreeSet; + +use reCTBN::{ + params, + params::ParamsTrait, + process::{ctmp::*, NetworkProcess}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn define_simple_ctmp() { + let _ = CtmpProcess::new(); + assert!(true); +} + +#[test] +fn add_node_to_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); +} + +#[test] +fn add_two_nodes_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + match n2 { + Ok(_) => assert!(false), + Err(_) => assert!(true), + }; +} + +#[test] +#[should_panic] +fn add_edge_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + net.add_edge(0, 1) +} + +#[test] +fn childen_and_parents() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + assert_eq!(0, net.get_parent_set(0).len()); + assert_eq!(0, net.get_children_set(0).len()); +} + +#[test] +#[should_panic] +fn get_childen_panic() { + let mut net = CtmpProcess::new(); + net.get_children_set(0); +} + +#[test] +#[should_panic] +fn get_childen_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_children_set(1); +} + +#[test] +#[should_panic] +fn get_parent_panic() { + let mut net = CtmpProcess::new(); + net.get_parent_set(0); +} + +#[test] +#[should_panic] +fn get_parent_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_parent_set(1); +} + +#[test] +fn compute_index_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); + assert_eq!(6, idx); +} + +#[test] +#[should_panic] +fn compute_index_from_custom_parent_set_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let _idx = net.get_param_index_from_custom_parent_set( + &vec![params::StateType::Discrete(6)], + &BTreeSet::from([0]) + ); +} diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 7d09b07..2cbc185 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -2,8 +2,8 @@ mod utils; use ndarray::arr3; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; use reCTBN::tools::*; diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a1667c2..2ec64b2 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -4,8 +4,8 @@ mod utils; use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 589b04e..806faef 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,6 +1,6 @@ use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*; From 7c3cba50d4afb08c1087711ef8fba12a2351ad54 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 11:14:41 +0100 Subject: [PATCH 02/11] Implemented amalgamation --- reCTBN/src/process/ctbn.rs | 78 +++++++++++++++++++++++++++++++++++++- reCTBN/tests/ctbn.rs | 36 +++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c59d99d..3852c50 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,8 +4,11 @@ use std::collections::BTreeSet; use ndarray::prelude::*; +use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; use crate::process; -use crate::params::{Params, ParamsTrait, StateType}; + +use super::ctmp::CtmpProcess; +use super::NetworkProcess; /// It represents both the structure and the parameters of a CTBN. /// @@ -67,6 +70,79 @@ impl CtbnNetwork { nodes: Vec::new(), } } + + pub fn amalgamation(&self) -> CtmpProcess { + for v in self.nodes.iter() { + match v { + Params::DiscreteStatesContinousTime(_) => {} + _ => panic!("Unsupported node"), + } + } + + let variables_domain = + Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); + + let state_space = variables_domain.product(); + let variables_set = BTreeSet::from_iter(self.get_node_indices()); + let mut amalgamated_cim: Array3 = Array::zeros((1, state_space, state_space)); + + for idx_current_state in 0..state_space { + let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); + let current_state_statetype: Vec = current_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + for idx_node in 0..self.nodes.len() { + let p = match self.get_node(idx_node) { + Params::DiscreteStatesContinousTime(p) => p, + }; + for next_node_state in 0..variables_domain[idx_node] { + let mut next_state = current_state.clone(); + next_state[idx_node] = next_node_state; + + let next_state_statetype: Vec = next_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + let idx_next_state = self.get_param_index_from_custom_parent_set( + &next_state_statetype, + &variables_set, + ); + amalgamated_cim[[0, idx_current_state, idx_next_state]] += + p.get_cim().as_ref().unwrap()[[ + self.get_param_index_network(idx_node, ¤t_state_statetype), + current_state[idx_node], + next_node_state, + ]]; + } + } + } + + let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( + "ctmp".to_string(), + BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), + ); + + println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); + + amalgamated_param.set_cim(amalgamated_cim).unwrap(); + + let mut ctmp = CtmpProcess::new(); + + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap(); + return ctmp; + } + + pub fn idx_to_state(variables_domain: &Array1, state: usize) -> Array1 { + let mut state = state; + let mut array_state = Array1::zeros(variables_domain.shape()[0]); + for (idx, var) in variables_domain.indexed_iter() { + array_state[idx] = state % var; + state = state / var; + } + + return array_state; + } } impl process::NetworkProcess for CtbnNetwork { diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 0ad0fc4..fc17a94 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,7 +1,10 @@ mod utils; use std::collections::BTreeSet; +use std::f64::EPSILON; -use reCTBN::process::ctbn::*; +use approx::AbsDiffEq; +use ndarray::arr3; +use reCTBN::process::{ctbn::*, ctmp::*}; use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; @@ -129,3 +132,34 @@ fn compute_index_from_custom_parent_set() { ); assert_eq!(2, idx); } + +#[test] +fn simple_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + net.initialize_adj_matrix(); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); + } + } + + let ctmp = net.amalgamation(); + let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){ + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + + assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); +} From 4a7a8c5fbab5c0addcd2785a2f585c6c32eb4637 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 14:06:30 +0100 Subject: [PATCH 03/11] Added more tests --- reCTBN/src/params.rs | 4 +- reCTBN/src/process/ctbn.rs | 10 +- reCTBN/tests/ctbn.rs | 228 ++++++++++++++++++++++++++++++++++++- 3 files changed, 234 insertions(+), 8 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 070c997..9f63860 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,11 +267,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } + let domain_size = domain_size as f64; + // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) + .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 3852c50..a6be923 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -70,7 +70,12 @@ impl CtbnNetwork { nodes: Vec::new(), } } - + + ///Transform the **CTBN** into a **CTMP** + /// + /// # Return + /// + /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { for v in self.nodes.iter() { match v { @@ -123,8 +128,7 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); - + println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index fc17a94..a7752f2 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -4,9 +4,9 @@ use std::f64::EPSILON; use approx::AbsDiffEq; use ndarray::arr3; -use reCTBN::process::{ctbn::*, ctmp::*}; -use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; +use reCTBN::process::NetworkProcess; +use reCTBN::process::{ctbn::*, ctmp::*}; use utils::generate_discrete_time_continous_node; #[test] @@ -149,7 +149,7 @@ fn simple_amalgamation() { } let ctmp = net.amalgamation(); - let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){ + let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0) { p.get_cim().as_ref().unwrap() } else { unreachable!(); @@ -160,6 +160,226 @@ fn simple_amalgamation() { unreachable!(); }; - assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); } + +#[test] +fn chain_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + let ctmp = net.amalgamation(); + + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + let p_ctmp_handmade = arr3(&[[ + [ + -1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} + +#[test] +fn chainfork_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + let n4 = net + .add_node(generate_discrete_time_continous_node(String::from("n4"), 2)) + .unwrap(); + + net.add_edge(n1, n3); + net.add_edge(n2, n3); + net.add_edge(n3, n4); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n4) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + + let ctmp = net.amalgamation(); + + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + let p_ctmp_handmade = arr3(&[[ + [ + -2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} From 28ed1a40b32bb1a6629e5a67b379ed0b56c89861 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 14:52:39 +0100 Subject: [PATCH 04/11] Fix for clippy --- reCTBN/src/lib.rs | 2 +- reCTBN/src/process/ctbn.rs | 13 ++-- reCTBN/src/process/ctmp.rs | 60 +++++++++++-------- reCTBN/src/sampling.rs | 2 +- .../src/structure_learning/hypothesis_test.rs | 2 +- .../src/structure_learning/score_function.rs | 2 +- reCTBN/src/tools.rs | 2 +- reCTBN/tests/ctbn.rs | 4 +- reCTBN/tests/ctmp.rs | 6 +- 9 files changed, 52 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 6ab59cb..c62c42e 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -5,7 +5,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; +pub mod process; pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod process; diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index a6be923..7cb327d 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -70,7 +70,7 @@ impl CtbnNetwork { nodes: Vec::new(), } } - + ///Transform the **CTBN** into a **CTMP** /// /// # Return @@ -105,10 +105,8 @@ impl CtbnNetwork { let mut next_state = current_state.clone(); next_state[idx_node] = next_node_state; - let next_state_statetype: Vec = next_state - .iter() - .map(|x| StateType::Discrete(*x)) - .collect(); + let next_state_statetype: Vec = + next_state.iter().map(|x| StateType::Discrete(*x)).collect(); let idx_next_state = self.get_param_index_from_custom_parent_set( &next_state_statetype, &variables_set, @@ -127,13 +125,14 @@ impl CtbnNetwork { "ctmp".to_string(), BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - + println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); - ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap(); + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) + .unwrap(); return ctmp; } diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index b0b042a..81509fa 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -1,11 +1,14 @@ use std::collections::BTreeSet; -use crate::{process, params::{Params, StateType}}; +use crate::{ + params::{Params, StateType}, + process, +}; use super::NetworkProcess; pub struct CtmpProcess { - param: Option + param: Option, } impl CtmpProcess { @@ -24,26 +27,28 @@ impl NetworkProcess for CtmpProcess { None => { self.param = Some(n); Ok(0) - }, - Some(_) => Err(process::NetworkError::NodeInsertionError("CtmpProcess has only one node".to_string())) + } + Some(_) => Err(process::NetworkError::NodeInsertionError( + "CtmpProcess has only one node".to_string(), + )), } } - fn add_edge(&mut self, parent: usize, child: usize) { + fn add_edge(&mut self, _parent: usize, _child: usize) { unimplemented!("CtmpProcess has only one node") } fn get_node_indices(&self) -> std::ops::Range { match self.param { None => 0..0, - Some(_) => 0..1 + Some(_) => 0..1, } } fn get_number_of_nodes(&self) -> usize { match self.param { None => 0, - Some(_) => 1 + Some(_) => 1, } } @@ -63,11 +68,14 @@ impl NetworkProcess for CtmpProcess { } } - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize { + fn get_param_index_network( + &self, + node: usize, + current_state: &Vec, + ) -> usize { if node == 0 { match current_state[0] { - StateType::Discrete(x) => x + StateType::Discrete(x) => x, } } else { unimplemented!("CtmpProcess has only one node") @@ -76,31 +84,35 @@ impl NetworkProcess for CtmpProcess { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, - parent_set: &std::collections::BTreeSet, + _current_state: &Vec, + _parent_set: &std::collections::BTreeSet, ) -> usize { unimplemented!("CtmpProcess has only one node") } fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { match self.param { - Some(_) => if node == 0 { - BTreeSet::new() - } else { - unimplemented!("CtmpProcess has only one node") - }, - None => panic!("Uninitialized CtmpProcess") + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), } } fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { match self.param { - Some(_) => if node == 0 { - BTreeSet::new() - } else { - unimplemented!("CtmpProcess has only one node") - }, - None => panic!("Uninitialized CtmpProcess") + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), } } } diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 050daeb..0662994 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,8 +1,8 @@ //! Module containing methods for the sampling. use crate::{ - process::NetworkProcess, params::{self, ParamsTrait}, + process::NetworkProcess, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4ec3377..344c995 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::params::*; -use crate::{process, parameter_learning}; +use crate::{parameter_learning, process}; pub trait HypothesisTest { fn call( diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index 8943478..f8b38b5 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{process, parameter_learning, params, tools}; +use crate::{parameter_learning, params, process, tools}; pub trait ScoreFunction { fn call( diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 6f2f648..2e727e8 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{process, params}; +use crate::{params, process}; pub struct Trajectory { time: Array1, diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index a7752f2..7db2bae 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,12 +1,12 @@ mod utils; use std::collections::BTreeSet; -use std::f64::EPSILON; + use approx::AbsDiffEq; use ndarray::arr3; use reCTBN::params::{self, ParamsTrait}; use reCTBN::process::NetworkProcess; -use reCTBN::process::{ctbn::*, ctmp::*}; +use reCTBN::process::{ctbn::*}; use utils::generate_discrete_time_continous_node; #[test] diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs index 31bc6df..830bfe0 100644 --- a/reCTBN/tests/ctmp.rs +++ b/reCTBN/tests/ctmp.rs @@ -45,7 +45,7 @@ fn add_edge_to_ctmp() { let _n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); net.add_edge(0, 1) } @@ -64,7 +64,7 @@ fn childen_and_parents() { #[test] #[should_panic] fn get_childen_panic() { - let mut net = CtmpProcess::new(); + let net = CtmpProcess::new(); net.get_children_set(0); } @@ -81,7 +81,7 @@ fn get_childen_panic2() { #[test] #[should_panic] fn get_parent_panic() { - let mut net = CtmpProcess::new(); + let net = CtmpProcess::new(); net.get_parent_set(0); } From 44eaf8713fb9ce8f7bb05cc63af4bc625438e983 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 15:16:33 +0100 Subject: [PATCH 05/11] Fix for clippy --- reCTBN/src/process/ctbn.rs | 6 ------ reCTBN/tests/ctbn.rs | 31 +++++++++++-------------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 7cb327d..7473d4c 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -77,12 +77,6 @@ impl CtbnNetwork { /// /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { - for v in self.nodes.iter() { - match v { - Params::DiscreteStatesContinousTime(_) => {} - _ => panic!("Unsupported node"), - } - } let variables_domain = Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 7db2bae..3eb40d7 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -149,16 +149,10 @@ fn simple_amalgamation() { } let ctmp = net.amalgamation(); - let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0); + let p_ctbn = p_ctbn.get_cim().as_ref().unwrap(); + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); } @@ -211,11 +205,10 @@ fn chain_amalgamation() { let ctmp = net.amalgamation(); - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + + + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); let p_ctmp_handmade = arr3(&[[ [ @@ -308,11 +301,9 @@ fn chainfork_amalgamation() { let ctmp = net.amalgamation(); - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); let p_ctmp_handmade = arr3(&[[ [ From 38e744e034e52e277bab2c3c7052f9e796862d81 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 15:25:27 +0100 Subject: [PATCH 06/11] Fix fmt --- reCTBN/src/process/ctbn.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 7473d4c..c949afe 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -77,7 +77,6 @@ impl CtbnNetwork { /// /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { - let variables_domain = Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); From 1878f687d6198a16b56f618ea6e3945ef1703ee5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 21 Nov 2022 16:34:39 +0100 Subject: [PATCH 07/11] Refactor of sampling --- reCTBN/src/sampling.rs | 13 ++++++++++--- reCTBN/src/tools.rs | 12 ++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 0662994..3bc0c6f 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,10 +7,17 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -pub trait Sampler: Iterator { +pub struct Sample { + pub t: f64, + pub state: Vec +} + +pub trait Sampler: Iterator { fn reset(&mut self); } + + pub struct ForwardSampler<'a, T> where T: NetworkProcess, @@ -43,7 +50,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { } impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { - type Item = (f64, Vec); + type Item = Sample; fn next(&mut self) -> Option { let ret_time = self.current_time.clone(); @@ -96,7 +103,7 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) + Some(Sample{t: ret_time, state: ret_state}) } } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 2e727e8..e749d69 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -72,15 +72,15 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); //Current Time and Current State - let (mut t, mut current_state) = sampler.next().unwrap(); + let mut sample = sampler.next().unwrap(); //Generate new samples until ending time is reached. - while t < t_end { - time.push(t); - events.push(current_state); - (t, current_state) = sampler.next().unwrap(); + while sample.t < t_end { + time.push(sample.t); + events.push(sample.state); + sample = sampler.next().unwrap(); } - current_state = events.last().unwrap().clone(); + let current_state = events.last().unwrap().clone(); events.push(current_state); //Add t_end as last time. From 055eb7088e8bf5d139312e55806101a2738cac73 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 21 Nov 2022 17:34:32 +0100 Subject: [PATCH 08/11] Implemented FactoredRewardFunction --- reCTBN/src/lib.rs | 1 + reCTBN/src/process/ctbn.rs | 1 - reCTBN/src/reward_function.rs | 80 +++++++++++++++++++++++++++++++++ reCTBN/src/sampling.rs | 1 + reCTBN/tests/reward_function.rs | 30 +++++++++++++ 5 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 reCTBN/src/reward_function.rs create mode 100644 reCTBN/tests/reward_function.rs diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index c62c42e..1d25552 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -9,3 +9,4 @@ pub mod process; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod reward_function; diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c949afe..0b6161c 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -119,7 +119,6 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs new file mode 100644 index 0000000..9ff09cc --- /dev/null +++ b/reCTBN/src/reward_function.rs @@ -0,0 +1,80 @@ +use crate::{process, sampling, params::{ParamsTrait, self}}; +use ndarray; + + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64 +} + +pub trait RewardFunction { + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward; + fn initialize_from_network_process(p: &T) -> Self; +} + + +pub struct FactoredRewardFunction { + transition_reward: Vec>, + instantaneous_reward: Vec> +} + +impl FactoredRewardFunction { + pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2 { + &self.transition_reward[node_idx] + } + + pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2 { + &mut self.transition_reward[node_idx] + } + + pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1 { + &self.instantaneous_reward[node_idx] + } + + pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { + &mut self.instantaneous_reward[node_idx] + } + + +} + +impl RewardFunction for FactoredRewardFunction { + + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward { + let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { + let x = match x {params::StateType::Discrete(x) => x}; + self.instantaneous_reward[idx][*x] + }).sum(); + if let Some(previous_state) = previous_state { + let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option { + let p = match p {params::StateType::Discrete(p) => p}; + let c = match c {params::StateType::Discrete(c) => c}; + if p != c { + Some(self.transition_reward[idx][[*p,*c]]) + } else { + None + } + }).unwrap_or(0.0); + Reward {transition_reward, instantaneous_reward} + } else { + Reward { transition_reward: 0.0, instantaneous_reward} + } + } + + fn initialize_from_network_process(p: &T) -> Self { + let mut transition_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; + for i in p.get_node_indices() { + //This works only for discrete nodes! + let size: usize = p.get_node(i).get_reserved_space_as_parent(); + instantaneous_reward.push(ndarray::Array1::zeros(size)); + transition_reward.push(ndarray::Array2::zeros((size, size))); + } + + FactoredRewardFunction { transition_reward, instantaneous_reward } + + } + +} + diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 3bc0c6f..d5a1dbe 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,6 +7,7 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; +#[derive(Clone)] pub struct Sample { pub t: f64, pub state: Vec diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs new file mode 100644 index 0000000..7f73e6c --- /dev/null +++ b/reCTBN/tests/reward_function.rs @@ -0,0 +1,30 @@ +mod utils; + +use ndarray::*; +use utils::generate_discrete_time_continous_node; +use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; + + +#[test] +fn simple_factored_reward_function() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); + + let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; + let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + + assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); +} From f6015acce99e41582d3902dc7342556e3fe4a115 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 08:53:29 +0100 Subject: [PATCH 09/11] Added tests --- reCTBN/tests/reward_function.rs | 90 ++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 7f73e6c..0c7fd9b 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -6,7 +6,7 @@ use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; #[test] -fn simple_factored_reward_function() { +fn simple_factored_reward_function_binary_node() { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) @@ -28,3 +28,91 @@ fn simple_factored_reward_function() { assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); } + + +#[test] +fn simple_factored_reward_function_ternary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; + let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + let s2 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2)]}; + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + + + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); +} + +#[test] +fn factored_reward_function_two_nodes() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + + rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); + + let s00 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]}; + let s01 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]}; + let s02 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]}; + + + let s10 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]}; + let s11 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]}; + let s12 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]}; + + assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + + + assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + + + assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + + + assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); +} From 68ef7ea7c3ad4f33849f1cdf84349939e2e4a6b7 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 09:30:59 +0100 Subject: [PATCH 10/11] Added comments --- reCTBN/src/lib.rs | 2 +- reCTBN/src/reward_function.rs | 120 ++++++++++++++++++++++++++-------- reCTBN/src/sampling.rs | 9 +-- 3 files changed, 98 insertions(+), 33 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 1d25552..8feddfb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -6,7 +6,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; pub mod process; +pub mod reward_function; pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod reward_function; diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs index 9ff09cc..eeddd85 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward_function.rs @@ -1,22 +1,62 @@ -use crate::{process, sampling, params::{ParamsTrait, self}}; +//! Module for dealing with reward functions + +use crate::{ + params::{self, ParamsTrait}, + process, sampling, +}; use ndarray; +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state #[derive(Debug, PartialEq)] pub struct Reward { pub transition_reward: f64, - pub instantaneous_reward: f64 + pub instantaneous_reward: f64, } +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + pub trait RewardFunction { - fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward; + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `sampling::Sample` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: sampling::Sample, + previous_state: Option, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` fn initialize_from_network_process(p: &T) -> Self; } +/// Reward function over a factored state space +/// +/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node +/// of the underling `NetworkProcess` +/// +/// # Arguments +/// +/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition +/// reward of a node pub struct FactoredRewardFunction { transition_reward: Vec>, - instantaneous_reward: Vec> + instantaneous_reward: Vec>, } impl FactoredRewardFunction { @@ -35,36 +75,60 @@ impl FactoredRewardFunction { pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { &mut self.instantaneous_reward[node_idx] } - - } impl RewardFunction for FactoredRewardFunction { - - fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward { - let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { - let x = match x {params::StateType::Discrete(x) => x}; - self.instantaneous_reward[idx][*x] - }).sum(); + fn call( + &self, + current_state: sampling::Sample, + previous_state: Option, + ) -> Reward { + let instantaneous_reward: f64 = current_state + .state + .iter() + .enumerate() + .map(|(idx, x)| { + let x = match x { + params::StateType::Discrete(x) => x, + }; + self.instantaneous_reward[idx][*x] + }) + .sum(); if let Some(previous_state) = previous_state { - let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option { - let p = match p {params::StateType::Discrete(p) => p}; - let c = match c {params::StateType::Discrete(c) => c}; - if p != c { - Some(self.transition_reward[idx][[*p,*c]]) - } else { - None - } - }).unwrap_or(0.0); - Reward {transition_reward, instantaneous_reward} + let transition_reward = previous_state + .state + .iter() + .zip(current_state.state.iter()) + .enumerate() + .find_map(|(idx, (p, c))| -> Option { + let p = match p { + params::StateType::Discrete(p) => p, + }; + let c = match c { + params::StateType::Discrete(c) => c, + }; + if p != c { + Some(self.transition_reward[idx][[*p, *c]]) + } else { + None + } + }) + .unwrap_or(0.0); + Reward { + transition_reward, + instantaneous_reward, + } } else { - Reward { transition_reward: 0.0, instantaneous_reward} + Reward { + transition_reward: 0.0, + instantaneous_reward, + } } } fn initialize_from_network_process(p: &T) -> Self { let mut transition_reward: Vec> = vec![]; - let mut instantaneous_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; for i in p.get_node_indices() { //This works only for discrete nodes! let size: usize = p.get_node(i).get_reserved_space_as_parent(); @@ -72,9 +136,9 @@ impl RewardFunction for FactoredRewardFunction { transition_reward.push(ndarray::Array2::zeros((size, size))); } - FactoredRewardFunction { transition_reward, instantaneous_reward } - + FactoredRewardFunction { + transition_reward, + instantaneous_reward, + } } - } - diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d5a1dbe..a0a9fcb 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -10,15 +10,13 @@ use rand_chacha::ChaCha8Rng; #[derive(Clone)] pub struct Sample { pub t: f64, - pub state: Vec + pub state: Vec, } pub trait Sampler: Iterator { fn reset(&mut self); } - - pub struct ForwardSampler<'a, T> where T: NetworkProcess, @@ -104,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some(Sample{t: ret_time, state: ret_state}) + Some(Sample { + t: ret_time, + state: ret_state, + }) } } From bcb64a161ad49204eea20142b7f803e06e72becb Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 10:02:21 +0100 Subject: [PATCH 11/11] Mini refactor. Introduced the type alias NetworkProcessState. --- reCTBN/src/process.rs | 6 ++++-- reCTBN/src/process/ctbn.rs | 10 +++++----- reCTBN/src/process/ctmp.rs | 12 ++++-------- reCTBN/src/reward_function.rs | 16 +++++++--------- reCTBN/src/sampling.rs | 8 ++++---- reCTBN/src/tools.rs | 2 +- reCTBN/tests/reward_function.rs | 24 ++++++++++++------------ 7 files changed, 37 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index 2b70b59..dc297bc 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -16,6 +16,9 @@ pub enum NetworkError { NodeInsertionError(String), } +/// This type is used to represent a specific realization of a generic NetworkProcess +pub type NetworkProcessState = Vec; + /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). pub trait NetworkProcess { @@ -71,8 +74,7 @@ pub trait NetworkProcess { /// # Return /// /// * Index of the `node` relative to the network. - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize; + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; /// Compute the index that must be used to access the parameters of a `node`, given a specific /// configuration of the network and a generic `parent_set`. diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 0b6161c..162345e 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -8,7 +8,7 @@ use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, Stat use crate::process; use super::ctmp::CtmpProcess; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; /// It represents both the structure and the parameters of a CTBN. /// @@ -86,7 +86,7 @@ impl CtbnNetwork { for idx_current_state in 0..state_space { let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); - let current_state_statetype: Vec = current_state + let current_state_statetype: NetworkProcessState = current_state .iter() .map(|x| StateType::Discrete(*x)) .collect(); @@ -98,7 +98,7 @@ impl CtbnNetwork { let mut next_state = current_state.clone(); next_state[idx_node] = next_node_state; - let next_state_statetype: Vec = + let next_state_statetype: NetworkProcessState = next_state.iter().map(|x| StateType::Discrete(*x)).collect(); let idx_next_state = self.get_param_index_from_custom_parent_set( &next_state_statetype, @@ -185,7 +185,7 @@ impl process::NetworkProcess for CtbnNetwork { &mut self.nodes[node_idx] } - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { self.adj_matrix .as_ref() .unwrap() @@ -204,7 +204,7 @@ impl process::NetworkProcess for CtbnNetwork { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, + current_state: &NetworkProcessState, parent_set: &BTreeSet, ) -> usize { parent_set diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 81509fa..41b8db6 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -5,7 +5,7 @@ use crate::{ process, }; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; pub struct CtmpProcess { param: Option, @@ -68,11 +68,7 @@ impl NetworkProcess for CtmpProcess { } } - fn get_param_index_network( - &self, - node: usize, - current_state: &Vec, - ) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { if node == 0 { match current_state[0] { StateType::Discrete(x) => x, @@ -84,8 +80,8 @@ impl NetworkProcess for CtmpProcess { fn get_param_index_from_custom_parent_set( &self, - _current_state: &Vec, - _parent_set: &std::collections::BTreeSet, + _current_state: &NetworkProcessState, + _parent_set: &BTreeSet, ) -> usize { unimplemented!("CtmpProcess has only one node") } diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs index eeddd85..35e15c8 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward_function.rs @@ -2,7 +2,7 @@ use crate::{ params::{self, ParamsTrait}, - process, sampling, + process, }; use ndarray; @@ -27,13 +27,13 @@ pub trait RewardFunction { /// /// # Arguments /// - /// * `current_state`: the current state of the network represented as a `sampling::Sample` + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` /// * `previous_state`: an optional argument representing the previous state of the network fn call( &self, - current_state: sampling::Sample, - previous_state: Option, + current_state: process::NetworkProcessState, + previous_state: Option, ) -> Reward; /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess @@ -80,11 +80,10 @@ impl FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction { fn call( &self, - current_state: sampling::Sample, - previous_state: Option, + current_state: process::NetworkProcessState, + previous_state: Option, ) -> Reward { let instantaneous_reward: f64 = current_state - .state .iter() .enumerate() .map(|(idx, x)| { @@ -96,9 +95,8 @@ impl RewardFunction for FactoredRewardFunction { .sum(); if let Some(previous_state) = previous_state { let transition_reward = previous_state - .state .iter() - .zip(current_state.state.iter()) + .zip(current_state.iter()) .enumerate() .find_map(|(idx, (p, c))| -> Option { let p = match p { diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index a0a9fcb..1384872 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,8 +1,8 @@ //! Module containing methods for the sampling. use crate::{ - params::{self, ParamsTrait}, - process::NetworkProcess, + params::ParamsTrait, + process::{NetworkProcess, NetworkProcessState}, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; @@ -10,7 +10,7 @@ use rand_chacha::ChaCha8Rng; #[derive(Clone)] pub struct Sample { pub t: f64, - pub state: Vec, + pub state: NetworkProcessState, } pub trait Sampler: Iterator { @@ -24,7 +24,7 @@ where net: &'a T, rng: ChaCha8Rng, current_time: f64, - current_state: Vec, + current_state: NetworkProcessState, next_transitions: Vec>, } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index e749d69..ecfeff9 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -69,7 +69,7 @@ pub fn trajectory_generator( let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut events: Vec> = Vec::new(); + let mut events: Vec = Vec::new(); //Current Time and Current State let mut sample = sampler.next().unwrap(); diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 0c7fd9b..dcc5e69 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -2,7 +2,7 @@ mod utils; use ndarray::*; use utils::generate_discrete_time_continous_node; -use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; #[test] @@ -16,8 +16,8 @@ fn simple_factored_reward_function_binary_node() { rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); - let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; - let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); @@ -41,9 +41,9 @@ fn simple_factored_reward_function_ternary_node() { rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); - let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; - let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; - let s2 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2)]}; + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); @@ -78,14 +78,14 @@ fn factored_reward_function_two_nodes() { rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); - let s00 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]}; - let s01 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]}; - let s02 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]}; + let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; + let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; + let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; - let s10 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]}; - let s11 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]}; - let s12 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]}; + let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; + let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; + let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});