Added sigle state evaluation

pull/87/head
AlessandroBregoli 2 years ago
parent 4fc5c1d4b5
commit cecf16a771
  1. 20
      reCTBN/src/reward.rs
  2. 71
      reCTBN/src/reward/reward_evaluation.rs
  3. 4
      reCTBN/src/reward/reward_function.rs
  4. 27
      reCTBN/src/sampling.rs
  5. 3
      reCTBN/src/tools.rs
  6. 107
      reCTBN/tests/reward_evaluation.rs
  7. 61
      reCTBN/tests/reward_function.rs

@ -1,6 +1,8 @@
pub mod reward_function;
pub mod reward_evaluation;
use crate::process;
use ndarray;
/// Instantiation of reward function and instantaneous reward
///
@ -28,8 +30,8 @@ pub trait RewardFunction {
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
@ -39,3 +41,17 @@ pub trait RewardFunction {
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
pub trait RewardEvaluation {
fn call<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> ndarray::Array1<f64>;
fn call_state<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &process::NetworkProcessState,
) -> f64;
}

@ -0,0 +1,71 @@
use crate::{
reward::RewardEvaluation,
sampling::{ForwardSampler, Sampler},
process::NetworkProcessState
};
pub struct MonteCarloDiscountedRward {
n_iterations: usize,
end_time: f64,
discount_factor: f64,
seed: Option<u64>,
}
impl MonteCarloDiscountedRward {
pub fn new(
n_iterations: usize,
end_time: f64,
discount_factor: f64,
seed: Option<u64>,
) -> MonteCarloDiscountedRward {
MonteCarloDiscountedRward {
n_iterations,
end_time,
discount_factor,
seed,
}
}
}
impl RewardEvaluation for MonteCarloDiscountedRward {
fn call<N: crate::process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> ndarray::Array1<f64> {
todo!()
}
fn call_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &NetworkProcessState,
) -> f64 {
let mut sampler = ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone()));
let mut ret = 0.0;
for _i in 0..self.n_iterations {
sampler.reset();
let mut previous = sampler.next().unwrap();
while previous.t < self.end_time {
let current = sampler.next().unwrap();
if current.t > self.end_time {
let r = reward_function.call(&previous.state, None);
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t)
- std::f64::consts::E.powf(-self.discount_factor * self.end_time);
ret += discount * r.instantaneous_reward;
} else {
let r = reward_function.call(&previous.state, Some(&current.state));
let discount = std::f64::consts::E.powf(-self.discount_factor * previous.t)
- std::f64::consts::E.powf(-self.discount_factor * current.t);
ret += discount * r.instantaneous_reward;
ret += std::f64::consts::E.powf(-self.discount_factor * current.t) * r.transition_reward;
}
previous = current;
}
}
ret / self.n_iterations as f64
}
}

@ -44,8 +44,8 @@ impl FactoredRewardFunction {
impl RewardFunction for FactoredRewardFunction {
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.iter()

@ -26,10 +26,15 @@ where
current_time: f64,
current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>,
initial_state: Option<NetworkProcessState>,
}
impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> {
pub fn new(
net: &'a T,
seed: Option<u64>,
initial_state: Option<NetworkProcessState>,
) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed),
@ -37,11 +42,12 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
None => SeedableRng::from_entropy(),
};
let mut fs = ForwardSampler {
net: net,
rng: rng,
net,
rng,
current_time: 0.0,
current_state: vec![],
next_transitions: vec![],
initial_state,
};
fs.reset();
return fs;
@ -112,11 +118,16 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) {
self.current_time = 0.0;
self.current_state = self
.net
.get_node_indices()
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng))
.collect();
match &self.initial_state {
None => {
self.current_state = self
.net
.get_node_indices()
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng))
.collect()
}
Some(is) => self.current_state = is.clone(),
};
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect();
}
}

@ -61,8 +61,7 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut trajectories: Vec<Trajectory> = Vec::new();
//Random Generator object
let mut sampler = ForwardSampler::new(net, seed);
let mut sampler = ForwardSampler::new(net, seed, None);
//Each iteration generate one trajectory
for _ in 0..n_trajectories {
//History of all the moments in which something changed

@ -0,0 +1,107 @@
mod utils;
use approx::{abs_diff_eq, assert_abs_diff_eq};
use ndarray::*;
use reCTBN::{
params,
process::{ctbn::*, NetworkProcess, NetworkProcessState},
reward::{reward_evaluation::*, reward_function::*, *},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn simple_factored_reward_function_binary_node_MC() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1)
.assign(&arr1(&[3.0, 3.0]));
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap();
}
}
net.initialize_adj_matrix();
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let mc = MonteCarloDiscountedRward::new(100, 10.0, 1.0, Some(215));
assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.call_state(&net, &rf, &s1), epsilon = 1e-2);
}
#[test]
fn simple_factored_reward_function_chain_MC() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap();
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n2)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n3)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
let s000: NetworkProcessState = vec![
params::StateType::Discrete(1),
params::StateType::Discrete(0),
params::StateType::Discrete(0),
];
let mc = MonteCarloDiscountedRward::new(10000, 100.0, 1.0, Some(215));
assert_abs_diff_eq!(2.447, mc.call_state(&net, &rf, &s000), epsilon = 1e-1);
}

@ -18,15 +18,15 @@ fn simple_factored_reward_function_binary_node() {
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
@ -46,16 +46,16 @@ fn simple_factored_reward_function_ternary_node() {
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
}
#[test]
@ -77,7 +77,6 @@ fn factored_reward_function_two_nodes() {
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
@ -87,32 +86,32 @@ fn factored_reward_function_two_nodes() {
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
}

Loading…
Cancel
Save