Syncing from devpull/79/head
commit
2e49df0266
@ -0,0 +1,114 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use crate::{ |
||||
params::{Params, StateType}, |
||||
process, |
||||
}; |
||||
|
||||
use super::{NetworkProcess, NetworkProcessState}; |
||||
|
||||
pub struct CtmpProcess { |
||||
param: Option<Params>, |
||||
} |
||||
|
||||
impl CtmpProcess { |
||||
pub fn new() -> CtmpProcess { |
||||
CtmpProcess { param: None } |
||||
} |
||||
} |
||||
|
||||
impl NetworkProcess for CtmpProcess { |
||||
fn initialize_adj_matrix(&mut self) { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> { |
||||
match self.param { |
||||
None => { |
||||
self.param = Some(n); |
||||
Ok(0) |
||||
} |
||||
Some(_) => Err(process::NetworkError::NodeInsertionError( |
||||
"CtmpProcess has only one node".to_string(), |
||||
)), |
||||
} |
||||
} |
||||
|
||||
fn add_edge(&mut self, _parent: usize, _child: usize) { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||
match self.param { |
||||
None => 0..0, |
||||
Some(_) => 0..1, |
||||
} |
||||
} |
||||
|
||||
fn get_number_of_nodes(&self) -> usize { |
||||
match self.param { |
||||
None => 0, |
||||
Some(_) => 1, |
||||
} |
||||
} |
||||
|
||||
fn get_node(&self, node_idx: usize) -> &crate::params::Params { |
||||
if node_idx == 0 { |
||||
self.param.as_ref().unwrap() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { |
||||
if node_idx == 0 { |
||||
self.param.as_mut().unwrap() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
||||
if node == 0 { |
||||
match current_state[0] { |
||||
StateType::Discrete(x) => x, |
||||
} |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
_current_state: &NetworkProcessState, |
||||
_parent_set: &BTreeSet<usize>, |
||||
) -> usize { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||
match self.param { |
||||
Some(_) => { |
||||
if node == 0 { |
||||
BTreeSet::new() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
None => panic!("Uninitialized CtmpProcess"), |
||||
} |
||||
} |
||||
|
||||
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||
match self.param { |
||||
Some(_) => { |
||||
if node == 0 { |
||||
BTreeSet::new() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
None => panic!("Uninitialized CtmpProcess"), |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,142 @@ |
||||
//! Module for dealing with reward functions
|
||||
|
||||
use crate::{ |
||||
params::{self, ParamsTrait}, |
||||
process, |
||||
}; |
||||
use ndarray; |
||||
|
||||
/// Instantiation of reward function and instantaneous reward
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: reward obtained transitioning from one state to another
|
||||
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
||||
|
||||
#[derive(Debug, PartialEq)] |
||||
pub struct Reward { |
||||
pub transition_reward: f64, |
||||
pub instantaneous_reward: f64, |
||||
} |
||||
|
||||
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
||||
|
||||
pub trait RewardFunction { |
||||
/// Given the current state and the previous state, it compute the reward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
||||
/// * `previous_state`: an optional argument representing the previous state of the network
|
||||
|
||||
fn call( |
||||
&self, |
||||
current_state: process::NetworkProcessState, |
||||
previous_state: Option<process::NetworkProcessState>, |
||||
) -> Reward; |
||||
|
||||
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||
} |
||||
|
||||
/// Reward function over a factored state space
|
||||
///
|
||||
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
|
||||
/// of the underling `NetworkProcess`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
|
||||
/// reward of a node
|
||||
|
||||
pub struct FactoredRewardFunction { |
||||
transition_reward: Vec<ndarray::Array2<f64>>, |
||||
instantaneous_reward: Vec<ndarray::Array1<f64>>, |
||||
} |
||||
|
||||
impl FactoredRewardFunction { |
||||
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
||||
&self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
||||
&mut self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
||||
&self.instantaneous_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
||||
&mut self.instantaneous_reward[node_idx] |
||||
} |
||||
} |
||||
|
||||
impl RewardFunction for FactoredRewardFunction { |
||||
fn call( |
||||
&self, |
||||
current_state: process::NetworkProcessState, |
||||
previous_state: Option<process::NetworkProcessState>, |
||||
) -> Reward { |
||||
let instantaneous_reward: f64 = current_state |
||||
.iter() |
||||
.enumerate() |
||||
.map(|(idx, x)| { |
||||
let x = match x { |
||||
params::StateType::Discrete(x) => x, |
||||
}; |
||||
self.instantaneous_reward[idx][*x] |
||||
}) |
||||
.sum(); |
||||
if let Some(previous_state) = previous_state { |
||||
let transition_reward = previous_state |
||||
.iter() |
||||
.zip(current_state.iter()) |
||||
.enumerate() |
||||
.find_map(|(idx, (p, c))| -> Option<f64> { |
||||
let p = match p { |
||||
params::StateType::Discrete(p) => p, |
||||
}; |
||||
let c = match c { |
||||
params::StateType::Discrete(c) => c, |
||||
}; |
||||
if p != c { |
||||
Some(self.transition_reward[idx][[*p, *c]]) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
.unwrap_or(0.0); |
||||
Reward { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} else { |
||||
Reward { |
||||
transition_reward: 0.0, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
||||
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
||||
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
||||
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; |
||||
for i in p.get_node_indices() { |
||||
//This works only for discrete nodes!
|
||||
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
||||
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
||||
transition_reward.push(ndarray::Array2::zeros((size, size))); |
||||
} |
||||
|
||||
FactoredRewardFunction { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,127 @@ |
||||
mod utils; |
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use reCTBN::{ |
||||
params, |
||||
params::ParamsTrait, |
||||
process::{ctmp::*, NetworkProcess}, |
||||
}; |
||||
use utils::generate_discrete_time_continous_node; |
||||
|
||||
#[test] |
||||
fn define_simple_ctmp() { |
||||
let _ = CtmpProcess::new(); |
||||
assert!(true); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_node_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_two_nodes_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||
|
||||
match n2 { |
||||
Ok(_) => assert!(false), |
||||
Err(_) => assert!(true), |
||||
}; |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn add_edge_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||
|
||||
net.add_edge(0, 1) |
||||
} |
||||
|
||||
#[test] |
||||
fn childen_and_parents() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
assert_eq!(0, net.get_parent_set(0).len()); |
||||
assert_eq!(0, net.get_children_set(0).len()); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_childen_panic() { |
||||
let net = CtmpProcess::new(); |
||||
net.get_children_set(0); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_childen_panic2() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
net.get_children_set(1); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_parent_panic() { |
||||
let net = CtmpProcess::new(); |
||||
net.get_parent_set(0); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_parent_panic2() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
net.get_parent_set(1); |
||||
} |
||||
|
||||
#[test] |
||||
fn compute_index_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node( |
||||
String::from("n1"), |
||||
10, |
||||
)) |
||||
.unwrap(); |
||||
|
||||
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); |
||||
assert_eq!(6, idx); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn compute_index_from_custom_parent_set_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node( |
||||
String::from("n1"), |
||||
10, |
||||
)) |
||||
.unwrap(); |
||||
|
||||
let _idx = net.get_param_index_from_custom_parent_set( |
||||
&vec![params::StateType::Discrete(6)], |
||||
&BTreeSet::from([0]) |
||||
); |
||||
} |
@ -0,0 +1,118 @@ |
||||
mod utils; |
||||
|
||||
use ndarray::*; |
||||
use utils::generate_discrete_time_continous_node; |
||||
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_binary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
|
||||
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_ternary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; |
||||
|
||||
|
||||
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); |
||||
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); |
||||
} |
||||
|
||||
#[test] |
||||
fn factored_reward_function_two_nodes() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
|
||||
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); |
||||
|
||||
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; |
||||
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; |
||||
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; |
||||
|
||||
|
||||
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; |
||||
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; |
||||
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; |
||||
|
||||
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); |
||||
} |
Loading…
Reference in new issue