76 feature reward evaluation #87
Merged
AlessandroBregoli
merged 11 commits from 76-feature-reward-evaluation
into dev
2 years ago
@ -0,0 +1,59 @@ |
|||||||
|
pub mod reward_evaluation; |
||||||
|
pub mod reward_function; |
||||||
|
|
||||||
|
use std::collections::HashMap; |
||||||
|
|
||||||
|
use crate::process; |
||||||
|
|
||||||
|
/// Instantiation of reward function and instantaneous reward
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `transition_reward`: reward obtained transitioning from one state to another
|
||||||
|
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)] |
||||||
|
pub struct Reward { |
||||||
|
pub transition_reward: f64, |
||||||
|
pub instantaneous_reward: f64, |
||||||
|
} |
||||||
|
|
||||||
|
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
||||||
|
|
||||||
|
pub trait RewardFunction: Sync { |
||||||
|
/// Given the current state and the previous state, it compute the reward.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
||||||
|
/// * `previous_state`: an optional argument representing the previous state of the network
|
||||||
|
|
||||||
|
fn call( |
||||||
|
&self, |
||||||
|
current_state: &process::NetworkProcessState, |
||||||
|
previous_state: Option<&process::NetworkProcessState>, |
||||||
|
) -> Reward; |
||||||
|
|
||||||
|
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
||||||
|
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||||
|
} |
||||||
|
|
||||||
|
pub trait RewardEvaluation { |
||||||
|
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
) -> HashMap<process::NetworkProcessState, f64>; |
||||||
|
|
||||||
|
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
state: &process::NetworkProcessState, |
||||||
|
) -> f64; |
||||||
|
} |
@ -0,0 +1,205 @@ |
|||||||
|
use std::collections::HashMap; |
||||||
|
|
||||||
|
use rayon::prelude::{IntoParallelIterator, ParallelIterator}; |
||||||
|
use statrs::distribution::ContinuousCDF; |
||||||
|
|
||||||
|
use crate::params::{self, ParamsTrait}; |
||||||
|
use crate::process; |
||||||
|
|
||||||
|
use crate::{ |
||||||
|
process::NetworkProcessState, |
||||||
|
reward::RewardEvaluation, |
||||||
|
sampling::{ForwardSampler, Sampler}, |
||||||
|
}; |
||||||
|
|
||||||
|
pub enum RewardCriteria { |
||||||
|
FiniteHorizon, |
||||||
|
InfiniteHorizon { discount_factor: f64 }, |
||||||
|
} |
||||||
|
|
||||||
|
pub struct MonteCarloReward { |
||||||
|
max_iterations: usize, |
||||||
|
max_err_stop: f64, |
||||||
|
alpha_stop: f64, |
||||||
|
end_time: f64, |
||||||
|
reward_criteria: RewardCriteria, |
||||||
|
seed: Option<u64>, |
||||||
|
} |
||||||
|
|
||||||
|
impl MonteCarloReward { |
||||||
|
pub fn new( |
||||||
|
max_iterations: usize, |
||||||
|
max_err_stop: f64, |
||||||
|
alpha_stop: f64, |
||||||
|
end_time: f64, |
||||||
|
reward_criteria: RewardCriteria, |
||||||
|
seed: Option<u64>, |
||||||
|
) -> MonteCarloReward { |
||||||
|
MonteCarloReward { |
||||||
|
max_iterations, |
||||||
|
max_err_stop, |
||||||
|
alpha_stop, |
||||||
|
end_time, |
||||||
|
reward_criteria, |
||||||
|
seed, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl RewardEvaluation for MonteCarloReward { |
||||||
|
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
) -> HashMap<process::NetworkProcessState, f64> { |
||||||
|
let variables_domain: Vec<Vec<params::StateType>> = network_process |
||||||
|
.get_node_indices() |
||||||
|
.map(|x| match network_process.get_node(x) { |
||||||
|
params::Params::DiscreteStatesContinousTime(x) => (0..x |
||||||
|
.get_reserved_space_as_parent()) |
||||||
|
.map(|s| params::StateType::Discrete(s)) |
||||||
|
.collect(), |
||||||
|
}) |
||||||
|
.collect(); |
||||||
|
|
||||||
|
let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); |
||||||
|
|
||||||
|
(0..n_states) |
||||||
|
.into_par_iter() |
||||||
|
.map(|s| { |
||||||
|
let state: process::NetworkProcessState = variables_domain |
||||||
|
.iter() |
||||||
|
.fold((s, vec![]), |acc, x| { |
||||||
|
let mut acc = acc; |
||||||
|
let idx_s = acc.0 % x.len(); |
||||||
|
acc.1.push(x[idx_s].clone()); |
||||||
|
acc.0 = acc.0 / x.len(); |
||||||
|
acc |
||||||
|
}) |
||||||
|
.1; |
||||||
|
|
||||||
|
let r = self.evaluate_state(network_process, reward_function, &state); |
||||||
|
(state, r) |
||||||
|
}) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
|
||||||
|
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
state: &NetworkProcessState, |
||||||
|
) -> f64 { |
||||||
|
let mut sampler = |
||||||
|
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); |
||||||
|
let mut expected_value = 0.0; |
||||||
|
let mut squared_expected_value = 0.0; |
||||||
|
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); |
||||||
|
|
||||||
|
for i in 0..self.max_iterations { |
||||||
|
sampler.reset(); |
||||||
|
let mut ret = 0.0; |
||||||
|
let mut previous = sampler.next().unwrap(); |
||||||
|
while previous.t < self.end_time { |
||||||
|
let current = sampler.next().unwrap(); |
||||||
|
if current.t > self.end_time { |
||||||
|
let r = reward_function.call(&previous.state, None); |
||||||
|
let discount = match self.reward_criteria { |
||||||
|
RewardCriteria::FiniteHorizon => self.end_time - previous.t, |
||||||
|
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||||
|
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||||
|
- std::f64::consts::E.powf(-discount_factor * self.end_time) |
||||||
|
} |
||||||
|
}; |
||||||
|
ret += discount * r.instantaneous_reward; |
||||||
|
} else { |
||||||
|
let r = reward_function.call(&previous.state, Some(¤t.state)); |
||||||
|
let discount = match self.reward_criteria { |
||||||
|
RewardCriteria::FiniteHorizon => current.t - previous.t, |
||||||
|
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||||
|
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||||
|
- std::f64::consts::E.powf(-discount_factor * current.t) |
||||||
|
} |
||||||
|
}; |
||||||
|
ret += discount * r.instantaneous_reward; |
||||||
|
ret += match self.reward_criteria { |
||||||
|
RewardCriteria::FiniteHorizon => 1.0, |
||||||
|
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||||
|
std::f64::consts::E.powf(-discount_factor * current.t) |
||||||
|
} |
||||||
|
} * r.transition_reward; |
||||||
|
} |
||||||
|
previous = current; |
||||||
|
} |
||||||
|
|
||||||
|
let float_i = i as f64; |
||||||
|
expected_value = |
||||||
|
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); |
||||||
|
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) |
||||||
|
+ ret.powi(2) / (float_i + 1.0); |
||||||
|
|
||||||
|
if i > 2 { |
||||||
|
let var = |
||||||
|
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); |
||||||
|
if self.alpha_stop |
||||||
|
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) |
||||||
|
> 0.0 |
||||||
|
{ |
||||||
|
return expected_value; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
expected_value |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> { |
||||||
|
inner_reward: RE, |
||||||
|
} |
||||||
|
|
||||||
|
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> { |
||||||
|
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> { |
||||||
|
NeighborhoodRelativeReward { inner_reward } |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> { |
||||||
|
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
) -> HashMap<process::NetworkProcessState, f64> { |
||||||
|
let absolute_reward = self |
||||||
|
.inner_reward |
||||||
|
.evaluate_state_space(network_process, reward_function); |
||||||
|
|
||||||
|
//This approach optimize memory. Maybe optimizing execution time can be better.
|
||||||
|
absolute_reward |
||||||
|
.iter() |
||||||
|
.map(|(k1, v1)| { |
||||||
|
let mut max_val: f64 = 1.0; |
||||||
|
absolute_reward.iter().for_each(|(k2, v2)| { |
||||||
|
let count_diff: usize = k1 |
||||||
|
.iter() |
||||||
|
.zip(k2.iter()) |
||||||
|
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) |
||||||
|
.sum(); |
||||||
|
if count_diff < 2 { |
||||||
|
max_val = max_val.max(v1 / v2); |
||||||
|
} |
||||||
|
}); |
||||||
|
(k1.clone(), max_val) |
||||||
|
}) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
|
||||||
|
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
_network_process: &N, |
||||||
|
_reward_function: &R, |
||||||
|
_state: &process::NetworkProcessState, |
||||||
|
) -> f64 { |
||||||
|
unimplemented!(); |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,122 @@ |
|||||||
|
mod utils; |
||||||
|
|
||||||
|
use approx::assert_abs_diff_eq; |
||||||
|
use ndarray::*; |
||||||
|
use reCTBN::{ |
||||||
|
params, |
||||||
|
process::{ctbn::*, NetworkProcess, NetworkProcessState}, |
||||||
|
reward::{reward_evaluation::*, reward_function::*, *}, |
||||||
|
}; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_binary_node_mc() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1) |
||||||
|
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1) |
||||||
|
.assign(&arr1(&[3.0, 3.0])); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
net.initialize_adj_matrix(); |
||||||
|
|
||||||
|
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||||
|
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||||
|
|
||||||
|
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||||
|
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||||
|
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||||
|
|
||||||
|
let rst = mc.evaluate_state_space(&net, &rf); |
||||||
|
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); |
||||||
|
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); |
||||||
|
|
||||||
|
|
||||||
|
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); |
||||||
|
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||||
|
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||||
|
|
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_chain_mc() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let n3 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
net.add_edge(n1, n2); |
||||||
|
net.add_edge(n2, n3); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param |
||||||
|
.set_cim(arr3(&[ |
||||||
|
[[-0.01, 0.01], [5.0, -5.0]], |
||||||
|
[[-5.0, 5.0], [0.01, -0.01]], |
||||||
|
])) |
||||||
|
.unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
match &mut net.get_node_mut(n3) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param |
||||||
|
.set_cim(arr3(&[ |
||||||
|
[[-0.01, 0.01], [5.0, -5.0]], |
||||||
|
[[-5.0, 5.0], [0.01, -0.01]], |
||||||
|
])) |
||||||
|
.unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n2) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n3) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
let s000: NetworkProcessState = vec![ |
||||||
|
params::StateType::Discrete(1), |
||||||
|
params::StateType::Discrete(0), |
||||||
|
params::StateType::Discrete(0), |
||||||
|
]; |
||||||
|
|
||||||
|
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||||
|
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); |
||||||
|
|
||||||
|
let rst = mc.evaluate_state_space(&net, &rf); |
||||||
|
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); |
||||||
|
|
||||||
|
} |
Loading…
Reference in new issue