CTPC {
+ pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC
{
+ CTPC {
+ parameter_learning,
+ Ftest,
+ Chi2test,
+ }
+ }
+}
+
+impl StructureLearningAlgorithm for CTPC {
+ fn fit_transform(&self, net: T, dataset: &Dataset) -> T
+ where
+ T: process::NetworkProcess,
+ {
+ //Check the coherence between dataset and network
+ if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
+ panic!("Dataset and Network must have the same number of variables.")
+ }
+
+ //Make the network mutable.
+ let mut net = net;
+
+ net.initialize_adj_matrix();
+
+ let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![];
+ learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| {
+ let mut cache = Cache::new(&self.parameter_learning);
+ let mut candidate_parent_set: BTreeSet = net
+ .get_node_indices()
+ .into_iter()
+ .filter(|x| x != &child_node)
+ .collect();
+ let mut separation_set_size = 0;
+ while separation_set_size < candidate_parent_set.len() {
+ let mut candidate_parent_set_TMP = candidate_parent_set.clone();
+ for parent_node in candidate_parent_set.iter() {
+ for separation_set in candidate_parent_set
+ .iter()
+ .filter(|x| x != &parent_node)
+ .map(|x| *x)
+ .combinations(separation_set_size)
+ {
+ let separation_set = separation_set.into_iter().collect();
+ if self.Ftest.call(
+ &net,
+ child_node,
+ *parent_node,
+ &separation_set,
+ dataset,
+ &mut cache,
+ ) && self.Chi2test.call(
+ &net,
+ child_node,
+ *parent_node,
+ &separation_set,
+ dataset,
+ &mut cache,
+ ) {
+ candidate_parent_set_TMP.remove(parent_node);
+ break;
+ }
+ }
+ }
+ candidate_parent_set = candidate_parent_set_TMP;
+ separation_set_size += 1;
+ }
+ (child_node, candidate_parent_set)
+ }));
+ for (child_node, candidate_parent_set) in learned_parent_sets {
+ for parent_node in candidate_parent_set.iter() {
+ net.add_edge(*parent_node, child_node);
+ }
+ }
+ net
+ }
+}
diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs
new file mode 100644
index 0000000..311ec47
--- /dev/null
+++ b/reCTBN/src/structure_learning/hypothesis_test.rs
@@ -0,0 +1,261 @@
+//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc...
+
+use std::collections::BTreeSet;
+
+use ndarray::{Array3, Axis};
+use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
+
+use crate::params::*;
+use crate::structure_learning::constraint_based_algorithm::Cache;
+use crate::{parameter_learning, process, tools::Dataset};
+
+pub trait HypothesisTest {
+ fn call(
+ &self,
+ net: &T,
+ child_node: usize,
+ parent_node: usize,
+ separation_set: &BTreeSet,
+ dataset: &Dataset,
+ cache: &mut Cache,
+ ) -> bool
+ where
+ T: process::NetworkProcess,
+ P: parameter_learning::ParameterLearning;
+}
+
+/// Does the chi-squared test (χ2 test).
+///
+/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
+/// a relationship (dependence) between the variables.
+///
+/// # Arguments
+///
+/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
+/// in other words is the risk of concluding that an association between the variables exists
+/// when there is no actual association.
+
+pub struct ChiSquare {
+ alpha: f64,
+}
+
+/// Does the F-test.
+///
+/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
+/// a relationship (dependence) between the variables.
+///
+/// # Arguments
+///
+/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
+/// in other words is the risk of concluding that an association between the variables exists
+/// when there is no actual association.
+
+pub struct F {
+ alpha: f64,
+}
+
+impl F {
+ pub fn new(alpha: f64) -> F {
+ F { alpha }
+ }
+
+ /// Compare two matrices extracted from two 3rd-orer tensors.
+ ///
+ /// # Arguments
+ ///
+ /// * `i` - Position of the matrix of `M1` to compare with `M2`.
+ /// * `M1` - 3rd-order tensor 1.
+ /// * `j` - Position of the matrix of `M2` to compare with `M1`.
+ /// * `M2` - 3rd-order tensor 2.
+ ///
+ /// # Returns
+ ///
+ /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
+ /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
+
+ pub fn compare_matrices(
+ &self,
+ i: usize,
+ M1: &Array3,
+ cim_1: &Array3,
+ j: usize,
+ M2: &Array3,
+ cim_2: &Array3,
+ ) -> bool {
+ let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
+ let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
+ let cim_1 = cim_1.index_axis(Axis(0), i);
+ let cim_2 = cim_2.index_axis(Axis(0), j);
+ let r1 = M1.sum_axis(Axis(1));
+ let r2 = M2.sum_axis(Axis(1));
+ let q1 = cim_1.diag();
+ let q2 = cim_2.diag();
+ for idx in 0..r1.shape()[0] {
+ let s = q2[idx] / q1[idx];
+ let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap();
+ let s = F.cdf(s);
+ let lim_sx = self.alpha / 2.0;
+ let lim_dx = 1.0 - (self.alpha / 2.0);
+ if s < lim_sx || s > lim_dx {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+impl HypothesisTest for F {
+ fn call(
+ &self,
+ net: &T,
+ child_node: usize,
+ parent_node: usize,
+ separation_set: &BTreeSet,
+ dataset: &Dataset,
+ cache: &mut Cache,
+ ) -> bool
+ where
+ T: process::NetworkProcess,
+ P: parameter_learning::ParameterLearning,
+ {
+ let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
+ Params::DiscreteStatesContinousTime(node) => node,
+ };
+ let mut extended_separation_set = separation_set.clone();
+ extended_separation_set.insert(parent_node);
+
+ let P_big = match cache.fit(
+ net,
+ &dataset,
+ child_node,
+ Some(extended_separation_set.clone()),
+ ) {
+ Params::DiscreteStatesContinousTime(node) => node,
+ };
+ let partial_cardinality_product: usize = extended_separation_set
+ .iter()
+ .take_while(|x| **x != parent_node)
+ .map(|x| net.get_node(*x).get_reserved_space_as_parent())
+ .product();
+ for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
+ let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ + (idx_M_big
+ / (partial_cardinality_product
+ * net.get_node(parent_node).get_reserved_space_as_parent()))
+ * partial_cardinality_product;
+ if !self.compare_matrices(
+ idx_M_small,
+ P_small.get_transitions().as_ref().unwrap(),
+ P_small.get_cim().as_ref().unwrap(),
+ idx_M_big,
+ P_big.get_transitions().as_ref().unwrap(),
+ P_big.get_cim().as_ref().unwrap(),
+ ) {
+ return false;
+ }
+ }
+ return true;
+ }
+}
+
+impl ChiSquare {
+ pub fn new(alpha: f64) -> ChiSquare {
+ ChiSquare { alpha }
+ }
+
+ /// Compare two matrices extracted from two 3rd-orer tensors.
+ ///
+ /// # Arguments
+ ///
+ /// * `i` - Position of the matrix of `M1` to compare with `M2`.
+ /// * `M1` - 3rd-order tensor 1.
+ /// * `j` - Position of the matrix of `M2` to compare with `M1`.
+ /// * `M2` - 3rd-order tensor 2.
+ ///
+ /// # Returns
+ ///
+ /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
+ /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
+
+ pub fn compare_matrices(
+ &self,
+ i: usize,
+ M1: &Array3,
+ j: usize,
+ M2: &Array3,
+ ) -> bool {
+ // Bregoli, A., Scutari, M. and Stella, F., 2021.
+ // A constraint-based algorithm for the structural learning of
+ // continuous-time Bayesian networks.
+ // International Journal of Approximate Reasoning, 138, pp.105-122.
+ // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
+ let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
+ let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
+ let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
+ let K = K.mapv(f64::sqrt);
+ // Reshape to column vector.
+ let K = {
+ let n = K.len();
+ K.into_shape((n, 1)).unwrap()
+ };
+ let L = 1.0 / &K;
+ let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
+ X_2.diag_mut().fill(0.0);
+ let X_2 = X_2.sum_axis(Axis(1));
+ let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
+ let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
+ ret
+ }
+}
+
+impl HypothesisTest for ChiSquare {
+ fn call(
+ &self,
+ net: &T,
+ child_node: usize,
+ parent_node: usize,
+ separation_set: &BTreeSet,
+ dataset: &Dataset,
+ cache: &mut Cache,
+ ) -> bool
+ where
+ T: process::NetworkProcess,
+ P: parameter_learning::ParameterLearning,
+ {
+ let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
+ Params::DiscreteStatesContinousTime(node) => node,
+ };
+ let mut extended_separation_set = separation_set.clone();
+ extended_separation_set.insert(parent_node);
+
+ let P_big = match cache.fit(
+ net,
+ &dataset,
+ child_node,
+ Some(extended_separation_set.clone()),
+ ) {
+ Params::DiscreteStatesContinousTime(node) => node,
+ };
+ let partial_cardinality_product: usize = extended_separation_set
+ .iter()
+ .take_while(|x| **x != parent_node)
+ .map(|x| net.get_node(*x).get_reserved_space_as_parent())
+ .product();
+ for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
+ let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ + (idx_M_big
+ / (partial_cardinality_product
+ * net.get_node(parent_node).get_reserved_space_as_parent()))
+ * partial_cardinality_product;
+ if !self.compare_matrices(
+ idx_M_small,
+ P_small.get_transitions().as_ref().unwrap(),
+ idx_M_big,
+ P_big.get_transitions().as_ref().unwrap(),
+ ) {
+ return false;
+ }
+ }
+ return true;
+ }
+}
diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs
new file mode 100644
index 0000000..9173b86
--- /dev/null
+++ b/reCTBN/src/structure_learning/score_based_algorithm.rs
@@ -0,0 +1,93 @@
+//! Module containing score based algorithms like Hill Climbing and Tabu Search.
+
+use std::collections::BTreeSet;
+
+use crate::structure_learning::score_function::ScoreFunction;
+use crate::structure_learning::StructureLearningAlgorithm;
+use crate::{process, tools::Dataset};
+
+use rayon::iter::{IntoParallelIterator, ParallelIterator};
+use rayon::prelude::ParallelExtend;
+
+pub struct HillClimbing {
+ score_function: S,
+ max_parent_set: Option,
+}
+
+impl HillClimbing {
+ pub fn new(score_function: S, max_parent_set: Option) -> HillClimbing {
+ HillClimbing {
+ score_function,
+ max_parent_set,
+ }
+ }
+}
+
+impl StructureLearningAlgorithm for HillClimbing {
+ fn fit_transform(&self, net: T, dataset: &Dataset) -> T
+ where
+ T: process::NetworkProcess,
+ {
+ //Check the coherence between dataset and network
+ if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
+ panic!("Dataset and Network must have the same number of variables.")
+ }
+
+ //Make the network mutable.
+ let mut net = net;
+ //Check if the max_parent_set constraint is present.
+ let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
+ //Reset the adj matrix
+ net.initialize_adj_matrix();
+ let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![];
+ //Iterate over each node to learn their parent set.
+ learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| {
+ //Initialize an empty parent set.
+ let mut parent_set: BTreeSet = BTreeSet::new();
+ //Compute the score for the empty parent set
+ let mut current_score = self.score_function.call(&net, node, &parent_set, dataset);
+ //Set the old score to -\infty.
+ let mut old_score = f64::NEG_INFINITY;
+ //Iterate until convergence
+ while current_score > old_score {
+ //Save the current_score.
+ old_score = current_score;
+ //Iterate over each node.
+ for parent in net.get_node_indices() {
+ //Continue if the parent and the node are the same.
+ if parent == node {
+ continue;
+ }
+ //Try to remove parent from the parent_set.
+ let is_removed = parent_set.remove(&parent);
+ //If parent was not in the parent_set add it.
+ if !is_removed && parent_set.len() < max_parent_set {
+ parent_set.insert(parent);
+ }
+ //Compute the score with the modified parent_set.
+ let tmp_score = self.score_function.call(&net, node, &parent_set, dataset);
+ //If tmp_score is worst than current_score revert the change to the parent set
+ if tmp_score < current_score {
+ if is_removed {
+ parent_set.insert(parent);
+ } else {
+ parent_set.remove(&parent);
+ }
+ }
+ //Otherwise save the computed score as current_score
+ else {
+ current_score = tmp_score;
+ }
+ }
+ }
+ (node, parent_set)
+ }));
+
+ for (child_node, candidate_parent_set) in learned_parent_sets {
+ for parent_node in candidate_parent_set.iter() {
+ net.add_edge(*parent_node, child_node);
+ }
+ }
+ return net;
+ }
+}
diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs
new file mode 100644
index 0000000..5a56594
--- /dev/null
+++ b/reCTBN/src/structure_learning/score_function.rs
@@ -0,0 +1,146 @@
+//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
+
+use std::collections::BTreeSet;
+
+use ndarray::prelude::*;
+use statrs::function::gamma;
+
+use crate::{parameter_learning, params, process, tools};
+
+pub trait ScoreFunction: Sync {
+ fn call(
+ &self,
+ net: &T,
+ node: usize,
+ parent_set: &BTreeSet,
+ dataset: &tools::Dataset,
+ ) -> f64
+ where
+ T: process::NetworkProcess;
+}
+
+pub struct LogLikelihood {
+ alpha: usize,
+ tau: f64,
+}
+
+impl LogLikelihood {
+ pub fn new(alpha: usize, tau: f64) -> LogLikelihood {
+ //Tau must be >=0.0
+ if tau < 0.0 {
+ panic!("tau must be >=0.0");
+ }
+ LogLikelihood { alpha, tau }
+ }
+
+ fn compute_score(
+ &self,
+ net: &T,
+ node: usize,
+ parent_set: &BTreeSet,
+ dataset: &tools::Dataset,
+ ) -> (f64, Array3)
+ where
+ T: process::NetworkProcess,
+ {
+ //Identify the type of node used
+ match &net.get_node(node) {
+ params::Params::DiscreteStatesContinousTime(_params) => {
+ //Compute the sufficient statistics M (number of transistions) and T (residence
+ //time)
+ let (M, T) =
+ parameter_learning::sufficient_statistics(net, dataset, node, parent_set);
+
+ //Scale alpha accordingly to the size of the parent set
+ let alpha = self.alpha as f64 / M.shape()[0] as f64;
+ //Scale tau accordingly to the size of the parent set
+ let tau = self.tau / M.shape()[0] as f64;
+
+ //Compute the log likelihood for q
+ let log_ll_q: f64 = M
+ .sum_axis(Axis(2))
+ .iter()
+ .zip(T.iter())
+ .map(|(m, t)| {
+ gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau)
+ - gamma::ln_gamma(alpha + 1.0)
+ - (alpha + *m as f64 + 1.0) * f64::ln(tau + t)
+ })
+ .sum();
+
+ //Compute the log likelihood for theta
+ let log_ll_theta: f64 = M
+ .outer_iter()
+ .map(|x| {
+ x.outer_iter()
+ .map(|y| {
+ gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64)
+ + y.iter()
+ .map(|z| {
+ gamma::ln_gamma(alpha + *z as f64)
+ - gamma::ln_gamma(alpha)
+ })
+ .sum::()
+ })
+ .sum::()
+ })
+ .sum();
+ (log_ll_theta + log_ll_q, M)
+ }
+ }
+ }
+}
+
+impl ScoreFunction for LogLikelihood {
+ fn call(
+ &self,
+ net: &T,
+ node: usize,
+ parent_set: &BTreeSet,
+ dataset: &tools::Dataset,
+ ) -> f64
+ where
+ T: process::NetworkProcess,
+ {
+ self.compute_score(net, node, parent_set, dataset).0
+ }
+}
+
+pub struct BIC {
+ ll: LogLikelihood,
+}
+
+impl BIC {
+ pub fn new(alpha: usize, tau: f64) -> BIC {
+ BIC {
+ ll: LogLikelihood::new(alpha, tau),
+ }
+ }
+}
+
+impl ScoreFunction for BIC {
+ fn call(
+ &self,
+ net: &T,
+ node: usize,
+ parent_set: &BTreeSet,
+ dataset: &tools::Dataset,
+ ) -> f64
+ where
+ T: process::NetworkProcess,
+ {
+ //Compute the log-likelihood
+ let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
+ //Compute the number of parameters
+ let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
+ //TODO: Optimize this
+ //Compute the sample size
+ let sample_size: usize = dataset
+ .get_trajectories()
+ .iter()
+ .map(|x| x.get_time().len() - 1)
+ .sum();
+ //Compute BIC
+ ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
+ }
+}
diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs
new file mode 100644
index 0000000..5085c43
--- /dev/null
+++ b/reCTBN/src/tools.rs
@@ -0,0 +1,355 @@
+//! Contains commonly used methods used across the crate.
+
+use std::ops::{DivAssign, MulAssign, Range};
+
+use ndarray::{Array, Array1, Array2, Array3, Axis};
+use rand::{Rng, SeedableRng};
+use rand_chacha::ChaCha8Rng;
+
+use crate::params::ParamsTrait;
+use crate::process::NetworkProcess;
+use crate::sampling::{ForwardSampler, Sampler};
+use crate::{params, process};
+
+#[derive(Clone)]
+pub struct Trajectory {
+ time: Array1,
+ events: Array2,
+}
+
+impl Trajectory {
+ pub fn new(time: Array1, events: Array2) -> Trajectory {
+ //Events and time are two part of the same trajectory. For this reason they must have the
+ //same number of sample.
+ if time.shape()[0] != events.shape()[0] {
+ panic!("time.shape[0] must be equal to events.shape[0]");
+ }
+ Trajectory { time, events }
+ }
+
+ pub fn get_time(&self) -> &Array1 {
+ &self.time
+ }
+
+ pub fn get_events(&self) -> &Array2 {
+ &self.events
+ }
+}
+
+#[derive(Clone)]
+pub struct Dataset {
+ trajectories: Vec,
+}
+
+impl Dataset {
+ pub fn new(trajectories: Vec) -> Dataset {
+ //All the trajectories in the same dataset must represent the same process. For this reason
+ //each trajectory must represent the same number of variables.
+ if trajectories
+ .iter()
+ .any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1])
+ {
+ panic!("All the trajectories mus represents the same number of variables");
+ }
+ Dataset { trajectories }
+ }
+
+ pub fn get_trajectories(&self) -> &Vec {
+ &self.trajectories
+ }
+}
+
+pub fn trajectory_generator(
+ net: &T,
+ n_trajectories: u64,
+ t_end: f64,
+ seed: Option,
+) -> Dataset {
+ //Tmp growing vector containing generated trajectories.
+ let mut trajectories: Vec = Vec::new();
+
+ //Random Generator object
+ let mut sampler = ForwardSampler::new(net, seed, None);
+ //Each iteration generate one trajectory
+ for _ in 0..n_trajectories {
+ //History of all the moments in which something changed
+ let mut time: Vec = Vec::new();
+ //Configuration of the process variables at time t initialized with an uniform
+ //distribution.
+ let mut events: Vec = Vec::new();
+
+ //Current Time and Current State
+ let mut sample = sampler.next().unwrap();
+ //Generate new samples until ending time is reached.
+ while sample.t < t_end {
+ time.push(sample.t);
+ events.push(sample.state);
+ sample = sampler.next().unwrap();
+ }
+
+ let current_state = events.last().unwrap().clone();
+ events.push(current_state);
+
+ //Add t_end as last time.
+ time.push(t_end.clone());
+
+ //Add the sampled trajectory to trajectories.
+ trajectories.push(Trajectory::new(
+ Array::from_vec(time),
+ Array2::from_shape_vec(
+ (events.len(), events.last().unwrap().len()),
+ events
+ .iter()
+ .flatten()
+ .map(|x| match x {
+ params::StateType::Discrete(x) => x.clone(),
+ })
+ .collect(),
+ )
+ .unwrap(),
+ ));
+ sampler.reset();
+ }
+ //Return a dataset object with the sampled trajectories.
+ Dataset::new(trajectories)
+}
+
+pub trait RandomGraphGenerator {
+ fn new(density: f64, seed: Option) -> Self;
+ fn generate_graph(&mut self, net: &mut T);
+}
+
+/// Graph Generator using an uniform distribution.
+///
+/// A method to generate a random graph with edges uniformly distributed.
+///
+/// # Arguments
+///
+/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`.
+/// * `rng` - is the random numbers generator.
+///
+/// # Example
+///
+/// ```rust
+/// # use std::collections::BTreeSet;
+/// # use ndarray::{arr1, arr2, arr3};
+/// # use reCTBN::params;
+/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
+/// # use reCTBN::tools::trajectory_generator;
+/// # use reCTBN::process::NetworkProcess;
+/// # use reCTBN::process::ctbn::CtbnNetwork;
+/// use reCTBN::tools::UniformGraphGenerator;
+/// use reCTBN::tools::RandomGraphGenerator;
+/// # let mut net = CtbnNetwork::new();
+/// # let nodes_cardinality = 8;
+/// # let domain_cardinality = 4;
+/// # for node in 0..nodes_cardinality {
+/// # // Create the domain for a discrete node
+/// # let mut domain = BTreeSet::new();
+/// # for dvalue in 0..domain_cardinality {
+/// # domain.insert(dvalue.to_string());
+/// # }
+/// # // Create the parameters for a discrete node using the domain
+/// # let param = params::DiscreteStatesContinousTimeParams::new(
+/// # node.to_string(),
+/// # domain
+/// # );
+/// # //Create the node using the parameters
+/// # let node = DiscreteStatesContinousTime(param);
+/// # // Add the node to the network
+/// # net.add_node(node).unwrap();
+/// # }
+///
+/// // Initialize the Graph Generator using the one with an
+/// // uniform distribution
+/// let density = 1.0/3.0;
+/// let seed = Some(7641630759785120);
+/// let mut structure_generator = UniformGraphGenerator::new(
+/// density,
+/// seed
+/// );
+///
+/// // Generate the graph directly on the network
+/// structure_generator.generate_graph(&mut net);
+/// # // Count all the edges generated in the network
+/// # let mut edges = 0;
+/// # for node in net.get_node_indices(){
+/// # edges += net.get_children_set(node).len()
+/// # }
+/// # // Number of all the nodes in the network
+/// # let nodes = net.get_node_indices().len() as f64;
+/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
+/// # // ±10% of tolerance
+/// # let tolerance = ((expected_edges as f64)*0.10) as usize;
+/// # // As the way `generate_graph()` is implemented we can only reasonably
+/// # // expect the number of edges to be somewhere around the expected value.
+/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
+/// ```
+pub struct UniformGraphGenerator {
+ density: f64,
+ rng: ChaCha8Rng,
+}
+
+impl RandomGraphGenerator for UniformGraphGenerator {
+ fn new(density: f64, seed: Option) -> UniformGraphGenerator {
+ if density < 0.0 || density > 1.0 {
+ panic!(
+ "Density value must be between 1.0 and 0.0, got {}.",
+ density
+ );
+ }
+ let rng: ChaCha8Rng = match seed {
+ Some(seed) => SeedableRng::seed_from_u64(seed),
+ None => SeedableRng::from_entropy(),
+ };
+ UniformGraphGenerator { density, rng }
+ }
+
+ /// Generate an uniformly distributed graph.
+ fn generate_graph(&mut self, net: &mut T) {
+ net.initialize_adj_matrix();
+ let last_node_idx = net.get_node_indices().len();
+ for parent in 0..last_node_idx {
+ for child in 0..last_node_idx {
+ if parent != child {
+ if self.rng.gen_bool(self.density) {
+ net.add_edge(parent, child);
+ }
+ }
+ }
+ }
+ }
+}
+
+pub trait RandomParametersGenerator {
+ fn new(interval: Range, seed: Option) -> Self;
+ fn generate_parameters(&mut self, net: &mut T);
+}
+
+/// Parameters Generator using an uniform distribution.
+///
+/// A method to generate random parameters uniformly distributed.
+///
+/// # Arguments
+///
+/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`.
+/// * `rng` - is the random numbers generator.
+///
+/// # Example
+///
+/// ```rust
+/// # use std::collections::BTreeSet;
+/// # use ndarray::{arr1, arr2, arr3};
+/// # use reCTBN::params;
+/// # use reCTBN::params::ParamsTrait;
+/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
+/// # use reCTBN::process::NetworkProcess;
+/// # use reCTBN::process::ctbn::CtbnNetwork;
+/// # use reCTBN::tools::trajectory_generator;
+/// # use reCTBN::tools::RandomGraphGenerator;
+/// # use reCTBN::tools::UniformGraphGenerator;
+/// use reCTBN::tools::RandomParametersGenerator;
+/// use reCTBN::tools::UniformParametersGenerator;
+/// # let mut net = CtbnNetwork::new();
+/// # let nodes_cardinality = 8;
+/// # let domain_cardinality = 4;
+/// # for node in 0..nodes_cardinality {
+/// # // Create the domain for a discrete node
+/// # let mut domain = BTreeSet::new();
+/// # for dvalue in 0..domain_cardinality {
+/// # domain.insert(dvalue.to_string());
+/// # }
+/// # // Create the parameters for a discrete node using the domain
+/// # let param = params::DiscreteStatesContinousTimeParams::new(
+/// # node.to_string(),
+/// # domain
+/// # );
+/// # //Create the node using the parameters
+/// # let node = DiscreteStatesContinousTime(param);
+/// # // Add the node to the network
+/// # net.add_node(node).unwrap();
+/// # }
+/// #
+/// # // Initialize the Graph Generator using the one with an
+/// # // uniform distribution
+/// # let mut structure_generator = UniformGraphGenerator::new(
+/// # 1.0/3.0,
+/// # Some(7641630759785120)
+/// # );
+/// #
+/// # // Generate the graph directly on the network
+/// # structure_generator.generate_graph(&mut net);
+///
+/// // Initialize the parameters generator with uniform distributin
+/// let mut cim_generator = UniformParametersGenerator::new(
+/// 0.0..7.0,
+/// Some(7641630759785120)
+/// );
+///
+/// // Generate CIMs with uniformly distributed parameters.
+/// cim_generator.generate_parameters(&mut net);
+/// #
+/// # for node in net.get_node_indices() {
+/// # assert_eq!(
+/// # Ok(()),
+/// # net.get_node(node).validate_params()
+/// # );
+/// }
+/// ```
+pub struct UniformParametersGenerator {
+ interval: Range,
+ rng: ChaCha8Rng,
+}
+
+impl RandomParametersGenerator for UniformParametersGenerator {
+ fn new(interval: Range, seed: Option) -> UniformParametersGenerator {
+ if interval.start < 0.0 || interval.end < 0.0 {
+ panic!(
+ "Interval must be entirely less or equal than 0, got {}..{}.",
+ interval.start, interval.end
+ );
+ }
+ let rng: ChaCha8Rng = match seed {
+ Some(seed) => SeedableRng::seed_from_u64(seed),
+ None => SeedableRng::from_entropy(),
+ };
+ UniformParametersGenerator { interval, rng }
+ }
+
+ /// Generate CIMs with uniformly distributed parameters.
+ fn generate_parameters(&mut self, net: &mut T) {
+ for node in net.get_node_indices() {
+ let parent_set_state_space_cardinality: usize = net
+ .get_parent_set(node)
+ .iter()
+ .map(|x| net.get_node(*x).get_reserved_space_as_parent())
+ .product();
+ match &mut net.get_node_mut(node) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ let node_domain_cardinality = param.get_reserved_space_as_parent();
+ let mut cim = Array3::::from_shape_fn(
+ (
+ parent_set_state_space_cardinality,
+ node_domain_cardinality,
+ node_domain_cardinality,
+ ),
+ |_| self.rng.gen(),
+ );
+ cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
+ x.diag_mut().fill(0.0);
+ x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
+ let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| {
+ self.rng.gen_range(self.interval.clone())
+ });
+ x.mul_assign(&diag.clone().insert_axis(Axis(1)));
+ // Recomputing the diagonal in order to reduce the issues caused by the
+ // loss of precision when validating the parameters.
+ let diag_sum = -x.sum_axis(Axis(1));
+ x.diag_mut().assign(&diag_sum)
+ });
+ param.set_cim_unchecked(cim);
+ }
+ }
+ }
+ }
+}
diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs
new file mode 100644
index 0000000..3eb40d7
--- /dev/null
+++ b/reCTBN/tests/ctbn.rs
@@ -0,0 +1,376 @@
+mod utils;
+use std::collections::BTreeSet;
+
+
+use approx::AbsDiffEq;
+use ndarray::arr3;
+use reCTBN::params::{self, ParamsTrait};
+use reCTBN::process::NetworkProcess;
+use reCTBN::process::{ctbn::*};
+use utils::generate_discrete_time_continous_node;
+
+#[test]
+fn define_simpe_ctbn() {
+ let _ = CtbnNetwork::new();
+ assert!(true);
+}
+
+#[test]
+fn add_node_to_ctbn() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
+}
+
+#[test]
+fn add_edge_to_ctbn() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ net.add_edge(n1, n2);
+ let cs = net.get_children_set(n1);
+ assert_eq!(&n2, cs.iter().next().unwrap());
+}
+
+#[test]
+fn children_and_parents() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ net.add_edge(n1, n2);
+ let cs = net.get_children_set(n1);
+ assert_eq!(&n2, cs.iter().next().unwrap());
+ let ps = net.get_parent_set(n2);
+ assert_eq!(&n1, ps.iter().next().unwrap());
+}
+
+#[test]
+fn compute_index_ctbn() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ let n3 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
+ .unwrap();
+ net.add_edge(n1, n2);
+ net.add_edge(n3, n2);
+ let idx = net.get_param_index_network(
+ n2,
+ &vec![
+ params::StateType::Discrete(1),
+ params::StateType::Discrete(1),
+ params::StateType::Discrete(1),
+ ],
+ );
+ assert_eq!(3, idx);
+
+ let idx = net.get_param_index_network(
+ n2,
+ &vec![
+ params::StateType::Discrete(0),
+ params::StateType::Discrete(1),
+ params::StateType::Discrete(1),
+ ],
+ );
+ assert_eq!(2, idx);
+
+ let idx = net.get_param_index_network(
+ n2,
+ &vec![
+ params::StateType::Discrete(1),
+ params::StateType::Discrete(1),
+ params::StateType::Discrete(0),
+ ],
+ );
+ assert_eq!(1, idx);
+}
+
+#[test]
+fn compute_index_from_custom_parent_set() {
+ let mut net = CtbnNetwork::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let _n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ let _n3 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
+ .unwrap();
+
+ let idx = net.get_param_index_from_custom_parent_set(
+ &vec![
+ params::StateType::Discrete(0),
+ params::StateType::Discrete(0),
+ params::StateType::Discrete(1),
+ ],
+ &BTreeSet::from([1]),
+ );
+ assert_eq!(0, idx);
+
+ let idx = net.get_param_index_from_custom_parent_set(
+ &vec![
+ params::StateType::Discrete(0),
+ params::StateType::Discrete(0),
+ params::StateType::Discrete(1),
+ ],
+ &BTreeSet::from([1, 2]),
+ );
+ assert_eq!(2, idx);
+}
+
+#[test]
+fn simple_amalgamation() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+
+ net.initialize_adj_matrix();
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
+ }
+ }
+
+ let ctmp = net.amalgamation();
+ let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0);
+ let p_ctbn = p_ctbn.get_cim().as_ref().unwrap();
+ let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
+ let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
+
+ assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
+}
+
+#[test]
+fn chain_amalgamation() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ let n3 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
+ .unwrap();
+
+ net.add_edge(n1, n2);
+ net.add_edge(n2, n3);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-5.0, 5.0], [0.01, -0.01]]
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n3) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-5.0, 5.0], [0.01, -0.01]]
+ ]))
+ );
+ }
+ }
+
+ let ctmp = net.amalgamation();
+
+
+
+ let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
+ let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
+
+ let p_ctmp_handmade = arr3(&[[
+ [
+ -1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
+ ],
+ [
+ 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00,
+ ],
+ ]]);
+
+ assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
+}
+
+#[test]
+fn chainfork_amalgamation() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ let n3 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
+ .unwrap();
+ let n4 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n4"), 2))
+ .unwrap();
+
+ net.add_edge(n1, n3);
+ net.add_edge(n2, n3);
+ net.add_edge(n3, n4);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
+ }
+ }
+
+ match &mut net.get_node_mut(n3) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-5.0, 5.0], [0.01, -0.01]]
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n4) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-5.0, 5.0], [0.01, -0.01]]
+ ]))
+ );
+ }
+ }
+
+
+ let ctmp = net.amalgamation();
+
+ let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
+
+ let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
+
+ let p_ctmp_handmade = arr3(&[[
+ [
+ -2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
+ 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
+ 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
+ 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
+ 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00,
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01,
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01,
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00,
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
+ ],
+ [
+ 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
+ 0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
+ 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
+ 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
+ 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01,
+ ],
+ [
+ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02,
+ 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00,
+ ],
+ ]]);
+
+ assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
+}
diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs
new file mode 100644
index 0000000..830bfe0
--- /dev/null
+++ b/reCTBN/tests/ctmp.rs
@@ -0,0 +1,127 @@
+mod utils;
+
+use std::collections::BTreeSet;
+
+use reCTBN::{
+ params,
+ params::ParamsTrait,
+ process::{ctmp::*, NetworkProcess},
+};
+use utils::generate_discrete_time_continous_node;
+
+#[test]
+fn define_simple_ctmp() {
+ let _ = CtmpProcess::new();
+ assert!(true);
+}
+
+#[test]
+fn add_node_to_ctmp() {
+ let mut net = CtmpProcess::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
+}
+
+#[test]
+fn add_two_nodes_to_ctmp() {
+ let mut net = CtmpProcess::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
+
+ match n2 {
+ Ok(_) => assert!(false),
+ Err(_) => assert!(true),
+ };
+}
+
+#[test]
+#[should_panic]
+fn add_edge_to_ctmp() {
+ let mut net = CtmpProcess::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
+
+ net.add_edge(0, 1)
+}
+
+#[test]
+fn childen_and_parents() {
+ let mut net = CtmpProcess::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+
+ assert_eq!(0, net.get_parent_set(0).len());
+ assert_eq!(0, net.get_children_set(0).len());
+}
+
+#[test]
+#[should_panic]
+fn get_childen_panic() {
+ let net = CtmpProcess::new();
+ net.get_children_set(0);
+}
+
+#[test]
+#[should_panic]
+fn get_childen_panic2() {
+ let mut net = CtmpProcess::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ net.get_children_set(1);
+}
+
+#[test]
+#[should_panic]
+fn get_parent_panic() {
+ let net = CtmpProcess::new();
+ net.get_parent_set(0);
+}
+
+#[test]
+#[should_panic]
+fn get_parent_panic2() {
+ let mut net = CtmpProcess::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ net.get_parent_set(1);
+}
+
+#[test]
+fn compute_index_ctmp() {
+ let mut net = CtmpProcess::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(
+ String::from("n1"),
+ 10,
+ ))
+ .unwrap();
+
+ let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]);
+ assert_eq!(6, idx);
+}
+
+#[test]
+#[should_panic]
+fn compute_index_from_custom_parent_set_ctmp() {
+ let mut net = CtmpProcess::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(
+ String::from("n1"),
+ 10,
+ ))
+ .unwrap();
+
+ let _idx = net.get_param_index_from_custom_parent_set(
+ &vec![params::StateType::Discrete(6)],
+ &BTreeSet::from([0])
+ );
+}
diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs
new file mode 100644
index 0000000..0a09a2a
--- /dev/null
+++ b/reCTBN/tests/parameter_learning.rs
@@ -0,0 +1,648 @@
+#![allow(non_snake_case)]
+
+mod utils;
+use ndarray::arr3;
+use reCTBN::process::ctbn::*;
+use reCTBN::process::NetworkProcess;
+use reCTBN::parameter_learning::*;
+use reCTBN::params;
+use reCTBN::params::Params::DiscreteStatesContinousTime;
+use reCTBN::tools::*;
+use utils::*;
+
+extern crate approx;
+use crate::approx::AbsDiffEq;
+
+fn learn_binary_cim(pl: T) {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ net.add_edge(n1, n2);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 1.0],
+ [4.0, -4.0]
+ ],
+ [
+ [-6.0, 6.0],
+ [2.0, -2.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
+ let p = match pl.fit(&net, &data, 1, None) {
+ params::Params::DiscreteStatesContinousTime(p) => p,
+ };
+ assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
+ assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
+ &arr3(&[
+ [
+ [-1.0, 1.0],
+ [4.0, -4.0]
+ ],
+ [
+ [-6.0, 6.0],
+ [2.0, -2.0]
+ ],
+ ]),
+ 0.1
+ ));
+}
+
+fn generate_nodes(
+ net: &mut CtbnNetwork,
+ nodes_cardinality: usize,
+ nodes_domain_cardinality: usize
+) {
+ for node_label in 0..nodes_cardinality {
+ net.add_node(
+ generate_discrete_time_continous_node(
+ node_label.to_string(),
+ nodes_domain_cardinality,
+ )
+ ).unwrap();
+ }
+}
+
+fn learn_binary_cim_gen(pl: T) {
+ let mut net = CtbnNetwork::new();
+ generate_nodes(&mut net, 2, 2);
+
+ net.add_edge(0, 1);
+
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ 1.0..6.0,
+ Some(6813071588535822)
+ );
+ cim_generator.generate_parameters(&mut net);
+
+ let p_gen = match net.get_node(1) {
+ DiscreteStatesContinousTime(p_gen) => p_gen,
+ };
+
+ let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
+ let p_tj = match pl.fit(&net, &data, 1, None) {
+ DiscreteStatesContinousTime(p_tj) => p_tj,
+ };
+
+ assert_eq!(
+ p_tj.get_cim().as_ref().unwrap().shape(),
+ p_gen.get_cim().as_ref().unwrap().shape()
+ );
+ assert!(
+ p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
+ &p_gen.get_cim().as_ref().unwrap(),
+ 0.1
+ )
+ );
+}
+
+#[test]
+fn learn_binary_cim_MLE() {
+ let mle = MLE {};
+ learn_binary_cim(mle);
+}
+
+#[test]
+fn learn_binary_cim_MLE_gen() {
+ let mle = MLE {};
+ learn_binary_cim_gen(mle);
+}
+
+#[test]
+fn learn_binary_cim_BA() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_binary_cim(ba);
+}
+
+#[test]
+fn learn_binary_cim_BA_gen() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_binary_cim_gen(ba);
+}
+
+fn learn_ternary_cim(pl: T) {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
+ .unwrap();
+ net.add_edge(n1, n2);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-3.0, 2.0, 1.0],
+ [1.5, -2.0, 0.5],
+ [0.4, 0.6, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.5],
+ [3.0, -4.0, 1.0],
+ [0.9, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 4.0],
+ [1.5, -2.0, 0.5],
+ [3.0, 1.0, -4.0]
+ ],
+ [
+ [-1.0, 0.1, 0.9],
+ [2.0, -2.5, 0.5],
+ [0.9, 0.1, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
+ let p = match pl.fit(&net, &data, 1, None) {
+ params::Params::DiscreteStatesContinousTime(p) => p,
+ };
+ assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
+ assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
+ &arr3(&[
+ [
+ [-1.0, 0.5, 0.5],
+ [3.0, -4.0, 1.0],
+ [0.9, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 4.0],
+ [1.5, -2.0, 0.5],
+ [3.0, 1.0, -4.0]
+ ],
+ [
+ [-1.0, 0.1, 0.9],
+ [2.0, -2.5, 0.5],
+ [0.9, 0.1, -1.0]
+ ],
+ ]),
+ 0.1
+ ));
+}
+
+fn learn_ternary_cim_gen(pl: T) {
+ let mut net = CtbnNetwork::new();
+ generate_nodes(&mut net, 2, 3);
+
+ net.add_edge(0, 1);
+
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ 4.0..6.0,
+ Some(6813071588535822)
+ );
+ cim_generator.generate_parameters(&mut net);
+
+ let p_gen = match net.get_node(1) {
+ DiscreteStatesContinousTime(p_gen) => p_gen,
+ };
+
+ let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
+ let p_tj = match pl.fit(&net, &data, 1, None) {
+ DiscreteStatesContinousTime(p_tj) => p_tj,
+ };
+
+ assert_eq!(
+ p_tj.get_cim().as_ref().unwrap().shape(),
+ p_gen.get_cim().as_ref().unwrap().shape()
+ );
+ assert!(
+ p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
+ &p_gen.get_cim().as_ref().unwrap(),
+ 0.1
+ )
+ );
+}
+
+#[test]
+fn learn_ternary_cim_MLE() {
+ let mle = MLE {};
+ learn_ternary_cim(mle);
+}
+
+#[test]
+fn learn_ternary_cim_MLE_gen() {
+ let mle = MLE {};
+ learn_ternary_cim_gen(mle);
+}
+
+#[test]
+fn learn_ternary_cim_BA() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_ternary_cim(ba);
+}
+
+#[test]
+fn learn_ternary_cim_BA_gen() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_ternary_cim_gen(ba);
+}
+
+fn learn_ternary_cim_no_parents(pl: T) {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
+ .unwrap();
+ net.add_edge(n1, n2);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-3.0, 2.0, 1.0],
+ [1.5, -2.0, 0.5],
+ [0.4, 0.6, -1.0]
+ ]
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.5],
+ [3.0, -4.0, 1.0],
+ [0.9, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 4.0],
+ [1.5, -2.0, 0.5],
+ [3.0, 1.0, -4.0]
+ ],
+ [
+ [-1.0, 0.1, 0.9],
+ [2.0, -2.5, 0.5],
+ [0.9, 0.1, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
+ let p = match pl.fit(&net, &data, 0, None) {
+ params::Params::DiscreteStatesContinousTime(p) => p,
+ };
+ assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
+ assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
+ &arr3(&[
+ [
+ [-3.0, 2.0, 1.0],
+ [1.5, -2.0, 0.5],
+ [0.4, 0.6, -1.0]
+ ],
+ ]),
+ 0.1
+ ));
+}
+
+fn learn_ternary_cim_no_parents_gen(pl: T) {
+ let mut net = CtbnNetwork::new();
+ generate_nodes(&mut net, 2, 3);
+
+ net.add_edge(0, 1);
+
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ 1.0..6.0,
+ Some(6813071588535822)
+ );
+ cim_generator.generate_parameters(&mut net);
+
+ let p_gen = match net.get_node(0) {
+ DiscreteStatesContinousTime(p_gen) => p_gen,
+ };
+
+ let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
+ let p_tj = match pl.fit(&net, &data, 0, None) {
+ DiscreteStatesContinousTime(p_tj) => p_tj,
+ };
+
+ assert_eq!(
+ p_tj.get_cim().as_ref().unwrap().shape(),
+ p_gen.get_cim().as_ref().unwrap().shape()
+ );
+ assert!(
+ p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
+ &p_gen.get_cim().as_ref().unwrap(),
+ 0.1
+ )
+ );
+}
+
+#[test]
+fn learn_ternary_cim_no_parents_MLE() {
+ let mle = MLE {};
+ learn_ternary_cim_no_parents(mle);
+}
+
+#[test]
+fn learn_ternary_cim_no_parents_MLE_gen() {
+ let mle = MLE {};
+ learn_ternary_cim_no_parents_gen(mle);
+}
+
+#[test]
+fn learn_ternary_cim_no_parents_BA() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_ternary_cim_no_parents(ba);
+}
+
+#[test]
+fn learn_ternary_cim_no_parents_BA_gen() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_ternary_cim_no_parents_gen(ba);
+}
+
+fn learn_mixed_discrete_cim(pl: T) {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
+ .unwrap();
+
+ let n3 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n3"), 4))
+ .unwrap();
+ net.add_edge(n1, n2);
+ net.add_edge(n1, n3);
+ net.add_edge(n2, n3);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-3.0, 2.0, 1.0],
+ [1.5, -2.0, 0.5],
+ [0.4, 0.6, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.5],
+ [3.0, -4.0, 1.0],
+ [0.9, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 4.0],
+ [1.5, -2.0, 0.5],
+ [3.0, 1.0, -4.0]
+ ],
+ [
+ [-1.0, 0.1, 0.9],
+ [2.0, -2.5, 0.5],
+ [0.9, 0.1, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n3) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.3, 0.2],
+ [0.5, -4.0, 2.5, 1.0],
+ [2.5, 0.5, -4.0, 1.0],
+ [0.7, 0.2, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 3.0, 1.0],
+ [1.5, -3.0, 0.5, 1.0],
+ [2.0, 1.3, -5.0, 1.7],
+ [2.5, 0.5, 1.0, -4.0]
+ ],
+ [
+ [-1.3, 0.3, 0.1, 0.9],
+ [1.4, -4.0, 0.5, 2.1],
+ [1.0, 1.5, -3.0, 0.5],
+ [0.4, 0.3, 0.1, -0.8]
+ ],
+ [
+ [-2.0, 1.0, 0.7, 0.3],
+ [1.3, -5.9, 2.7, 1.9],
+ [2.0, 1.5, -4.0, 0.5],
+ [0.2, 0.7, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 1.0, 2.0, 3.0],
+ [0.5, -3.0, 1.0, 1.5],
+ [1.4, 2.1, -4.3, 0.8],
+ [0.5, 1.0, 2.5, -4.0]
+ ],
+ [
+ [-1.3, 0.9, 0.3, 0.1],
+ [0.1, -1.3, 0.2, 1.0],
+ [0.5, 1.0, -3.0, 1.5],
+ [0.1, 0.4, 0.3, -0.8]
+ ],
+ [
+ [-2.0, 1.0, 0.6, 0.4],
+ [2.6, -7.1, 1.4, 3.1],
+ [5.0, 1.0, -8.0, 2.0],
+ [1.4, 0.4, 0.2, -2.0]
+ ],
+ [
+ [-3.0, 1.0, 1.5, 0.5],
+ [3.0, -6.0, 1.0, 2.0],
+ [0.3, 0.5, -1.9, 1.1],
+ [5.0, 1.0, 2.0, -8.0]
+ ],
+ [
+ [-2.6, 0.6, 0.2, 1.8],
+ [2.0, -6.0, 3.0, 1.0],
+ [0.1, 0.5, -1.3, 0.7],
+ [0.8, 0.6, 0.2, -1.6]
+ ],
+ ]))
+ );
+ }
+ }
+
+ let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
+ let p = match pl.fit(&net, &data, 2, None) {
+ params::Params::DiscreteStatesContinousTime(p) => p,
+ };
+ assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]);
+ assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
+ &arr3(&[
+ [
+ [-1.0, 0.5, 0.3, 0.2],
+ [0.5, -4.0, 2.5, 1.0],
+ [2.5, 0.5, -4.0, 1.0],
+ [0.7, 0.2, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 3.0, 1.0],
+ [1.5, -3.0, 0.5, 1.0],
+ [2.0, 1.3, -5.0, 1.7],
+ [2.5, 0.5, 1.0, -4.0]
+ ],
+ [
+ [-1.3, 0.3, 0.1, 0.9],
+ [1.4, -4.0, 0.5, 2.1],
+ [1.0, 1.5, -3.0, 0.5],
+ [0.4, 0.3, 0.1, -0.8]
+ ],
+ [
+ [-2.0, 1.0, 0.7, 0.3],
+ [1.3, -5.9, 2.7, 1.9],
+ [2.0, 1.5, -4.0, 0.5],
+ [0.2, 0.7, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 1.0, 2.0, 3.0],
+ [0.5, -3.0, 1.0, 1.5],
+ [1.4, 2.1, -4.3, 0.8],
+ [0.5, 1.0, 2.5, -4.0]
+ ],
+ [
+ [-1.3, 0.9, 0.3, 0.1],
+ [0.1, -1.3, 0.2, 1.0],
+ [0.5, 1.0, -3.0, 1.5],
+ [0.1, 0.4, 0.3, -0.8]
+ ],
+ [
+ [-2.0, 1.0, 0.6, 0.4],
+ [2.6, -7.1, 1.4, 3.1],
+ [5.0, 1.0, -8.0, 2.0],
+ [1.4, 0.4, 0.2, -2.0]
+ ],
+ [
+ [-3.0, 1.0, 1.5, 0.5],
+ [3.0, -6.0, 1.0, 2.0],
+ [0.3, 0.5, -1.9, 1.1],
+ [5.0, 1.0, 2.0, -8.0]
+ ],
+ [
+ [-2.6, 0.6, 0.2, 1.8],
+ [2.0, -6.0, 3.0, 1.0],
+ [0.1, 0.5, -1.3, 0.7],
+ [0.8, 0.6, 0.2, -1.6]
+ ],
+ ]),
+ 0.2
+ ));
+}
+
+fn learn_mixed_discrete_cim_gen(pl: T) {
+ let mut net = CtbnNetwork::new();
+ generate_nodes(&mut net, 2, 3);
+ net.add_node(
+ generate_discrete_time_continous_node(
+ String::from("3"),
+ 4
+ )
+ ).unwrap();
+ net.add_edge(0, 1);
+ net.add_edge(0, 2);
+ net.add_edge(1, 2);
+
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ 1.0..8.0,
+ Some(6813071588535822)
+ );
+ cim_generator.generate_parameters(&mut net);
+
+ let p_gen = match net.get_node(2) {
+ DiscreteStatesContinousTime(p_gen) => p_gen,
+ };
+
+ let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
+ let p_tj = match pl.fit(&net, &data, 2, None) {
+ DiscreteStatesContinousTime(p_tj) => p_tj,
+ };
+
+ assert_eq!(
+ p_tj.get_cim().as_ref().unwrap().shape(),
+ p_gen.get_cim().as_ref().unwrap().shape()
+ );
+ assert!(
+ p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
+ &p_gen.get_cim().as_ref().unwrap(),
+ 0.2
+ )
+ );
+}
+
+#[test]
+fn learn_mixed_discrete_cim_MLE() {
+ let mle = MLE {};
+ learn_mixed_discrete_cim(mle);
+}
+
+#[test]
+fn learn_mixed_discrete_cim_MLE_gen() {
+ let mle = MLE {};
+ learn_mixed_discrete_cim_gen(mle);
+}
+
+#[test]
+fn learn_mixed_discrete_cim_BA() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_mixed_discrete_cim(ba);
+}
+
+#[test]
+fn learn_mixed_discrete_cim_BA_gen() {
+ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
+ learn_mixed_discrete_cim_gen(ba);
+}
diff --git a/reCTBN/tests/params.rs b/reCTBN/tests/params.rs
new file mode 100644
index 0000000..7f16f12
--- /dev/null
+++ b/reCTBN/tests/params.rs
@@ -0,0 +1,148 @@
+use ndarray::prelude::*;
+use rand_chacha::rand_core::SeedableRng;
+use rand_chacha::ChaCha8Rng;
+use reCTBN::params::{ParamsTrait, *};
+
+mod utils;
+
+#[macro_use]
+extern crate approx;
+
+fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
+ #![allow(unused_must_use)]
+ let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3);
+
+ let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
+
+ params.set_cim(cim);
+ params
+}
+
+#[test]
+fn test_get_label() {
+ let param = create_ternary_discrete_time_continous_param();
+ assert_eq!(&String::from("A"), param.get_label())
+}
+
+#[test]
+fn test_uniform_generation() {
+ #![allow(irrefutable_let_patterns)]
+ let param = create_ternary_discrete_time_continous_param();
+ let mut states = Array1::::zeros(10000);
+
+ let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
+
+ states.mapv_inplace(|_| {
+ if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) {
+ val
+ } else {
+ panic!()
+ }
+ });
+ let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
+
+ assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01);
+}
+
+#[test]
+fn test_random_generation_state() {
+ #![allow(irrefutable_let_patterns)]
+ let param = create_ternary_discrete_time_continous_param();
+ let mut states = Array1::::zeros(10000);
+
+ let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
+
+ states.mapv_inplace(|_| {
+ if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() {
+ val
+ } else {
+ panic!()
+ }
+ });
+ let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
+ let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
+
+ assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01);
+ assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01);
+}
+
+#[test]
+fn test_random_generation_residence_time() {
+ let param = create_ternary_discrete_time_continous_param();
+ let mut states = Array1::::zeros(10000);
+
+ let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
+
+ states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap());
+
+ assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
+}
+
+#[test]
+fn test_validate_params_valid_cim() {
+ let param = create_ternary_discrete_time_continous_param();
+
+ assert_eq!(Ok(()), param.validate_params());
+}
+
+#[test]
+fn test_validate_params_valid_cim_with_huge_values() {
+ let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
+ let cim = array![[
+ [-2e10, 1e10, 1e10],
+ [1.5e10, -3e10, 1.5e10],
+ [1e10, 1e10, -2e10]
+ ]];
+ let result = param.set_cim(cim);
+ assert_eq!(Ok(()), result);
+}
+
+#[test]
+fn test_validate_params_cim_not_initialized() {
+ let param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
+ assert_eq!(
+ Err(ParamsError::ParametersNotInitialized(String::from(
+ "CIM not initialized",
+ ))),
+ param.validate_params()
+ );
+}
+
+#[test]
+fn test_validate_params_wrong_shape() {
+ let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4);
+ let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
+ let result = param.set_cim(cim);
+ assert_eq!(
+ Err(ParamsError::InvalidCIM(String::from(
+ "Incompatible shape [1, 3, 3] with domain 4"
+ ))),
+ result
+ );
+}
+
+#[test]
+fn test_validate_params_positive_diag() {
+ let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
+ let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
+ let result = param.set_cim(cim);
+ assert_eq!(
+ Err(ParamsError::InvalidCIM(String::from(
+ "The diagonal of each cim must be non-positive",
+ ))),
+ result
+ );
+}
+
+#[test]
+fn test_validate_params_row_not_sum_to_zero() {
+ let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
+ let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]];
+ let result = param.set_cim(cim);
+ assert_eq!(
+ Err(ParamsError::InvalidCIM(String::from(
+ "The sum of each row must be 0"
+ ))),
+ result
+ );
+}
diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs
new file mode 100644
index 0000000..355341c
--- /dev/null
+++ b/reCTBN/tests/reward_evaluation.rs
@@ -0,0 +1,122 @@
+mod utils;
+
+use approx::assert_abs_diff_eq;
+use ndarray::*;
+use reCTBN::{
+ params,
+ process::{ctbn::*, NetworkProcess, NetworkProcessState},
+ reward::{reward_evaluation::*, reward_function::*, *},
+};
+use utils::generate_discrete_time_continous_node;
+
+#[test]
+fn simple_factored_reward_function_binary_node_mc() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+
+ let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
+ rf.get_transition_reward_mut(n1)
+ .assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]]));
+ rf.get_instantaneous_reward_mut(n1)
+ .assign(&arr1(&[3.0, 3.0]));
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap();
+ }
+ }
+
+ net.initialize_adj_matrix();
+
+ let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
+ let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
+
+ let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
+ assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
+ assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
+
+ let rst = mc.evaluate_state_space(&net, &rf);
+ assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2);
+ assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2);
+
+
+ let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215));
+ assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
+ assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
+
+
+}
+
+#[test]
+fn simple_factored_reward_function_chain_mc() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+
+ let n3 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
+ .unwrap();
+
+ net.add_edge(n1, n2);
+ net.add_edge(n2, n3);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap();
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ param
+ .set_cim(arr3(&[
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-5.0, 5.0], [0.01, -0.01]],
+ ]))
+ .unwrap();
+ }
+ }
+
+
+ match &mut net.get_node_mut(n3) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ param
+ .set_cim(arr3(&[
+ [[-0.01, 0.01], [5.0, -5.0]],
+ [[-5.0, 5.0], [0.01, -0.01]],
+ ]))
+ .unwrap();
+ }
+ }
+
+
+ let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
+ rf.get_transition_reward_mut(n1)
+ .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
+
+ rf.get_transition_reward_mut(n2)
+ .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
+
+ rf.get_transition_reward_mut(n3)
+ .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
+
+ let s000: NetworkProcessState = vec![
+ params::StateType::Discrete(1),
+ params::StateType::Discrete(0),
+ params::StateType::Discrete(0),
+ ];
+
+ let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
+ assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1);
+
+ let rst = mc.evaluate_state_space(&net, &rf);
+ assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1);
+
+}
diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs
new file mode 100644
index 0000000..853efc9
--- /dev/null
+++ b/reCTBN/tests/reward_function.rs
@@ -0,0 +1,117 @@
+mod utils;
+
+use ndarray::*;
+use utils::generate_discrete_time_continous_node;
+use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params};
+
+
+#[test]
+fn simple_factored_reward_function_binary_node() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+
+ let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
+ rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
+ rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
+
+ let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
+ let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
+ assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
+ assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
+
+
+ assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
+ assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
+
+ assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
+ assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
+}
+
+
+#[test]
+fn simple_factored_reward_function_ternary_node() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+
+ let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
+ rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
+ rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
+
+ let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
+ let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
+ let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
+
+
+ assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
+ assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
+
+
+ assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
+ assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
+
+
+ assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
+ assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
+}
+
+#[test]
+fn factored_reward_function_two_nodes() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
+ .unwrap();
+ net.add_edge(n1, n2);
+
+
+ let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
+ rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
+ rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
+
+
+ rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
+ rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
+ let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
+ let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
+ let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
+
+
+ let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
+ let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
+ let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
+
+ assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
+ assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
+ assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
+
+
+ assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
+ assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
+ assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
+
+
+ assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
+ assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
+ assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
+
+
+ assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
+ assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
+ assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
+
+
+ assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
+ assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
+ assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
+
+
+ assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
+ assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
+ assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
+}
diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs
new file mode 100644
index 0000000..3d7e230
--- /dev/null
+++ b/reCTBN/tests/structure_learning.rs
@@ -0,0 +1,692 @@
+#![allow(non_snake_case)]
+
+mod utils;
+use std::collections::BTreeSet;
+
+use ndarray::{arr1, arr2, arr3};
+use reCTBN::process::ctbn::*;
+use reCTBN::process::NetworkProcess;
+use reCTBN::parameter_learning::BayesianApproach;
+use reCTBN::params;
+use reCTBN::structure_learning::hypothesis_test::*;
+use reCTBN::structure_learning::constraint_based_algorithm::*;
+use reCTBN::structure_learning::score_based_algorithm::*;
+use reCTBN::structure_learning::score_function::*;
+use reCTBN::structure_learning::StructureLearningAlgorithm;
+use reCTBN::tools::*;
+use utils::*;
+
+#[macro_use]
+extern crate approx;
+
+#[test]
+fn simple_score_test() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+
+ let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
+
+ let dataset = Dataset::new(vec![trj]);
+
+ let ll = LogLikelihood::new(1, 1.0);
+
+ assert_abs_diff_eq!(
+ 0.04257,
+ ll.call(&net, n1, &BTreeSet::new(), &dataset),
+ epsilon = 1e-3
+ );
+}
+
+#[test]
+fn simple_bic() {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
+ .unwrap();
+
+ let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
+
+ let dataset = Dataset::new(vec![trj]);
+ let bic = BIC::new(1, 1.0);
+
+ assert_abs_diff_eq!(
+ -0.65058,
+ bic.call(&net, n1, &BTreeSet::new(), &dataset),
+ epsilon = 1e-3
+ );
+}
+
+fn check_compatibility_between_dataset_and_network(sl: T) {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
+ .unwrap();
+ net.add_edge(n1, n2);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-3.0, 2.0, 1.0],
+ [1.5, -2.0, 0.5],
+ [0.4, 0.6, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.5],
+ [3.0, -4.0, 1.0],
+ [0.9, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 4.0],
+ [1.5, -2.0, 0.5],
+ [3.0, 1.0, -4.0]
+ ],
+ [
+ [-1.0, 0.1, 0.9],
+ [2.0, -2.5, 0.5],
+ [0.9, 0.1, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
+
+ let mut net = CtbnNetwork::new();
+ let _n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let _net = sl.fit_transform(net, &data);
+}
+
+fn generate_nodes(
+ net: &mut CtbnNetwork,
+ nodes_cardinality: usize,
+ nodes_domain_cardinality: usize
+) {
+ for node_label in 0..nodes_cardinality {
+ net.add_node(
+ generate_discrete_time_continous_node(
+ node_label.to_string(),
+ nodes_domain_cardinality,
+ )
+ ).unwrap();
+ }
+}
+
+fn check_compatibility_between_dataset_and_network_gen(sl: T) {
+ let mut net = CtbnNetwork::new();
+ generate_nodes(&mut net, 2, 3);
+ net.add_node(
+ generate_discrete_time_continous_node(
+ String::from("3"),
+ 4
+ )
+ ).unwrap();
+
+ net.add_edge(0, 1);
+
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ 0.0..7.0,
+ Some(6813071588535822)
+ );
+ cim_generator.generate_parameters(&mut net);
+
+ let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
+
+ let mut net = CtbnNetwork::new();
+ let _n1 = net
+ .add_node(
+ generate_discrete_time_continous_node(String::from("0"),
+ 3)
+ ).unwrap();
+ let _net = sl.fit_transform(net, &data);
+}
+
+#[test]
+#[should_panic]
+pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, None);
+ check_compatibility_between_dataset_and_network(hl);
+}
+
+#[test]
+#[should_panic]
+pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, None);
+ check_compatibility_between_dataset_and_network_gen(hl);
+}
+
+fn learn_ternary_net_2_nodes(sl: T) {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
+ .unwrap();
+ net.add_edge(n1, n2);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-3.0, 2.0, 1.0],
+ [1.5, -2.0, 0.5],
+ [0.4, 0.6, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.5],
+ [3.0, -4.0, 1.0],
+ [0.9, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 4.0],
+ [1.5, -2.0, 0.5],
+ [3.0, 1.0, -4.0]
+ ],
+ [
+ [-1.0, 0.1, 0.9],
+ [2.0, -2.5, 0.5],
+ [0.9, 0.1, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
+
+ let net = sl.fit_transform(net, &data);
+ assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
+ assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
+}
+
+fn learn_ternary_net_2_nodes_gen(sl: T) {
+ let mut net = CtbnNetwork::new();
+ generate_nodes(&mut net, 2, 3);
+
+ net.add_edge(0, 1);
+
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ 0.0..7.0,
+ Some(6813071588535822)
+ );
+ cim_generator.generate_parameters(&mut net);
+
+ let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
+
+ let net = sl.fit_transform(net, &data);
+ assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
+ assert_eq!(BTreeSet::new(), net.get_parent_set(0));
+}
+
+#[test]
+pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, None);
+ learn_ternary_net_2_nodes(hl);
+}
+
+#[test]
+pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, None);
+ learn_ternary_net_2_nodes_gen(hl);
+}
+
+#[test]
+pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
+ let bic = BIC::new(1, 1.0);
+ let hl = HillClimbing::new(bic, None);
+ learn_ternary_net_2_nodes(hl);
+}
+
+#[test]
+pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() {
+ let bic = BIC::new(1, 1.0);
+ let hl = HillClimbing::new(bic, None);
+ learn_ternary_net_2_nodes_gen(hl);
+}
+
+fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
+ .unwrap();
+ let n2 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
+ .unwrap();
+
+ let n3 = net
+ .add_node(generate_discrete_time_continous_node(String::from("n3"), 4))
+ .unwrap();
+ net.add_edge(n1, n2);
+ net.add_edge(n1, n3);
+ net.add_edge(n2, n3);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-3.0, 2.0, 1.0],
+ [1.5, -2.0, 0.5],
+ [0.4, 0.6, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.5],
+ [3.0, -4.0, 1.0],
+ [0.9, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 4.0],
+ [1.5, -2.0, 0.5],
+ [3.0, 1.0, -4.0]
+ ],
+ [
+ [-1.0, 0.1, 0.9],
+ [2.0, -2.5, 0.5],
+ [0.9, 0.1, -1.0]
+ ],
+ ]))
+ );
+ }
+ }
+
+ match &mut net.get_node_mut(n3) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ assert_eq!(
+ Ok(()),
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 0.5, 0.3, 0.2],
+ [0.5, -4.0, 2.5, 1.0],
+ [2.5, 0.5, -4.0, 1.0],
+ [0.7, 0.2, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 2.0, 3.0, 1.0],
+ [1.5, -3.0, 0.5, 1.0],
+ [2.0, 1.3, -5.0, 1.7],
+ [2.5, 0.5, 1.0, -4.0]
+ ],
+ [
+ [-1.3, 0.3, 0.1, 0.9],
+ [1.4, -4.0, 0.5, 2.1],
+ [1.0, 1.5, -3.0, 0.5],
+ [0.4, 0.3, 0.1, -0.8]
+ ],
+ [
+ [-2.0, 1.0, 0.7, 0.3],
+ [1.3, -5.9, 2.7, 1.9],
+ [2.0, 1.5, -4.0, 0.5],
+ [0.2, 0.7, 0.1, -1.0]
+ ],
+ [
+ [-6.0, 1.0, 2.0, 3.0],
+ [0.5, -3.0, 1.0, 1.5],
+ [1.4, 2.1, -4.3, 0.8],
+ [0.5, 1.0, 2.5, -4.0]
+ ],
+ [
+ [-1.3, 0.9, 0.3, 0.1],
+ [0.1, -1.3, 0.2, 1.0],
+ [0.5, 1.0, -3.0, 1.5],
+ [0.1, 0.4, 0.3, -0.8]
+ ],
+ [
+ [-2.0, 1.0, 0.6, 0.4],
+ [2.6, -7.1, 1.4, 3.1],
+ [5.0, 1.0, -8.0, 2.0],
+ [1.4, 0.4, 0.2, -2.0]
+ ],
+ [
+ [-3.0, 1.0, 1.5, 0.5],
+ [3.0, -6.0, 1.0, 2.0],
+ [0.3, 0.5, -1.9, 1.1],
+ [5.0, 1.0, 2.0, -8.0]
+ ],
+ [
+ [-2.6, 0.6, 0.2, 1.8],
+ [2.0, -6.0, 3.0, 1.0],
+ [0.1, 0.5, -1.3, 0.7],
+ [0.8, 0.6, 0.2, -1.6]
+ ],
+ ]))
+ );
+ }
+ }
+
+ let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
+ return (net, data);
+}
+
+fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) {
+ let mut net = CtbnNetwork::new();
+ generate_nodes(&mut net, 2, 3);
+ net.add_node(
+ generate_discrete_time_continous_node(
+ String::from("3"),
+ 4
+ )
+ ).unwrap();
+
+ net.add_edge(0, 1);
+ net.add_edge(0, 2);
+ net.add_edge(1, 2);
+
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ 0.0..7.0,
+ Some(6813071588535822)
+ );
+ cim_generator.generate_parameters(&mut net);
+
+ let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
+ return (net, data);
+}
+
+fn learn_mixed_discrete_net_3_nodes(sl: T) {
+ let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
+ let net = sl.fit_transform(net, &data);
+ assert_eq!(BTreeSet::new(), net.get_parent_set(0));
+ assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
+ assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
+}
+
+fn learn_mixed_discrete_net_3_nodes_gen(sl: T) {
+ let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
+ let net = sl.fit_transform(net, &data);
+ assert_eq!(BTreeSet::new(), net.get_parent_set(0));
+ assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
+ assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, None);
+ learn_mixed_discrete_net_3_nodes(hl);
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, None);
+ learn_mixed_discrete_net_3_nodes_gen(hl);
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
+ let bic = BIC::new(1, 1.0);
+ let hl = HillClimbing::new(bic, None);
+ learn_mixed_discrete_net_3_nodes(hl);
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() {
+ let bic = BIC::new(1, 1.0);
+ let hl = HillClimbing::new(bic, None);
+ learn_mixed_discrete_net_3_nodes_gen(hl);
+}
+
+fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) {
+ let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
+ let net = sl.fit_transform(net, &data);
+ assert_eq!(BTreeSet::new(), net.get_parent_set(0));
+ assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
+ assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
+}
+
+fn learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(sl: T) {
+ let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
+ let net = sl.fit_transform(net, &data);
+ assert_eq!(BTreeSet::new(), net.get_parent_set(0));
+ assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
+ assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, Some(1));
+ learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() {
+ let ll = LogLikelihood::new(1, 1.0);
+ let hl = HillClimbing::new(ll, Some(1));
+ learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() {
+ let bic = BIC::new(1, 1.0);
+ let hl = HillClimbing::new(bic, Some(1));
+ learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
+}
+
+#[test]
+pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() {
+ let bic = BIC::new(1, 1.0);
+ let hl = HillClimbing::new(bic, Some(1));
+ learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
+}
+
+#[test]
+pub fn chi_square_compare_matrices() {
+ let i: usize = 1;
+ let M1 = arr3(&[
+ [
+ [ 0, 2, 3],
+ [ 4, 0, 6],
+ [ 7, 8, 0]
+ ],
+ [
+ [0, 12, 90],
+ [ 3, 0, 40],
+ [ 6, 40, 0]
+ ],
+ [
+ [ 0, 2, 3],
+ [ 4, 0, 6],
+ [ 44, 66, 0]
+ ],
+ ]);
+ let j: usize = 0;
+ let M2 = arr3(&[
+ [
+ [ 0, 200, 300],
+ [ 400, 0, 600],
+ [ 700, 800, 0]
+ ],
+ ]);
+ let chi_sq = ChiSquare::new(1e-4);
+ assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
+}
+
+#[test]
+pub fn chi_square_compare_matrices_2() {
+ let i: usize = 1;
+ let M1 = arr3(&[
+ [
+ [ 0, 2, 3],
+ [ 4, 0, 6],
+ [ 7, 8, 0]
+ ],
+ [
+ [0, 20, 30],
+ [ 40, 0, 60],
+ [ 70, 80, 0]
+ ],
+ [
+ [ 0, 2, 3],
+ [ 4, 0, 6],
+ [ 44, 66, 0]
+ ],
+ ]);
+ let j: usize = 0;
+ let M2 = arr3(&[
+ [[ 0, 200, 300],
+ [ 400, 0, 600],
+ [ 700, 800, 0]]
+ ]);
+ let chi_sq = ChiSquare::new(1e-4);
+ assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
+}
+
+#[test]
+pub fn chi_square_compare_matrices_3() {
+ let i: usize = 1;
+ let M1 = arr3(&[
+ [
+ [ 0, 2, 3],
+ [ 4, 0, 6],
+ [ 7, 8, 0]
+ ],
+ [
+ [0, 21, 31],
+ [ 41, 0, 59],
+ [ 71, 79, 0]
+ ],
+ [
+ [ 0, 2, 3],
+ [ 4, 0, 6],
+ [ 44, 66, 0]
+ ],
+ ]);
+ let j: usize = 0;
+ let M2 = arr3(&[
+ [
+ [ 0, 200, 300],
+ [ 400, 0, 600],
+ [ 700, 800, 0]
+ ],
+ ]);
+ let chi_sq = ChiSquare::new(1e-4);
+ assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
+}
+
+
+#[test]
+pub fn chi_square_call() {
+
+ let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
+ let N3: usize = 2;
+ let N2: usize = 1;
+ let N1: usize = 0;
+ let mut separation_set = BTreeSet::new();
+ let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
+ let mut cache = Cache::new(¶meter_learning);
+ let chi_sq = ChiSquare::new(1e-4);
+
+ assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache));
+ let mut cache = Cache::new(¶meter_learning);
+ assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache));
+ assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache));
+ separation_set.insert(N1);
+ let mut cache = Cache::new(¶meter_learning);
+ assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache));
+}
+
+#[test]
+pub fn f_call() {
+
+ let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
+ let N3: usize = 2;
+ let N2: usize = 1;
+ let N1: usize = 0;
+ let mut separation_set = BTreeSet::new();
+ let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
+ let mut cache = Cache::new(¶meter_learning);
+ let f = F::new(1e-6);
+
+
+ assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache));
+ let mut cache = Cache::new(¶meter_learning);
+ assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache));
+ assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache));
+ separation_set.insert(N1);
+ let mut cache = Cache::new(¶meter_learning);
+ assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache));
+}
+
+#[test]
+pub fn learn_ternary_net_2_nodes_ctpc() {
+ let f = F::new(1e-6);
+ let chi_sq = ChiSquare::new(1e-4);
+ let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
+ let ctpc = CTPC::new(parameter_learning, f, chi_sq);
+ learn_ternary_net_2_nodes(ctpc);
+}
+
+#[test]
+pub fn learn_ternary_net_2_nodes_ctpc_gen() {
+ let f = F::new(1e-6);
+ let chi_sq = ChiSquare::new(1e-4);
+ let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
+ let ctpc = CTPC::new(parameter_learning, f, chi_sq);
+ learn_ternary_net_2_nodes_gen(ctpc);
+}
+
+#[test]
+fn learn_mixed_discrete_net_3_nodes_ctpc() {
+ let f = F::new(1e-6);
+ let chi_sq = ChiSquare::new(1e-4);
+ let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
+ let ctpc = CTPC::new(parameter_learning, f, chi_sq);
+ learn_mixed_discrete_net_3_nodes(ctpc);
+}
+
+#[test]
+fn learn_mixed_discrete_net_3_nodes_ctpc_gen() {
+ let f = F::new(1e-6);
+ let chi_sq = ChiSquare::new(1e-4);
+ let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
+ let ctpc = CTPC::new(parameter_learning, f, chi_sq);
+ learn_mixed_discrete_net_3_nodes_gen(ctpc);
+}
diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs
new file mode 100644
index 0000000..59d8f27
--- /dev/null
+++ b/reCTBN/tests/tools.rs
@@ -0,0 +1,251 @@
+use std::ops::Range;
+
+use ndarray::{arr1, arr2, arr3};
+use reCTBN::params::ParamsTrait;
+use reCTBN::process::ctbn::*;
+use reCTBN::process::ctmp::*;
+use reCTBN::process::NetworkProcess;
+use reCTBN::params;
+use reCTBN::tools::*;
+
+use utils::*;
+
+#[macro_use]
+extern crate approx;
+
+mod utils;
+
+#[test]
+fn run_sampling() {
+ #![allow(unused_must_use)]
+ let mut net = CtbnNetwork::new();
+ let n1 = net
+ .add_node(utils::generate_discrete_time_continous_node(
+ String::from("n1"),
+ 2,
+ ))
+ .unwrap();
+ let n2 = net
+ .add_node(utils::generate_discrete_time_continous_node(
+ String::from("n2"),
+ 2,
+ ))
+ .unwrap();
+ net.add_edge(n1, n2);
+
+ match &mut net.get_node_mut(n1) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ param.set_cim(arr3(&[
+ [
+ [-3.0, 3.0],
+ [2.0, -2.0]
+ ],
+ ]));
+ }
+ }
+
+ match &mut net.get_node_mut(n2) {
+ params::Params::DiscreteStatesContinousTime(param) => {
+ param.set_cim(arr3(&[
+ [
+ [-1.0, 1.0],
+ [4.0, -4.0]
+ ],
+ [
+ [-6.0, 6.0],
+ [2.0, -2.0]
+ ],
+ ]));
+ }
+ }
+
+ let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259));
+
+ assert_eq!(4, data.get_trajectories().len());
+ assert_relative_eq!(
+ 1.0,
+ data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1]
+ );
+}
+
+#[test]
+#[should_panic]
+fn trajectory_wrong_shape() {
+ let time = arr1(&[0.0, 0.2]);
+ let events = arr2(&[[0, 3]]);
+ Trajectory::new(time, events);
+}
+
+#[test]
+#[should_panic]
+fn dataset_wrong_shape() {
+ let time = arr1(&[0.0, 0.2]);
+ let events = arr2(&[[0, 3], [1, 2]]);
+ let t1 = Trajectory::new(time, events);
+
+ let time = arr1(&[0.0, 0.2]);
+ let events = arr2(&[[0, 3, 3], [1, 2, 3]]);
+ let t2 = Trajectory::new(time, events);
+ Dataset::new(vec![t1, t2]);
+}
+
+#[test]
+#[should_panic]
+fn uniform_graph_generator_wrong_density_1() {
+ let density = 2.1;
+ let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
+ density,
+ None
+ );
+}
+
+#[test]
+#[should_panic]
+fn uniform_graph_generator_wrong_density_2() {
+ let density = -0.5;
+ let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
+ density,
+ None
+ );
+}
+
+#[test]
+fn uniform_graph_generator_right_densities() {
+ for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
+ let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
+ density,
+ None
+ );
+ }
+}
+
+#[test]
+fn uniform_graph_generator_generate_graph_ctbn() {
+ let mut net = CtbnNetwork::new();
+ let nodes_cardinality = 0..=100;
+ let nodes_domain_cardinality = 2;
+ for node_label in nodes_cardinality {
+ net.add_node(
+ utils::generate_discrete_time_continous_node(
+ node_label.to_string(),
+ nodes_domain_cardinality,
+ )
+ ).unwrap();
+ }
+ let density = 1.0/3.0;
+ let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
+ density,
+ Some(7641630759785120)
+ );
+ structure_generator.generate_graph(&mut net);
+ let mut edges = 0;
+ for node in net.get_node_indices(){
+ edges += net.get_children_set(node).len()
+ }
+ let nodes = net.get_node_indices().len() as f64;
+ let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
+ let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance
+ // As the way `generate_graph()` is implemented we can only reasonably
+ // expect the number of edges to be somewhere around the expected value.
+ assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
+}
+
+#[test]
+#[should_panic]
+fn uniform_graph_generator_generate_graph_ctmp() {
+ let mut net = CtmpProcess::new();
+ let node_label = String::from("0");
+ let node_domain_cardinality = 4;
+ net.add_node(
+ generate_discrete_time_continous_node(
+ node_label,
+ node_domain_cardinality
+ )
+ ).unwrap();
+ let density = 1.0/3.0;
+ let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
+ density,
+ Some(7641630759785120)
+ );
+ structure_generator.generate_graph(&mut net);
+}
+
+#[test]
+#[should_panic]
+fn uniform_parameters_generator_wrong_density_1() {
+ let interval: Range = -2.0..-5.0;
+ let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ interval,
+ None
+ );
+}
+
+#[test]
+#[should_panic]
+fn uniform_parameters_generator_wrong_density_2() {
+ let interval: Range = -1.0..0.0;
+ let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ interval,
+ None
+ );
+}
+
+#[test]
+fn uniform_parameters_generator_right_densities_ctbn() {
+ let mut net = CtbnNetwork::new();
+ let nodes_cardinality = 0..=3;
+ let nodes_domain_cardinality = 9;
+ for node_label in nodes_cardinality {
+ net.add_node(
+ generate_discrete_time_continous_node(
+ node_label.to_string(),
+ nodes_domain_cardinality,
+ )
+ ).unwrap();
+ }
+ let density = 1.0/3.0;
+ let seed = Some(7641630759785120);
+ let interval = 0.0..7.0;
+ let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
+ density,
+ seed
+ );
+ structure_generator.generate_graph(&mut net);
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ interval,
+ seed
+ );
+ cim_generator.generate_parameters(&mut net);
+ for node in net.get_node_indices() {
+ assert_eq!(
+ Ok(()),
+ net.get_node(node).validate_params()
+ );
+ }
+}
+
+#[test]
+fn uniform_parameters_generator_right_densities_ctmp() {
+ let mut net = CtmpProcess::new();
+ let node_label = String::from("0");
+ let node_domain_cardinality = 4;
+ net.add_node(
+ generate_discrete_time_continous_node(
+ node_label,
+ node_domain_cardinality
+ )
+ ).unwrap();
+ let seed = Some(7641630759785120);
+ let interval = 0.0..7.0;
+ let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
+ interval,
+ seed
+ );
+ cim_generator.generate_parameters(&mut net);
+ for node in net.get_node_indices() {
+ assert_eq!(
+ Ok(()),
+ net.get_node(node).validate_params()
+ );
+ }
+}
diff --git a/reCTBN/tests/utils.rs b/reCTBN/tests/utils.rs
new file mode 100644
index 0000000..ed43215
--- /dev/null
+++ b/reCTBN/tests/utils.rs
@@ -0,0 +1,19 @@
+use std::collections::BTreeSet;
+
+use reCTBN::params;
+
+#[allow(dead_code)]
+pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params {
+ params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(
+ label,
+ cardinality,
+ ))
+}
+
+pub fn generate_discrete_time_continous_params(
+ label: String,
+ cardinality: usize,
+) -> params::DiscreteStatesContinousTimeParams {
+ let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect();
+ params::DiscreteStatesContinousTimeParams::new(label, domain)
+}
diff --git a/rust-toolchain.toml b/rust-toolchain.toml
new file mode 100644
index 0000000..367bc0b
--- /dev/null
+++ b/rust-toolchain.toml
@@ -0,0 +1,7 @@
+# This file defines the Rust toolchain to use when a command is executed.
+# See also https://rust-lang.github.io/rustup/overrides.html
+
+[toolchain]
+channel = "stable"
+components = [ "clippy", "rustfmt" ]
+profile = "minimal"
diff --git a/rustfmt.toml b/rustfmt.toml
new file mode 100644
index 0000000..b6f1257
--- /dev/null
+++ b/rustfmt.toml
@@ -0,0 +1,39 @@
+# This file defines the Rust style for automatic reformatting.
+# See also https://rust-lang.github.io/rustfmt
+
+# NOTE: the unstable options will be uncommented when stabilized.
+
+# Version of the formatting rules to use.
+#version = "One"
+
+# Number of spaces per tab.
+tab_spaces = 4
+
+max_width = 100
+#comment_width = 80
+
+# Prevent carriage returns, admitted only \n.
+newline_style = "Unix"
+
+# The "Default" setting has a heuristic which can split lines too aggresively.
+#use_small_heuristics = "Max"
+
+# How imports should be grouped into `use` statements.
+#imports_granularity = "Module"
+
+# How consecutive imports are grouped together.
+#group_imports = "StdExternalCrate"
+
+# Error if unable to get all lines within max_width, except for comments and
+# string literals.
+#error_on_line_overflow = true
+
+# Error if unable to get comments or string literals within max_width, or they
+# are left with trailing whitespaces.
+#error_on_unformatted = true
+
+# Files to ignore like third party code which is formatted upstream.
+# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors
+ignore = [
+ "tests/"
+]
diff --git a/src/ctbn.rs b/src/ctbn.rs
deleted file mode 100644
index 9cabe20..0000000
--- a/src/ctbn.rs
+++ /dev/null
@@ -1,164 +0,0 @@
-use ndarray::prelude::*;
-use crate::node;
-use crate::params::{StateType, ParamsTrait};
-use crate::network;
-use std::collections::BTreeSet;
-
-
-
-
-///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
-///composed by the following elements:
-///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
-///- **nodes**: a vector containing all the nodes and their parameters.
-///The index of a node inside the vector is also used as index for the adj_matrix.
-///
-///# Examples
-///
-///```
-///
-/// use std::collections::BTreeSet;
-/// use rustyCTBN::network::Network;
-/// use rustyCTBN::node;
-/// use rustyCTBN::params;
-/// use rustyCTBN::ctbn::*;
-///
-/// //Create the domain for a discrete node
-/// let mut domain = BTreeSet::new();
-/// domain.insert(String::from("A"));
-/// domain.insert(String::from("B"));
-///
-/// //Create the parameters for a discrete node using the domain
-/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
-///
-/// //Create the node using the parameters
-/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
-///
-/// let mut domain = BTreeSet::new();
-/// domain.insert(String::from("A"));
-/// domain.insert(String::from("B"));
-/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
-/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
-///
-/// //Initialize a ctbn
-/// let mut net = CtbnNetwork::init();
-///
-/// //Add nodes
-/// let X1 = net.add_node(X1).unwrap();
-/// let X2 = net.add_node(X2).unwrap();
-///
-/// //Add an edge
-/// net.add_edge(X1, X2);
-///
-/// //Get all the children of node X1
-/// let cs = net.get_children_set(X1);
-/// assert_eq!(&X2, cs.iter().next().unwrap());
-/// ```
-pub struct CtbnNetwork {
- adj_matrix: Option>,
- nodes: Vec
-}
-
-
-impl CtbnNetwork {
- pub fn init() -> CtbnNetwork {
- CtbnNetwork {
- adj_matrix: None,
- nodes: Vec::new()
- }
- }
-}
-
-impl network::Network for CtbnNetwork {
- fn initialize_adj_matrix(&mut self) {
- self.adj_matrix = Some(Array2::::zeros((self.nodes.len(), self.nodes.len()).f()));
-
- }
-
- fn add_node(&mut self, mut n: node::Node) -> Result {
- n.params.reset_params();
- self.adj_matrix = Option::None;
- self.nodes.push(n);
- Ok(self.nodes.len() -1)
- }
-
- fn add_edge(&mut self, parent: usize, child: usize) {
- if let None = self.adj_matrix {
- self.initialize_adj_matrix();
- }
-
- if let Some(network) = &mut self.adj_matrix {
- network[[parent, child]] = 1;
- self.nodes[child].params.reset_params();
- }
- }
-
- fn get_node_indices(&self) -> std::ops::Range{
- 0..self.nodes.len()
- }
-
- fn get_number_of_nodes(&self) -> usize {
- self.nodes.len()
- }
-
- fn get_node(&self, node_idx: usize) -> &node::Node{
- &self.nodes[node_idx]
- }
-
-
- fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{
- &mut self.nodes[node_idx]
- }
-
-
- fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{
- self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
- if x.1 > &0 {
- acc.0 += self.nodes[x.0].params.state_to_index(¤t_state[x.0]) * acc.1;
- acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent();
- }
- acc
- }).0
- }
-
-
- fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize {
- parent_set.iter().fold((0, 1), |mut acc, x| {
- acc.0 += self.nodes[*x].params.state_to_index(¤t_state[*x]) * acc.1;
- acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent();
- acc
- }).0
- }
-
- fn get_parent_set(&self, node: usize) -> BTreeSet {
- self.adj_matrix.as_ref()
- .unwrap()
- .column(node)
- .iter()
- .enumerate()
- .filter_map(|(idx, x)| {
- if x > &0 {
- Some(idx)
- } else {
- None
- }
- }).collect()
- }
-
- fn get_children_set(&self, node: usize) -> BTreeSet{
- self.adj_matrix.as_ref()
- .unwrap()
- .row(node)
- .iter()
- .enumerate()
- .filter_map(|(idx, x)| {
- if x > &0 {
- Some(idx)
- } else {
- None
- }
- }).collect()
- }
-
-}
-
diff --git a/src/lib.rs b/src/lib.rs
deleted file mode 100644
index 65e4b11..0000000
--- a/src/lib.rs
+++ /dev/null
@@ -1,11 +0,0 @@
-#[cfg(test)]
-#[macro_use]
-extern crate approx;
-
-pub mod node;
-pub mod params;
-pub mod network;
-pub mod ctbn;
-pub mod tools;
-pub mod parameter_learning;
-
diff --git a/src/network.rs b/src/network.rs
deleted file mode 100644
index 3b6ce06..0000000
--- a/src/network.rs
+++ /dev/null
@@ -1,39 +0,0 @@
-use thiserror::Error;
-use crate::params;
-use crate::node;
-use std::collections::BTreeSet;
-
-/// Error types for trait Network
-#[derive(Error, Debug)]
-pub enum NetworkError {
- #[error("Error during node insertion")]
- NodeInsertionError(String)
-}
-
-
-///Network
-///The Network trait define the required methods for a structure used as pgm (such as ctbn).
-pub trait Network {
- fn initialize_adj_matrix(&mut self);
- fn add_node(&mut self, n: node::Node) -> Result;
- fn add_edge(&mut self, parent: usize, child: usize);
-
- ///Get all the indices of the nodes contained inside the network
- fn get_node_indices(&self) -> std::ops::Range;
- fn get_number_of_nodes(&self) -> usize;
- fn get_node(&self, node_idx: usize) -> &node::Node;
- fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node;
-
- ///Compute the index that must be used to access the parameters of a node given a specific
- ///configuration of the network. Usually, the only values really used in *current_state* are
- ///the ones in the parent set of the *node*.
- fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize;
-
-
- ///Compute the index that must be used to access the parameters of a node given a specific
- ///configuration of the network and a generic parent_set. Usually, the only values really used
- ///in *current_state* are the ones in the parent set of the *node*.
- fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize;
- fn get_parent_set(&self, node: usize) -> BTreeSet;
- fn get_children_set(&self, node: usize) -> BTreeSet;
-}
diff --git a/src/node.rs b/src/node.rs
deleted file mode 100644
index 7ed21ba..0000000
--- a/src/node.rs
+++ /dev/null
@@ -1,25 +0,0 @@
-use crate::params::*;
-
-
-pub struct Node {
- pub params: Params,
- pub label: String
-}
-
-impl Node {
- pub fn init(params: Params, label: String) -> Node {
- Node{
- params: params,
- label:label
- }
- }
-
-}
-
-impl PartialEq for Node {
- fn eq(&self, other: &Node) -> bool{
- self.label == other.label
- }
-}
-
-
diff --git a/src/params.rs b/src/params.rs
deleted file mode 100644
index c5a9acf..0000000
--- a/src/params.rs
+++ /dev/null
@@ -1,161 +0,0 @@
-use ndarray::prelude::*;
-use rand::Rng;
-use std::collections::{BTreeSet, HashMap};
-use thiserror::Error;
-use enum_dispatch::enum_dispatch;
-
-/// Error types for trait Params
-#[derive(Error, Debug)]
-pub enum ParamsError {
- #[error("Unsupported method")]
- UnsupportedMethod(String),
- #[error("Paramiters not initialized")]
- ParametersNotInitialized(String),
-}
-
-/// Allowed type of states
-#[derive(Clone)]
-pub enum StateType {
- Discrete(usize),
-}
-
-/// Parameters
-/// The Params trait is the core element for building different types of nodes. The goal is to
-/// define the set of method required to describes a generic node.
-#[enum_dispatch(Params)]
-pub trait ParamsTrait {
- fn reset_params(&mut self);
-
- /// Randomly generate a possible state of the node disregarding the state of the node and it's
- /// parents.
- fn get_random_state_uniform(&self) -> StateType;
-
- /// Randomly generate a residence time for the given node taking into account the node state
- /// and its parent set.
- fn get_random_residence_time(&self, state: usize, u: usize) -> Result;
-
- /// Randomly generate a possible state for the given node taking into account the node state
- /// and its parent set.
- fn get_random_state(&self, state: usize, u: usize) -> Result;
-
- /// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
- fn get_reserved_space_as_parent(&self) -> usize;
-
- /// Index used by discrete node to represents their states as usize.
- fn state_to_index(&self, state: &StateType) -> usize;
-}
-
-/// The Params enum is the core element for building different types of nodes. The goal is to
-/// define all the supported type of parameters.
-#[enum_dispatch]
-pub enum Params {
- DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
-}
-
-
-/// DiscreteStatesContinousTime.
-/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
-/// following elements:
-/// - **domain**: an ordered and exhaustive set of possible states
-/// - **cim**: Conditional Intensity Matrix
-/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
-/// learning task and are composed by:
-/// - **transitions**: number of transitions from one state to another given a specific
-/// realization of the parent set
-/// - **residence_time**: permanence time in each possible states given a specific
-/// realization of the parent set
-pub struct DiscreteStatesContinousTimeParams {
- pub domain: BTreeSet,
- pub cim: Option>,
- pub transitions: Option>,
- pub residence_time: Option>,
-}
-
-impl DiscreteStatesContinousTimeParams {
- pub fn init(domain: BTreeSet) -> DiscreteStatesContinousTimeParams {
- DiscreteStatesContinousTimeParams {
- domain: domain,
- cim: Option::None,
- transitions: Option::None,
- residence_time: Option::None,
- }
- }
-}
-
-impl ParamsTrait for DiscreteStatesContinousTimeParams {
- fn reset_params(&mut self) {
- self.cim = Option::None;
- self.transitions = Option::None;
- self.residence_time = Option::None;
- }
-
- fn get_random_state_uniform(&self) -> StateType {
- let mut rng = rand::thread_rng();
- StateType::Discrete(rng.gen_range(0..(self.domain.len())))
- }
-
- fn get_random_residence_time(&self, state: usize, u: usize) -> Result {
- // Generate a random residence time given the current state of the node and its parent set.
- // The method used is described in:
- // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
- match &self.cim {
- Option::Some(cim) => {
- let mut rng = rand::thread_rng();
- let lambda = cim[[u, state, state]] * -1.0;
- let x: f64 = rng.gen_range(0.0..=1.0);
- Ok(-x.ln() / lambda)
- }
- Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
- "CIM not initialized",
- ))),
- }
- }
-
- fn get_random_state(&self, state: usize, u: usize) -> Result {
- // Generate a random transition given the current state of the node and its parent set.
- // The method used is described in:
- // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
- match &self.cim {
- Option::Some(cim) => {
- let mut rng = rand::thread_rng();
- let lambda = cim[[u, state, state]] * -1.0;
- let urand: f64 = rng.gen_range(0.0..=1.0);
-
- let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
- (0, 0.0),
- |mut acc, ele| {
- if &acc.1 + ele < urand && ele > &0.0 {
- acc.0 += 1;
- }
- if ele > &0.0 {
- acc.1 += ele;
- }
- acc
- },
- );
-
- let next_state = if next_state.0 < state {
- next_state.0
- } else {
- next_state.0 + 1
- };
-
- Ok(StateType::Discrete(next_state))
- }
- Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
- "CIM not initialized",
- ))),
- }
- }
-
- fn get_reserved_space_as_parent(&self) -> usize {
- self.domain.len()
- }
-
- fn state_to_index(&self, state: &StateType) -> usize {
- match state {
- StateType::Discrete(val) => val.clone() as usize,
- }
- }
-}
-
diff --git a/src/tools.rs b/src/tools.rs
deleted file mode 100644
index 27438f9..0000000
--- a/src/tools.rs
+++ /dev/null
@@ -1,119 +0,0 @@
-use crate::network;
-use crate::node;
-use crate::params;
-use crate::params::ParamsTrait;
-use ndarray::prelude::*;
-
-pub struct Trajectory {
- pub time: Array1,
- pub events: Array2,
-}
-
-pub struct Dataset {
- pub trajectories: Vec,
-}
-
-pub fn trajectory_generator(
- net: &T,
- n_trajectories: u64,
- t_end: f64,
-) -> Dataset {
- let mut dataset = Dataset {
- trajectories: Vec::new(),
- };
-
- let node_idx: Vec<_> = net.get_node_indices().collect();
- for _ in 0..n_trajectories {
- let mut t = 0.0;
- let mut time: Vec = Vec::new();
- let mut events: Vec> = Vec::new();
- let mut current_state: Vec = node_idx
- .iter()
- .map(|x| net.get_node(*x).params.get_random_state_uniform())
- .collect();
- let mut next_transitions: Vec