Added comments

pull/75/head
AlessandroBregoli 2 years ago
parent f6015acce9
commit 68ef7ea7c3
  1. 2
      reCTBN/src/lib.rs
  2. 118
      reCTBN/src/reward_function.rs
  3. 9
      reCTBN/src/sampling.rs

@ -6,7 +6,7 @@ extern crate approx;
pub mod parameter_learning;
pub mod params;
pub mod process;
pub mod reward_function;
pub mod sampling;
pub mod structure_learning;
pub mod tools;
pub mod reward_function;

@ -1,22 +1,62 @@
use crate::{process, sampling, params::{ParamsTrait, self}};
//! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process, sampling,
};
use ndarray;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction {
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward;
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `sampling::Sample`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: sampling::Sample,
previous_state: Option<sampling::Sample>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>
instantaneous_reward: Vec<ndarray::Array1<f64>>,
}
impl FactoredRewardFunction {
@ -35,30 +75,54 @@ impl FactoredRewardFunction {
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(&self, current_state: sampling::Sample, previous_state: Option<sampling::Sample>) -> Reward {
let instantaneous_reward: f64 = current_state.state.iter().enumerate().map(|(idx, x)| {
let x = match x {params::StateType::Discrete(x) => x};
self.instantaneous_reward[idx][*x]
}).sum();
fn call(
&self,
current_state: sampling::Sample,
previous_state: Option<sampling::Sample>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state.state.iter().zip(current_state.state.iter()).enumerate().find_map(|(idx,(p,c))|->Option<f64> {
let p = match p {params::StateType::Discrete(p) => p};
let c = match c {params::StateType::Discrete(c) => c};
if p != c {
Some(self.transition_reward[idx][[*p,*c]])
} else {
None
}
}).unwrap_or(0.0);
Reward {transition_reward, instantaneous_reward}
let transition_reward = previous_state
.state
.iter()
.zip(current_state.state.iter())
.enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p {
params::StateType::Discrete(p) => p,
};
let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else {
Reward { transition_reward: 0.0, instantaneous_reward}
Reward {
transition_reward: 0.0,
instantaneous_reward,
}
}
}
@ -72,9 +136,9 @@ impl RewardFunction for FactoredRewardFunction {
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction { transition_reward, instantaneous_reward }
FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
}
}

@ -10,15 +10,13 @@ use rand_chacha::ChaCha8Rng;
#[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: Vec<params::StateType>
pub state: Vec<params::StateType>,
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self);
}
pub struct ForwardSampler<'a, T>
where
T: NetworkProcess,
@ -104,7 +102,10 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None;
}
Some(Sample{t: ret_time, state: ret_state})
Some(Sample {
t: ret_time,
state: ret_state,
})
}
}

Loading…
Cancel
Save