parent
4fc5c1d4b5
commit
cecf16a771
@ -0,0 +1,71 @@ |
|||||||
|
use crate::{ |
||||||
|
reward::RewardEvaluation, |
||||||
|
sampling::{ForwardSampler, Sampler}, |
||||||
|
process::NetworkProcessState |
||||||
|
}; |
||||||
|
|
||||||
|
pub struct MonteCarloDiscountedRward { |
||||||
|
n_iterations: usize, |
||||||
|
end_time: f64, |
||||||
|
discount_factor: f64, |
||||||
|
seed: Option<u64>, |
||||||
|
} |
||||||
|
|
||||||
|
impl MonteCarloDiscountedRward { |
||||||
|
pub fn new( |
||||||
|
n_iterations: usize, |
||||||
|
end_time: f64, |
||||||
|
discount_factor: f64, |
||||||
|
seed: Option<u64>, |
||||||
|
) -> MonteCarloDiscountedRward { |
||||||
|
MonteCarloDiscountedRward { |
||||||
|
n_iterations, |
||||||
|
end_time, |
||||||
|
discount_factor, |
||||||
|
seed, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl RewardEvaluation for MonteCarloDiscountedRward { |
||||||
|
fn call<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
) -> ndarray::Array1<f64> { |
||||||
|
todo!() |
||||||
|
} |
||||||
|
|
||||||
|
fn call_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
state: &NetworkProcessState, |
||||||
|
) -> f64 { |
||||||
|
let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); |
||||||
|
let mut ret = 0.0; |
||||||
|
|
||||||
|
for _i in 0..self.n_iterations { |
||||||
|
sampler.reset(); |
||||||
|
let mut previous = sampler.next().unwrap(); |
||||||
|
while previous.t < self.end_time { |
||||||
|
let current = sampler.next().unwrap(); |
||||||
|
if current.t > self.end_time { |
||||||
|
let r = reward_function.call(&previous.state, None); |
||||||
|
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) |
||||||
|
- std::f64::consts::E.powf(-self.discount_factor * self.end_time); |
||||||
|
ret += discount * r.instantaneous_reward; |
||||||
|
} else { |
||||||
|
let r = reward_function.call(&previous.state, Some(¤t.state)); |
||||||
|
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) |
||||||
|
- std::f64::consts::E.powf(-self.discount_factor * current.t); |
||||||
|
ret += discount * r.instantaneous_reward; |
||||||
|
ret += std::f64::consts::E.powf(-self.discount_factor * current.t) * r.transition_reward; |
||||||
|
} |
||||||
|
previous = current; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
ret / self.n_iterations as f64 |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,107 @@ |
|||||||
|
mod utils; |
||||||
|
|
||||||
|
use approx::{abs_diff_eq, assert_abs_diff_eq}; |
||||||
|
use ndarray::*; |
||||||
|
use reCTBN::{ |
||||||
|
params, |
||||||
|
process::{ctbn::*, NetworkProcess, NetworkProcessState}, |
||||||
|
reward::{reward_evaluation::*, reward_function::*, *}, |
||||||
|
}; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_binary_node_MC() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1) |
||||||
|
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1) |
||||||
|
.assign(&arr1(&[3.0, 3.0])); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
net.initialize_adj_matrix(); |
||||||
|
|
||||||
|
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||||
|
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||||
|
|
||||||
|
let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); |
||||||
|
assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s0), epsilon = 1e-2); |
||||||
|
assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s1), epsilon = 1e-2); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_chain_MC() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let n3 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
net.add_edge(n1, n2); |
||||||
|
net.add_edge(n2, n3); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param |
||||||
|
.set_cim(arr3(&[ |
||||||
|
[[-0.01, 0.01], [5.0, -5.0]], |
||||||
|
[[-5.0, 5.0], [0.01, -0.01]], |
||||||
|
])) |
||||||
|
.unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
match &mut net.get_node_mut(n3) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param |
||||||
|
.set_cim(arr3(&[ |
||||||
|
[[-0.01, 0.01], [5.0, -5.0]], |
||||||
|
[[-5.0, 5.0], [0.01, -0.01]], |
||||||
|
])) |
||||||
|
.unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n2) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n3) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
let s000: NetworkProcessState = vec![ |
||||||
|
params::StateType::Discrete(1), |
||||||
|
params::StateType::Discrete(0), |
||||||
|
params::StateType::Discrete(0), |
||||||
|
]; |
||||||
|
|
||||||
|
let mc = MonteCarloDiscountedRward::new(10000, 100.0, 1.0, Some(215)); |
||||||
|
assert_abs_diff_eq!(2.447, mc.call_state(&net, &rf, &s000), epsilon = 1e-1); |
||||||
|
} |
Loading…
Reference in new issue