diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 9f63860..dc941e5 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,13 +267,11 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } - let domain_size = domain_size as f64; - // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) + .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 47a067d..0a48410 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,7 +1,13 @@ //! Contains commonly used methods used across the crate. -use ndarray::prelude::*; +use std::ops::{DivAssign, MulAssign, Range}; +use ndarray::{Array, Array1, Array2, Array3, Axis}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +use crate::params::ParamsTrait; +use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -108,3 +114,243 @@ pub fn trajectory_generator( //Return a dataset object with the sampled trajectories. Dataset::new(trajectories) } + +pub trait RandomGraphGenerator { + fn new(density: f64, seed: Option) -> Self; + fn generate_graph(&mut self, net: &mut T); +} + +/// Graph Generator using an uniform distribution. +/// +/// A method to generate a random graph with edges uniformly distributed. +/// +/// # Arguments +/// +/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomGraphGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// +/// // Initialize the Graph Generator using the one with an +/// // uniform distribution +/// let density = 1.0/3.0; +/// let seed = Some(7641630759785120); +/// let mut structure_generator = UniformGraphGenerator::new( +/// density, +/// seed +/// ); +/// +/// // Generate the graph directly on the network +/// structure_generator.generate_graph(&mut net); +/// # // Count all the edges generated in the network +/// # let mut edges = 0; +/// # for node in net.get_node_indices(){ +/// # edges += net.get_children_set(node).len() +/// # } +/// # // Number of all the nodes in the network +/// # let nodes = net.get_node_indices().len() as f64; +/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; +/// # // ±10% of tolerance +/// # let tolerance = ((expected_edges as f64)*0.10) as usize; +/// # // As the way `generate_graph()` is implemented we can only reasonably +/// # // expect the number of edges to be somewhere around the expected value. +/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); +/// ``` +pub struct UniformGraphGenerator { + density: f64, + rng: ChaCha8Rng, +} + +impl RandomGraphGenerator for UniformGraphGenerator { + fn new(density: f64, seed: Option) -> UniformGraphGenerator { + if density < 0.0 || density > 1.0 { + panic!( + "Density value must be between 1.0 and 0.0, got {}.", + density + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + UniformGraphGenerator { density, rng } + } + + /// Generate an uniformly distributed graph. + fn generate_graph(&mut self, net: &mut T) { + net.initialize_adj_matrix(); + let last_node_idx = net.get_node_indices().len(); + for parent in 0..last_node_idx { + for child in 0..last_node_idx { + if parent != child { + if self.rng.gen_bool(self.density) { + net.add_edge(parent, child); + } + } + } + } + } +} + +pub trait RandomParametersGenerator { + fn new(interval: Range, seed: Option) -> Self; + fn generate_parameters(&mut self, net: &mut T); +} + +/// Parameters Generator using an uniform distribution. +/// +/// A method to generate random parameters uniformly distributed. +/// +/// # Arguments +/// +/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::ParamsTrait; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::tools::RandomGraphGenerator; +/// # use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomParametersGenerator; +/// use reCTBN::tools::UniformParametersGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// # +/// # // Initialize the Graph Generator using the one with an +/// # // uniform distribution +/// # let mut structure_generator = UniformGraphGenerator::new( +/// # 1.0/3.0, +/// # Some(7641630759785120) +/// # ); +/// # +/// # // Generate the graph directly on the network +/// # structure_generator.generate_graph(&mut net); +/// +/// // Initialize the parameters generator with uniform distributin +/// let mut cim_generator = UniformParametersGenerator::new( +/// 0.0..7.0, +/// Some(7641630759785120) +/// ); +/// +/// // Generate CIMs with uniformly distributed parameters. +/// cim_generator.generate_parameters(&mut net); +/// # +/// # for node in net.get_node_indices() { +/// # assert_eq!( +/// # Ok(()), +/// # net.get_node(node).validate_params() +/// # ); +/// } +/// ``` +pub struct UniformParametersGenerator { + interval: Range, + rng: ChaCha8Rng, +} + +impl RandomParametersGenerator for UniformParametersGenerator { + fn new(interval: Range, seed: Option) -> UniformParametersGenerator { + if interval.start < 0.0 || interval.end < 0.0 { + panic!( + "Interval must be entirely less or equal than 0, got {}..{}.", + interval.start, interval.end + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + UniformParametersGenerator { interval, rng } + } + + /// Generate CIMs with uniformly distributed parameters. + fn generate_parameters(&mut self, net: &mut T) { + for node in net.get_node_indices() { + let parent_set_state_space_cardinality: usize = net + .get_parent_set(node) + .iter() + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + let node_domain_cardinality = param.get_reserved_space_as_parent(); + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); + // Recomputing the diagonal in order to reduce the issues caused by the + // loss of precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) + }); + param.set_cim_unchecked(cim); + } + } + } + } +} diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 2cbc185..0a09a2a 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -6,6 +6,7 @@ use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; +use reCTBN::params::Params::DiscreteStatesContinousTime; use reCTBN::tools::*; use utils::*; @@ -66,18 +67,78 @@ fn learn_binary_cim(pl: T) { )); } +fn generate_nodes( + net: &mut CtbnNetwork, + nodes_cardinality: usize, + nodes_domain_cardinality: usize +) { + for node_label in 0..nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } +} + +fn learn_binary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 2); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_binary_cim_MLE() { let mle = MLE {}; learn_binary_cim(mle); } +#[test] +fn learn_binary_cim_MLE_gen() { + let mle = MLE {}; + learn_binary_cim_gen(mle); +} + #[test] fn learn_binary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_binary_cim(ba); } +#[test] +fn learn_binary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_binary_cim_gen(ba); +} + fn learn_ternary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -155,18 +216,63 @@ fn learn_ternary_cim(pl: T) { )); } +fn learn_ternary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 4.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_MLE() { let mle = MLE {}; learn_ternary_cim(mle); } +#[test] +fn learn_ternary_cim_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_gen(mle); +} + #[test] fn learn_ternary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim(ba); } +#[test] +fn learn_ternary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_gen(ba); +} + fn learn_ternary_cim_no_parents(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -234,18 +340,63 @@ fn learn_ternary_cim_no_parents(pl: T) { )); } +fn learn_ternary_cim_no_parents_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(0) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 0, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_no_parents_MLE() { let mle = MLE {}; learn_ternary_cim_no_parents(mle); } +#[test] +fn learn_ternary_cim_no_parents_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_no_parents_gen(mle); +} + #[test] fn learn_ternary_cim_no_parents_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim_no_parents(ba); } +#[test] +fn learn_ternary_cim_no_parents_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_no_parents_gen(ba); +} + fn learn_mixed_discrete_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -432,14 +583,66 @@ fn learn_mixed_discrete_cim(pl: T) { )); } +fn learn_mixed_discrete_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..8.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(2) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 2, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.2 + ) + ); +} + #[test] fn learn_mixed_discrete_cim_MLE() { let mle = MLE {}; learn_mixed_discrete_cim(mle); } +#[test] +fn learn_mixed_discrete_cim_MLE_gen() { + let mle = MLE {}; + learn_mixed_discrete_cim_gen(mle); +} + #[test] fn learn_mixed_discrete_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_mixed_discrete_cim(ba); } + +#[test] +fn learn_mixed_discrete_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_mixed_discrete_cim_gen(ba); +} diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 9a69b45..3d7e230 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -117,6 +117,50 @@ fn check_compatibility_between_dataset_and_network(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); + + let mut net = CtbnNetwork::new(); + let _n1 = net + .add_node( + generate_discrete_time_continous_node(String::from("0"), + 3) + ).unwrap(); + let _net = sl.fit_transform(net, &data); +} + #[test] #[should_panic] pub fn check_compatibility_between_dataset_and_network_hill_climbing() { @@ -125,6 +169,14 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + check_compatibility_between_dataset_and_network_gen(hl); +} + fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -182,6 +234,25 @@ fn learn_ternary_net_2_nodes(sl: T) { assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } +fn learn_ternary_net_2_nodes_gen(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -189,6 +260,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_ternary_net_2_nodes_gen(hl); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -196,6 +274,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_ternary_net_2_nodes_gen(hl); +} + fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { let mut net = CtbnNetwork::new(); let n1 = net @@ -320,6 +405,30 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } +fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); + return (net, data); +} + fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -328,6 +437,14 @@ fn learn_mixed_discrete_net_3_nodes(sl: T) { assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); } +fn learn_mixed_discrete_net_3_nodes_gen(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -335,6 +452,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -342,6 +466,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -350,6 +481,14 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { let ll = LogLikelihood::new(1, 1.0); @@ -357,6 +496,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { let bic = BIC::new(1, 1.0); @@ -364,6 +510,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn chi_square_compare_matrices() { let i: usize = 1; @@ -511,6 +664,15 @@ pub fn learn_ternary_net_2_nodes_ctpc() { learn_ternary_net_2_nodes(ctpc); } +#[test] +pub fn learn_ternary_net_2_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes_gen(ctpc); +} + #[test] fn learn_mixed_discrete_net_3_nodes_ctpc() { let f = F::new(1e-6); @@ -519,3 +681,12 @@ fn learn_mixed_discrete_net_3_nodes_ctpc() { let ctpc = CTPC::new(parameter_learning, f, chi_sq); learn_mixed_discrete_net_3_nodes(ctpc); } + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_mixed_discrete_net_3_nodes_gen(ctpc); +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 806faef..59d8f27 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,9 +1,15 @@ +use std::ops::Range; + use ndarray::{arr1, arr2, arr3}; +use reCTBN::params::ParamsTrait; use reCTBN::process::ctbn::*; +use reCTBN::process::ctmp::*; use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*; +use utils::*; + #[macro_use] extern crate approx; @@ -82,3 +88,164 @@ fn dataset_wrong_shape() { let t2 = Trajectory::new(time, events); Dataset::new(vec![t1, t2]); } + +#[test] +#[should_panic] +fn uniform_graph_generator_wrong_density_1() { + let density = 2.1; + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); +} + +#[test] +#[should_panic] +fn uniform_graph_generator_wrong_density_2() { + let density = -0.5; + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); +} + +#[test] +fn uniform_graph_generator_right_densities() { + for density in [1.0, 0.75, 0.5, 0.25, 0.0] { + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); + } +} + +#[test] +fn uniform_graph_generator_generate_graph_ctbn() { + let mut net = CtbnNetwork::new(); + let nodes_cardinality = 0..=100; + let nodes_domain_cardinality = 2; + for node_label in nodes_cardinality { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); + structure_generator.generate_graph(&mut net); + let mut edges = 0; + for node in net.get_node_indices(){ + edges += net.get_children_set(node).len() + } + let nodes = net.get_node_indices().len() as f64; + let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; + let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance + // As the way `generate_graph()` is implemented we can only reasonably + // expect the number of edges to be somewhere around the expected value. + assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); +} + +#[test] +#[should_panic] +fn uniform_graph_generator_generate_graph_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); + structure_generator.generate_graph(&mut net); +} + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_1() { + let interval: Range = -2.0..-5.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); +} + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_2() { + let interval: Range = -1.0..0.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); +} + +#[test] +fn uniform_parameters_generator_right_densities_ctbn() { + let mut net = CtbnNetwork::new(); + let nodes_cardinality = 0..=3; + let nodes_domain_cardinality = 9; + for node_label in nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + seed + ); + structure_generator.generate_graph(&mut net); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); + cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); + } +} + +#[test] +fn uniform_parameters_generator_right_densities_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); + cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); + } +}