Added comments

pull/75/head
AlessandroBregoli 2 years ago
parent f6015acce9
commit 68ef7ea7c3
  1. 2
      reCTBN/src/lib.rs
  2. 120
      reCTBN/src/reward_function.rs
  3. 9
      reCTBN/src/sampling.rs

@ -6,7 +6,7 @@ extern crate approx;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod process; pub mod process;
pub mod reward_function;
pub mod sampling; pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;
pub mod reward_function;

@ -1,22 +1,62 @@
use crate::{process, sampling, params::{ParamsTrait, self}}; //! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process, sampling,
};
use ndarray; use ndarray;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct Reward { pub struct Reward {
pub transition_reward: f64, pub transition_reward: f64,
pub instantaneous_reward: f64 pub instantaneous_reward: f64,
} }
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction { pub trait RewardFunction {
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward; /// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `sampling::Sample`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: sampling::Sample,
previous_state: Option<sampling::Sample>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
} }
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction { pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>, transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>> instantaneous_reward: Vec<ndarray::Array1<f64>>,
} }
impl FactoredRewardFunction { impl FactoredRewardFunction {
@ -35,36 +75,60 @@ impl FactoredRewardFunction {
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx] &mut self.instantaneous_reward[node_idx]
} }
} }
impl RewardFunction for FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction {
fn call(
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward { &self,
let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| { current_state: sampling::Sample,
let x = match x {params::StateType::Discrete(x) => x}; previous_state: Option<sampling::Sample>,
self.instantaneous_reward[idx][*x] ) -> Reward {
}).sum(); let instantaneous_reward: f64 = current_state
.state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state { if let Some(previous_state) = previous_state {
let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option<f64> { let transition_reward = previous_state
let p = match p {params::StateType::Discrete(p) => p}; .state
let c = match c {params::StateType::Discrete(c) => c}; .iter()
if p != c { .zip(current_state.state.iter())
Some(self.transition_reward[idx][[*p,*c]]) .enumerate()
} else { .find_map(|(idx, (p, c))| -> Option<f64> {
None let p = match p {
} params::StateType::Discrete(p) => p,
}).unwrap_or(0.0); };
Reward {transition_reward, instantaneous_reward} let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else { } else {
Reward { transition_reward: 0.0, instantaneous_reward} Reward {
transition_reward: 0.0,
instantaneous_reward,
}
} }
} }
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() { for i in p.get_node_indices() {
//This works only for discrete nodes! //This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent(); let size: usize = p.get_node(i).get_reserved_space_as_parent();
@ -72,9 +136,9 @@ impl RewardFunction for FactoredRewardFunction {
transition_reward.push(ndarray::Array2::zeros((size, size))); transition_reward.push(ndarray::Array2::zeros((size, size)));
} }
FactoredRewardFunction { transition_reward, instantaneous_reward } FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
} }
} }

@ -10,15 +10,13 @@ use rand_chacha::ChaCha8Rng;
#[derive(Clone)] #[derive(Clone)]
pub struct Sample { pub struct Sample {
pub t: f64, pub t: f64,
pub state: Vec<params::StateType> pub state: Vec<params::StateType>,
} }
pub trait Sampler: Iterator<Item = Sample> { pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self); fn reset(&mut self);
} }
pub struct ForwardSampler<'a, T> pub struct ForwardSampler<'a, T>
where where
T: NetworkProcess, T: NetworkProcess,
@ -104,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None; self.next_transitions[child] = None;
} }
Some(Sample{t: ret_time, state: ret_state}) Some(Sample {
t: ret_time,
state: ret_state,
})
} }
} }

Loading…
Cancel
Save