diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index 67baa5e..9dcb6d2 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -9,30 +9,35 @@ use crate::{ sampling::{ForwardSampler, Sampler}, }; -pub struct MonteCarloDiscountedRward { +pub enum RewardCriteria { + FiniteHorizon, + InfiniteHorizon {discount_factor: f64}, +} + +pub struct MonteCarloRward { n_iterations: usize, end_time: f64, - discount_factor: f64, + reward_criteria: RewardCriteria, seed: Option, } -impl MonteCarloDiscountedRward { +impl MonteCarloRward { pub fn new( n_iterations: usize, end_time: f64, - discount_factor: f64, + reward_criteria: RewardCriteria, seed: Option, - ) -> MonteCarloDiscountedRward { - MonteCarloDiscountedRward { + ) -> MonteCarloRward { + MonteCarloRward { n_iterations, end_time, - discount_factor, + reward_criteria, seed, } } } -impl RewardEvaluation for MonteCarloDiscountedRward { +impl RewardEvaluation for MonteCarloRward { fn evaluate_state_space( &self, network_process: &N, @@ -41,24 +46,32 @@ impl RewardEvaluation for MonteCarloDiscountedRward { let variables_domain: Vec> = network_process .get_node_indices() .map(|x| match network_process.get_node(x) { - params::Params::DiscreteStatesContinousTime(x) => - (0..x.get_reserved_space_as_parent()).map(|s| params::StateType::Discrete(s)).collect() - }).collect(); + params::Params::DiscreteStatesContinousTime(x) => (0..x + .get_reserved_space_as_parent()) + .map(|s| params::StateType::Discrete(s)) + .collect(), + }) + .collect(); + + let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); - let n_states:usize = variables_domain.iter().map(|x| x.len()).product(); - - (0..n_states).map(|s| { - let state: process::NetworkProcessState = variables_domain.iter().fold((s, vec![]), |acc, x| { - let mut acc = acc; - let idx_s = acc.0%x.len(); - acc.1.push(x[idx_s].clone()); - acc.0 = acc.0 / x.len(); - acc - }).1; + (0..n_states) + .map(|s| { + let state: process::NetworkProcessState = variables_domain + .iter() + .fold((s, vec![]), |acc, x| { + let mut acc = acc; + let idx_s = acc.0 % x.len(); + acc.1.push(x[idx_s].clone()); + acc.0 = acc.0 / x.len(); + acc + }) + .1; - let r = self.evaluate_state(network_process, reward_function, &state); - (state, r) - }).collect() + let r = self.evaluate_state(network_process, reward_function, &state); + (state, r) + }) + .collect() } fn evaluate_state( @@ -78,16 +91,30 @@ impl RewardEvaluation for MonteCarloDiscountedRward { let current = sampler.next().unwrap(); if current.t > self.end_time { let r = reward_function.call(&previous.state, None); - let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) - - std::f64::consts::E.powf(-self.discount_factor * self.end_time); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => self.end_time - previous.t, + RewardCriteria::InfiniteHorizon {discount_factor} => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * self.end_time) + } + }; ret += discount * r.instantaneous_reward; } else { let r = reward_function.call(&previous.state, Some(¤t.state)); - let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) - - std::f64::consts::E.powf(-self.discount_factor * current.t); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => current.t-previous.t, + RewardCriteria::InfiniteHorizon {discount_factor} => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * current.t) + } + }; ret += discount * r.instantaneous_reward; - ret += std::f64::consts::E.powf(-self.discount_factor * current.t) - * r.transition_reward; + ret += match self.reward_criteria { + RewardCriteria::FiniteHorizon => 1.0, + RewardCriteria::InfiniteHorizon {discount_factor} => { + std::f64::consts::E.powf(-discount_factor * current.t) + } + } * r.transition_reward; } previous = current; } diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index 1650507..b2cfd29 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -33,7 +33,7 @@ fn simple_factored_reward_function_binary_node_MC() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); + let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -41,6 +41,12 @@ fn simple_factored_reward_function_binary_node_MC() { assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); + + let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215)); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + } #[test] @@ -107,7 +113,7 @@ fn simple_factored_reward_function_chain_MC() { params::StateType::Discrete(0), ]; - let mc = MonteCarloDiscountedRward::new(1000, 10.0, 1.0, Some(215)); + let mc = MonteCarloRward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); let rst = mc.evaluate_state_space(&net, &rf);