Added FiniteHorizon

pull/87/head
AlessandroBregoli 2 years ago
parent bb239aaa0c
commit 687f19ff1f
  1. 73
      reCTBN/src/reward/reward_evaluation.rs
  2. 10
      reCTBN/tests/reward_evaluation.rs

@ -9,30 +9,35 @@ use crate::{
sampling::{ForwardSampler, Sampler}, sampling::{ForwardSampler, Sampler},
}; };
pub struct MonteCarloDiscountedRward { pub enum RewardCriteria {
FiniteHorizon,
InfiniteHorizon {discount_factor: f64},
}
pub struct MonteCarloRward {
n_iterations: usize, n_iterations: usize,
end_time: f64, end_time: f64,
discount_factor: f64, reward_criteria: RewardCriteria,
seed: Option<u64>, seed: Option<u64>,
} }
impl MonteCarloDiscountedRward { impl MonteCarloRward {
pub fn new( pub fn new(
n_iterations: usize, n_iterations: usize,
end_time: f64, end_time: f64,
discount_factor: f64, reward_criteria: RewardCriteria,
seed: Option<u64>, seed: Option<u64>,
) -> MonteCarloDiscountedRward { ) -> MonteCarloRward {
MonteCarloDiscountedRward { MonteCarloRward {
n_iterations, n_iterations,
end_time, end_time,
discount_factor, reward_criteria,
seed, seed,
} }
} }
} }
impl RewardEvaluation for MonteCarloDiscountedRward { impl RewardEvaluation for MonteCarloRward {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self, &self,
network_process: &N, network_process: &N,
@ -41,24 +46,32 @@ impl RewardEvaluation for MonteCarloDiscountedRward {
let variables_domain: Vec<Vec<params::StateType>> = network_process let variables_domain: Vec<Vec<params::StateType>> = network_process
.get_node_indices() .get_node_indices()
.map(|x| match network_process.get_node(x) { .map(|x| match network_process.get_node(x) {
params::Params::DiscreteStatesContinousTime(x) => params::Params::DiscreteStatesContinousTime(x) => (0..x
(0..x.get_reserved_space_as_parent()).map(|s| params::StateType::Discrete(s)).collect() .get_reserved_space_as_parent())
}).collect(); .map(|s| params::StateType::Discrete(s))
.collect(),
})
.collect();
let n_states:usize = variables_domain.iter().map(|x| x.len()).product(); let n_states: usize = variables_domain.iter().map(|x| x.len()).product();
(0..n_states).map(|s| { (0..n_states)
let state: process::NetworkProcessState = variables_domain.iter().fold((s, vec![]), |acc, x| { .map(|s| {
let state: process::NetworkProcessState = variables_domain
.iter()
.fold((s, vec![]), |acc, x| {
let mut acc = acc; let mut acc = acc;
let idx_s = acc.0%x.len(); let idx_s = acc.0 % x.len();
acc.1.push(x[idx_s].clone()); acc.1.push(x[idx_s].clone());
acc.0 = acc.0 / x.len(); acc.0 = acc.0 / x.len();
acc acc
}).1; })
.1;
let r = self.evaluate_state(network_process, reward_function, &state); let r = self.evaluate_state(network_process, reward_function, &state);
(state, r) (state, r)
}).collect() })
.collect()
} }
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
@ -78,16 +91,30 @@ impl RewardEvaluation for MonteCarloDiscountedRward {
let current = sampler.next().unwrap(); let current = sampler.next().unwrap();
if current.t > self.end_time { if current.t > self.end_time {
let r = reward_function.call(&previous.state, None); let r = reward_function.call(&previous.state, None);
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) let discount = match self.reward_criteria {
- std::f64::consts::E.powf(-self.discount_factor * self.end_time); RewardCriteria::FiniteHorizon => self.end_time - previous.t,
RewardCriteria::InfiniteHorizon {discount_factor} => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * self.end_time)
}
};
ret += discount * r.instantaneous_reward; ret += discount * r.instantaneous_reward;
} else { } else {
let r = reward_function.call(&previous.state, Some(&current.state)); let r = reward_function.call(&previous.state, Some(&current.state));
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) let discount = match self.reward_criteria {
- std::f64::consts::E.powf(-self.discount_factor * current.t); RewardCriteria::FiniteHorizon => current.t-previous.t,
RewardCriteria::InfiniteHorizon {discount_factor} => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * current.t)
}
};
ret += discount * r.instantaneous_reward; ret += discount * r.instantaneous_reward;
ret += std::f64::consts::E.powf(-self.discount_factor * current.t) ret += match self.reward_criteria {
* r.transition_reward; RewardCriteria::FiniteHorizon => 1.0,
RewardCriteria::InfiniteHorizon {discount_factor} => {
std::f64::consts::E.powf(-discount_factor * current.t)
}
} * r.transition_reward;
} }
previous = current; previous = current;
} }

@ -33,7 +33,7 @@ fn simple_factored_reward_function_binary_node_MC() {
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
@ -41,6 +41,12 @@ fn simple_factored_reward_function_binary_node_MC() {
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2);
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2);
let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215));
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
} }
#[test] #[test]
@ -107,7 +113,7 @@ fn simple_factored_reward_function_chain_MC() {
params::StateType::Discrete(0), params::StateType::Discrete(0),
]; ];
let mc = MonteCarloDiscountedRward::new(1000, 10.0, 1.0, Some(215)); let mc = MonteCarloRward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1);
let rst = mc.evaluate_state_space(&net, &rf); let rst = mc.evaluate_state_space(&net, &rf);

Loading…
Cancel
Save