Merge branch 'dev' into '8-feature-constraint-based-structure-learning-algorithm-for-ctbn'

Syncing from dev
pull/79/head
Meliurwen 2 years ago
commit 2e49df0266
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 4
      reCTBN/src/lib.rs
  2. 12
      reCTBN/src/parameter_learning.rs
  3. 4
      reCTBN/src/params.rs
  4. 11
      reCTBN/src/process.rs
  5. 87
      reCTBN/src/process/ctbn.rs
  6. 114
      reCTBN/src/process/ctmp.rs
  7. 142
      reCTBN/src/reward_function.rs
  8. 29
      reCTBN/src/sampling.rs
  9. 4
      reCTBN/src/structure_learning.rs
  10. 6
      reCTBN/src/structure_learning/hypothesis_test.rs
  11. 4
      reCTBN/src/structure_learning/score_based_algorithm.rs
  12. 10
      reCTBN/src/structure_learning/score_function.rs
  13. 18
      reCTBN/src/tools.rs
  14. 249
      reCTBN/tests/ctbn.rs
  15. 127
      reCTBN/tests/ctmp.rs
  16. 4
      reCTBN/tests/parameter_learning.rs
  17. 118
      reCTBN/tests/reward_function.rs
  18. 4
      reCTBN/tests/structure_learning.rs
  19. 4
      reCTBN/tests/tools.rs

@ -3,10 +3,10 @@
#[cfg(test)] #[cfg(test)]
extern crate approx; extern crate approx;
pub mod ctbn;
pub mod network;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod process;
pub mod reward_function;
pub mod sampling; pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;

@ -5,10 +5,10 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::params::*; use crate::params::*;
use crate::{network, tools}; use crate::{process, tools};
pub trait ParameterLearning { pub trait ParameterLearning {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -17,7 +17,7 @@ pub trait ParameterLearning {
) -> Params; ) -> Params;
} }
pub fn sufficient_statistics<T: network::Network>( pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
@ -73,7 +73,7 @@ pub fn sufficient_statistics<T: network::Network>(
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -120,7 +120,7 @@ pub struct BayesianApproach {
} }
impl ParameterLearning for BayesianApproach { impl ParameterLearning for BayesianApproach {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -177,7 +177,7 @@ impl<P: ParameterLearning> Cache<P> {
dataset, dataset,
} }
} }
pub fn fit<T: network::Network>( pub fn fit<T: process::NetworkProcess>(
&mut self, &mut self,
net: &T, net: &T,
node: usize, node: usize,

@ -267,11 +267,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
))); )));
} }
let domain_size = domain_size as f64;
// Check if each row sum up to 0 // Check if each row sum up to 0
if cim if cim
.sum_axis(Axis(2)) .sum_axis(Axis(2))
.iter() .iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size)
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",

