13 feature add ctmp #73

Merged
AlessandroBregoli merged 6 commits from 13-feature-add-ctmp into dev 2 years ago
  1. 3
      reCTBN/src/lib.rs
  2. 12
      reCTBN/src/parameter_learning.rs
  3. 4
      reCTBN/src/params.rs
  4. 5
      reCTBN/src/process.rs
  5. 84
      reCTBN/src/process/ctbn.rs
  6. 118
      reCTBN/src/process/ctmp.rs
  7. 10
      reCTBN/src/sampling.rs
  8. 4
      reCTBN/src/structure_learning.rs
  9. 6
      reCTBN/src/structure_learning/hypothesis_test.rs
  10. 4
      reCTBN/src/structure_learning/score_based_algorithm.rs
  11. 10
      reCTBN/src/structure_learning/score_function.rs
  12. 4
      reCTBN/src/tools.rs
  13. 249
      reCTBN/tests/ctbn.rs
  14. 127
      reCTBN/tests/ctmp.rs
  15. 4
      reCTBN/tests/parameter_learning.rs
  16. 4
      reCTBN/tests/structure_learning.rs
  17. 4
      reCTBN/tests/tools.rs

@ -3,10 +3,9 @@
#[cfg(test)] #[cfg(test)]
extern crate approx; extern crate approx;
pub mod ctbn;
pub mod network;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod process;
pub mod sampling; pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;

@ -5,10 +5,10 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::params::*; use crate::params::*;
use crate::{network, tools}; use crate::{process, tools};
pub trait ParameterLearning { pub trait ParameterLearning {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -17,7 +17,7 @@ pub trait ParameterLearning {
) -> Params; ) -> Params;
} }
pub fn sufficient_statistics<T: network::Network>( pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
node: usize, node: usize,
@ -73,7 +73,7 @@ pub fn sufficient_statistics<T: network::Network>(
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -120,7 +120,7 @@ pub struct BayesianApproach {
} }
impl ParameterLearning for BayesianApproach { impl ParameterLearning for BayesianApproach {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &tools::Dataset,
@ -177,7 +177,7 @@ impl<P: ParameterLearning> Cache<P> {
dataset, dataset,
} }
} }
pub fn fit<T: network::Network>( pub fn fit<T: process::NetworkProcess>(
&mut self, &mut self,
net: &T, net: &T,
node: usize, node: usize,

@ -267,11 +267,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
))); )));
} }
let domain_size = domain_size as f64;
// Check if each row sum up to 0 // Check if each row sum up to 0
if cim if cim
.sum_axis(Axis(2)) .sum_axis(Axis(2))
.iter() .iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size)
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",

@ -1,5 +1,8 @@
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs //! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
pub mod ctbn;
pub mod ctmp;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use thiserror::Error; use thiserror::Error;
@ -15,7 +18,7 @@ pub enum NetworkError {
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN). /// as a CTBN).
pub trait Network { pub trait NetworkProcess {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network. /// Add an **directed edge** between a two nodes of the network.

@ -4,8 +4,11 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::network; use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
use crate::params::{Params, ParamsTrait, StateType}; use crate::process;
use super::ctmp::CtmpProcess;
use super::NetworkProcess;
/// It represents both the structure and the parameters of a CTBN. /// It represents both the structure and the parameters of a CTBN.
/// ///
@ -20,9 +23,9 @@ use crate::params::{Params, ParamsTrait, StateType};
/// ///
/// ```rust /// ```rust
/// use std::collections::BTreeSet; /// use std::collections::BTreeSet;
/// use reCTBN::network::Network; /// use reCTBN::process::NetworkProcess;
/// use reCTBN::params; /// use reCTBN::params;
/// use reCTBN::ctbn::*; /// use reCTBN::process::ctbn::*;
/// ///
/// //Create the domain for a discrete node /// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new(); /// let mut domain = BTreeSet::new();
@ -67,9 +70,78 @@ impl CtbnNetwork {
nodes: Vec::new(), nodes: Vec::new(),
} }
} }
///Transform the **CTBN** into a **CTMP**
///
/// # Return
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess {
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
let state_space = variables_domain.product();
let variables_set = BTreeSet::from_iter(self.get_node_indices());
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space));
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: Vec<StateType> = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
for idx_node in 0..self.nodes.len() {
let p = match self.get_node(idx_node) {
Params::DiscreteStatesContinousTime(p) => p,
};
for next_node_state in 0..variables_domain[idx_node] {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: Vec<StateType> =
next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
&variables_set,
);
amalgamated_cim[[0, idx_current_state, idx_next_state]] +=
p.get_cim().as_ref().unwrap()[[
self.get_param_index_network(idx_node, &current_state_statetype),
current_state[idx_node],
next_node_state,
]];
}
}
}
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new(
"ctmp".to_string(),
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
println!("{:?}", amalgamated_cim);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param))
.unwrap();
return ctmp;
}
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> {
let mut state = state;
let mut array_state = Array1::zeros(variables_domain.shape()[0]);
for (idx, var) in variables_domain.indexed_iter() {
array_state[idx] = state % var;
state = state / var;
}
return array_state;
}
} }
impl network::Network for CtbnNetwork { impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix. /// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) { fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros( self.adj_matrix = Some(Array2::<u16>::zeros(
@ -78,7 +150,7 @@ impl network::Network for CtbnNetwork {
} }
/// Add a new node. /// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params(); n.reset_params();
self.adj_matrix = Option::None; self.adj_matrix = Option::None;
self.nodes.push(n); self.nodes.push(n);

