Merge pull request #87 from AlessandroBregoli/76-feature-reward-evaluation

76 feature reward evaluation
88-feature-add-benchmarks
AlessandroBregoli 2 years ago committed by GitHub
commit e638a627bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      reCTBN/src/lib.rs
  2. 4
      reCTBN/src/parameter_learning.rs
  3. 2
      reCTBN/src/params.rs
  4. 59
      reCTBN/src/reward.rs
  5. 205
      reCTBN/src/reward/reward_evaluation.rs
  6. 44
      reCTBN/src/reward/reward_function.rs
  7. 27
      reCTBN/src/sampling.rs
  8. 16
      reCTBN/src/structure_learning/score_based_algorithm.rs
  9. 2
      reCTBN/src/structure_learning/score_function.rs
  10. 3
      reCTBN/src/tools.rs
  11. 122
      reCTBN/tests/reward_evaluation.rs
  12. 63
      reCTBN/tests/reward_function.rs

@ -6,7 +6,7 @@ extern crate approx;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod process; pub mod process;
pub mod reward_function; pub mod reward;
pub mod sampling; pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;

@ -144,6 +144,10 @@ impl ParameterLearning for BayesianApproach {
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau))));
CIM.outer_iter_mut().for_each(|mut C| {
C.diag_mut().fill(0.0);
});
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut() CIM.outer_iter_mut()

@ -20,7 +20,7 @@ pub enum ParamsError {
} }
/// Allowed type of states /// Allowed type of states
#[derive(Clone)] #[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub enum StateType { pub enum StateType {
Discrete(usize), Discrete(usize),
} }

@ -0,0 +1,59 @@
pub mod reward_evaluation;
pub mod reward_function;
use std::collections::HashMap;
use crate::process;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction: Sync {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
pub trait RewardEvaluation {
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64>;
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &process::NetworkProcessState,
) -> f64;
}

@ -0,0 +1,205 @@
use std::collections::HashMap;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use statrs::distribution::ContinuousCDF;
use crate::params::{self, ParamsTrait};
use crate::process;
use crate::{
process::NetworkProcessState,
reward::RewardEvaluation,
sampling::{ForwardSampler, Sampler},
};
pub enum RewardCriteria {
FiniteHorizon,
InfiniteHorizon { discount_factor: f64 },
}
pub struct MonteCarloReward {
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
}
impl MonteCarloReward {
pub fn new(
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
) -> MonteCarloReward {
MonteCarloReward {
max_iterations,
max_err_stop,
alpha_stop,
end_time,
reward_criteria,
seed,
}
}
}
impl RewardEvaluation for MonteCarloReward {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let variables_domain: Vec<Vec<params::StateType>> = network_process
.get_node_indices()
.map(|x| match network_process.get_node(x) {
params::Params::DiscreteStatesContinousTime(x) => (0..x
.get_reserved_space_as_parent())
.map(|s| params::StateType::Discrete(s))
.collect(),
})
.collect();
let n_states: usize = variables_domain.iter().map(|x| x.len()).product();
(0..n_states)
.into_par_iter()
.map(|s| {
let state: process::NetworkProcessState = variables_domain
.iter()
.fold((s, vec![]), |acc, x| {
let mut acc = acc;
let idx_s = acc.0 % x.len();
acc.1.push(x[idx_s].clone());
acc.0 = acc.0 / x.len();
acc
})
.1;
let r = self.evaluate_state(network_process, reward_function, &state);
(state, r)
})
.collect()
}
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &NetworkProcessState,
) -> f64 {
let mut sampler =
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone()));
let mut expected_value = 0.0;
let mut squared_expected_value = 0.0;
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap();
for i in 0..self.max_iterations {
sampler.reset();
let mut ret = 0.0;
let mut previous = sampler.next().unwrap();
while previous.t < self.end_time {
let current = sampler.next().unwrap();
if current.t > self.end_time {
let r = reward_function.call(&previous.state, None);
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => self.end_time - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * self.end_time)
}
};
ret += discount * r.instantaneous_reward;
} else {
let r = reward_function.call(&previous.state, Some(&current.state));
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => current.t - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * current.t)
}
};
ret += discount * r.instantaneous_reward;
ret += match self.reward_criteria {
RewardCriteria::FiniteHorizon => 1.0,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * current.t)
}
} * r.transition_reward;
}
previous = current;
}
let float_i = i as f64;
expected_value =
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0);
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0)
+ ret.powi(2) / (float_i + 1.0);
if i > 2 {
let var =
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2));
if self.alpha_stop
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt())
> 0.0
{
return expected_value;
}
}
}
expected_value
}
}
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> {
inner_reward: RE,
}
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> {
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> {
NeighborhoodRelativeReward { inner_reward }
}
}
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let absolute_reward = self
.inner_reward
.evaluate_state_space(network_process, reward_function);
//This approach optimize memory. Maybe optimizing execution time can be better.
absolute_reward
.iter()
.map(|(k1, v1)| {
let mut max_val: f64 = 1.0;
absolute_reward.iter().for_each(|(k2, v2)| {
let count_diff: usize = k1
.iter()
.zip(k2.iter())
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 })
.sum();
if count_diff < 2 {
max_val = max_val.max(v1 / v2);
}
});
(k1.clone(), max_val)
})
.collect()
}
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
_network_process: &N,
_reward_function: &R,
_state: &process::NetworkProcessState,
) -> f64 {
unimplemented!();
}
}

