diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 9f63860..3d08273 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -20,7 +20,7 @@ pub enum ParamsError { } /// Allowed type of states -#[derive(Clone)] +#[derive(Clone, Hash, PartialEq, Eq, Debug)] pub enum StateType { Discrete(usize), } diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs index 1ea575c..b34db7f 100644 --- a/reCTBN/src/reward.rs +++ b/reCTBN/src/reward.rs @@ -1,6 +1,8 @@ pub mod reward_function; pub mod reward_evaluation; +use std::collections::HashMap; + use crate::process; use ndarray; @@ -43,12 +45,13 @@ pub trait RewardFunction { } pub trait RewardEvaluation { - fn call( + fn evaluate_state_space( &self, network_process: &N, reward_function: &R, - ) -> ndarray::Array1; - fn call_state( + ) -> HashMap; + + fn evaluate_state( &self, network_process: &N, reward_function: &R, diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs index fca7c1a..67baa5e 100644 --- a/reCTBN/src/reward/reward_evaluation.rs +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -1,7 +1,12 @@ +use std::collections::HashMap; + +use crate::params::{self, ParamsTrait}; +use crate::process; + use crate::{ + process::NetworkProcessState, reward::RewardEvaluation, sampling::{ForwardSampler, Sampler}, - process::NetworkProcessState }; pub struct MonteCarloDiscountedRward { @@ -28,21 +33,42 @@ impl MonteCarloDiscountedRward { } impl RewardEvaluation for MonteCarloDiscountedRward { - fn call( + fn evaluate_state_space( &self, network_process: &N, reward_function: &R, - ) -> ndarray::Array1 { - todo!() + ) -> HashMap { + let variables_domain: Vec> = network_process + .get_node_indices() + .map(|x| match network_process.get_node(x) { + params::Params::DiscreteStatesContinousTime(x) => + (0..x.get_reserved_space_as_parent()).map(|s| params::StateType::Discrete(s)).collect() + }).collect(); + + let n_states:usize = variables_domain.iter().map(|x| x.len()).product(); + + (0..n_states).map(|s| { + let state: process::NetworkProcessState = variables_domain.iter().fold((s, vec![]), |acc, x| { + let mut acc = acc; + let idx_s = acc.0%x.len(); + acc.1.push(x[idx_s].clone()); + acc.0 = acc.0 / x.len(); + acc + }).1; + + let r = self.evaluate_state(network_process, reward_function, &state); + (state, r) + }).collect() } - fn call_state( + fn evaluate_state( &self, network_process: &N, reward_function: &R, state: &NetworkProcessState, ) -> f64 { - let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); + let mut sampler = + ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); let mut ret = 0.0; for _i in 0..self.n_iterations { @@ -60,7 +86,8 @@ impl RewardEvaluation for MonteCarloDiscountedRward { let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) - std::f64::consts::E.powf(-self.discount_factor * current.t); ret += discount * r.instantaneous_reward; - ret += std::f64::consts::E.powf(-self.discount_factor * current.t) * r.transition_reward; + ret += std::f64::consts::E.powf(-self.discount_factor * current.t) + * r.transition_reward; } previous = current; } diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs index f3938e7..1650507 100644 --- a/reCTBN/tests/reward_evaluation.rs +++ b/reCTBN/tests/reward_evaluation.rs @@ -34,8 +34,13 @@ fn simple_factored_reward_function_binary_node_MC() { let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); - assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s0), epsilon = 1e-2); - assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s1), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); + assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); + } #[test] @@ -102,6 +107,10 @@ fn simple_factored_reward_function_chain_MC() { params::StateType::Discrete(0), ]; - let mc = MonteCarloDiscountedRward::new(10000, 100.0, 1.0, Some(215)); - assert_abs_diff_eq!(2.447, mc.call_state(&net, &rf, &s000), epsilon = 1e-1); + let mc = MonteCarloDiscountedRward::new(1000, 10.0, 1.0, Some(215)); + assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); + }