@ -0,0 +1,118 @@
use std::collections::BTreeSet;
use crate::{
params::{Params, StateType},
process,
};
use super::NetworkProcess;
pub struct CtmpProcess {
param: Option<Params>,
}
impl CtmpProcess {
pub fn new() -> CtmpProcess {
CtmpProcess { param: None }
}
}
impl NetworkProcess for CtmpProcess {
fn initialize_adj_matrix(&mut self) {
unimplemented!("CtmpProcess has only one node")
}
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> {
match self.param {
None => {
self.param = Some(n);
Ok(0)
}
Some(_) => Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(),
)),
}
}
fn add_edge(&mut self, _parent: usize, _child: usize) {
unimplemented!("CtmpProcess has only one node")
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
match self.param {
None => 0..0,
Some(_) => 0..1,
}
}
fn get_number_of_nodes(&self) -> usize {
match self.param {
None => 0,
Some(_) => 1,
}
}
fn get_node(&self, node_idx: usize) -> &crate::params::Params {
if node_idx == 0 {
self.param.as_ref().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params {
if node_idx == 0 {
self.param.as_mut().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_network(
&self,
node: usize,
current_state: &Vec<crate::params::StateType>,
) -> usize {
if node == 0 {
match current_state[0] {
StateType::Discrete(x) => x,
}
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_from_custom_parent_set(
&self,
_current_state: &Vec<crate::params::StateType>,
_parent_set: &std::collections::BTreeSet<usize>,
) -> usize {
unimplemented!("CtmpProcess has only one node")
}
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
}

@ -1,8 +1,8 @@
//! Module containing methods for the sampling. //! Module containing methods for the sampling.
use crate::{ use crate::{
network::Network,
params::{self, ParamsTrait}, params::{self, ParamsTrait},
process::NetworkProcess,
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
@ -13,7 +13,7 @@ pub trait Sampler: Iterator {
pub struct ForwardSampler<'a, T> pub struct ForwardSampler<'a, T>
where where
T: Network, T: NetworkProcess,
{ {
net: &'a T, net: &'a T,
rng: ChaCha8Rng, rng: ChaCha8Rng,
@ -22,7 +22,7 @@ where
next_transitions: Vec<Option<f64>>, next_transitions: Vec<Option<f64>>,
} }
impl<'a, T: Network> ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed { let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator. //If a seed is present use it to initialize the random generator.
@ -42,7 +42,7 @@ impl<'a, T: Network> ForwardSampler<'a, T> {
} }
} }
impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>); type Item = (f64, Vec<params::StateType>);
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
@ -100,7 +100,7 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
} }
} }
impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) { fn reset(&mut self) {
self.current_time = 0.0; self.current_time = 0.0;
self.current_state = self self.current_state = self

@ -4,10 +4,10 @@ pub mod constraint_based_algorithm;
pub mod hypothesis_test; pub mod hypothesis_test;
pub mod score_based_algorithm; pub mod score_based_algorithm;
pub mod score_function; pub mod score_function;
use crate::{network, tools}; use crate::{process, tools};
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network; T: process::NetworkProcess;
} }

