Merge branch '74-feature-reward-function' into 'dev'

Feature reward function
pull/83/head
Meliurwen 2 years ago
commit 3f80f07e9f
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 1
      reCTBN/src/lib.rs
  2. 6
      reCTBN/src/process.rs
  3. 11
      reCTBN/src/process/ctbn.rs
  4. 12
      reCTBN/src/process/ctmp.rs
  5. 142
      reCTBN/src/reward_function.rs
  6. 21
      reCTBN/src/sampling.rs
  7. 14
      reCTBN/src/tools.rs
  8. 118
      reCTBN/tests/reward_function.rs

@ -6,6 +6,7 @@ extern crate approx;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod process; pub mod process;
pub mod reward_function;
pub mod sampling; pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;

@ -16,6 +16,9 @@ pub enum NetworkError {
NodeInsertionError(String), NodeInsertionError(String),
} }
/// This type is used to represent a specific realization of a generic NetworkProcess
pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN). /// as a CTBN).
pub trait NetworkProcess { pub trait NetworkProcess {
@ -71,8 +74,7 @@ pub trait NetworkProcess {
/// # Return /// # Return
/// ///
/// * Index of the `node` relative to the network. /// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize;
-> usize;
/// Compute the index that must be used to access the parameters of a `node`, given a specific /// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`. /// configuration of the network and a generic `parent_set`.

@ -8,7 +8,7 @@ use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, Stat
use crate::process; use crate::process;
use super::ctmp::CtmpProcess; use super::ctmp::CtmpProcess;
use super::NetworkProcess; use super::{NetworkProcess, NetworkProcessState};
/// It represents both the structure and the parameters of a CTBN. /// It represents both the structure and the parameters of a CTBN.
/// ///
@ -86,7 +86,7 @@ impl CtbnNetwork {
for idx_current_state in 0..state_space { for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: Vec<StateType> = current_state let current_state_statetype: NetworkProcessState = current_state
.iter() .iter()
.map(|x| StateType::Discrete(*x)) .map(|x| StateType::Discrete(*x))
.collect(); .collect();
@ -98,7 +98,7 @@ impl CtbnNetwork {
let mut next_state = current_state.clone(); let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state; next_state[idx_node] = next_node_state;
let next_state_statetype: Vec<StateType> = let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect(); next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set( let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype, &next_state_statetype,
@ -119,7 +119,6 @@ impl CtbnNetwork {
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
); );
println!("{:?}", amalgamated_cim);
amalgamated_param.set_cim(amalgamated_cim).unwrap(); amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new(); let mut ctmp = CtmpProcess::new();
@ -186,7 +185,7 @@ impl process::NetworkProcess for CtbnNetwork {
&mut self.nodes[node_idx] &mut self.nodes[node_idx]
} }
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize { fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix self.adj_matrix
.as_ref() .as_ref()
.unwrap() .unwrap()
@ -205,7 +204,7 @@ impl process::NetworkProcess for CtbnNetwork {
fn get_param_index_from_custom_parent_set( fn get_param_index_from_custom_parent_set(
&self, &self,
current_state: &Vec<StateType>, current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>, parent_set: &BTreeSet<usize>,
) -> usize { ) -> usize {
parent_set parent_set

@ -5,7 +5,7 @@ use crate::{
process, process,
}; };
use super::NetworkProcess; use super::{NetworkProcess, NetworkProcessState};
pub struct CtmpProcess { pub struct CtmpProcess {
param: Option<Params>, param: Option<Params>,
@ -68,11 +68,7 @@ impl NetworkProcess for CtmpProcess {
} }
} }
fn get_param_index_network( fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
&self,
node: usize,
current_state: &Vec<crate::params::StateType>,
) -> usize {
if node == 0 { if node == 0 {
match current_state[0] { match current_state[0] {
StateType::Discrete(x) => x, StateType::Discrete(x) => x,
@ -84,8 +80,8 @@ impl NetworkProcess for CtmpProcess {
fn get_param_index_from_custom_parent_set( fn get_param_index_from_custom_parent_set(
&self, &self,
_current_state: &Vec<crate::params::StateType>, _current_state: &NetworkProcessState,
_parent_set: &std::collections::BTreeSet<usize>, _parent_set: &BTreeSet<usize>,
) -> usize { ) -> usize {
unimplemented!("CtmpProcess has only one node") unimplemented!("CtmpProcess has only one node")
} }

@ -0,0 +1,142 @@
//! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process,
};
use ndarray;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>,
}
impl FactoredRewardFunction {
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> {
&self.transition_reward[node_idx]
}
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> {
&mut self.transition_reward[node_idx]
}
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> {
&self.instantaneous_reward[node_idx]
}
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state
.iter()
.zip(current_state.iter())
.enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p {
params::StateType::Discrete(p) => p,
};
let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else {
Reward {
transition_reward: 0.0,
instantaneous_reward,
}
}
}
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() {
//This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent();
instantaneous_reward.push(ndarray::Array1::zeros(size));
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
}
}

@ -1,13 +1,19 @@
//! Module containing methods for the sampling. //! Module containing methods for the sampling.
use crate::{ use crate::{
params::{self, ParamsTrait}, params::ParamsTrait,
process::NetworkProcess, process::{NetworkProcess, NetworkProcessState},
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
pub trait Sampler: Iterator { #[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: NetworkProcessState,
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self); fn reset(&mut self);
} }
@ -18,7 +24,7 @@ where
net: &'a T, net: &'a T,
rng: ChaCha8Rng, rng: ChaCha8Rng,
current_time: f64, current_time: f64,
current_state: Vec<params::StateType>, current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>, next_transitions: Vec<Option<f64>>,
} }
@ -43,7 +49,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
} }
impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>); type Item = Sample;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone(); let ret_time = self.current_time.clone();
@ -96,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None; self.next_transitions[child] = None;
} }
Some((ret_time, ret_state)) Some(Sample {
t: ret_time,
state: ret_state,
})
} }
} }

@ -69,18 +69,18 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut events: Vec<Vec<params::StateType>> = Vec::new(); let mut events: Vec<process::NetworkProcessState> = Vec::new();
//Current Time and Current State //Current Time and Current State
let (mut t, mut current_state) = sampler.next().unwrap(); let mut sample = sampler.next().unwrap();
//Generate new samples until ending time is reached. //Generate new samples until ending time is reached.
while t < t_end { while sample.t < t_end {
time.push(t); time.push(sample.t);
events.push(current_state); events.push(sample.state);
(t, current_state) = sampler.next().unwrap(); sample = sampler.next().unwrap();
} }
current_state = events.last().unwrap().clone(); let current_state = events.last().unwrap().clone();
events.push(current_state); events.push(current_state);
//Add t_end as last time. //Add t_end as last time.

@ -0,0 +1,118 @@
mod utils;
use ndarray::*;
use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params};
#[test]
fn simple_factored_reward_function_binary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
#[test]
fn simple_factored_reward_function_ternary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
}
#[test]
fn factored_reward_function_two_nodes() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
}
Loading…
Cancel
Save