74 feature reward function #75

Merged
AlessandroBregoli merged 5 commits from 74-feature-reward-function into dev 2 years ago
  1. 1
      reCTBN/src/lib.rs
  2. 6
      reCTBN/src/process.rs
  3. 11
      reCTBN/src/process/ctbn.rs
  4. 12
      reCTBN/src/process/ctmp.rs
  5. 142
      reCTBN/src/reward_function.rs
  6. 21
      reCTBN/src/sampling.rs
  7. 14
      reCTBN/src/tools.rs
  8. 118
      reCTBN/tests/reward_function.rs

@ -6,6 +6,7 @@ extern crate approx;
pub mod parameter_learning;
pub mod params;
pub mod process;
pub mod reward_function;
pub mod sampling;
pub mod structure_learning;
pub mod tools;

@ -16,6 +16,9 @@ pub enum NetworkError {
NodeInsertionError(String),
}
/// This type is used to represent a specific realization of a generic NetworkProcess
pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN).
pub trait NetworkProcess {
@ -71,8 +74,7 @@ pub trait NetworkProcess {
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize;
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize;
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`.

@ -8,7 +8,7 @@ use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, Stat
use crate::process;
use super::ctmp::CtmpProcess;
use super::NetworkProcess;
use super::{NetworkProcess, NetworkProcessState};
/// It represents both the structure and the parameters of a CTBN.
///
@ -86,7 +86,7 @@ impl CtbnNetwork {
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: Vec<StateType> = current_state
let current_state_statetype: NetworkProcessState = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
@ -98,7 +98,7 @@ impl CtbnNetwork {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: Vec<StateType> =
let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
@ -119,7 +119,6 @@ impl CtbnNetwork {
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
println!("{:?}", amalgamated_cim);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
@ -186,7 +185,7 @@ impl process::NetworkProcess for CtbnNetwork {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize {
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix
.as_ref()
.unwrap()
@ -205,7 +204,7 @@ impl process::NetworkProcess for CtbnNetwork {
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<StateType>,
current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set

@ -5,7 +5,7 @@ use crate::{
process,
};
use super::NetworkProcess;
use super::{NetworkProcess, NetworkProcessState};
pub struct CtmpProcess {
param: Option<Params>,
@ -68,11 +68,7 @@ impl NetworkProcess for CtmpProcess {
}
}
fn get_param_index_network(
&self,
node: usize,
current_state: &Vec<crate::params::StateType>,
) -> usize {
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
if node == 0 {
match current_state[0] {
StateType::Discrete(x) => x,
@ -84,8 +80,8 @@ impl NetworkProcess for CtmpProcess {
fn get_param_index_from_custom_parent_set(
&self,
_current_state: &Vec<crate::params::StateType>,
_parent_set: &std::collections::BTreeSet<usize>,
_current_state: &NetworkProcessState,
_parent_set: &BTreeSet<usize>,
) -> usize {
unimplemented!("CtmpProcess has only one node")
}

@ -0,0 +1,142 @@
//! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process,
};
use ndarray;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>,
}
impl FactoredRewardFunction {
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> {
&self.transition_reward[node_idx]
}
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> {
&mut self.transition_reward[node_idx]
}
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> {
&self.instantaneous_reward[node_idx]
}
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state
.iter()
.zip(current_state.iter())
.enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p {
params::StateType::Discrete(p) => p,
};
let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else {
Reward {
transition_reward: 0.0,
instantaneous_reward,
}
}
}
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() {
//This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent();
instantaneous_reward.push(ndarray::Array1::zeros(size));
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
}
}

@ -1,13 +1,19 @@
//! Module containing methods for the sampling.
use crate::{
params::{self, ParamsTrait},
process::NetworkProcess,
params::ParamsTrait,
process::{NetworkProcess, NetworkProcessState},
};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
pub trait Sampler: Iterator {
#[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: NetworkProcessState,
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self);
}
@ -18,7 +24,7 @@ where
net: &'a T,
rng: ChaCha8Rng,
current_time: f64,
current_state: Vec<params::StateType>,
current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>,
}
@ -43,7 +49,7 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
}
impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>);
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone();
@ -96,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None;
}
Some((ret_time, ret_state))
Some(Sample {
t: ret_time,
state: ret_state,
})
}
}

@ -69,18 +69,18 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform
//distribution.
let mut events: Vec<Vec<params::StateType>> = Vec::new();
let mut events: Vec<process::NetworkProcessState> = Vec::new();
//Current Time and Current State
let (mut t, mut current_state) = sampler.next().unwrap();
let mut sample = sampler.next().unwrap();
//Generate new samples until ending time is reached.
while t < t_end {
time.push(t);
events.push(current_state);
(t, current_state) = sampler.next().unwrap();
while sample.t < t_end {
time.push(sample.t);
events.push(sample.state);
sample = sampler.next().unwrap();
}
current_state = events.last().unwrap().clone();
let current_state = events.last().unwrap().clone();
events.push(current_state);
//Add t_end as last time.

@ -0,0 +1,118 @@
mod utils;
use ndarray::*;
use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params};
#[test]
fn simple_factored_reward_function_binary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
#[test]
fn simple_factored_reward_function_ternary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
}
#[test]
fn factored_reward_function_two_nodes() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
}
Loading…
Cancel
Save