diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index c62c42e..8feddfb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -6,6 +6,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; pub mod process; +pub mod reward_function; pub mod sampling; pub mod structure_learning; pub mod tools; diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs index 2b70b59..dc297bc 100644 --- a/reCTBN/src/process.rs +++ b/reCTBN/src/process.rs @@ -16,6 +16,9 @@ pub enum NetworkError { NodeInsertionError(String), } +/// This type is used to represent a specific realization of a generic NetworkProcess +pub type NetworkProcessState = Vec; + /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). pub trait NetworkProcess { @@ -71,8 +74,7 @@ pub trait NetworkProcess { /// # Return /// /// * Index of the `node` relative to the network. - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize; + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; /// Compute the index that must be used to access the parameters of a `node`, given a specific /// configuration of the network and a generic `parent_set`. diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c949afe..162345e 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -8,7 +8,7 @@ use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, Stat use crate::process; use super::ctmp::CtmpProcess; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; /// It represents both the structure and the parameters of a CTBN. /// @@ -86,7 +86,7 @@ impl CtbnNetwork { for idx_current_state in 0..state_space { let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); - let current_state_statetype: Vec = current_state + let current_state_statetype: NetworkProcessState = current_state .iter() .map(|x| StateType::Discrete(*x)) .collect(); @@ -98,7 +98,7 @@ impl CtbnNetwork { let mut next_state = current_state.clone(); next_state[idx_node] = next_node_state; - let next_state_statetype: Vec = + let next_state_statetype: NetworkProcessState = next_state.iter().map(|x| StateType::Discrete(*x)).collect(); let idx_next_state = self.get_param_index_from_custom_parent_set( &next_state_statetype, @@ -119,7 +119,6 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); @@ -186,7 +185,7 @@ impl process::NetworkProcess for CtbnNetwork { &mut self.nodes[node_idx] } - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { self.adj_matrix .as_ref() .unwrap() @@ -205,7 +204,7 @@ impl process::NetworkProcess for CtbnNetwork { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, + current_state: &NetworkProcessState, parent_set: &BTreeSet, ) -> usize { parent_set diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index 81509fa..41b8db6 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -5,7 +5,7 @@ use crate::{ process, }; -use super::NetworkProcess; +use super::{NetworkProcess, NetworkProcessState}; pub struct CtmpProcess { param: Option, @@ -68,11 +68,7 @@ impl NetworkProcess for CtmpProcess { } } - fn get_param_index_network( - &self, - node: usize, - current_state: &Vec, - ) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { if node == 0 { match current_state[0] { StateType::Discrete(x) => x, @@ -84,8 +80,8 @@ impl NetworkProcess for CtmpProcess { fn get_param_index_from_custom_parent_set( &self, - _current_state: &Vec, - _parent_set: &std::collections::BTreeSet, + _current_state: &NetworkProcessState, + _parent_set: &BTreeSet, ) -> usize { unimplemented!("CtmpProcess has only one node") } diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs new file mode 100644 index 0000000..35e15c8 --- /dev/null +++ b/reCTBN/src/reward_function.rs @@ -0,0 +1,142 @@ +//! Module for dealing with reward functions + +use crate::{ + params::{self, ParamsTrait}, + process, +}; +use ndarray; + +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64, +} + +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + +pub trait RewardFunction { + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: process::NetworkProcessState, + previous_state: Option, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` + fn initialize_from_network_process(p: &T) -> Self; +} + +/// Reward function over a factored state space +/// +/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node +/// of the underling `NetworkProcess` +/// +/// # Arguments +/// +/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition +/// reward of a node + +pub struct FactoredRewardFunction { + transition_reward: Vec>, + instantaneous_reward: Vec>, +} + +impl FactoredRewardFunction { + pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2 { + &self.transition_reward[node_idx] + } + + pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2 { + &mut self.transition_reward[node_idx] + } + + pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1 { + &self.instantaneous_reward[node_idx] + } + + pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { + &mut self.instantaneous_reward[node_idx] + } +} + +impl RewardFunction for FactoredRewardFunction { + fn call( + &self, + current_state: process::NetworkProcessState, + previous_state: Option, + ) -> Reward { + let instantaneous_reward: f64 = current_state + .iter() + .enumerate() + .map(|(idx, x)| { + let x = match x { + params::StateType::Discrete(x) => x, + }; + self.instantaneous_reward[idx][*x] + }) + .sum(); + if let Some(previous_state) = previous_state { + let transition_reward = previous_state + .iter() + .zip(current_state.iter()) + .enumerate() + .find_map(|(idx, (p, c))| -> Option { + let p = match p { + params::StateType::Discrete(p) => p, + }; + let c = match c { + params::StateType::Discrete(c) => c, + }; + if p != c { + Some(self.transition_reward[idx][[*p, *c]]) + } else { + None + } + }) + .unwrap_or(0.0); + Reward { + transition_reward, + instantaneous_reward, + } + } else { + Reward { + transition_reward: 0.0, + instantaneous_reward, + } + } + } + + fn initialize_from_network_process(p: &T) -> Self { + let mut transition_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; + for i in p.get_node_indices() { + //This works only for discrete nodes! + let size: usize = p.get_node(i).get_reserved_space_as_parent(); + instantaneous_reward.push(ndarray::Array1::zeros(size)); + transition_reward.push(ndarray::Array2::zeros((size, size))); + } + + FactoredRewardFunction { + transition_reward, + instantaneous_reward, + } + } +} diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 0662994..1384872 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,13 +1,19 @@ //! Module containing methods for the sampling. use crate::{ - params::{self, ParamsTrait}, - process::NetworkProcess, + params::ParamsTrait, + process::{NetworkProcess, NetworkProcessState}, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -pub trait Sampler: Iterator { +#[derive(Clone)] +pub struct Sample { + pub t: f64, + pub state: NetworkProcessState, +} + +pub trait Sampler: Iterator { fn reset(&mut self); } @@ -18,7 +24,7 @@ where net: &'a T, rng: ChaCha8Rng, current_time: f64, - current_state: Vec, + current_state: NetworkProcessState, next_transitions: Vec>, } @@ -43,7 +49,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { } impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { - type Item = (f64, Vec); + type Item = Sample; fn next(&mut self) -> Option { let ret_time = self.current_time.clone(); @@ -96,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) + Some(Sample { + t: ret_time, + state: ret_state, + }) } } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 2e727e8..ecfeff9 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -69,18 +69,18 @@ pub fn trajectory_generator( let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut events: Vec> = Vec::new(); + let mut events: Vec = Vec::new(); //Current Time and Current State - let (mut t, mut current_state) = sampler.next().unwrap(); + let mut sample = sampler.next().unwrap(); //Generate new samples until ending time is reached. - while t < t_end { - time.push(t); - events.push(current_state); - (t, current_state) = sampler.next().unwrap(); + while sample.t < t_end { + time.push(sample.t); + events.push(sample.state); + sample = sampler.next().unwrap(); } - current_state = events.last().unwrap().clone(); + let current_state = events.last().unwrap().clone(); events.push(current_state); //Add t_end as last time. diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs new file mode 100644 index 0000000..dcc5e69 --- /dev/null +++ b/reCTBN/tests/reward_function.rs @@ -0,0 +1,118 @@ +mod utils; + +use ndarray::*; +use utils::generate_discrete_time_continous_node; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; + + +#[test] +fn simple_factored_reward_function_binary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + + assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); +} + + +#[test] +fn simple_factored_reward_function_ternary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + + + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); +} + +#[test] +fn factored_reward_function_two_nodes() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + + rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); + + let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; + let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; + let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; + + + let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; + let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; + let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; + + assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + + + assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + + + assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + + + assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); +}