From cecf16a771d6fd53ec7ad1b909c7310c9a8368d8 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 29 Nov 2022 09:43:12 +0100 Subject: [PATCH] Added sigle state evaluation --- reCTBN/src/reward.rs | 20 ++++- reCTBN/src/reward/reward_evaluation.rs | 71 ++++++++++++++++ reCTBN/src/reward/reward_function.rs | 4 +- reCTBN/src/sampling.rs | 27 +++++-- reCTBN/src/tools.rs | 3 +- reCTBN/tests/reward_evaluation.rs | 107 +++++++++++++++++++++++++ reCTBN/tests/reward_function.rs | 61 +++++++------- 7 files changed, 248 insertions(+), 45 deletions(-) create mode 100644 reCTBN/src/reward/reward_evaluation.rs create mode 100644 reCTBN/tests/reward_evaluation.rs diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs index 114ba03..1ea575c 100644 --- a/reCTBN/src/reward.rs +++ b/reCTBN/src/reward.rs @@ -1,6 +1,8 @@ pub mod reward_function; +pub mod reward_evaluation; use crate::process; +use ndarray; /// Instantiation of reward function and instantaneous reward /// @@ -28,8 +30,8 @@ pub trait RewardFunction { fn call( &self, - current_state: process::NetworkProcessState, - previous_state: Option, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, ) -> Reward; /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess @@ -39,3 +41,17 @@ pub trait RewardFunction { /// * `p`: any structure that implements the trait `process::NetworkProcess` fn initialize_from_network_process(p: &T) -> Self; } + +pub trait RewardEvaluation { + fn call( + &self, + network_process: &N, + reward_function: &R, + ) -> ndarray::Array1; + fn call_state( + &self, + network_process: &N, + reward_function: &R, + state: &process::NetworkProcessState, + ) -> f64; +} diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs new file mode 100644 index 0000000..fca7c1a --- /dev/null +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -0,0 +1,71 @@ +use crate::{ + reward::RewardEvaluation, + sampling::{ForwardSampler, Sampler}, + process::NetworkProcessState +}; + +pub struct MonteCarloDiscountedRward { + n_iterations: usize, + end_time: f64, + discount_factor: f64, + seed: Option, +} + +impl MonteCarloDiscountedRward { + pub fn new( + n_iterations: usize, + end_time: f64, + discount_factor: f64, + seed: Option, + ) -> MonteCarloDiscountedRward { + MonteCarloDiscountedRward { + n_iterations, + end_time, + discount_factor, + seed, + } + } +} + +impl RewardEvaluation for MonteCarloDiscountedRward { + fn call( + &self, + network_process: &N, + reward_function: &R, + ) -> ndarray::Array1 { + todo!() + } + + fn call_state( + &self, + network_process: &N, + reward_function: &R, + state: &NetworkProcessState, + ) -> f64 { + let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); + let mut ret = 0.0; + + for _i in 0..self.n_iterations { + sampler.reset(); + let mut previous = sampler.next().unwrap(); + while previous.t < self.end_time { + let current = sampler.next().unwrap(); + if current.t > self.end_time { + let r = reward_function.call(&previous.state, None); + let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) + - std::f64::consts::E.powf(-self.discount_factor * self.end_time); + ret += discount * r.instantaneous_reward; + } else { + let r = reward_function.call(&previous.state, Some(¤t.state)); + let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t) + - std::f64::consts::E.powf(-self.discount_factor * current.t); + ret += discount * r.instantaneous_reward; + ret += std::f64::consts::E.powf(-self.discount_factor * current.t) * r.transition_reward; + } + previous = current; + } + } + + ret / self.n_iterations as f64 + } +} diff --git a/reCTBN/src/reward/reward_function.rs b/reCTBN/src/reward/reward_function.rs index ae94ff1..216df6a 100644 --- a/reCTBN/src/reward/reward_function.rs +++ b/reCTBN/src/reward/reward_function.rs @@ -44,8 +44,8 @@ impl FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction { fn call( &self, - current_state: process::NetworkProcessState, - previous_state: Option, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, ) -> Reward { let instantaneous_reward: f64 = current_state .iter() diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 1384872..73c6d78 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -26,10 +26,15 @@ where current_time: f64, current_state: NetworkProcessState, next_transitions: Vec>, + initial_state: Option, } impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { - pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { + pub fn new( + net: &'a T, + seed: Option, + initial_state: Option, + ) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. Some(seed) => SeedableRng::seed_from_u64(seed), @@ -37,11 +42,12 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { None => SeedableRng::from_entropy(), }; let mut fs = ForwardSampler { - net: net, - rng: rng, + net, + rng, current_time: 0.0, current_state: vec![], next_transitions: vec![], + initial_state, }; fs.reset(); return fs; @@ -112,11 +118,16 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; - self.current_state = self - .net - .get_node_indices() - .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) - .collect(); + match &self.initial_state { + None => { + self.current_state = self + .net + .get_node_indices() + .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) + .collect() + } + Some(is) => self.current_state = is.clone(), + }; self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); } } diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index ecfeff9..38ebd49 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -61,8 +61,7 @@ pub fn trajectory_generator( let mut trajectories: Vec = Vec::new(); //Random Generator object - - let mut sampler = ForwardSampler::new(net, seed); + let mut sampler = ForwardSampler::new(net, seed, None); //Each iteration generate one trajectory for _ in 0..n_trajectories { //History of all the moments in which something changed diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs new file mode 100644 index 0000000..f3938e7 --- /dev/null +++ b/reCTBN/tests/reward_evaluation.rs @@ -0,0 +1,107 @@ +mod utils; + +use approx::{abs_diff_eq, assert_abs_diff_eq}; +use ndarray::*; +use reCTBN::{ + params, + process::{ctbn::*, NetworkProcess, NetworkProcessState}, + reward::{reward_evaluation::*, reward_function::*, *}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn simple_factored_reward_function_binary_node_MC() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1) + .assign(&arr1(&[3.0, 3.0])); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); + } + } + + net.initialize_adj_matrix(); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + + let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215)); + assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s1), epsilon = 1e-2); +} + +#[test] +fn simple_factored_reward_function_chain_MC() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n2) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n3) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + let s000: NetworkProcessState = vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(0), + params::StateType::Discrete(0), + ]; + + let mc = MonteCarloDiscountedRward::new(10000, 100.0, 1.0, Some(215)); + assert_abs_diff_eq!(2.447, mc.call_state(&net, &rf, &s000), epsilon = 1e-1); +} diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index 03f2ab7..853efc9 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -18,15 +18,15 @@ fn simple_factored_reward_function_binary_node() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); } @@ -46,16 +46,16 @@ fn simple_factored_reward_function_ternary_node() { let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; - assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); - assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); } #[test] @@ -77,7 +77,6 @@ fn factored_reward_function_two_nodes() { rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); - let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; @@ -87,32 +86,32 @@ fn factored_reward_function_two_nodes() { let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; - assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); - assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); - assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); }