From a077f738eeafc606dd1f791503b0d0c629e97aaa Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 26 Jan 2023 16:16:22 +0100 Subject: [PATCH 01/12] Added `StructureGen` struct for generating the structure of a `CtbnNetwork` --- reCTBN/src/tools.rs | 39 +++++++++++++++++++++++++++++++++++++++ reCTBN/tests/tools.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 47a067d..c58403a 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,7 +1,11 @@ //! Contains commonly used methods used across the crate. use ndarray::prelude::*; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use crate::process::ctbn::CtbnNetwork; +use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -108,3 +112,38 @@ pub fn trajectory_generator( //Return a dataset object with the sampled trajectories. Dataset::new(trajectories) } + +pub struct StructureGen { + density: f64, + rng: ChaCha8Rng, +} + +impl StructureGen { + pub fn new(density: f64, seed: Option) -> StructureGen { + if density < 0.0 || density > 1.0 { + panic!( + "Density value must be between 1.0 and 0.0, got {}.", + density + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + StructureGen { density, rng } + } + + pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + let last_node_idx = net.get_node_indices().len(); + for parent in 0..last_node_idx { + for child in 0..last_node_idx { + if parent != child { + if self.rng.gen_bool(self.density) { + net.add_edge(parent, child); + } + } + } + } + net + } +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 806faef..ac64f8d 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -82,3 +82,43 @@ fn dataset_wrong_shape() { let t2 = Trajectory::new(time, events); Dataset::new(vec![t1, t2]); } + +#[test] +#[should_panic] +fn structure_gen_wrong_density() { + let density = 2.1; + StructureGen::new(density, None); +} + +#[test] +fn structure_gen_right_densities() { + for density in [1.0, 0.75, 0.5, 0.25, 0.0] { + StructureGen::new(density, None); + } +} + +#[test] +fn structure_gen_gen_structure() { + let mut net = CtbnNetwork::new(); + for node_label in 0..100 { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + 2, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator = StructureGen::new(density, Some(7641630759785120)); + structure_generator.gen_structure(&mut net); + let mut edges = 0; + for node in net.get_node_indices(){ + edges += net.get_children_set(node).len() + } + let nodes = net.get_node_indices().len() as f64; + let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; + let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance + // As the way `gen_structure()` is implemented we can only reasonably + // expect the number of edges to be somewhere around the expected value. + assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); +} From 4b994d8a19855dd387b98141cc4a154f8eb62521 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 26 Jan 2023 16:16:23 +0100 Subject: [PATCH 02/12] Renamed `StructureGen` with `UniformRandomGenerator` and defining the new trait `RandomGraphGenerator` --- reCTBN/src/tools.rs | 15 ++++++++++----- reCTBN/tests/tools.rs | 16 ++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index c58403a..599b420 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -113,13 +113,18 @@ pub fn trajectory_generator( Dataset::new(trajectories) } -pub struct StructureGen { +pub trait RandomGraphGenerator { + fn new(density: f64, seed: Option) -> Self; + fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork; +} + +pub struct UniformRandomGenerator { density: f64, rng: ChaCha8Rng, } -impl StructureGen { - pub fn new(density: f64, seed: Option) -> StructureGen { +impl RandomGraphGenerator for UniformRandomGenerator { + fn new(density: f64, seed: Option) -> UniformRandomGenerator { if density < 0.0 || density > 1.0 { panic!( "Density value must be between 1.0 and 0.0, got {}.", @@ -130,10 +135,10 @@ impl StructureGen { Some(seed) => SeedableRng::seed_from_u64(seed), None => SeedableRng::from_entropy(), }; - StructureGen { density, rng } + UniformRandomGenerator { density, rng } } - pub fn gen_structure<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { let last_node_idx = net.get_node_indices().len(); for parent in 0..last_node_idx { for child in 0..last_node_idx { diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index ac64f8d..4c32de7 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -85,20 +85,20 @@ fn dataset_wrong_shape() { #[test] #[should_panic] -fn structure_gen_wrong_density() { +fn uniform_random_generator_wrong_density() { let density = 2.1; - StructureGen::new(density, None); + let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); } #[test] -fn structure_gen_right_densities() { +fn uniform_random_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - StructureGen::new(density, None); + let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); } } #[test] -fn structure_gen_gen_structure() { +fn uniform_random_generator_generate_graph() { let mut net = CtbnNetwork::new(); for node_label in 0..100 { net.add_node( @@ -109,8 +109,8 @@ fn structure_gen_gen_structure() { ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator = StructureGen::new(density, Some(7641630759785120)); - structure_generator.gen_structure(&mut net); + let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ edges += net.get_children_set(node).len() @@ -118,7 +118,7 @@ fn structure_gen_gen_structure() { let nodes = net.get_node_indices().len() as f64; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance - // As the way `gen_structure()` is implemented we can only reasonably + // As the way `generate_graph()` is implemented we can only reasonably // expect the number of edges to be somewhere around the expected value. assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); } From 434e671f0a114c95a7f18825cfe5cf42773d96a7 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Fri, 27 Jan 2023 12:43:47 +0100 Subject: [PATCH 03/12] Renamed `UniformRandomGenerator` to `UniformGraphGenerator`, replaced `CtbnNetwork` requirement with `NetworkProcess` in `RandomGraphGenerator`, some related tweaks --- reCTBN/src/tools.rs | 15 +++++++-------- reCTBN/tests/tools.rs | 10 +++++----- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 599b420..9222239 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -4,7 +4,6 @@ use ndarray::prelude::*; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use crate::process::ctbn::CtbnNetwork; use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -115,16 +114,16 @@ pub fn trajectory_generator( pub trait RandomGraphGenerator { fn new(density: f64, seed: Option) -> Self; - fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork; + fn generate_graph(&mut self, net: &mut T); } -pub struct UniformRandomGenerator { +pub struct UniformGraphGenerator { density: f64, rng: ChaCha8Rng, } -impl RandomGraphGenerator for UniformRandomGenerator { - fn new(density: f64, seed: Option) -> UniformRandomGenerator { +impl RandomGraphGenerator for UniformGraphGenerator { + fn new(density: f64, seed: Option) -> UniformGraphGenerator { if density < 0.0 || density > 1.0 { panic!( "Density value must be between 1.0 and 0.0, got {}.", @@ -135,10 +134,11 @@ impl RandomGraphGenerator for UniformRandomGenerator { Some(seed) => SeedableRng::seed_from_u64(seed), None => SeedableRng::from_entropy(), }; - UniformRandomGenerator { density, rng } + UniformGraphGenerator { density, rng } } - fn generate_graph<'a>(&'a mut self, net: &'a mut CtbnNetwork) -> &CtbnNetwork { + fn generate_graph(&mut self, net: &mut T) { + net.initialize_adj_matrix(); let last_node_idx = net.get_node_indices().len(); for parent in 0..last_node_idx { for child in 0..last_node_idx { @@ -149,6 +149,5 @@ impl RandomGraphGenerator for UniformRandomGenerator { } } } - net } } diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 4c32de7..9a96959 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -87,13 +87,13 @@ fn dataset_wrong_shape() { #[should_panic] fn uniform_random_generator_wrong_density() { let density = 2.1; - let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } #[test] fn uniform_random_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - let _structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } } @@ -109,7 +109,7 @@ fn uniform_random_generator_generate_graph() { ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformRandomGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ @@ -117,8 +117,8 @@ fn uniform_random_generator_generate_graph() { } let nodes = net.get_node_indices().len() as f64; let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; - let tolerance = ((expected_edges as f64)/100.0*5.0) as usize; // ±5% of tolerance + let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance // As the way `generate_graph()` is implemented we can only reasonably // expect the number of edges to be somewhere around the expected value. - assert!((expected_edges - tolerance) < edges && edges < (expected_edges + tolerance)); + assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); } From d6f0fb9623b16187bf4341707797e1b094e6eeba Mon Sep 17 00:00:00 2001 From: meliurwen Date: Sun, 29 Jan 2023 16:44:13 +0100 Subject: [PATCH 04/12] WIP: implementing `UniformParametersGenerator` --- reCTBN/src/tools.rs | 92 ++++++++++++++++++++++++++++++++++++++++++- reCTBN/tests/tools.rs | 47 ++++++++++++++++++++-- 2 files changed, 135 insertions(+), 4 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 9222239..7c438d5 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -1,9 +1,12 @@ //! Contains commonly used methods used across the crate. -use ndarray::prelude::*; +use std::ops::{DivAssign, MulAssign, Range}; + +use ndarray::{Array, Array1, Array2, Array3, Axis}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; +use crate::params::ParamsTrait; use crate::process::NetworkProcess; use crate::sampling::{ForwardSampler, Sampler}; use crate::{params, process}; @@ -151,3 +154,90 @@ impl RandomGraphGenerator for UniformGraphGenerator { } } } + +pub trait RandomParametersGenerator { + fn new(interval: Range, seed: Option) -> Self; + fn generate_parameters(&mut self, net: &mut T); +} + +pub struct UniformParametersGenerator { + interval: Range, + rng: ChaCha8Rng, +} + +impl RandomParametersGenerator for UniformParametersGenerator { + fn new(interval: Range, seed: Option) -> UniformParametersGenerator { + if interval.start < 0.0 || interval.end < 0.0 { + panic!( + "Interval must be entirely less or equal than 0, got {}..{}.", + interval.start, interval.end + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + UniformParametersGenerator { interval, rng } + } + fn generate_parameters(&mut self, net: &mut T) { + for node in net.get_node_indices() { + let parent_set = net.get_parent_set(node); + let parent_set_state_space_cardinality: usize = parent_set + .iter() + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + println!( + "parent_set_state_space_cardinality = {}", + parent_set_state_space_cardinality + ); + let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); + println!("node_domain_cardinality = {}", node_domain_cardinality); + let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64) + ..=(self.interval.end / node_domain_cardinality as f64); + println!("cim_single_param_range = {:?}", cim_single_param_range); + + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + + //let diagonal = cim.axis_iter(Axis(0)); + cim.axis_iter_mut(Axis(0)) + .for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}"))); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + let sum_axis = x.sum_axis(Axis(0)); + //let division = 1.0 / &sum_axis; + x.div_assign(&sum_axis); + println!("{}", x); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag); + println!("{}", x); + x.diag_mut().assign(&-diag) + }); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.diag().iter().for_each(|x| println!("{x}"))); + + println!("Sum Axis"); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}"))); + println!("Matrices"); + cim.axis_iter_mut(Axis(0)) + .for_each(|x| x.iter().for_each(|x| println!("{}", x))); + //.any(|x| x.diag().iter().any(|x| x >= &0.0)) + + //println!("{:?}", diagonal); + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(cim)); + } + } + } + } +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 9a96959..e91cf04 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,3 +1,5 @@ +use std::ops::Range; + use ndarray::{arr1, arr2, arr3}; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; @@ -85,20 +87,27 @@ fn dataset_wrong_shape() { #[test] #[should_panic] -fn uniform_random_generator_wrong_density() { +fn uniform_graph_generator_wrong_density_1() { let density = 2.1; let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } #[test] -fn uniform_random_generator_right_densities() { +#[should_panic] +fn uniform_graph_generator_wrong_density_2() { + let density = -0.5; + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); +} + +#[test] +fn uniform_graph_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); } } #[test] -fn uniform_random_generator_generate_graph() { +fn uniform_graph_generator_generate_graph() { let mut net = CtbnNetwork::new(); for node_label in 0..100 { net.add_node( @@ -122,3 +131,35 @@ fn uniform_random_generator_generate_graph() { // expect the number of edges to be somewhere around the expected value. assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); } + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_1() { + let interval: Range = -2.0..-5.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); +} + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_2() { + let interval: Range = -1.0..0.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); +} + +#[test] +fn uniform_parameters_generator_right_densities() { + let mut net = CtbnNetwork::new(); + for node_label in 0..3 { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + 9, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + structure_generator.generate_graph(&mut net); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120)); + cim_generator.generate_parameters(&mut net); +} From f4e3c98c796aea5863f15a9969ca5500eb340f5d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Mon, 30 Jan 2023 10:48:17 +0100 Subject: [PATCH 05/12] Implemented `UniformParametersGenerator` and its test --- reCTBN/src/tools.rs | 40 ++++++---------------------------------- reCTBN/tests/tools.rs | 21 +++++++++++++++++---- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 7c438d5..344c66c 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -179,23 +179,15 @@ impl RandomParametersGenerator for UniformParametersGenerator { }; UniformParametersGenerator { interval, rng } } + fn generate_parameters(&mut self, net: &mut T) { for node in net.get_node_indices() { - let parent_set = net.get_parent_set(node); - let parent_set_state_space_cardinality: usize = parent_set + let parent_set_state_space_cardinality: usize = net + .get_parent_set(node) .iter() .map(|x| net.get_node(*x).get_reserved_space_as_parent()) .product(); - println!( - "parent_set_state_space_cardinality = {}", - parent_set_state_space_cardinality - ); let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); - println!("node_domain_cardinality = {}", node_domain_cardinality); - let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64) - ..=(self.interval.end / node_domain_cardinality as f64); - println!("cim_single_param_range = {:?}", cim_single_param_range); - let mut cim = Array3::::from_shape_fn( ( parent_set_state_space_cardinality, @@ -204,38 +196,18 @@ impl RandomParametersGenerator for UniformParametersGenerator { ), |_| self.rng.gen(), ); - - //let diagonal = cim.axis_iter(Axis(0)); - cim.axis_iter_mut(Axis(0)) - .for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}"))); cim.axis_iter_mut(Axis(0)).for_each(|mut x| { x.diag_mut().fill(0.0); - let sum_axis = x.sum_axis(Axis(0)); - //let division = 1.0 / &sum_axis; - x.div_assign(&sum_axis); - println!("{}", x); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { self.rng.gen_range(self.interval.clone()) }); - x.mul_assign(&diag); - println!("{}", x); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); x.diag_mut().assign(&-diag) }); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.diag().iter().for_each(|x| println!("{x}"))); - - println!("Sum Axis"); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}"))); - println!("Matrices"); - cim.axis_iter_mut(Axis(0)) - .for_each(|x| x.iter().for_each(|x| println!("{}", x))); - //.any(|x| x.diag().iter().any(|x| x >= &0.0)) - - //println!("{:?}", diagonal); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(cim)); + param.set_cim_unchecked(cim); } } } diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index e91cf04..f04fb2a 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -149,17 +149,30 @@ fn uniform_parameters_generator_wrong_density_2() { #[test] fn uniform_parameters_generator_right_densities() { let mut net = CtbnNetwork::new(); - for node_label in 0..3 { + let nodes_cardinality = 0..5; + let nodes_domain_cardinality = 9; + for node_label in nodes_cardinality { net.add_node( utils::generate_discrete_time_continous_node( node_label.to_string(), - 9, + nodes_domain_cardinality, ) ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, seed); structure_generator.generate_graph(&mut net); - let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(0.0..7.0, Some(7641630759785120)); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed); cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(param.get_cim().clone().unwrap())); + } + } + } } From 0f61cbee4c6fc72a159a3b3bb91ca2a7cb553ccd Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 31 Jan 2023 09:30:40 +0100 Subject: [PATCH 06/12] Refactored CIM validation for `UniformParametersGenerator` test --- reCTBN/tests/tools.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index f04fb2a..59ed71c 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,6 +1,7 @@ use std::ops::Range; use ndarray::{arr1, arr2, arr3}; +use reCTBN::params::ParamsTrait; use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::params; @@ -167,12 +168,9 @@ fn uniform_parameters_generator_right_densities() { let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed); cim_generator.generate_parameters(&mut net); for node in net.get_node_indices() { - match &mut net.get_node_mut(node) { - params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!( - Ok(()), - param.set_cim(param.get_cim().clone().unwrap())); - } - } + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); } } From 097dc25030732f7fc7d778aef6bb74d9cb8ac723 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 31 Jan 2023 11:28:26 +0100 Subject: [PATCH 07/12] Added tests for `UniformParametersGenerator` and `UniformGraphGenerator` against `CTMP`, plus some small refactoring to the other tests --- reCTBN/tests/tools.rs | 103 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 59ed71c..59d8f27 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -3,10 +3,13 @@ use std::ops::Range; use ndarray::{arr1, arr2, arr3}; use reCTBN::params::ParamsTrait; use reCTBN::process::ctbn::*; +use reCTBN::process::ctmp::*; use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*; +use utils::*; + #[macro_use] extern crate approx; @@ -90,36 +93,50 @@ fn dataset_wrong_shape() { #[should_panic] fn uniform_graph_generator_wrong_density_1() { let density = 2.1; - let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); } #[test] #[should_panic] fn uniform_graph_generator_wrong_density_2() { let density = -0.5; - let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); } #[test] fn uniform_graph_generator_right_densities() { for density in [1.0, 0.75, 0.5, 0.25, 0.0] { - let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, None); + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); } } #[test] -fn uniform_graph_generator_generate_graph() { +fn uniform_graph_generator_generate_graph_ctbn() { let mut net = CtbnNetwork::new(); - for node_label in 0..100 { + let nodes_cardinality = 0..=100; + let nodes_domain_cardinality = 2; + for node_label in nodes_cardinality { net.add_node( utils::generate_discrete_time_continous_node( node_label.to_string(), - 2, + nodes_domain_cardinality, ) ).unwrap(); } let density = 1.0/3.0; - let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, Some(7641630759785120)); + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); structure_generator.generate_graph(&mut net); let mut edges = 0; for node in net.get_node_indices(){ @@ -133,28 +150,54 @@ fn uniform_graph_generator_generate_graph() { assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); } +#[test] +#[should_panic] +fn uniform_graph_generator_generate_graph_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); + structure_generator.generate_graph(&mut net); +} + #[test] #[should_panic] fn uniform_parameters_generator_wrong_density_1() { let interval: Range = -2.0..-5.0; - let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); } #[test] #[should_panic] fn uniform_parameters_generator_wrong_density_2() { let interval: Range = -1.0..0.0; - let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, None); + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); } #[test] -fn uniform_parameters_generator_right_densities() { +fn uniform_parameters_generator_right_densities_ctbn() { let mut net = CtbnNetwork::new(); - let nodes_cardinality = 0..5; + let nodes_cardinality = 0..=3; let nodes_domain_cardinality = 9; for node_label in nodes_cardinality { net.add_node( - utils::generate_discrete_time_continous_node( + generate_discrete_time_continous_node( node_label.to_string(), nodes_domain_cardinality, ) @@ -163,9 +206,41 @@ fn uniform_parameters_generator_right_densities() { let density = 1.0/3.0; let seed = Some(7641630759785120); let interval = 0.0..7.0; - let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(density, seed); + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + seed + ); structure_generator.generate_graph(&mut net); - let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(interval, seed); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); + cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); + } +} + +#[test] +fn uniform_parameters_generator_right_densities_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); cim_generator.generate_parameters(&mut net); for node in net.get_node_indices() { assert_eq!( From a01a9ef20107983667cc2c30f627c8fcf3662df5 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Tue, 31 Jan 2023 13:16:52 +0100 Subject: [PATCH 08/12] Recomputing the diagonal when generating parameters to counter the precision loss and increase `f64::EPSILON` calculating its square root instead of multiplying it with the node's `domain_size` --- reCTBN/src/params.rs | 4 +--- reCTBN/src/tools.rs | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 9f63860..dc941e5 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,13 +267,11 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } - let domain_size = domain_size as f64; - // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) + .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 344c66c..e9b9fd8 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -203,7 +203,10 @@ impl RandomParametersGenerator for UniformParametersGenerator { self.rng.gen_range(self.interval.clone()) }); x.mul_assign(&diag.clone().insert_axis(Axis(1))); - x.diag_mut().assign(&-diag) + // Recomputing the diagonal in order to reduce the issues caused by the loss of + // precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) }); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { From e08d12ac1f243511befbc76c0c35c7dc03efd679 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 09:04:35 +0100 Subject: [PATCH 09/12] Added tests for structure learning algorithms using uniform graph and parameters generators as complementary to their handcrafted version --- reCTBN/tests/structure_learning.rs | 171 +++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index 9a69b45..3d7e230 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -117,6 +117,50 @@ fn check_compatibility_between_dataset_and_network(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); + + let mut net = CtbnNetwork::new(); + let _n1 = net + .add_node( + generate_discrete_time_continous_node(String::from("0"), + 3) + ).unwrap(); + let _net = sl.fit_transform(net, &data); +} + #[test] #[should_panic] pub fn check_compatibility_between_dataset_and_network_hill_climbing() { @@ -125,6 +169,14 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + check_compatibility_between_dataset_and_network_gen(hl); +} + fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -182,6 +234,25 @@ fn learn_ternary_net_2_nodes(sl: T) { assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } +fn learn_ternary_net_2_nodes_gen(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -189,6 +260,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_ternary_net_2_nodes_gen(hl); +} + #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -196,6 +274,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { learn_ternary_net_2_nodes(hl); } +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_ternary_net_2_nodes_gen(hl); +} + fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { let mut net = CtbnNetwork::new(); let n1 = net @@ -320,6 +405,30 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { return (net, data); } +fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); + return (net, data); +} + fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -328,6 +437,14 @@ fn learn_mixed_discrete_net_3_nodes(sl: T) { assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); } +fn learn_mixed_discrete_net_3_nodes_gen(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -335,6 +452,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { let bic = BIC::new(1, 1.0); @@ -342,6 +466,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); @@ -350,6 +481,14 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { let ll = LogLikelihood::new(1, 1.0); @@ -357,6 +496,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { let bic = BIC::new(1, 1.0); @@ -364,6 +510,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); } +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + #[test] pub fn chi_square_compare_matrices() { let i: usize = 1; @@ -511,6 +664,15 @@ pub fn learn_ternary_net_2_nodes_ctpc() { learn_ternary_net_2_nodes(ctpc); } +#[test] +pub fn learn_ternary_net_2_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes_gen(ctpc); +} + #[test] fn learn_mixed_discrete_net_3_nodes_ctpc() { let f = F::new(1e-6); @@ -519,3 +681,12 @@ fn learn_mixed_discrete_net_3_nodes_ctpc() { let ctpc = CTPC::new(parameter_learning, f, chi_sq); learn_mixed_discrete_net_3_nodes(ctpc); } + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_mixed_discrete_net_3_nodes_gen(ctpc); +} From 430033afdb17a239ada1ccad16f3c32e3ce48234 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 11:20:13 +0100 Subject: [PATCH 10/12] Added tests for the learning of parameters using uniform graph and parameters generators as complementary to their handcrafted version --- reCTBN/tests/parameter_learning.rs | 203 +++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 2cbc185..0a09a2a 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -6,6 +6,7 @@ use reCTBN::process::ctbn::*; use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; +use reCTBN::params::Params::DiscreteStatesContinousTime; use reCTBN::tools::*; use utils::*; @@ -66,18 +67,78 @@ fn learn_binary_cim(pl: T) { )); } +fn generate_nodes( + net: &mut CtbnNetwork, + nodes_cardinality: usize, + nodes_domain_cardinality: usize +) { + for node_label in 0..nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } +} + +fn learn_binary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 2); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_binary_cim_MLE() { let mle = MLE {}; learn_binary_cim(mle); } +#[test] +fn learn_binary_cim_MLE_gen() { + let mle = MLE {}; + learn_binary_cim_gen(mle); +} + #[test] fn learn_binary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_binary_cim(ba); } +#[test] +fn learn_binary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_binary_cim_gen(ba); +} + fn learn_ternary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -155,18 +216,63 @@ fn learn_ternary_cim(pl: T) { )); } +fn learn_ternary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 4.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_MLE() { let mle = MLE {}; learn_ternary_cim(mle); } +#[test] +fn learn_ternary_cim_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_gen(mle); +} + #[test] fn learn_ternary_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim(ba); } +#[test] +fn learn_ternary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_gen(ba); +} + fn learn_ternary_cim_no_parents(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -234,18 +340,63 @@ fn learn_ternary_cim_no_parents(pl: T) { )); } +fn learn_ternary_cim_no_parents_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(0) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 0, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + #[test] fn learn_ternary_cim_no_parents_MLE() { let mle = MLE {}; learn_ternary_cim_no_parents(mle); } +#[test] +fn learn_ternary_cim_no_parents_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_no_parents_gen(mle); +} + #[test] fn learn_ternary_cim_no_parents_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim_no_parents(ba); } +#[test] +fn learn_ternary_cim_no_parents_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_no_parents_gen(ba); +} + fn learn_mixed_discrete_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net @@ -432,14 +583,66 @@ fn learn_mixed_discrete_cim(pl: T) { )); } +fn learn_mixed_discrete_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..8.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(2) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 2, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.2 + ) + ); +} + #[test] fn learn_mixed_discrete_cim_MLE() { let mle = MLE {}; learn_mixed_discrete_cim(mle); } +#[test] +fn learn_mixed_discrete_cim_MLE_gen() { + let mle = MLE {}; + learn_mixed_discrete_cim_gen(mle); +} + #[test] fn learn_mixed_discrete_cim_BA() { let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_mixed_discrete_cim(ba); } + +#[test] +fn learn_mixed_discrete_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_mixed_discrete_cim_gen(ba); +} From 4884010ea97f1670f79ee2a2e9fe914b9ea65b80 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 14:36:12 +0100 Subject: [PATCH 11/12] Added doctests for `UniformParametersGenerator` and `UniformGraphGenerator` --- reCTBN/src/tools.rs | 138 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index e9b9fd8..89c19a9 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -120,6 +120,72 @@ pub trait RandomGraphGenerator { fn generate_graph(&mut self, net: &mut T); } +/// Graph Generator using an uniform distribution. +/// +/// A method to generate a random graph with edges uniformly distributed. +/// +/// # Arguments +/// +/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomGraphGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// +/// // Initialize the Graph Generator using the one with an +/// // uniform distribution +/// let density = 1.0/3.0; +/// let seed = Some(7641630759785120); +/// let mut structure_generator = UniformGraphGenerator::new( +/// density, +/// seed +/// ); +/// +/// // Generate the graph directly on the network +/// structure_generator.generate_graph(&mut net); +/// # // Count all the edges generated in the network +/// # let mut edges = 0; +/// # for node in net.get_node_indices(){ +/// # edges += net.get_children_set(node).len() +/// # } +/// # // Number of all the nodes in the network +/// # let nodes = net.get_node_indices().len() as f64; +/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; +/// # // ±10% of tolerance +/// # let tolerance = ((expected_edges as f64)*0.10) as usize; +/// # // As the way `generate_graph()` is implemented we can only reasonably +/// # // expect the number of edges to be somewhere around the expected value. +/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); +/// ``` pub struct UniformGraphGenerator { density: f64, rng: ChaCha8Rng, @@ -140,6 +206,7 @@ impl RandomGraphGenerator for UniformGraphGenerator { UniformGraphGenerator { density, rng } } + /// Generate an uniformly distributed graph. fn generate_graph(&mut self, net: &mut T) { net.initialize_adj_matrix(); let last_node_idx = net.get_node_indices().len(); @@ -160,6 +227,76 @@ pub trait RandomParametersGenerator { fn generate_parameters(&mut self, net: &mut T); } +/// Parameters Generator using an uniform distribution. +/// +/// A method to generate random parameters uniformly distributed. +/// +/// # Arguments +/// +/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::ParamsTrait; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::tools::RandomGraphGenerator; +/// # use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomParametersGenerator; +/// use reCTBN::tools::UniformParametersGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// # +/// # // Initialize the Graph Generator using the one with an +/// # // uniform distribution +/// # let mut structure_generator = UniformGraphGenerator::new( +/// # 1.0/3.0, +/// # Some(7641630759785120) +/// # ); +/// # +/// # // Generate the graph directly on the network +/// # structure_generator.generate_graph(&mut net); +/// +/// // Initialize the parameters generator with uniform distributin +/// let mut cim_generator = UniformParametersGenerator::new( +/// 0.0..7.0, +/// Some(7641630759785120) +/// ); +/// +/// // Generate CIMs with uniformly distributed parameters. +/// cim_generator.generate_parameters(&mut net); +/// # +/// # for node in net.get_node_indices() { +/// # assert_eq!( +/// # Ok(()), +/// # net.get_node(node).validate_params() +/// # ); +/// } +/// ``` pub struct UniformParametersGenerator { interval: Range, rng: ChaCha8Rng, @@ -180,6 +317,7 @@ impl RandomParametersGenerator for UniformParametersGenerator { UniformParametersGenerator { interval, rng } } + /// Generate CIMs with uniformly distributed parameters. fn generate_parameters(&mut self, net: &mut T) { for node in net.get_node_indices() { let parent_set_state_space_cardinality: usize = net From 0639a755d0e74f7563f8d254609152b8f9480167 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 1 Feb 2023 15:32:13 +0100 Subject: [PATCH 12/12] Refactored `generate_parameters` moving some code inside `match` statement --- reCTBN/src/tools.rs | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 89c19a9..0a48410 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -325,29 +325,29 @@ impl RandomParametersGenerator for UniformParametersGenerator { .iter() .map(|x| net.get_node(*x).get_reserved_space_as_parent()) .product(); - let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); - let mut cim = Array3::::from_shape_fn( - ( - parent_set_state_space_cardinality, - node_domain_cardinality, - node_domain_cardinality, - ), - |_| self.rng.gen(), - ); - cim.axis_iter_mut(Axis(0)).for_each(|mut x| { - x.diag_mut().fill(0.0); - x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); - let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { - self.rng.gen_range(self.interval.clone()) - }); - x.mul_assign(&diag.clone().insert_axis(Axis(1))); - // Recomputing the diagonal in order to reduce the issues caused by the loss of - // precision when validating the parameters. - let diag_sum = -x.sum_axis(Axis(1)); - x.diag_mut().assign(&diag_sum) - }); match &mut net.get_node_mut(node) { params::Params::DiscreteStatesContinousTime(param) => { + let node_domain_cardinality = param.get_reserved_space_as_parent(); + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); + // Recomputing the diagonal in order to reduce the issues caused by the + // loss of precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) + }); param.set_cim_unchecked(cim); } }