diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 8feddfb..1997fa6 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -6,7 +6,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; pub mod process; -pub mod reward_function; +pub mod reward; pub mod sampling; pub mod structure_learning; pub mod tools; diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 536a9d5..3c34d06 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -144,6 +144,10 @@ impl ParameterLearning for BayesianApproach { .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); + CIM.outer_iter_mut().for_each(|mut C| { + C.diag_mut().fill(0.0); + }); + //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); CIM.outer_iter_mut() diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index dc941e5..ccbb750 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -20,7 +20,7 @@ pub enum ParamsError { } /// Allowed type of states -#[derive(Clone)] +#[derive(Clone, Hash, PartialEq, Eq, Debug)] pub enum StateType { Discrete(usize), } diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs new file mode 100644 index 0000000..910954c --- /dev/null +++ b/reCTBN/src/reward.rs @@ -0,0 +1,59 @@ +pub mod reward_evaluation; +pub mod reward_function; + +use std::collections::HashMap; + +use crate::process; + +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64, +} + +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + +pub trait RewardFunction: Sync { + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` + fn initialize_from_network_process(p: &T) -> Self; +} + +pub trait RewardEvaluation { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap; + + fn evaluate_state( + &self, + network_process: &N, + reward_function: &R, + state: &process::NetworkProcessState, + ) -> f64; +} diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs new file mode 100644 index 0000000..3802489 --- /dev/null +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -0,0 +1,205 @@ +use std::collections::HashMap; + +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use statrs::distribution::ContinuousCDF; + +use crate::params::{self, ParamsTrait}; +use crate::process; + +use crate::{ + process::NetworkProcessState, + reward::RewardEvaluation, + sampling::{ForwardSampler, Sampler}, +}; + +pub enum RewardCriteria { + FiniteHorizon, + InfiniteHorizon { discount_factor: f64 }, +} + +pub struct MonteCarloReward { + max_iterations: usize, + max_err_stop: f64, + alpha_stop: f64, + end_time: f64, + reward_criteria: RewardCriteria, + seed: Option, +} + +impl MonteCarloReward { + pub fn new( + max_iterations: usize, + max_err_stop: f64, + alpha_stop: f64, + end_time: f64, + reward_criteria: RewardCriteria, + seed: Option, + ) -> MonteCarloReward { + MonteCarloReward { + max_iterations, + max_err_stop, + alpha_stop, + end_time, + reward_criteria, + seed, + } + } +} + +impl RewardEvaluation for MonteCarloReward { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap { + let variables_domain: Vec> = network_process + .get_node_indices() + .map(|x| match network_process.get_node(x) { + params::Params::DiscreteStatesContinousTime(x) => (0..x + .get_reserved_space_as_parent()) + .map(|s| params::StateType::Discrete(s)) + .collect(), + }) + .collect(); + + let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); + + (0..n_states) + .into_par_iter() + .map(|s| { + let state: process::NetworkProcessState = variables_domain + .iter() + .fold((s, vec![]), |acc, x| { + let mut acc = acc; + let idx_s = acc.0 % x.len(); + acc.1.push(x[idx_s].clone()); + acc.0 = acc.0 / x.len(); + acc + }) + .1; + + let r = self.evaluate_state(network_process, reward_function, &state); + (state, r) + }) + .collect() + } + + fn evaluate_state( + &self, + network_process: &N, + reward_function: &R, + state: &NetworkProcessState, + ) -> f64 { + let mut sampler = + ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); + let mut expected_value = 0.0; + let mut squared_expected_value = 0.0; + let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); + + for i in 0..self.max_iterations { + sampler.reset(); + let mut ret = 0.0; + let mut previous = sampler.next().unwrap(); + while previous.t < self.end_time { + let current = sampler.next().unwrap(); + if current.t > self.end_time { + let r = reward_function.call(&previous.state, None); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => self.end_time - previous.t, + RewardCriteria::InfiniteHorizon { discount_factor } => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * self.end_time) + } + }; + ret += discount * r.instantaneous_reward; + } else { + let r = reward_function.call(&previous.state, Some(¤t.state)); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => current.t - previous.t, + RewardCriteria::InfiniteHorizon { discount_factor } => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * current.t) + } + }; + ret += discount * r.instantaneous_reward; + ret += match self.reward_criteria { + RewardCriteria::FiniteHorizon => 1.0, + RewardCriteria::InfiniteHorizon { discount_factor } => { + std::f64::consts::E.powf(-discount_factor * current.t) + } + } * r.transition_reward; + } + previous = current; + } + + let float_i = i as f64; + expected_value = + expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); + squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) + + ret.powi(2) / (float_i + 1.0); + + if i > 2 { + let var = + (float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); + if self.alpha_stop + - 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) + > 0.0 + { + return expected_value; + } + } + } + + expected_value + } +} + +pub struct NeighborhoodRelativeReward { + inner_reward: RE, +} + +impl NeighborhoodRelativeReward { + pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward { + NeighborhoodRelativeReward { inner_reward } + } +} + +impl RewardEvaluation for NeighborhoodRelativeReward { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap { + let absolute_reward = self + .inner_reward + .evaluate_state_space(network_process, reward_function); + + //This approach optimize memory. Maybe optimizing execution time can be better. + absolute_reward + .iter() + .map(|(k1, v1)| { + let mut max_val: f64 = 1.0; + absolute_reward.iter().for_each(|(k2, v2)| { + let count_diff: usize = k1 + .iter() + .zip(k2.iter()) + .map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) + .sum(); + if count_diff < 2 { + max_val = max_val.max(v1 / v2); + } + }); + (k1.clone(), max_val) + }) + .collect() + } + + fn evaluate_state( + &self, + _network_process: &N, + _reward_function: &R, + _state: &process::NetworkProcessState, + ) -> f64 { + unimplemented!(); + } +} diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward/reward_function.rs similarity index 70% rename from reCTBN/src/reward_function.rs rename to reCTBN/src/reward/reward_function.rs index 35e15c8..216df6a 100644 --- a/reCTBN/src/reward_function.rs +++ b/reCTBN/src/reward/reward_function.rs @@ -3,46 +3,10 @@ use crate::{ params::{self, ParamsTrait}, process, + reward::{Reward, RewardFunction}, }; -use ndarray; - -/// Instantiation of reward function and instantaneous reward -/// -/// -/// # Arguments -/// -/// * `transition_reward`: reward obtained transitioning from one state to another -/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state - -#[derive(Debug, PartialEq)] -pub struct Reward { - pub transition_reward: f64, - pub instantaneous_reward: f64, -} -/// The trait RewardFunction describe the methods that all the reward functions must satisfy - -pub trait RewardFunction { - /// Given the current state and the previous state, it compute the reward. - /// - /// # Arguments - /// - /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` - /// * `previous_state`: an optional argument representing the previous state of the network - - fn call( - &self, - current_state: process::NetworkProcessState, - previous_state: Option, - ) -> Reward; - - /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess - /// - /// # Arguments - /// - /// * `p`: any structure that implements the trait `process::NetworkProcess` - fn initialize_from_network_process(p: &T) -> Self; -} +use ndarray; /// Reward function over a factored state space /// @@ -80,8 +44,8 @@ impl FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction { fn call( &self, - current_state: process::NetworkProcessState, - previous_state: Option, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, ) -> Reward { let instantaneous_reward: f64 = current_state .iter() diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 1384872..73c6d78 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -26,10 +26,15 @@ where current_time: f64, current_state: NetworkProcessState, next_transitions: Vec>, + initial_state: Option, } impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { - pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { + pub fn new( + net: &'a T, + seed: Option, + initial_state: Option, + ) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. Some(seed) => SeedableRng::seed_from_u64(seed), @@ -37,11 +42,12 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { None => SeedableRng::from_entropy(), }; let mut fs = ForwardSampler { - net: net, - rng: rng, + net, + rng, current_time: 0.0, current_state: vec![], next_transitions: vec![], + initial_state, }; fs.reset(); return fs; @@ -112,11 +118,16 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; - self.current_state = self - .net - .get_node_indices() - .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) - .collect(); + match &self.initial_state { + None => { + self.current_state = self + .net + .get_node_indices() + .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) + .collect() + } + Some(is) => self.current_state = is.clone(), + }; self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); } } diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index d65ea88..9173b86 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -6,6 +6,9 @@ use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; use crate::{process, tools::Dataset}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::ParallelExtend; + pub struct HillClimbing { score_function: S, max_parent_set: Option, @@ -36,8 +39,9 @@ impl StructureLearningAlgorithm for HillClimbing { let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); //Reset the adj matrix net.initialize_adj_matrix(); + let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![]; //Iterate over each node to learn their parent set. - for node in net.get_node_indices() { + learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| { //Initialize an empty parent set. let mut parent_set: BTreeSet = BTreeSet::new(); //Compute the score for the empty parent set @@ -76,10 +80,14 @@ impl StructureLearningAlgorithm for HillClimbing { } } } - //Apply the learned parent_set to the network struct. - parent_set.iter().for_each(|p| net.add_edge(*p, node)); - } + (node, parent_set) + })); + for (child_node, candidate_parent_set) in learned_parent_sets { + for parent_node in candidate_parent_set.iter() { + net.add_edge(*parent_node, child_node); + } + } return net; } } diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index f8b38b5..5a56594 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -7,7 +7,7 @@ use statrs::function::gamma; use crate::{parameter_learning, params, process, tools}; -pub trait ScoreFunction { +pub trait ScoreFunction: Sync { fn call( &self, net: &T, diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 0a48410..5085c43 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -69,8 +69,7 @@ pub fn trajectory_generator( let mut trajectories: Vec = Vec::new(); //Random Generator object - - let mut sampler = ForwardSampler::new(net, seed); + let mut sampler = ForwardSampler::new(net, seed, None); //Each iteration generate one trajectory for _ in 0..n_trajectories { //History of all the moments in which something changed diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs new file mode 100644 index 0000000..355341c --- /dev/null +++ b/reCTBN/tests/reward_evaluation.rs @@ -0,0 +1,122 @@ +mod utils; + +use approx::assert_abs_diff_eq; +use ndarray::*; +use reCTBN::{ + params, + process::{ctbn::*, NetworkProcess, NetworkProcessState}, + reward::{reward_evaluation::*, reward_function::*, *}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn simple_factored_reward_function_binary_node_mc() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1) + .assign(&arr1(&[3.0, 3.0])); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); + } + } + + net.initialize_adj_matrix(); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); + assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); + + + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + +} + +#[test] +fn simple_factored_reward_function_chain_mc() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n2) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n3) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + let s000: NetworkProcessState = vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(0), + params::StateType::Discrete(0), + ]; + + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); + +} diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs index dcc5e69..853efc9 100644 --- a/reCTBN/tests/reward_function.rs +++ b/reCTBN/tests/reward_function.rs @@ -2,7 +2,7 @@ mod utils; use ndarray::*; use utils::generate_discrete_time_continous_node; -use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params}; #[test] @@ -18,15 +18,15 @@ fn simple_factored_reward_function_binary_node() { let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; - assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); } @@ -46,16 +46,16 @@ fn simple_factored_reward_function_ternary_node() { let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; - assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); - assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); - assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); - assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); } #[test] @@ -77,7 +77,6 @@ fn factored_reward_function_two_nodes() { rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); - let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; @@ -87,32 +86,32 @@ fn factored_reward_function_two_nodes() { let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; - assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); - assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); - assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); - assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); - assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); - assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); - assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); }