Implemented reward_evaluation for an entire process.

pull/87/head
AlessandroBregoli 2 years ago
parent cecf16a771
commit bb239aaa0c
  1. 2
      reCTBN/src/params.rs
  2. 9
      reCTBN/src/reward.rs
  3. 41
      reCTBN/src/reward/reward_evaluation.rs
  4. 17
      reCTBN/tests/reward_evaluation.rs

@ -20,7 +20,7 @@ pub enum ParamsError {
}
/// Allowed type of states
#[derive(Clone)]
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub enum StateType {
Discrete(usize),
}

@ -1,6 +1,8 @@
pub mod reward_function;
pub mod reward_evaluation;
use std::collections::HashMap;
use crate::process;
use ndarray;
@ -43,12 +45,13 @@ pub trait RewardFunction {
}
pub trait RewardEvaluation {
fn call<N: process::NetworkProcess, R: RewardFunction>(
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> ndarray::Array1<f64>;
fn call_state<N: process::NetworkProcess, R: RewardFunction>(
) -> HashMap<process::NetworkProcessState, f64>;
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,

@ -1,7 +1,12 @@
use std::collections::HashMap;
use crate::params::{self, ParamsTrait};
use crate::process;
use crate::{
process::NetworkProcessState,
reward::RewardEvaluation,
sampling::{ForwardSampler, Sampler},
process::NetworkProcessState
};
pub struct MonteCarloDiscountedRward {
@ -28,21 +33,42 @@ impl MonteCarloDiscountedRward {
}
impl RewardEvaluation for MonteCarloDiscountedRward {
fn call<N: crate::process::NetworkProcess, R: super::RewardFunction>(
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> ndarray::Array1<f64> {
todo!()
) -> HashMap<process::NetworkProcessState, f64> {
let variables_domain: Vec<Vec<params::StateType>> = network_process
.get_node_indices()
.map(|x| match network_process.get_node(x) {
params::Params::DiscreteStatesContinousTime(x) =>
(0..x.get_reserved_space_as_parent()).map(|s| params::StateType::Discrete(s)).collect()
}).collect();
let n_states:usize = variables_domain.iter().map(|x| x.len()).product();
(0..n_states).map(|s| {
let state: process::NetworkProcessState = variables_domain.iter().fold((s, vec![]), |acc, x| {
let mut acc = acc;
let idx_s = acc.0%x.len();
acc.1.push(x[idx_s].clone());
acc.0 = acc.0 / x.len();
acc
}).1;
let r = self.evaluate_state(network_process, reward_function, &state);
(state, r)
}).collect()
}
fn call_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &NetworkProcessState,
) -> f64 {
let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone()));
let mut sampler =
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone()));
let mut ret = 0.0;
for _i in 0..self.n_iterations {
@ -60,7 +86,8 @@ impl RewardEvaluation for MonteCarloDiscountedRward {
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t)
- std::f64::consts::E.powf(-self.discount_factor * current.t);
ret += discount * r.instantaneous_reward;
ret += std::f64::consts::E.powf(-self.discount_factor * current.t) * r.transition_reward;
ret += std::f64::consts::E.powf(-self.discount_factor * current.t)
* r.transition_reward;
}
previous = current;
}

@ -34,8 +34,13 @@ fn simple_factored_reward_function_binary_node_MC() {
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215));
assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s1), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2);
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2);
}
#[test]
@ -102,6 +107,10 @@ fn simple_factored_reward_function_chain_MC() {
params::StateType::Discrete(0),
];
let mc = MonteCarloDiscountedRward::new(10000, 100.0, 1.0, Some(215));
assert_abs_diff_eq!(2.447, mc.call_state(&net, &rf, &s000), epsilon = 1e-1);
let mc = MonteCarloDiscountedRward::new(1000, 10.0, 1.0, Some(215));
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1);
}

Loading…
Cancel
Save