@ -1,5 +1,8 @@
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs //! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
pub mod ctbn;
pub mod ctmp;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use thiserror::Error; use thiserror::Error;
@ -13,9 +16,12 @@ pub enum NetworkError {
NodeInsertionError(String), NodeInsertionError(String),
} }
/// This type is used to represent a specific realization of a generic NetworkProcess
pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN). /// as a CTBN).
pub trait Network { pub trait NetworkProcess {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network. /// Add an **directed edge** between a two nodes of the network.
@ -68,8 +74,7 @@ pub trait Network {
/// # Return /// # Return
/// ///
/// * Index of the `node` relative to the network. /// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize;
-> usize;
/// Compute the index that must be used to access the parameters of a `node`, given a specific /// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`. /// configuration of the network and a generic `parent_set`.

@ -4,8 +4,11 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::network; use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
use crate::params::{Params, ParamsTrait, StateType}; use crate::process;
use super::ctmp::CtmpProcess;
use super::{NetworkProcess, NetworkProcessState};
/// It represents both the structure and the parameters of a CTBN. /// It represents both the structure and the parameters of a CTBN.
/// ///
@ -20,9 +23,9 @@ use crate::params::{Params, ParamsTrait, StateType};
/// ///
/// ```rust /// ```rust
/// use std::collections::BTreeSet; /// use std::collections::BTreeSet;
/// use reCTBN::network::Network; /// use reCTBN::process::NetworkProcess;
/// use reCTBN::params; /// use reCTBN::params;
/// use reCTBN::ctbn::*; /// use reCTBN::process::ctbn::*;
/// ///
/// //Create the domain for a discrete node /// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new(); /// let mut domain = BTreeSet::new();
@ -67,9 +70,77 @@ impl CtbnNetwork {
nodes: Vec::new(), nodes: Vec::new(),
} }
} }
///Transform the **CTBN** into a **CTMP**
///
/// # Return
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess {
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
let state_space = variables_domain.product();
let variables_set = BTreeSet::from_iter(self.get_node_indices());
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space));
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: NetworkProcessState = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
for idx_node in 0..self.nodes.len() {
let p = match self.get_node(idx_node) {
Params::DiscreteStatesContinousTime(p) => p,
};
for next_node_state in 0..variables_domain[idx_node] {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
&variables_set,
);
amalgamated_cim[[0, idx_current_state, idx_next_state]] +=
p.get_cim().as_ref().unwrap()[[
self.get_param_index_network(idx_node, &current_state_statetype),
current_state[idx_node],
next_node_state,
]];
}
}
}
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new(
"ctmp".to_string(),
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param))
.unwrap();
return ctmp;
}
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> {
let mut state = state;
let mut array_state = Array1::zeros(variables_domain.shape()[0]);
for (idx, var) in variables_domain.indexed_iter() {
array_state[idx] = state % var;
state = state / var;
}
return array_state;
}
} }
impl network::Network for CtbnNetwork { impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix. /// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) { fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros( self.adj_matrix = Some(Array2::<u16>::zeros(
@ -78,7 +149,7 @@ impl network::Network for CtbnNetwork {
} }
/// Add a new node. /// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params(); n.reset_params();
self.adj_matrix = Option::None; self.adj_matrix = Option::None;
self.nodes.push(n); self.nodes.push(n);
@ -114,7 +185,7 @@ impl network::Network for CtbnNetwork {
&mut self.nodes[node_idx] &mut self.nodes[node_idx]
} }
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize { fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix self.adj_matrix
.as_ref() .as_ref()
.unwrap() .unwrap()
@ -133,7 +204,7 @@ impl network::Network for CtbnNetwork {
fn get_param_index_from_custom_parent_set( fn get_param_index_from_custom_parent_set(
&self, &self,
current_state: &Vec<StateType>, current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>, parent_set: &BTreeSet<usize>,
) -> usize { ) -> usize {
parent_set parent_set

@ -0,0 +1,114 @@
use std::collections::BTreeSet;
use crate::{
params::{Params, StateType},
process,
};
use super::{NetworkProcess, NetworkProcessState};
pub struct CtmpProcess {
param: Option<Params>,
}
impl CtmpProcess {
pub fn new() -> CtmpProcess {
CtmpProcess { param: None }
}
}
impl NetworkProcess for CtmpProcess {
fn initialize_adj_matrix(&mut self) {
unimplemented!("CtmpProcess has only one node")
}
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> {
match self.param {
None => {
self.param = Some(n);
Ok(0)
}
Some(_) => Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(),
)),
}
}
fn add_edge(&mut self, _parent: usize, _child: usize) {
unimplemented!("CtmpProcess has only one node")
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
match self.param {
None => 0..0,
Some(_) => 0..1,
}
}
fn get_number_of_nodes(&self) -> usize {
match self.param {
None => 0,
Some(_) => 1,
}
}
fn get_node(&self, node_idx: usize) -> &crate::params::Params {
if node_idx == 0 {
self.param.as_ref().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params {
if node_idx == 0 {
self.param.as_mut().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
if node == 0 {
match current_state[0] {
StateType::Discrete(x) => x,
}
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_from_custom_parent_set(
&self,
_current_state: &NetworkProcessState,
_parent_set: &BTreeSet<usize>,
) -> usize {
unimplemented!("CtmpProcess has only one node")
}
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
}

@ -0,0 +1,142 @@
//! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process,
};
use ndarray;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>,
}
impl FactoredRewardFunction {
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> {
&self.transition_reward[node_idx]
}
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> {
&mut self.transition_reward[node_idx]
}
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> {
&self.instantaneous_reward[node_idx]
}
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(
&self,
current_state: process::NetworkProcessState,
previous_state: Option<process::NetworkProcessState>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state
.iter()
.zip(current_state.iter())
.enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p {
params::StateType::Discrete(p) => p,
};
let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else {
Reward {
transition_reward: 0.0,
instantaneous_reward,
}
}
}
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() {
//This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent();
instantaneous_reward.push(ndarray::Array1::zeros(size));
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
}
}

@ -1,28 +1,34 @@
//! Module containing methods for the sampling. //! Module containing methods for the sampling.
use crate::{ use crate::{
network::Network, params::ParamsTrait,
params::{self, ParamsTrait}, process::{NetworkProcess, NetworkProcessState},
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
pub trait Sampler: Iterator { #[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: NetworkProcessState,
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self); fn reset(&mut self);
} }
pub struct ForwardSampler<'a, T> pub struct ForwardSampler<'a, T>
where where
T: Network, T: NetworkProcess,
{ {
net: &'a T, net: &'a T,
rng: ChaCha8Rng, rng: ChaCha8Rng,
current_time: f64, current_time: f64,
current_state: Vec<params::StateType>, current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>, next_transitions: Vec<Option<f64>>,
} }
impl<'a, T: Network> ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed { let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator. //If a seed is present use it to initialize the random generator.
@ -42,8 +48,8 @@ impl<'a, T: Network> ForwardSampler<'a, T> {
} }
} }
impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>); type Item = Sample;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone(); let ret_time = self.current_time.clone();
@ -96,11 +102,14 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None; self.next_transitions[child] = None;
} }
Some((ret_time, ret_state)) Some(Sample {
t: ret_time,
state: ret_state,
})
} }
} }
impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) { fn reset(&mut self) {
self.current_time = 0.0; self.current_time = 0.0;
self.current_state = self self.current_state = self

@ -4,10 +4,10 @@ pub mod constraint_based_algorithm;
pub mod hypothesis_test; pub mod hypothesis_test;
pub mod score_based_algorithm; pub mod score_based_algorithm;
pub mod score_function; pub mod score_function;
use crate::{network, tools}; use crate::{process, tools};
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network; T: process::NetworkProcess;
} }

