Mini refactor. Introduced the type alias NetworkProcessState.

pull/75/head
AlessandroBregoli 2 years ago
parent 68ef7ea7c3
commit bcb64a161a
  1. 6
      reCTBN/src/process.rs
  2. 10
      reCTBN/src/process/ctbn.rs
  3. 12
      reCTBN/src/process/ctmp.rs
  4. 16
      reCTBN/src/reward_function.rs
  5. 8
      reCTBN/src/sampling.rs
  6. 2
      reCTBN/src/tools.rs
  7. 24
      reCTBN/tests/reward_function.rs

@ -16,6 +16,9 @@ pub enum NetworkError {
NodeInsertionError(String), NodeInsertionError(String),
} }
/// This type is used to represent a specific realization of a generic NetworkProcess
pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN). /// as a CTBN).
pub trait NetworkProcess { pub trait NetworkProcess {
@ -71,8 +74,7 @@ pub trait NetworkProcess {
/// # Return /// # Return
/// ///
/// * Index of the `node` relative to the network. /// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize;
-> usize;
/// Compute the index that must be used to access the parameters of a `node`, given a specific /// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`. /// configuration of the network and a generic `parent_set`.

@ -8,7 +8,7 @@ use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, Stat
use crate::process; use crate::process;
use super::ctmp::CtmpProcess; use super::ctmp::CtmpProcess;
use super::NetworkProcess; use super::{NetworkProcess, NetworkProcessState};
/// It represents both the structure and the parameters of a CTBN. /// It represents both the structure and the parameters of a CTBN.
/// ///
@ -86,7 +86,7 @@ impl CtbnNetwork {
for idx_current_state in 0..state_space { for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: Vec<StateType> = current_state let current_state_statetype: NetworkProcessState = current_state
.iter() .iter()
.map(|x| StateType::Discrete(*x)) .map(|x| StateType::Discrete(*x))
.collect(); .collect();
@ -98,7 +98,7 @@ impl CtbnNetwork {
let mut next_state = current_state.clone(); let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state; next_state[idx_node] = next_node_state;
let next_state_statetype: Vec<StateType> = let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect(); next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set( let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype, &next_state_statetype,
@ -185,7 +185,7 @@ impl process::NetworkProcess for CtbnNetwork {
&mut self.nodes[node_idx] &mut self.nodes[node_idx]
} }
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize { fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix self.adj_matrix
.as_ref() .as_ref()
.unwrap() .unwrap()
@ -204,7 +204,7 @@ impl process::NetworkProcess for CtbnNetwork {
fn get_param_index_from_custom_parent_set( fn get_param_index_from_custom_parent_set(
&self, &self,
current_state: &Vec<StateType>, current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>, parent_set: &BTreeSet<usize>,
) -> usize { ) -> usize {
parent_set parent_set

@ -5,7 +5,7 @@ use crate::{
process, process,
}; };
use super::NetworkProcess; use super::{NetworkProcess, NetworkProcessState};
pub struct CtmpProcess { pub struct CtmpProcess {
param: Option<Params>, param: Option<Params>,
@ -68,11 +68,7 @@ impl NetworkProcess for CtmpProcess {
} }
} }
fn get_param_index_network( fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
&self,
node: usize,
current_state: &Vec<crate::params::StateType>,
) -> usize {
if node == 0 { if node == 0 {
match current_state[0] { match current_state[0] {
StateType::Discrete(x) => x, StateType::Discrete(x) => x,
@ -84,8 +80,8 @@ impl NetworkProcess for CtmpProcess {
fn get_param_index_from_custom_parent_set( fn get_param_index_from_custom_parent_set(
&self, &self,
_current_state: &Vec<crate::params::StateType>, _current_state: &NetworkProcessState,
_parent_set: &std::collections::BTreeSet<usize>, _parent_set: &BTreeSet<usize>,
) -> usize { ) -> usize {
unimplemented!("CtmpProcess has only one node") unimplemented!("CtmpProcess has only one node")
} }

@ -2,7 +2,7 @@
use crate::{ use crate::{
params::{self, ParamsTrait}, params::{self, ParamsTrait},
process, sampling, process,
}; };
use ndarray; use ndarray;
@ -27,13 +27,13 @@ pub trait RewardFunction {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `current_state`: the current state of the network represented as a `sampling::Sample` /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network /// * `previous_state`: an optional argument representing the previous state of the network
fn call( fn call(
&self, &self,
current_state: sampling::Sample, current_state: process::NetworkProcessState,
previous_state: Option<sampling::Sample>, previous_state: Option<process::NetworkProcessState>,
) -> Reward; ) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
@ -80,11 +80,10 @@ impl FactoredRewardFunction {
impl RewardFunction for FactoredRewardFunction { impl RewardFunction for FactoredRewardFunction {
fn call( fn call(
&self, &self,
current_state: sampling::Sample, current_state: process::NetworkProcessState,
previous_state: Option<sampling::Sample>, previous_state: Option<process::NetworkProcessState>,
) -> Reward { ) -> Reward {
let instantaneous_reward: f64 = current_state let instantaneous_reward: f64 = current_state
.state
.iter() .iter()
.enumerate() .enumerate()
.map(|(idx, x)| { .map(|(idx, x)| {
@ -96,9 +95,8 @@ impl RewardFunction for FactoredRewardFunction {
.sum(); .sum();
if let Some(previous_state) = previous_state { if let Some(previous_state) = previous_state {
let transition_reward = previous_state let transition_reward = previous_state
.state
.iter() .iter()
.zip(current_state.state.iter()) .zip(current_state.iter())
.enumerate() .enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> { .find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p { let p = match p {

@ -1,8 +1,8 @@
//! Module containing methods for the sampling. //! Module containing methods for the sampling.
use crate::{ use crate::{
params::{self, ParamsTrait}, params::ParamsTrait,
process::NetworkProcess, process::{NetworkProcess, NetworkProcessState},
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
@ -10,7 +10,7 @@ use rand_chacha::ChaCha8Rng;
#[derive(Clone)] #[derive(Clone)]
pub struct Sample { pub struct Sample {
pub t: f64, pub t: f64,
pub state: Vec<params::StateType>, pub state: NetworkProcessState,
} }
pub trait Sampler: Iterator<Item = Sample> { pub trait Sampler: Iterator<Item = Sample> {
@ -24,7 +24,7 @@ where
net: &'a T, net: &'a T,
rng: ChaCha8Rng, rng: ChaCha8Rng,
current_time: f64, current_time: f64,
current_state: Vec<params::StateType>, current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>, next_transitions: Vec<Option<f64>>,
} }

@ -69,7 +69,7 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut events: Vec<Vec<params::StateType>> = Vec::new(); let mut events: Vec<process::NetworkProcessState> = Vec::new();
//Current Time and Current State //Current Time and Current State
let mut sample = sampler.next().unwrap(); let mut sample = sampler.next().unwrap();

@ -2,7 +2,7 @@ mod utils;
use ndarray::*; use ndarray::*;
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*}, reward_function::*, params}; use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params};
#[test] #[test]
@ -16,8 +16,8 @@ fn simple_factored_reward_function_binary_node() {
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
@ -41,9 +41,9 @@ fn simple_factored_reward_function_ternary_node() {
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
let s0 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0)]}; let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1)]}; let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let s2 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2)]}; let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
@ -78,14 +78,14 @@ fn factored_reward_function_two_nodes() {
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]}; let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]}; let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]}; let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
let s10 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]}; let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
let s11 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]}; let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12 = reCTBN::sampling::Sample { t: 0.0, state: vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]}; let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});

Loading…
Cancel
Save