|
|
@ -9,30 +9,35 @@ use crate::{ |
|
|
|
sampling::{ForwardSampler, Sampler}, |
|
|
|
sampling::{ForwardSampler, Sampler}, |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
pub struct MonteCarloDiscountedRward { |
|
|
|
pub enum RewardCriteria { |
|
|
|
|
|
|
|
FiniteHorizon, |
|
|
|
|
|
|
|
InfiniteHorizon {discount_factor: f64}, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub struct MonteCarloRward { |
|
|
|
n_iterations: usize, |
|
|
|
n_iterations: usize, |
|
|
|
end_time: f64, |
|
|
|
end_time: f64, |
|
|
|
discount_factor: f64, |
|
|
|
reward_criteria: RewardCriteria, |
|
|
|
seed: Option<u64>, |
|
|
|
seed: Option<u64>, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
impl MonteCarloDiscountedRward { |
|
|
|
impl MonteCarloRward { |
|
|
|
pub fn new( |
|
|
|
pub fn new( |
|
|
|
n_iterations: usize, |
|
|
|
n_iterations: usize, |
|
|
|
end_time: f64, |
|
|
|
end_time: f64, |
|
|
|
discount_factor: f64, |
|
|
|
reward_criteria: RewardCriteria, |
|
|
|
seed: Option<u64>, |
|
|
|
seed: Option<u64>, |
|
|
|
) -> MonteCarloDiscountedRward { |
|
|
|
) -> MonteCarloRward { |
|
|
|
MonteCarloDiscountedRward { |
|
|
|
MonteCarloRward { |
|
|
|
n_iterations, |
|
|
|
n_iterations, |
|
|
|
end_time, |
|
|
|
end_time, |
|
|
|
discount_factor, |
|
|
|
reward_criteria, |
|
|
|
seed, |
|
|
|
seed, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
impl RewardEvaluation for MonteCarloDiscountedRward { |
|
|
|
impl RewardEvaluation for MonteCarloRward { |
|
|
|
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
|
|
|
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
|
|
|
&self, |
|
|
|
&self, |
|
|
|
network_process: &N, |
|
|
|
network_process: &N, |
|
|
@ -41,24 +46,32 @@ impl RewardEvaluation for MonteCarloDiscountedRward { |
|
|
|
let variables_domain: Vec<Vec<params::StateType>> = network_process |
|
|
|
let variables_domain: Vec<Vec<params::StateType>> = network_process |
|
|
|
.get_node_indices() |
|
|
|
.get_node_indices() |
|
|
|
.map(|x| match network_process.get_node(x) { |
|
|
|
.map(|x| match network_process.get_node(x) { |
|
|
|
params::Params::DiscreteStatesContinousTime(x) =>
|
|
|
|
params::Params::DiscreteStatesContinousTime(x) => (0..x |
|
|
|
(0..x.get_reserved_space_as_parent()).map(|s| params::StateType::Discrete(s)).collect() |
|
|
|
.get_reserved_space_as_parent()) |
|
|
|
}).collect(); |
|
|
|
.map(|s| params::StateType::Discrete(s)) |
|
|
|
|
|
|
|
.collect(), |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
.collect(); |
|
|
|
|
|
|
|
|
|
|
|
let n_states:usize = variables_domain.iter().map(|x| x.len()).product(); |
|
|
|
let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); |
|
|
|
|
|
|
|
|
|
|
|
(0..n_states).map(|s| { |
|
|
|
(0..n_states) |
|
|
|
let state: process::NetworkProcessState = variables_domain.iter().fold((s, vec![]), |acc, x| { |
|
|
|
.map(|s| { |
|
|
|
|
|
|
|
let state: process::NetworkProcessState = variables_domain |
|
|
|
|
|
|
|
.iter() |
|
|
|
|
|
|
|
.fold((s, vec![]), |acc, x| { |
|
|
|
let mut acc = acc; |
|
|
|
let mut acc = acc; |
|
|
|
let idx_s = acc.0%x.len(); |
|
|
|
let idx_s = acc.0 % x.len(); |
|
|
|
acc.1.push(x[idx_s].clone()); |
|
|
|
acc.1.push(x[idx_s].clone()); |
|
|
|
acc.0 = acc.0 / x.len(); |
|
|
|
acc.0 = acc.0 / x.len(); |
|
|
|
acc |
|
|
|
acc |
|
|
|
}).1; |
|
|
|
}) |
|
|
|
|
|
|
|
.1; |
|
|
|
|
|
|
|
|
|
|
|
let r = self.evaluate_state(network_process, reward_function, &state); |
|
|
|
let r = self.evaluate_state(network_process, reward_function, &state); |
|
|
|
(state, r) |
|
|
|
(state, r) |
|
|
|
}).collect() |
|
|
|
}) |
|
|
|
|
|
|
|
.collect() |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
|
|
|
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
|
|
@ -78,16 +91,30 @@ impl RewardEvaluation for MonteCarloDiscountedRward { |
|
|
|
let current = sampler.next().unwrap(); |
|
|
|
let current = sampler.next().unwrap(); |
|
|
|
if current.t > self.end_time { |
|
|
|
if current.t > self.end_time { |
|
|
|
let r = reward_function.call(&previous.state, None); |
|
|
|
let r = reward_function.call(&previous.state, None); |
|
|
|
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) |
|
|
|
let discount = match self.reward_criteria { |
|
|
|
- std::f64::consts::E.powf(-self.discount_factor * self.end_time); |
|
|
|
RewardCriteria::FiniteHorizon => self.end_time - previous.t, |
|
|
|
|
|
|
|
RewardCriteria::InfiniteHorizon {discount_factor} => { |
|
|
|
|
|
|
|
std::f64::consts::E.powf(-discount_factor * previous.t) |
|
|
|
|
|
|
|
- std::f64::consts::E.powf(-discount_factor * self.end_time) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
}; |
|
|
|
ret += discount * r.instantaneous_reward; |
|
|
|
ret += discount * r.instantaneous_reward; |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
let r = reward_function.call(&previous.state, Some(¤t.state)); |
|
|
|
let r = reward_function.call(&previous.state, Some(¤t.state)); |
|
|
|
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) |
|
|
|
let discount = match self.reward_criteria { |
|
|
|
- std::f64::consts::E.powf(-self.discount_factor * current.t); |
|
|
|
RewardCriteria::FiniteHorizon => current.t-previous.t, |
|
|
|
|
|
|
|
RewardCriteria::InfiniteHorizon {discount_factor} => { |
|
|
|
|
|
|
|
std::f64::consts::E.powf(-discount_factor * previous.t) |
|
|
|
|
|
|
|
- std::f64::consts::E.powf(-discount_factor * current.t) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
}; |
|
|
|
ret += discount * r.instantaneous_reward; |
|
|
|
ret += discount * r.instantaneous_reward; |
|
|
|
ret += std::f64::consts::E.powf(-self.discount_factor * current.t) |
|
|
|
ret += match self.reward_criteria { |
|
|
|
* r.transition_reward; |
|
|
|
RewardCriteria::FiniteHorizon => 1.0, |
|
|
|
|
|
|
|
RewardCriteria::InfiniteHorizon {discount_factor} => { |
|
|
|
|
|
|
|
std::f64::consts::E.powf(-discount_factor * current.t) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} * r.transition_reward; |
|
|
|
} |
|
|
|
} |
|
|
|
previous = current; |
|
|
|
previous = current; |
|
|
|
} |
|
|
|
} |
|
|
|