@ -6,7 +6,7 @@ use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use crate::params::*; use crate::params::*;
use crate::{network, parameter_learning}; use crate::{parameter_learning, process};
pub trait HypothesisTest { pub trait HypothesisTest {
fn call<T, P>( fn call<T, P>(
@ -18,7 +18,7 @@ pub trait HypothesisTest {
cache: &mut parameter_learning::Cache<P>, cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning; P: parameter_learning::ParameterLearning;
} }
@ -221,7 +221,7 @@ impl HypothesisTest for ChiSquare {
cache: &mut parameter_learning::Cache<P>, cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning, P: parameter_learning::ParameterLearning,
{ {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM

@ -4,7 +4,7 @@ use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{network, tools}; use crate::{process, tools};
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,
@ -23,7 +23,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Check the coherence between dataset and network //Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {

@ -5,7 +5,7 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use statrs::function::gamma; use statrs::function::gamma;
use crate::{network, parameter_learning, params, tools}; use crate::{parameter_learning, params, process, tools};
pub trait ScoreFunction { pub trait ScoreFunction {
fn call<T>( fn call<T>(
@ -16,7 +16,7 @@ pub trait ScoreFunction {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network; T: process::NetworkProcess;
} }
pub struct LogLikelihood { pub struct LogLikelihood {
@ -41,7 +41,7 @@ impl LogLikelihood {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> (f64, Array3<usize>) ) -> (f64, Array3<usize>)
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Identify the type of node used //Identify the type of node used
match &net.get_node(node) { match &net.get_node(node) {
@ -100,7 +100,7 @@ impl ScoreFunction for LogLikelihood {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network, T: process::NetworkProcess,
{ {
self.compute_score(net, node, parent_set, dataset).0 self.compute_score(net, node, parent_set, dataset).0
} }
@ -127,7 +127,7 @@ impl ScoreFunction for BIC {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Compute the log-likelihood //Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);

@ -3,7 +3,7 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{network, params}; use crate::{params, process};
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,
@ -51,7 +51,7 @@ impl Dataset {
} }
} }
pub fn trajectory_generator<T: network::Network>( pub fn trajectory_generator<T: process::NetworkProcess>(
net: &T, net: &T,
n_trajectories: u64, n_trajectories: u64,
t_end: f64, t_end: f64,
@ -69,18 +69,18 @@ pub fn trajectory_generator<T: network::Network>(
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut events: Vec<Vec<params::StateType>> = Vec::new(); let mut events: Vec<process::NetworkProcessState> = Vec::new();
//Current Time and Current State //Current Time and Current State
let (mut t, mut current_state) = sampler.next().unwrap(); let mut sample = sampler.next().unwrap();
//Generate new samples until ending time is reached. //Generate new samples until ending time is reached.
while t < t_end { while sample.t < t_end {
time.push(t); time.push(sample.t);
events.push(current_state); events.push(sample.state);
(t, current_state) = sampler.next().unwrap(); sample = sampler.next().unwrap();
} }
current_state = events.last().unwrap().clone(); let current_state = events.last().unwrap().clone();
events.push(current_state); events.push(current_state);
//Add t_end as last time. //Add t_end as last time.

@ -1,9 +1,12 @@
mod utils; mod utils;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use reCTBN::ctbn::*;
use reCTBN::network::Network; use approx::AbsDiffEq;
use ndarray::arr3;
use reCTBN::params::{self, ParamsTrait}; use reCTBN::params::{self, ParamsTrait};
use reCTBN::process::NetworkProcess;
use reCTBN::process::{ctbn::*};
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
#[test] #[test]
@ -129,3 +132,245 @@ fn compute_index_from_custom_parent_set() {
); );
assert_eq!(2, idx); assert_eq!(2, idx);
} }
#[test]
fn simple_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.initialize_adj_matrix();
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0);
let p_ctbn = p_ctbn.get_cim().as_ref().unwrap();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
}
#[test]
fn chain_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}
#[test]
fn chainfork_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let n4 = net
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2))
.unwrap();
net.add_edge(n1, n3);
net.add_edge(n2, n3);
net.add_edge(n3, n4);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n4) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}

@ -0,0 +1,127 @@
mod utils;
use std::collections::BTreeSet;
use reCTBN::{
params,
params::ParamsTrait,
process::{ctmp::*, NetworkProcess},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn define_simple_ctmp() {
let _ = CtmpProcess::new();
assert!(true);
}
#[test]
fn add_node_to_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
}
#[test]
fn add_two_nodes_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
match n2 {
Ok(_) => assert!(false),
Err(_) => assert!(true),
};
}
#[test]
#[should_panic]
fn add_edge_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
net.add_edge(0, 1)
}
#[test]
fn childen_and_parents() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(0, net.get_parent_set(0).len());
assert_eq!(0, net.get_children_set(0).len());
}
#[test]
#[should_panic]
fn get_childen_panic() {
let net = CtmpProcess::new();
net.get_children_set(0);
}
#[test]
#[should_panic]
fn get_childen_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_children_set(1);
}
#[test]
#[should_panic]
fn get_parent_panic() {
let net = CtmpProcess::new();
net.get_parent_set(0);
}
#[test]
#[should_panic]
fn get_parent_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_parent_set(1);
}
#[test]
fn compute_index_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]);
assert_eq!(6, idx);
}
#[test]
#[should_panic]
fn compute_index_from_custom_parent_set_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let _idx = net.get_param_index_from_custom_parent_set(
&vec![params::StateType::Discrete(6)],
&BTreeSet::from([0])
);
}

