diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index c62c42e..1d25552 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -9,3 +9,4 @@ pub mod process; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod reward_function; diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c949afe..0b6161c 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -119,7 +119,6 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs new file mode 100644 index 0000000..9ff09cc --- /dev/null +++ b/reCTBN/src/reward_function.rs @@ -0,0 +1,80 @@ +use crate::{process, sampling, params::{ParamsTrait, self}}; +use ndarray; + + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64 +} + +pub trait RewardFunction { + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward; + fn initialize_from_network_process(p: &T) -> Self; +} + + +pub struct FactoredRewardFunction { + transition_reward: Vec>, + instantaneous_reward: Vec> +} + +impl FactoredRewardFunction { + pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2 { + &self.transition_reward[node_idx] + } + + pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2 { + &mut self.transition_reward[node_idx] + } + + pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1 { + &self.instantaneous_reward[node_idx] + } + + pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { + &mut self.instantaneous_reward[node_idx] + } + + +} + +impl RewardFunction for FactoredRewardFunction { + + fn call(&self, current_state: sampling::Sample, previous_state: Option) -> Reward { + let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { + let x = match x {params::StateType::Discrete(x) => x}; + self.instantaneous_reward[idx][*x] + }).sum(); + if let Some(previous_state) = previous_state { + let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option { + let p = match p {params::StateType::Discrete(p) => p}; + let c = match c {params::StateType::Discrete(c) => c}; + if p != c { + Some(self.transition_reward[idx][[*p,*c]]) + } else { + None + } + }).unwrap_or(0.0); + Reward {transition_reward, instantaneous_reward} + } else { + Reward { transition_reward: 0.0, instantaneous_reward} + } + } + + fn initialize_from_network_process(p: &T) -> Self { + let mut transition_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; + for i in p.get_node_indices() { + //This works only for discrete nodes! + let size: usize = p.get_node(i).get_reserved_space_as_parent(); + instantaneous_reward.push(ndarray::Array1::zeros(size)); + transition_reward.push(ndarray::Array2::zeros((size, size))); + } + + FactoredRewardFunction { transition_reward, instantaneous_reward } + + } + +} + diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 3bc0c6f..d5a1dbe 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -7,6 +7,7 @@ use crate::{ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; +#[derive(Clone)] pub struct Sample { pub t: f64, pub state: Vec diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs new file mode 100644 index 0000000..7f73e6c --- /dev/null +++ b/reCTBN/tests/reward_function.rs @@ -0,0 +1,30 @@ +mod utils; + +use ndarray::*; +use utils::generate_discrete_time_continous_node; +use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; + + +#[test] +fn simple_factored_reward_function() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); + + let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; + let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; + assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + + assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); +}