diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index db33ae4..8feddfb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -3,10 +3,10 @@ #[cfg(test)] extern crate approx; -pub mod ctbn; -pub mod network; pub mod parameter_learning; pub mod params; +pub mod process; +pub mod reward_function; pub mod sampling; pub mod structure_learning; pub mod tools; diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 61d4dca..2aa518c 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,10 +5,10 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{network, tools}; +use crate::{process, tools}; pub trait ParameterLearning { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -17,7 +17,7 @@ pub trait ParameterLearning { ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, dataset: &tools::Dataset, node: usize, @@ -73,7 +73,7 @@ pub fn sufficient_statistics( pub struct MLE {} impl ParameterLearning for MLE { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -120,7 +120,7 @@ pub struct BayesianApproach { } impl ParameterLearning for BayesianApproach { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -177,7 +177,7 @@ impl Cache

{ dataset, } } - pub fn fit( + pub fn fit( &mut self, net: &T, node: usize, diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 070c997..9f63860 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,11 +267,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } + let domain_size = domain_size as f64; + // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) + .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/network.rs b/reCTBN/src/process.rs similarity index 93% rename from reCTBN/src/network.rs rename to reCTBN/src/process.rs index fbdd2e6..dc297bc 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/process.rs @@ -1,5 +1,8 @@ //! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs +pub mod ctbn; +pub mod ctmp; + use std::collections::BTreeSet; use thiserror::Error; @@ -13,9 +16,12 @@ pub enum NetworkError { NodeInsertionError(String), } +/// This type is used to represent a specific realization of a generic NetworkProcess +pub type NetworkProcessState = Vec; + /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait Network { +pub trait NetworkProcess { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. @@ -68,8 +74,7 @@ pub trait Network { /// # Return /// /// * Index of the `node` relative to the network. - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize; + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; /// Compute the index that must be used to access the parameters of a `node`, given a specific /// configuration of the network and a generic `parent_set`. diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/process/ctbn.rs similarity index 58% rename from reCTBN/src/ctbn.rs rename to reCTBN/src/process/ctbn.rs index 2b01d14..162345e 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,8 +4,11 @@ use std::collections::BTreeSet; use ndarray::prelude::*; -use crate::network; -use crate::params::{Params, ParamsTrait, StateType}; +use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; +use crate::process; + +use super::ctmp::CtmpProcess; +use super::{NetworkProcess, NetworkProcessState}; /// It represents both the structure and the parameters of a CTBN. /// @@ -20,9 +23,9 @@ use crate::params::{Params, ParamsTrait, StateType}; /// /// ```rust /// use std::collections::BTreeSet; -/// use reCTBN::network::Network; +/// use reCTBN::process::NetworkProcess; /// use reCTBN::params; -/// use reCTBN::ctbn::*; +/// use reCTBN::process::ctbn::*; /// /// //Create the domain for a discrete node /// let mut domain = BTreeSet::new(); @@ -67,9 +70,77 @@ impl CtbnNetwork { nodes: Vec::new(), } } + + ///Transform the **CTBN** into a **CTMP** + /// + /// # Return + /// + /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork + pub fn amalgamation(&self) -> CtmpProcess { + let variables_domain = + Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); + + let state_space = variables_domain.product(); + let variables_set = BTreeSet::from_iter(self.get_node_indices()); + let mut amalgamated_cim: Array3 = Array::zeros((1, state_space, state_space)); + + for idx_current_state in 0..state_space { + let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); + let current_state_statetype: NetworkProcessState = current_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + for idx_node in 0..self.nodes.len() { + let p = match self.get_node(idx_node) { + Params::DiscreteStatesContinousTime(p) => p, + }; + for next_node_state in 0..variables_domain[idx_node] { + let mut next_state = current_state.clone(); + next_state[idx_node] = next_node_state; + + let next_state_statetype: NetworkProcessState = + next_state.iter().map(|x| StateType::Discrete(*x)).collect(); + let idx_next_state = self.get_param_index_from_custom_parent_set( + &next_state_statetype, + &variables_set, + ); + amalgamated_cim[[0, idx_current_state, idx_next_state]] += + p.get_cim().as_ref().unwrap()[[ + self.get_param_index_network(idx_node, ¤t_state_statetype), + current_state[idx_node], + next_node_state, + ]]; + } + } + } + + let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( + "ctmp".to_string(), + BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), + ); + + amalgamated_param.set_cim(amalgamated_cim).unwrap(); + + let mut ctmp = CtmpProcess::new(); + + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) + .unwrap(); + return ctmp; + } + + pub fn idx_to_state(variables_domain: &Array1, state: usize) -> Array1 { + let mut state = state; + let mut array_state = Array1::zeros(variables_domain.shape()[0]); + for (idx, var) in variables_domain.indexed_iter() { + array_state[idx] = state % var; + state = state / var; + } + + return array_state; + } } -impl network::Network for CtbnNetwork { +impl process::NetworkProcess for CtbnNetwork { /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( @@ -78,7 +149,7 @@ impl network::Network for CtbnNetwork { } /// Add a new node. - fn add_node(&mut self, mut n: Params) -> Result { + fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); @@ -114,7 +185,7 @@ impl network::Network for CtbnNetwork { &mut self.nodes[node_idx] } - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize { + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { self.adj_matrix .as_ref() .unwrap() @@ -133,7 +204,7 @@ impl network::Network for CtbnNetwork { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, + current_state: &NetworkProcessState, parent_set: &BTreeSet, ) -> usize { parent_set diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs new file mode 100644 index 0000000..41b8db6 --- /dev/null +++ b/reCTBN/src/process/ctmp.rs @@ -0,0 +1,114 @@ +use std::collections::BTreeSet; + +use crate::{ + params::{Params, StateType}, + process, +}; + +use super::{NetworkProcess, NetworkProcessState}; + +pub struct CtmpProcess { + param: Option, +} + +impl CtmpProcess { + pub fn new() -> CtmpProcess { + CtmpProcess { param: None } + } +} + +impl NetworkProcess for CtmpProcess { + fn initialize_adj_matrix(&mut self) { + unimplemented!("CtmpProcess has only one node") + } + + fn add_node(&mut self, n: crate::params::Params) -> Result { + match self.param { + None => { + self.param = Some(n); + Ok(0) + } + Some(_) => Err(process::NetworkError::NodeInsertionError( + "CtmpProcess has only one node".to_string(), + )), + } + } + + fn add_edge(&mut self, _parent: usize, _child: usize) { + unimplemented!("CtmpProcess has only one node") + } + + fn get_node_indices(&self) -> std::ops::Range { + match self.param { + None => 0..0, + Some(_) => 0..1, + } + } + + fn get_number_of_nodes(&self) -> usize { + match self.param { + None => 0, + Some(_) => 1, + } + } + + fn get_node(&self, node_idx: usize) -> &crate::params::Params { + if node_idx == 0 { + self.param.as_ref().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { + if node_idx == 0 { + self.param.as_mut().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { + if node == 0 { + match current_state[0] { + StateType::Discrete(x) => x, + } + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_from_custom_parent_set( + &self, + _current_state: &NetworkProcessState, + _parent_set: &BTreeSet, + ) -> usize { + unimplemented!("CtmpProcess has only one node") + } + + fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), + } + } + + fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), + } + } +} diff --git a/reCTBN/src/reward_function.rs b/reCTBN/src/reward_function.rs new file mode 100644 index 0000000..35e15c8 --- /dev/null +++ b/reCTBN/src/reward_function.rs @@ -0,0 +1,142 @@ +//! Module for dealing with reward functions + +use crate::{ + params::{self, ParamsTrait}, + process, +}; +use ndarray; + +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64, +} + +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + +pub trait RewardFunction { + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: process::NetworkProcessState, + previous_state: Option, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` + fn initialize_from_network_process(p: &T) -> Self; +} + +/// Reward function over a factored state space +/// +/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node +/// of the underling `NetworkProcess` +/// +/// # Arguments +/// +/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition +/// reward of a node + +pub struct FactoredRewardFunction { + transition_reward: Vec>, + instantaneous_reward: Vec>, +} + +impl FactoredRewardFunction { + pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2 { + &self.transition_reward[node_idx] + } + + pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2 { + &mut self.transition_reward[node_idx] + } + + pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1 { + &self.instantaneous_reward[node_idx] + } + + pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { + &mut self.instantaneous_reward[node_idx] + } +} + +impl RewardFunction for FactoredRewardFunction { + fn call( + &self, + current_state: process::NetworkProcessState, + previous_state: Option, + ) -> Reward { + let instantaneous_reward: f64 = current_state + .iter() + .enumerate() + .map(|(idx, x)| { + let x = match x { + params::StateType::Discrete(x) => x, + }; + self.instantaneous_reward[idx][*x] + }) + .sum(); + if let Some(previous_state) = previous_state { + let transition_reward = previous_state + .iter() + .zip(current_state.iter()) + .enumerate() + .find_map(|(idx, (p, c))| -> Option { + let p = match p { + params::StateType::Discrete(p) => p, + }; + let c = match c { + params::StateType::Discrete(c) => c, + }; + if p != c { + Some(self.transition_reward[idx][[*p, *c]]) + } else { + None + } + }) + .unwrap_or(0.0); + Reward { + transition_reward, + instantaneous_reward, + } + } else { + Reward { + transition_reward: 0.0, + instantaneous_reward, + } + } + } + + fn initialize_from_network_process(p: &T) -> Self { + let mut transition_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; + for i in p.get_node_indices() { + //This works only for discrete nodes! + let size: usize = p.get_node(i).get_reserved_space_as_parent(); + instantaneous_reward.push(ndarray::Array1::zeros(size)); + transition_reward.push(ndarray::Array2::zeros((size, size))); + } + + FactoredRewardFunction { + transition_reward, + instantaneous_reward, + } + } +} diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d435634..1384872 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,28 +1,34 @@ //! Module containing methods for the sampling. use crate::{ - network::Network, - params::{self, ParamsTrait}, + params::ParamsTrait, + process::{NetworkProcess, NetworkProcessState}, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -pub trait Sampler: Iterator { +#[derive(Clone)] +pub struct Sample { + pub t: f64, + pub state: NetworkProcessState, +} + +pub trait Sampler: Iterator { fn reset(&mut self); } pub struct ForwardSampler<'a, T> where - T: Network, + T: NetworkProcess, { net: &'a T, rng: ChaCha8Rng, current_time: f64, - current_state: Vec, + current_state: NetworkProcessState, next_transitions: Vec>, } -impl<'a, T: Network> ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. @@ -42,8 +48,8 @@ impl<'a, T: Network> ForwardSampler<'a, T> { } } -impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { - type Item = (f64, Vec); +impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { + type Item = Sample; fn next(&mut self) -> Option { let ret_time = self.current_time.clone(); @@ -96,11 +102,14 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { self.next_transitions[child] = None; } - Some((ret_time, ret_state)) + Some(Sample { + t: ret_time, + state: ret_state, + }) } } -impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; self.current_state = self diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index 57fed1e..b272e22 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{network, tools}; +use crate::{process, tools}; pub trait StructureLearningAlgorithm { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network; + T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 9f7a518..aa37cfa 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use crate::params::*; -use crate::{network, parameter_learning}; +use crate::{parameter_learning, process}; pub trait HypothesisTest { fn call( @@ -18,7 +18,7 @@ pub trait HypothesisTest { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning; } @@ -221,7 +221,7 @@ impl HypothesisTest for ChiSquare { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 9e329eb..16e9056 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{network, tools}; +use crate::{process, tools}; pub struct HillClimbing { score_function: S, @@ -23,7 +23,7 @@ impl HillClimbing { impl StructureLearningAlgorithm for HillClimbing { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network, + T: process::NetworkProcess, { //Check the coherence between dataset and network if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index cb6ad7b..f8b38b5 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{network, parameter_learning, params, tools}; +use crate::{parameter_learning, params, process, tools}; pub trait ScoreFunction { fn call( @@ -16,7 +16,7 @@ pub trait ScoreFunction { dataset: &tools::Dataset, ) -> f64 where - T: network::Network; + T: process::NetworkProcess; } pub struct LogLikelihood { @@ -41,7 +41,7 @@ impl LogLikelihood { dataset: &tools::Dataset, ) -> (f64, Array3) where - T: network::Network, + T: process::NetworkProcess, { //Identify the type of node used match &net.get_node(node) { @@ -100,7 +100,7 @@ impl ScoreFunction for LogLikelihood { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { self.compute_score(net, node, parent_set, dataset).0 } @@ -127,7 +127,7 @@ impl ScoreFunction for BIC { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index aa48883..ecfeff9 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{network, params}; +use crate::{params, process}; pub struct Trajectory { time: Array1, @@ -51,7 +51,7 @@ impl Dataset { } } -pub fn trajectory_generator( +pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, @@ -69,18 +69,18 @@ pub fn trajectory_generator( let mut time: Vec = Vec::new(); //Configuration of the process variables at time t initialized with an uniform //distribution. - let mut events: Vec> = Vec::new(); + let mut events: Vec = Vec::new(); //Current Time and Current State - let (mut t, mut current_state) = sampler.next().unwrap(); + let mut sample = sampler.next().unwrap(); //Generate new samples until ending time is reached. - while t < t_end { - time.push(t); - events.push(current_state); - (t, current_state) = sampler.next().unwrap(); + while sample.t < t_end { + time.push(sample.t); + events.push(sample.state); + sample = sampler.next().unwrap(); } - current_state = events.last().unwrap().clone(); + let current_state = events.last().unwrap().clone(); events.push(current_state); //Add t_end as last time. diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 63c9621..3eb40d7 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,9 +1,12 @@ mod utils; use std::collections::BTreeSet; -use reCTBN::ctbn::*; -use reCTBN::network::Network; + +use approx::AbsDiffEq; +use ndarray::arr3; use reCTBN::params::{self, ParamsTrait}; +use reCTBN::process::NetworkProcess; +use reCTBN::process::{ctbn::*}; use utils::generate_discrete_time_continous_node; #[test] @@ -129,3 +132,245 @@ fn compute_index_from_custom_parent_set() { ); assert_eq!(2, idx); } + +#[test] +fn simple_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + net.initialize_adj_matrix(); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); + } + } + + let ctmp = net.amalgamation(); + let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0); + let p_ctbn = p_ctbn.get_cim().as_ref().unwrap(); + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); + + assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); +} + +#[test] +fn chain_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + let ctmp = net.amalgamation(); + + + + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); + + let p_ctmp_handmade = arr3(&[[ + [ + -1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} + +#[test] +fn chainfork_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + let n4 = net + .add_node(generate_discrete_time_continous_node(String::from("n4"), 2)) + .unwrap(); + + net.add_edge(n1, n3); + net.add_edge(n2, n3); + net.add_edge(n3, n4); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n4) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + + let ctmp = net.amalgamation(); + + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); + + let p_ctmp_handmade = arr3(&[[ + [ + -2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs new file mode 100644 index 0000000..830bfe0 --- /dev/null +++ b/reCTBN/tests/ctmp.rs @@ -0,0 +1,127 @@ +mod utils; + +use std::collections::BTreeSet; + +use reCTBN::{ + params, + params::ParamsTrait, + process::{ctmp::*, NetworkProcess}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn define_simple_ctmp() { + let _ = CtmpProcess::new(); + assert!(true); +} + +#[test] +fn add_node_to_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); +} + +#[test] +fn add_two_nodes_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + match n2 { + Ok(_) => assert!(false), + Err(_) => assert!(true), + }; +} + +#[test] +#[should_panic] +fn add_edge_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + net.add_edge(0, 1) +} + +#[test] +fn childen_and_parents() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + assert_eq!(0, net.get_parent_set(0).len()); + assert_eq!(0, net.get_children_set(0).len()); +} + +#[test] +#[should_panic] +fn get_childen_panic() { + let net = CtmpProcess::new(); + net.get_children_set(0); +} + +#[test] +#[should_panic] +fn get_childen_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_children_set(1); +} + +#[test] +#[should_panic] +fn get_parent_panic() { + let net = CtmpProcess::new(); + net.get_parent_set(0); +} + +#[test] +#[should_panic] +fn get_parent_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_parent_set(1); +} + +#[test] +fn compute_index_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); + assert_eq!(6, idx); +} + +#[test] +#[should_panic] +fn compute_index_from_custom_parent_set_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let _idx = net.get_param_index_from_custom_parent_set( + &vec![params::StateType::Discrete(6)], + &BTreeSet::from([0]) + ); +} diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 7d09b07..2cbc185 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -2,8 +2,8 @@ mod utils; use ndarray::arr3; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; use reCTBN::tools::*; diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs new file mode 100644 index 0000000..dcc5e69 --- /dev/null +++ b/reCTBN/tests/reward_function.rs @@ -0,0 +1,118 @@ +mod utils; + +use ndarray::*; +use utils::generate_discrete_time_continous_node; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params}; + + +#[test] +fn simple_factored_reward_function_binary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + + assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); +} + + +#[test] +fn simple_factored_reward_function_ternary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; + + + assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + + + assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); +} + +#[test] +fn factored_reward_function_two_nodes() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + + rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); + + let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; + let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; + let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; + + + let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; + let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; + let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; + + assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + + + assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + + + assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + + + assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); +} diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 5c1db80..6134510 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -4,8 +4,8 @@ mod utils; use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 589b04e..806faef 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,6 +1,6 @@ use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*;