@ -2,8 +2,8 @@
mod utils; mod utils;
use ndarray::arr3; use ndarray::arr3;
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::*; use reCTBN::parameter_learning::*;
use reCTBN::params; use reCTBN::params;
use reCTBN::tools::*; use reCTBN::tools::*;

@ -0,0 +1,118 @@
mod utils;
use ndarray::*;
use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward_function::*, params};
#[test]
fn simple_factored_reward_function_binary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(s0.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s0.clone(), Some(s0.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s1.clone())), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
#[test]
fn simple_factored_reward_function_ternary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(s0.clone(), Some(s1.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s0.clone(), Some(s2.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(s1.clone(), Some(s0.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s1.clone(), Some(s2.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(s2.clone(), Some(s0.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(s2.clone(), Some(s1.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
}
#[test]
fn factored_reward_function_two_nodes() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(s00.clone(), Some(s01.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s02.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s00.clone(), Some(s10.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(s01.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s02.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s01.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s02.clone(), Some(s00.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s01.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s02.clone(), Some(s12.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(s10.clone(), Some(s11.clone())), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s12.clone())), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s10.clone(), Some(s00.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(s11.clone(), Some(s10.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s12.clone())), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s11.clone(), Some(s01.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(s12.clone(), Some(s10.clone())), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s11.clone())), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(s12.clone(), Some(s02.clone())), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
}

@ -4,8 +4,8 @@ mod utils;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::parameter_learning::Cache; use reCTBN::parameter_learning::Cache;
use reCTBN::params; use reCTBN::params;

@ -1,6 +1,6 @@
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::params; use reCTBN::params;
use reCTBN::tools::*; use reCTBN::tools::*;

Loading…
Cancel
Save