commit
3f80f07e9f
@ -0,0 +1,142 @@ |
||||
//! Module for dealing with reward functions
|
||||
|
||||
use crate::{ |
||||
params::{self, ParamsTrait}, |
||||
process, |
||||
}; |
||||
use ndarray; |
||||
|
||||
/// Instantiation of reward function and instantaneous reward
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: reward obtained transitioning from one state to another
|
||||
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
||||
|
||||
#[derive(Debug, PartialEq)] |
||||
pub struct Reward { |
||||
pub transition_reward: f64, |
||||
pub instantaneous_reward: f64, |
||||
} |
||||
|
||||
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
||||
|
||||
pub trait RewardFunction { |
||||
/// Given the current state and the previous state, it compute the reward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
||||
/// * `previous_state`: an optional argument representing the previous state of the network
|
||||
|
||||
fn call( |
||||
&self, |
||||
current_state: process::NetworkProcessState, |
||||
previous_state: Option<process::NetworkProcessState>, |
||||
) -> Reward; |
||||
|
||||
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||
} |
||||
|
||||
/// Reward function over a factored state space
|
||||
///
|
||||
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
|
||||
/// of the underling `NetworkProcess`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
|
||||
/// reward of a node
|
||||
|
||||
pub struct FactoredRewardFunction { |
||||
transition_reward: Vec<ndarray::Array2<f64>>, |
||||
instantaneous_reward: Vec<ndarray::Array1<f64>>, |
||||
} |
||||
|
||||
impl FactoredRewardFunction { |
||||
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
||||
&self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
||||
&mut self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
||||
&self.instantaneous_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
||||
&mut self.instantaneous_reward[node_idx] |
||||
} |
||||
} |
||||
|
||||
impl RewardFunction for FactoredRewardFunction { |
||||
fn call( |
||||
&self, |
||||
current_state: process::NetworkProcessState, |
||||
previous_state: Option<process::NetworkProcessState>, |
||||
) -> Reward { |
||||
let instantaneous_reward: f64 = current_state |
||||
.iter() |
||||
.enumerate() |
||||
.map(|(idx, x)| { |
||||
let x = match x { |
||||
params::StateType::Discrete(x) => x, |
||||
}; |
||||
self.instantaneous_reward[idx][*x] |
||||
}) |
||||
.sum(); |
||||
if let Some(previous_state) = previous_state { |
||||
let transition_reward = previous_state |
||||
.iter() |
||||
.zip(current_state.iter()) |
||||
.enumerate() |
||||
.find_map(|(idx, (p, c))| -> Option<f64> { |
||||
let p = match p { |
||||
params::StateType::Discrete(p) => p, |
||||
}; |
||||
let c = match c { |
||||
params::StateType::Discrete(c) => c, |
||||
}; |
||||
if p != c { |
||||
Some(self.transition_reward[idx][[*p, *c]]) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
.unwrap_or(0.0); |
||||
Reward { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} else { |
||||
Reward { |
||||
transition_reward: 0.0, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
||||
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
||||
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
||||
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; |
||||
for i in p.get_node_indices() { |
||||
//This works only for discrete nodes!
|
||||
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
||||
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
||||
transition_reward.push(ndarray::Array2::zeros((size, size))); |
||||
} |
||||
|
||||
FactoredRewardFunction { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,118 @@ |
||||
mod utils; |
||||
|
||||
use ndarray::*; |
||||
use utils::generate_discrete_time_continous_node; |
||||
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_binary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
|
||||
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_ternary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; |
||||
|
||||
|
||||
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); |
||||
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); |
||||
} |
||||
|
||||
#[test] |
||||
fn factored_reward_function_two_nodes() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
|
||||
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); |
||||
|
||||
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; |
||||
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; |
||||
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; |
||||
|
||||
|
||||
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; |
||||
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; |
||||
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; |
||||
|
||||
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); |
||||
} |
Loading…
Reference in new issue