From 1878f687d6198a16b56f618ea6e3945ef1703ee5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 21 Nov 2022 16:34:39 +0100 Subject: [PATCH 1/5] Refactor of sampling --- reCTBN/src/sampling.rs | 13 ++++++++++--- reCTBN/src/tools.rs | 12 ++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 0662994..3bc0c6f 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,10 +7,17 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -pub trait Sampler: Iterator { +pub struct Sample { + pub t: f64, + pub state: Vec +} + +pub trait Sampler: Iterator { fn reset(&mut self); } + + pub struct ForwardSampler<'a, T> where T: NetworkProcess, @@ -43,7 +50,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { } impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { - type Item = (f64, Vec); + type Item = Sample; fn next(&mut self) -> Option { let ret_time = self.current_time.clone(); @@ -96,7 +103,7 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) + Some(Sample{t: ret_time, state: ret_state}) } } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 2e727e8..e749d69 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -72,15 +72,15 @@ pub fn trajectory_generator( let mut events: Vec> = Vec::new(); //Current Time and Current State - let (mut t, mut current_state) = sampler.next().unwrap(); + let mut sample = sampler.next().unwrap(); //Generate new samples until ending time is reached. - while t < t_end { - time.push(t); - events.push(current_state); - (t, current_state) = sampler.next().unwrap(); + while sample.t < t_end { + time.push(sample.t); + events.push(sample.state); + sample = sampler.next().unwrap(); } - current_state = events.last().unwrap().clone(); + let current_state = events.last().unwrap().clone(); events.push(current_state); //Add t_end as last time. From 055eb7088e8bf5d139312e55806101a2738cac73 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 21 Nov 2022 17:34:32 +0100 Subject: [PATCH 2/5] Implemented FactoredRewardFunction --- reCTBN/src/lib.rs | 1 + reCTBN/src/process/ctbn.rs | 1 - reCTBN/src/reward_function.rs | 80 +++++++++++++++++++++++++++++++++ reCTBN/src/sampling.rs | 1 + reCTBN/tests/reward_function.rs | 30 +++++++++++++ 5 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 reCTBN/src/reward_function.rs create mode 100644 reCTBN/tests/reward_function.rs diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index c62c42e..1d25552 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -9,3 +9,4 @@ pub mod process; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod reward_function; diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c949afe..0b6161c 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -119,7 +119,6 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs new file mode 100644 index 0000000..9ff09cc --- /dev/null +++ b/reCTBN/src/reward_function.rs @@ -0,0 +1,80 @@ +use crate::{process, sampling, params::{ParamsTrait, self}}; +use ndarray; + + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64 +} + +pub trait RewardFunction { + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward; + fn initialize_from_network_process(p: &T) -> Self; +} + + +pub struct FactoredRewardFunction { + transition_reward: Vec>, + instantaneous_reward: Vec> +} + +impl FactoredRewardFunction { + pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2 { + &self.transition_reward[node_idx] + } + + pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2 { + &mut self.transition_reward[node_idx] + } + + pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1 { + &self.instantaneous_reward[node_idx] + } + + pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { + &mut self.instantaneous_reward[node_idx] + } + + +} + +impl RewardFunction for FactoredRewardFunction { + + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward { + let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { + let x = match x {params::StateType::Discrete(x) => x}; + self.instantaneous_reward[idx][*x] + }).sum(); + if let Some(previous_state) = previous_state { + let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option { + let p = match p {params::StateType::Discrete(p) => p}; + let c = match c {params::StateType::Discrete(c) => c}; + if p != c { + Some(self.transition_reward[idx][[*p,*c]]) + } else { + None + } + }).unwrap_or(0.0); + Reward {transition_reward, instantaneous_reward} + } else { + Reward { transition_reward: 0.0, instantaneous_reward} + } + } + + fn initialize_from_network_process(p: &T) -> Self { + let mut transition_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; + for i in p.get_node_indices() { + //This works only for discrete nodes! + let size: usize = p.get_node(i).get_reserved_space_as_parent(); + instantaneous_reward.push(ndarray::Array1::zeros(size)); + transition_reward.push(ndarray::Array2::zeros((size, size))); + } + + FactoredRewardFunction { transition_reward, instantaneous_reward } + + } + +} + diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 3bc0c6f..d5a1dbe 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,6 +7,7 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; +#[derive(Clone)] pub struct Sample { pub t: f64, pub state: Vec diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs new file mode 100644 index 0000000..7f73e6c --- /dev/null +++ b/reCTBN/tests/reward_function.rs @@ -0,0 +1,30 @@ +mod utils; + +use ndarray::*; +use utils::generate_discrete_time_continous_node; +use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; + + +#[test] +fn simple_factored_reward_function() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); + + let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; + let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + + assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); +} From f6015acce99e41582d3902dc7342556e3fe4a115 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 08:53:29 +0100 Subject: [PATCH 3/5] Added tests --- reCTBN/tests/reward_function.rs | 90 ++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 7f73e6c..0c7fd9b 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -6,7 +6,7 @@ use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; #[test] -fn simple_factored_reward_function() { +fn simple_factored_reward_function_binary_node() { let mut net = CtbnNetwork::new(); let n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) @@ -28,3 +28,91 @@ fn simple_factored_reward_function() { assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); } + + +#[test] +fn simple_factored_reward_function_ternary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; + let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + let s2 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2)]}; + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + + + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); +} + +#[test] +fn factored_reward_function_two_nodes() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + + rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); + + let s00 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]}; + let s01 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]}; + let s02 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]}; + + + let s10 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]}; + let s11 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]}; + let s12 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]}; + + assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + + + assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + + + assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + + + assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); +} From 68ef7ea7c3ad4f33849f1cdf84349939e2e4a6b7 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 09:30:59 +0100 Subject: [PATCH 4/5] Added comments --- reCTBN/src/lib.rs | 2 +- reCTBN/src/reward_function.rs | 120 ++++++++++++++++++++++++++-------- reCTBN/src/sampling.rs | 9 +-- 3 files changed, 98 insertions(+), 33 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 1d25552..8feddfb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -6,7 +6,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; pub mod process; +pub mod reward_function; pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod reward_function; diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs index 9ff09cc..eeddd85 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward_function.rs @@ -1,22 +1,62 @@ -use crate::{process, sampling, params::{ParamsTrait, self}}; +//! Module for dealing with reward functions + +use crate::{ + params::{self, ParamsTrait}, + process, sampling, +}; use ndarray; +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state #[derive(Debug, PartialEq)] pub struct Reward { pub transition_reward: f64, - pub instantaneous_reward: f64 + pub instantaneous_reward: f64, } +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + pub trait RewardFunction { - fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward; + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `sampling::Sample` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: sampling::Sample, + previous_state: Option, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` fn initialize_from_network_process(p: &T) -> Self; } +/// Reward function over a factored state space +/// +/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node +/// of the underling `NetworkProcess` +/// +/// # Arguments +/// +/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition +/// reward of a node pub struct FactoredRewardFunction { transition_reward: Vec>, - instantaneous_reward: Vec> + instantaneous_reward: Vec>, } impl FactoredRewardFunction { @@ -35,36 +75,60 @@ impl FactoredRewardFunction { pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { &mut self.instantaneous_reward[node_idx] } - - } impl RewardFunction for FactoredRewardFunction { - - fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward { - let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { - let x = match x {params::StateType::Discrete(x) => x}; - self.instantaneous_reward[idx][*x] - }).sum(); + fn call( + &self, + current_state: sampling::Sample, + previous_state: Option, + ) -> Reward { + let instantaneous_reward: f64 = current_state + .state + .iter() + .enumerate() + .map(|(idx, x)| { + let x = match x { + params::StateType::Discrete(x) => x, + }; + self.instantaneous_reward[idx][*x] + }) + .sum(); if let Some(previous_state) = previous_state { - let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option { - let p = match p {params::StateType::Discrete(p) => p}; - let c = match c {params::StateType::Discrete(c) => c}; - if p != c { - Some(self.transition_reward[idx][[*p,*c]]) - } else { - None - } - }).unwrap_or(0.0); - Reward {transition_reward, instantaneous_reward} + let transition_reward = previous_state + .state + .iter() + .zip(current_state.state.iter()) + .enumerate() + .find_map(|(idx, (p, c))| -> Option { + let p = match p { + params::StateType::Discrete(p) => p, + }; + let c = match c { + params::StateType::Discrete(c) => c, + }; + if p != c { + Some(self.transition_reward[idx][[*p, *c]]) + } else { + None + } + }) + .unwrap_or(0.0); + Reward { + transition_reward, + instantaneous_reward, + } } else { - Reward { transition_reward: 0.0, instantaneous_reward} + Reward { + transition_reward: 0.0, + instantaneous_reward, + } } } fn initialize_from_network_process(p: &T) -> Self { let mut transition_reward: Vec> = vec![]; - let mut instantaneous_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; for i in p.get_node_indices() { //This works only for discrete nodes! let size: usize = p.get_node(i).get_reserved_space_as_parent(); @@ -72,9 +136,9 @@ impl RewardFunction for FactoredRewardFunction { transition_reward.push(ndarray::Array2::zeros((size, size))); } - FactoredRewardFunction { transition_reward, instantaneous_reward } - + FactoredRewardFunction { + transition_reward, + instantaneous_reward, + } } - } - diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d5a1dbe..a0a9fcb 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -10,15 +10,13 @@ use rand_chacha::ChaCha8Rng; #[derive(Clone)] pub struct Sample { pub t: f64, - pub state: Vec + pub state: Vec, } pub trait Sampler: Iterator { fn reset(&mut self); } - - pub struct ForwardSampler<'a, T> where T: NetworkProcess, @@ -104,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some(Sample{t: ret_time, state: ret_state}) + Some(Sample { + t: ret_time, + state: ret_state, + }) } } From bcb64a161ad49204eea20142b7f803e06e72becb Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 22 Nov 2022 10:02:21 +0100 Subject: [PATCH 5/5] Mini refactor. Introduced the type alias NetworkProcessState. --- reCTBN/src/process.rs | 6 ++++-- reCTBN/src/process/ctbn.rs | 10 +++++----- reCTBN/src/process/ctmp.rs | 12 ++++-------- reCTBN/src/reward_function.rs | 16 +++++++--------- reCTBN/src/sampling.rs | 8 ++++---- reCTBN/src/tools.rs | 2 +- reCTBN/tests/reward_function.rs | 24 ++++++++++++------------ 7 files changed, 37 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index 2b70b59..dc297bc 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -16,6 +16,9 @@ pub enum NetworkError { NodeInsertionError(String), } +/// This type is used to represent a specific realization of a generic NetworkProcess +pub type NetworkProcessState = Vec; + /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). pub trait NetworkProcess { @@ -71,8 +74,7 @@ pub trait NetworkProcess { /// # Return /// /// * Index of the `node` relative to the network. - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize; + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; /// Compute the index that must be used to access the parameters of a `node`, given a specific /// configuration of the network and a generic `parent_set`. diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 0b6161c..162345e 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -8,7 +8,7 @@ use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, Stat use crate::process; use super::ctmp::CtmpProcess; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; /// It represents both the structure and the parameters of a CTBN. /// @@ -86,7 +86,7 @@ impl CtbnNetwork { for idx_current_state in 0..state_space { let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); - let current_state_statetype: Vec = current_state + let current_state_statetype: NetworkProcessState = current_state .iter() .map(|x| StateType::Discrete(*x)) .collect(); @@ -98,7 +98,7 @@ impl CtbnNetwork { let mut next_state = current_state.clone(); next_state[idx_node] = next_node_state; - let next_state_statetype: Vec = + let next_state_statetype: NetworkProcessState = next_state.iter().map(|x| StateType::Discrete(*x)).collect(); let idx_next_state = self.get_param_index_from_custom_parent_set( &next_state_statetype, @@ -185,7 +185,7 @@ impl process::NetworkProcess for CtbnNetwork { &mut self.nodes[node_idx] } - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { self.adj_matrix .as_ref() .unwrap() @@ -204,7 +204,7 @@ impl process::NetworkProcess for CtbnNetwork { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, + current_state: &NetworkProcessState, parent_set: &BTreeSet, ) -> usize { parent_set diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 81509fa..41b8db6 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -5,7 +5,7 @@ use crate::{ process, }; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; pub struct CtmpProcess { param: Option, @@ -68,11 +68,7 @@ impl NetworkProcess for CtmpProcess { } } - fn get_param_index_network( - &self, - node: usize, - current_state: &Vec, - ) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { if node == 0 { match current_state[0] { StateType::Discrete(x) => x, @@ -84,8 +80,8 @@ impl NetworkProcess for CtmpProcess { fn get_param_index_from_custom_parent_set( &self, - _current_state: &Vec, - _parent_set: &std::collections::BTreeSet, + _current_state: &NetworkProcessState, + _parent_set: &BTreeSet, ) -> usize { unimplemented!("CtmpProcess has only one node") } diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs index eeddd85..35e15c8 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward_function.rs @@ -2,7 +2,7 @@ use crate::{ params::{self, ParamsTrait}, - process, sampling, + process, }; use ndarray; @@ -27,13 +27,13 @@ pub trait RewardFunction { /// /// # Arguments /// - /// * `current_state`: the current state of the network represented as a `sampling::Sample` + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` /// * `previous_state`: an optional argument representing the previous state of the network fn call( &self, - current_state: sampling::Sample, - previous_state: Option, + current_state: process::NetworkProcessState, + previous_state: Option, ) -> Reward; /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess @@ -80,11 +80,10 @@ impl FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction { fn call( &self, - current_state: sampling::Sample, - previous_state: Option, + current_state: process::NetworkProcessState, + previous_state: Option, ) -> Reward { let instantaneous_reward: f64 = current_state - .state .iter() .enumerate() .map(|(idx, x)| { @@ -96,9 +95,8 @@ impl RewardFunction for FactoredRewardFunction { .sum(); if let Some(previous_state) = previous_state { let transition_reward = previous_state - .state .iter() - .zip(current_state.state.iter()) + .zip(current_state.iter()) .enumerate() .find_map(|(idx, (p, c))| -> Option { let p = match p { diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index a0a9fcb..1384872 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,8 +1,8 @@ //! Module containing methods for the sampling. use crate::{ - params::{self, ParamsTrait}, - process::NetworkProcess, + params::ParamsTrait, + process::{NetworkProcess, NetworkProcessState}, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; @@ -10,7 +10,7 @@ use rand_chacha::ChaCha8Rng; #[derive(Clone)] pub struct Sample { pub t: f64, - pub state: Vec, + pub state: NetworkProcessState, } pub trait Sampler: Iterator { @@ -24,7 +24,7 @@ where net: &'a T, rng: ChaCha8Rng, current_time: f64, - current_state: Vec, + current_state: NetworkProcessState, next_transitions: Vec>, } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index e749d69..ecfeff9 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -69,7 +69,7 @@ pub fn trajectory_generator( let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut events: Vec> = Vec::new(); + let mut events: Vec = Vec::new(); //Current Time and Current State let mut sample = sampler.next().unwrap(); diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 0c7fd9b..dcc5e69 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -2,7 +2,7 @@ mod utils; use ndarray::*; use utils::generate_discrete_time_continous_node; -use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; #[test] @@ -16,8 +16,8 @@ fn simple_factored_reward_function_binary_node() { rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); - let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; - let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); @@ -41,9 +41,9 @@ fn simple_factored_reward_function_ternary_node() { rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); - let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; - let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; - let s2 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2)]}; + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); @@ -78,14 +78,14 @@ fn factored_reward_function_two_nodes() { rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); - let s00 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]}; - let s01 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]}; - let s02 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]}; + let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; + let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; + let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; - let s10 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]}; - let s11 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]}; - let s12 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]}; + let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; + let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; + let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});