commit
3f80f07e9f
@ -0,0 +1,142 @@ |
|||||||
|
//! Module for dealing with reward functions
|
||||||
|
|
||||||
|
use crate::{ |
||||||
|
params::{self, ParamsTrait}, |
||||||
|
process, |
||||||
|
}; |
||||||
|
use ndarray; |
||||||
|
|
||||||
|
/// Instantiation of reward function and instantaneous reward
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `transition_reward`: reward obtained transitioning from one state to another
|
||||||
|
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)] |
||||||
|
pub struct Reward { |
||||||
|
pub transition_reward: f64, |
||||||
|
pub instantaneous_reward: f64, |
||||||
|
} |
||||||
|
|
||||||
|
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
||||||
|
|
||||||
|
pub trait RewardFunction { |
||||||
|
/// Given the current state and the previous state, it compute the reward.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
||||||
|
/// * `previous_state`: an optional argument representing the previous state of the network
|
||||||
|
|
||||||
|
fn call( |
||||||
|
&self, |
||||||
|
current_state: process::NetworkProcessState, |
||||||
|
previous_state: Option<process::NetworkProcessState>, |
||||||
|
) -> Reward; |
||||||
|
|
||||||
|
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
||||||
|
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||||
|
} |
||||||
|
|
||||||
|
/// Reward function over a factored state space
|
||||||
|
///
|
||||||
|
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
|
||||||
|
/// of the underling `NetworkProcess`
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
|
||||||
|
/// reward of a node
|
||||||
|
|
||||||
|
pub struct FactoredRewardFunction { |
||||||
|
transition_reward: Vec<ndarray::Array2<f64>>, |
||||||
|
instantaneous_reward: Vec<ndarray::Array1<f64>>, |
||||||
|
} |
||||||
|
|
||||||
|
impl FactoredRewardFunction { |
||||||
|
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
||||||
|
&self.transition_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
||||||
|
&mut self.transition_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
||||||
|
&self.instantaneous_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
||||||
|
&mut self.instantaneous_reward[node_idx] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl RewardFunction for FactoredRewardFunction { |
||||||
|
fn call( |
||||||
|
&self, |
||||||
|
current_state: process::NetworkProcessState, |
||||||
|
previous_state: Option<process::NetworkProcessState>, |
||||||
|
) -> Reward { |
||||||
|
let instantaneous_reward: f64 = current_state |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.map(|(idx, x)| { |
||||||
|
let x = match x { |
||||||
|
params::StateType::Discrete(x) => x, |
||||||
|
}; |
||||||
|
self.instantaneous_reward[idx][*x] |
||||||
|
}) |
||||||
|
.sum(); |
||||||
|
if let Some(previous_state) = previous_state { |
||||||
|
let transition_reward = previous_state |
||||||
|
.iter() |
||||||
|
.zip(current_state.iter()) |
||||||
|
.enumerate() |
||||||
|
.find_map(|(idx, (p, c))| -> Option<f64> { |
||||||
|
let p = match p { |
||||||
|
params::StateType::Discrete(p) => p, |
||||||
|
}; |
||||||
|
let c = match c { |
||||||
|
params::StateType::Discrete(c) => c, |
||||||
|
}; |
||||||
|
if p != c { |
||||||
|
Some(self.transition_reward[idx][[*p, *c]]) |
||||||
|
} else { |
||||||
|
None |
||||||
|
} |
||||||
|
}) |
||||||
|
.unwrap_or(0.0); |
||||||
|
Reward { |
||||||
|
transition_reward, |
||||||
|
instantaneous_reward, |
||||||
|
} |
||||||
|
} else { |
||||||
|
Reward { |
||||||
|
transition_reward: 0.0, |
||||||
|
instantaneous_reward, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
||||||
|
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
||||||
|
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; |
||||||
|
for i in p.get_node_indices() { |
||||||
|
//This works only for discrete nodes!
|
||||||
|
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
||||||
|
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
||||||
|
transition_reward.push(ndarray::Array2::zeros((size, size))); |
||||||
|
} |
||||||
|
|
||||||
|
FactoredRewardFunction { |
||||||
|
transition_reward, |
||||||
|
instantaneous_reward, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,118 @@ |
|||||||
|
mod utils; |
||||||
|
|
||||||
|
use ndarray::*; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_binary_node() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
||||||
|
|
||||||
|
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||||
|
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||||
|
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_ternary_node() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||||
|
|
||||||
|
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||||
|
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||||
|
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||||
|
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); |
||||||
|
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn factored_reward_function_two_nodes() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||||
|
|
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); |
||||||
|
|
||||||
|
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; |
||||||
|
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; |
||||||
|
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; |
||||||
|
|
||||||
|
|
||||||
|
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; |
||||||
|
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; |
||||||
|
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; |
||||||
|
|
||||||
|
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||||
|
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); |
||||||
|
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); |
||||||
|
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); |
||||||
|
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||||
|
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); |
||||||
|
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); |
||||||
|
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); |
||||||
|
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); |
||||||
|
} |
Loading…
Reference in new issue