@ -3,46 +3,10 @@
use crate::{ use crate::{
params::{self, ParamsTrait}, params::{self, ParamsTrait},
process, process,
reward::{Reward, RewardFunction},
}; };
use ndarray;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy use ndarray;
pub trait RewardFunction {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
/// Reward function over a factored state space /// Reward function over a factored state space
/// ///
@ -80,8 +44,8 @@ impl FactoredRewardFunction {
impl RewardFunction for FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction {
fn call( fn call(
&self, &self,
current_state: process::NetworkProcessState, current_state: &process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>, previous_state: Option<&process::NetworkProcessState>,
) -> Reward { ) -> Reward {
let instantaneous_reward: f64 = current_state let instantaneous_reward: f64 = current_state
.iter() .iter()

@ -26,10 +26,15 @@ where
current_time: f64, current_time: f64,
current_state: NetworkProcessState, current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>, next_transitions: Vec<Option<f64>>,
initial_state: Option<NetworkProcessState>,
} }
impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> { pub fn new(
net: &'a T,
seed: Option<u64>,
initial_state: Option<NetworkProcessState>,
) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed { let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator. //If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed), Some(seed) => SeedableRng::seed_from_u64(seed),
@ -37,11 +42,12 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
None => SeedableRng::from_entropy(), None => SeedableRng::from_entropy(),
}; };
let mut fs = ForwardSampler { let mut fs = ForwardSampler {
net: net, net,
rng: rng, rng,
current_time: 0.0, current_time: 0.0,
current_state: vec![], current_state: vec![],
next_transitions: vec![], next_transitions: vec![],
initial_state,
}; };
fs.reset(); fs.reset();
return fs; return fs;
@ -112,11 +118,16 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) { fn reset(&mut self) {
self.current_time = 0.0; self.current_time = 0.0;
self.current_state = self match &self.initial_state {
.net None => {
.get_node_indices() self.current_state = self
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) .net
.collect(); .get_node_indices()
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng))
.collect()
}
Some(is) => self.current_state = is.clone(),
};
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect();
} }
} }

@ -6,6 +6,9 @@ use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools::Dataset}; use crate::{process, tools::Dataset};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,
max_parent_set: Option<usize>, max_parent_set: Option<usize>,
@ -36,8 +39,9 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
//Reset the adj matrix //Reset the adj matrix
net.initialize_adj_matrix(); net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
//Iterate over each node to learn their parent set. //Iterate over each node to learn their parent set.
for node in net.get_node_indices() { learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| {
//Initialize an empty parent set. //Initialize an empty parent set.
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); let mut parent_set: BTreeSet<usize> = BTreeSet::new();
//Compute the score for the empty parent set //Compute the score for the empty parent set
@ -76,10 +80,14 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
} }
} }
} }
//Apply the learned parent_set to the network struct. (node, parent_set)
parent_set.iter().for_each(|p| net.add_edge(*p, node)); }));
}
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
}
return net; return net;
} }
} }

@ -7,7 +7,7 @@ use statrs::function::gamma;
use crate::{parameter_learning, params, process, tools}; use crate::{parameter_learning, params, process, tools};
pub trait ScoreFunction { pub trait ScoreFunction: Sync {
fn call<T>( fn call<T>(
&self, &self,
net: &T, net: &T,

@ -69,8 +69,7 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut trajectories: Vec<Trajectory> = Vec::new(); let mut trajectories: Vec<Trajectory> = Vec::new();
//Random Generator object //Random Generator object
let mut sampler = ForwardSampler::new(net, seed, None);
let mut sampler = ForwardSampler::new(net, seed);
//Each iteration generate one trajectory //Each iteration generate one trajectory
for _ in 0..n_trajectories { for _ in 0..n_trajectories {
//History of all the moments in which something changed //History of all the moments in which something changed

@ -0,0 +1,122 @@
mod utils;
use approx::assert_abs_diff_eq;
use ndarray::*;
use reCTBN::{
params,
process::{ctbn::*, NetworkProcess, NetworkProcessState},
reward::{reward_evaluation::*, reward_function::*, *},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn simple_factored_reward_function_binary_node_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1)
.assign(&arr1(&[3.0, 3.0]));
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap();
}
}
net.initialize_adj_matrix();
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2);
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2);
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215));
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
}
#[test]
fn simple_factored_reward_function_chain_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap();
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n2)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n3)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
let s000: NetworkProcessState = vec![
params::StateType::Discrete(1),
params::StateType::Discrete(0),
params::StateType::Discrete(0),
];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1);
}

@ -2,7 +2,7 @@ mod utils;
use ndarray::*; use ndarray::*;
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params};
#[test] #[test]
@ -18,15 +18,15 @@ fn simple_factored_reward_function_binary_node() {
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
} }
@ -46,16 +46,16 @@ fn simple_factored_reward_function_ternary_node() {
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
} }
#[test] #[test]
@ -77,7 +77,6 @@ fn factored_reward_function_two_nodes() {
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
@ -87,32 +86,32 @@ fn factored_reward_function_two_nodes() {
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
} }

Loading…
Cancel
Save