Implemented FactoredRewardFunction

pull/75/head
AlessandroBregoli 2 years ago
parent 1878f687d6
commit 055eb7088e
  1. 1
      reCTBN/src/lib.rs
  2. 1
      reCTBN/src/process/ctbn.rs
  3. 80
      reCTBN/src/reward_function.rs
  4. 1
      reCTBN/src/sampling.rs
  5. 30
      reCTBN/tests/reward_function.rs

@ -9,3 +9,4 @@ pub mod process;
pub mod sampling;
pub mod structure_learning;
pub mod tools;
pub mod reward_function;

@ -119,7 +119,6 @@ impl CtbnNetwork {
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
println!("{:?}", amalgamated_cim);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();

@ -0,0 +1,80 @@
use crate::{process, sampling, params::{ParamsTrait, self}};
use ndarray;
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64
}
pub trait RewardFunction {
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward;
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>
}
impl FactoredRewardFunction {
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> {
&self.transition_reward[node_idx]
}
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> {
&mut self.transition_reward[node_idx]
}
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> {
&self.instantaneous_reward[node_idx]
}
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward {
let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| {
let x = match x {params::StateType::Discrete(x) => x};
self.instantaneous_reward[idx][*x]
}).sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option<f64> {
let p = match p {params::StateType::Discrete(p) => p};
let c = match c {params::StateType::Discrete(c) => c};
if p != c {
Some(self.transition_reward[idx][[*p,*c]])
} else {
None
}
}).unwrap_or(0.0);
Reward {transition_reward, instantaneous_reward}
} else {
Reward { transition_reward: 0.0, instantaneous_reward}
}
}
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() {
//This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent();
instantaneous_reward.push(ndarray::Array1::zeros(size));
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction { transition_reward, instantaneous_reward }
}
}

@ -7,6 +7,7 @@ use crate::{
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: Vec<params::StateType>

@ -0,0 +1,30 @@
mod utils;
use ndarray::*;
use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params};
#[test]
fn simple_factored_reward_function() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]};
let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]};
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
Loading…
Cancel
Save