parent
1878f687d6
commit
055eb7088e
@ -0,0 +1,80 @@ |
|||||||
|
use crate::{process, sampling, params::{ParamsTrait, self}}; |
||||||
|
use ndarray; |
||||||
|
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)] |
||||||
|
pub struct Reward { |
||||||
|
pub transition_reward: f64, |
||||||
|
pub instantaneous_reward: f64 |
||||||
|
} |
||||||
|
|
||||||
|
pub trait RewardFunction { |
||||||
|
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward; |
||||||
|
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
pub struct FactoredRewardFunction { |
||||||
|
transition_reward: Vec<ndarray::Array2<f64>>, |
||||||
|
instantaneous_reward: Vec<ndarray::Array1<f64>> |
||||||
|
} |
||||||
|
|
||||||
|
impl FactoredRewardFunction { |
||||||
|
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
||||||
|
&self.transition_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
||||||
|
&mut self.transition_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
||||||
|
&self.instantaneous_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
||||||
|
&mut self.instantaneous_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
impl RewardFunction for FactoredRewardFunction { |
||||||
|
|
||||||
|
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward { |
||||||
|
let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { |
||||||
|
let x = match x {params::StateType::Discrete(x) => x}; |
||||||
|
self.instantaneous_reward[idx][*x] |
||||||
|
}).sum(); |
||||||
|
if let Some(previous_state) = previous_state { |
||||||
|
let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option<f64> { |
||||||
|
let p = match p {params::StateType::Discrete(p) => p}; |
||||||
|
let c = match c {params::StateType::Discrete(c) => c}; |
||||||
|
if p != c { |
||||||
|
Some(self.transition_reward[idx][[*p,*c]]) |
||||||
|
} else { |
||||||
|
None |
||||||
|
} |
||||||
|
}).unwrap_or(0.0); |
||||||
|
Reward {transition_reward, instantaneous_reward} |
||||||
|
} else { |
||||||
|
Reward { transition_reward: 0.0, instantaneous_reward} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
||||||
|
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
||||||
|
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
|
||||||
|
for i in p.get_node_indices() { |
||||||
|
//This works only for discrete nodes!
|
||||||
|
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
||||||
|
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
||||||
|
transition_reward.push(ndarray::Array2::zeros((size, size))); |
||||||
|
} |
||||||
|
|
||||||
|
FactoredRewardFunction { transition_reward, instantaneous_reward } |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
} |
||||||
|
|
@ -0,0 +1,30 @@ |
|||||||
|
mod utils; |
||||||
|
|
||||||
|
use ndarray::*; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
||||||
|
|
||||||
|
let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; |
||||||
|
let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; |
||||||
|
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||||
|
} |
Loading…
Reference in new issue