From 9284ca5dd2facaa30a7439ca9e9f00d2778479a4 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 5 Dec 2022 15:24:32 +0100 Subject: [PATCH] Implemanted NeighborhoodRelativeReward --- reCTBN/src/reward/reward_evaluation.rs | 54 +++++++++++++++++++++++--- reCTBN/tests/reward_evaluation.rs | 6 +-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index 9dcb6d2..cb7b8f1 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -14,21 +14,21 @@ pub enum RewardCriteria { InfiniteHorizon {discount_factor: f64}, } -pub struct MonteCarloRward { +pub struct MonteCarloReward { n_iterations: usize, end_time: f64, reward_criteria: RewardCriteria, seed: Option, } -impl MonteCarloRward { +impl MonteCarloReward { pub fn new( n_iterations: usize, end_time: f64, reward_criteria: RewardCriteria, seed: Option, - ) -> MonteCarloRward { - MonteCarloRward { + ) -> MonteCarloReward { + MonteCarloReward { n_iterations, end_time, reward_criteria, @@ -37,7 +37,7 @@ impl MonteCarloRward { } } -impl RewardEvaluation for MonteCarloRward { +impl RewardEvaluation for MonteCarloReward { fn evaluate_state_space( &self, network_process: &N, @@ -123,3 +123,47 @@ impl RewardEvaluation for MonteCarloRward { ret / self.n_iterations as f64 } } + +pub struct NeighborhoodRelativeReward { + inner_reward: RE +} + +impl NeighborhoodRelativeReward{ + pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward{ + NeighborhoodRelativeReward {inner_reward} + } +} + +impl RewardEvaluation for NeighborhoodRelativeReward { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap { + + let absolute_reward = self.inner_reward.evaluate_state_space(network_process, reward_function); + + //This approach optimize memory. Maybe optimizing execution time can be better. + absolute_reward.iter().map(|(k1, v1)| { + let mut max_val:f64 = 1.0; + absolute_reward.iter().for_each(|(k2,v2)| { + let count_diff:usize = k1.iter().zip(k2.iter()).map(|(s1, s2)| if s1 == s2 {0} else {1}).sum(); + if count_diff < 2 { + max_val = max_val.max(v1/v2); + } + + }); + (k1.clone(), max_val) + }).collect() + + } + + fn evaluate_state( + &self, + _network_process: &N, + _reward_function: &R, + _state: &process::NetworkProcessState, + ) -> f64 { + unimplemented!(); + } +} diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index b2cfd29..63e9c98 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -33,7 +33,7 @@ fn simple_factored_reward_function_binary_node_MC() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + let mc = MonteCarloReward::new(100, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -42,7 +42,7 @@ fn simple_factored_reward_function_binary_node_MC() { assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); - let mc = MonteCarloRward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215)); + let mc = MonteCarloReward::new(100, 10.0, RewardCriteria::FiniteHorizon, Some(215)); assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); @@ -113,7 +113,7 @@ fn simple_factored_reward_function_chain_MC() { params::StateType::Discrete(0), ]; - let mc = MonteCarloRward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + let mc = MonteCarloReward::new(1000, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); let rst = mc.evaluate_state_space(&net, &rf);