@ -6,7 +6,7 @@ use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF}; use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::params::*; use crate::params::*;
use crate::{network, parameter_learning}; use crate::{parameter_learning, process};
pub trait HypothesisTest { pub trait HypothesisTest {
fn call<T, P>( fn call<T, P>(
@ -18,7 +18,7 @@ pub trait HypothesisTest {
cache: &mut parameter_learning::Cache<P>, cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning; P: parameter_learning::ParameterLearning;
} }
@ -135,7 +135,7 @@ impl HypothesisTest for ChiSquare {
cache: &mut parameter_learning::Cache<P>, cache: &mut parameter_learning::Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning, P: parameter_learning::ParameterLearning,
{ {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM

@ -4,7 +4,7 @@ use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{network, tools}; use crate::{process, tools};
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,
@ -23,7 +23,7 @@ impl<S: ScoreFunction> HillClimbing<S> {
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Check the coherence between dataset and network //Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {

@ -5,7 +5,7 @@ use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use statrs::function::gamma; use statrs::function::gamma;
use crate::{network, parameter_learning, params, tools}; use crate::{parameter_learning, params, process, tools};
pub trait ScoreFunction { pub trait ScoreFunction {
fn call<T>( fn call<T>(
@ -16,7 +16,7 @@ pub trait ScoreFunction {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network; T: process::NetworkProcess;
} }
pub struct LogLikelihood { pub struct LogLikelihood {
@ -41,7 +41,7 @@ impl LogLikelihood {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> (f64, Array3<usize>) ) -> (f64, Array3<usize>)
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Identify the type of node used //Identify the type of node used
match &net.get_node(node) { match &net.get_node(node) {
@ -100,7 +100,7 @@ impl ScoreFunction for LogLikelihood {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network, T: process::NetworkProcess,
{ {
self.compute_score(net, node, parent_set, dataset).0 self.compute_score(net, node, parent_set, dataset).0
} }
@ -127,7 +127,7 @@ impl ScoreFunction for BIC {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Compute the log-likelihood //Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);

@ -3,7 +3,7 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{network, params}; use crate::{params, process};
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,
@ -51,7 +51,7 @@ impl Dataset {
} }
} }
pub fn trajectory_generator<T: network::Network>( pub fn trajectory_generator<T: process::NetworkProcess>(
net: &T, net: &T,
n_trajectories: u64, n_trajectories: u64,
t_end: f64, t_end: f64,

@ -1,9 +1,12 @@
mod utils; mod utils;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use reCTBN::ctbn::*;
use reCTBN::network::Network; use approx::AbsDiffEq;
use ndarray::arr3;
use reCTBN::params::{self, ParamsTrait}; use reCTBN::params::{self, ParamsTrait};
use reCTBN::process::NetworkProcess;
use reCTBN::process::{ctbn::*};
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
#[test] #[test]
@ -129,3 +132,245 @@ fn compute_index_from_custom_parent_set() {
); );
assert_eq!(2, idx); assert_eq!(2, idx);
} }
#[test]
fn simple_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.initialize_adj_matrix();
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0);
let p_ctbn = p_ctbn.get_cim().as_ref().unwrap();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
}
#[test]
fn chain_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}
#[test]
fn chainfork_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let n4 = net
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2))
.unwrap();
net.add_edge(n1, n3);
net.add_edge(n2, n3);
net.add_edge(n3, n4);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n4) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}

@ -0,0 +1,127 @@
mod utils;
use std::collections::BTreeSet;
use reCTBN::{
params,
params::ParamsTrait,
process::{ctmp::*, NetworkProcess},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn define_simple_ctmp() {
let _ = CtmpProcess::new();
assert!(true);
}
#[test]
fn add_node_to_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
}
#[test]
fn add_two_nodes_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
match n2 {
Ok(_) => assert!(false),
Err(_) => assert!(true),
};
}
#[test]
#[should_panic]
fn add_edge_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
net.add_edge(0, 1)
}
#[test]
fn childen_and_parents() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(0, net.get_parent_set(0).len());
assert_eq!(0, net.get_children_set(0).len());
}
#[test]
#[should_panic]
fn get_childen_panic() {
let net = CtmpProcess::new();
net.get_children_set(0);
}
#[test]
#[should_panic]
fn get_childen_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_children_set(1);
}
#[test]
#[should_panic]
fn get_parent_panic() {
let net = CtmpProcess::new();
net.get_parent_set(0);
}
#[test]
#[should_panic]
fn get_parent_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_parent_set(1);
}
#[test]
fn compute_index_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]);
assert_eq!(6, idx);
}
#[test]
#[should_panic]
fn compute_index_from_custom_parent_set_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let _idx = net.get_param_index_from_custom_parent_set(
&vec![params::StateType::Discrete(6)],
&BTreeSet::from([0])
);
}

@ -2,8 +2,8 @@
mod utils; mod utils;
use ndarray::arr3; use ndarray::arr3;
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::*; use reCTBN::parameter_learning::*;
use reCTBN::params; use reCTBN::params;
use reCTBN::tools::*; use reCTBN::tools::*;

@ -4,8 +4,8 @@ mod utils;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::parameter_learning::Cache; use reCTBN::parameter_learning::Cache;
use reCTBN::params; use reCTBN::params;

@ -1,6 +1,6 @@
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::params; use reCTBN::params;
use reCTBN::tools::*; use reCTBN::tools::*;

Loading…
Cancel
Save