From ed5471c7cf6d4a28e7486c3e8be0dd9e63cb79b5 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Mon, 14 Nov 2022 16:07:04 +0100 Subject: [PATCH 1/6] Added ctmp --- reCTBN/src/lib.rs | 3 +- reCTBN/src/parameter_learning.rs | 12 +- reCTBN/src/{network.rs => process.rs} | 5 +- reCTBN/src/{ => process}/ctbn.rs | 10 +- reCTBN/src/process/ctmp.rs | 106 +++++++++++++++ reCTBN/src/sampling.rs | 10 +- reCTBN/src/structure_learning.rs | 4 +- .../src/structure_learning/hypothesis_test.rs | 6 +- .../score_based_algorithm.rs | 4 +- .../src/structure_learning/score_function.rs | 10 +- reCTBN/src/tools.rs | 4 +- reCTBN/tests/ctbn.rs | 4 +- reCTBN/tests/ctmp.rs | 127 ++++++++++++++++++ reCTBN/tests/parameter_learning.rs | 4 +- reCTBN/tests/structure_learning.rs | 4 +- reCTBN/tests/tools.rs | 4 +- 16 files changed, 276 insertions(+), 41 deletions(-) rename reCTBN/src/{network.rs => process.rs} (98%) rename reCTBN/src/{ => process}/ctbn.rs (95%) create mode 100644 reCTBN/src/process/ctmp.rs create mode 100644 reCTBN/tests/ctmp.rs diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index db33ae4..6ab59cb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -3,10 +3,9 @@ #[cfg(test)] extern crate approx; -pub mod ctbn; -pub mod network; pub mod parameter_learning; pub mod params; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod process; diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 61d4dca..2aa518c 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,10 +5,10 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{network, tools}; +use crate::{process, tools}; pub trait ParameterLearning { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -17,7 +17,7 @@ pub trait ParameterLearning { ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, dataset: &tools::Dataset, node: usize, @@ -73,7 +73,7 @@ pub fn sufficient_statistics( pub struct MLE {} impl ParameterLearning for MLE { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -120,7 +120,7 @@ pub struct BayesianApproach { } impl ParameterLearning for BayesianApproach { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -177,7 +177,7 @@ impl Cache

{ dataset, } } - pub fn fit( + pub fn fit( &mut self, net: &T, node: usize, diff --git a/reCTBN/src/network.rs b/reCTBN/src/process.rs similarity index 98% rename from reCTBN/src/network.rs rename to reCTBN/src/process.rs index fbdd2e6..2b70b59 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/process.rs @@ -1,5 +1,8 @@ //! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs +pub mod ctbn; +pub mod ctmp; + use std::collections::BTreeSet; use thiserror::Error; @@ -15,7 +18,7 @@ pub enum NetworkError { /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait Network { +pub trait NetworkProcess { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/process/ctbn.rs similarity index 95% rename from reCTBN/src/ctbn.rs rename to reCTBN/src/process/ctbn.rs index 2b01d14..c59d99d 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; -use crate::network; +use crate::process; use crate::params::{Params, ParamsTrait, StateType}; /// It represents both the structure and the parameters of a CTBN. @@ -20,9 +20,9 @@ use crate::params::{Params, ParamsTrait, StateType}; /// /// ```rust /// use std::collections::BTreeSet; -/// use reCTBN::network::Network; +/// use reCTBN::process::NetworkProcess; /// use reCTBN::params; -/// use reCTBN::ctbn::*; +/// use reCTBN::process::ctbn::*; /// /// //Create the domain for a discrete node /// let mut domain = BTreeSet::new(); @@ -69,7 +69,7 @@ impl CtbnNetwork { } } -impl network::Network for CtbnNetwork { +impl process::NetworkProcess for CtbnNetwork { /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( @@ -78,7 +78,7 @@ impl network::Network for CtbnNetwork { } /// Add a new node. - fn add_node(&mut self, mut n: Params) -> Result { + fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs new file mode 100644 index 0000000..b0b042a --- /dev/null +++ b/reCTBN/src/process/ctmp.rs @@ -0,0 +1,106 @@ +use std::collections::BTreeSet; + +use crate::{process, params::{Params, StateType}}; + +use super::NetworkProcess; + +pub struct CtmpProcess { + param: Option +} + +impl CtmpProcess { + pub fn new() -> CtmpProcess { + CtmpProcess { param: None } + } +} + +impl NetworkProcess for CtmpProcess { + fn initialize_adj_matrix(&mut self) { + unimplemented!("CtmpProcess has only one node") + } + + fn add_node(&mut self, n: crate::params::Params) -> Result { + match self.param { + None => { + self.param = Some(n); + Ok(0) + }, + Some(_) => Err(process::NetworkError::NodeInsertionError("CtmpProcess has only one node".to_string())) + } + } + + fn add_edge(&mut self, parent: usize, child: usize) { + unimplemented!("CtmpProcess has only one node") + } + + fn get_node_indices(&self) -> std::ops::Range { + match self.param { + None => 0..0, + Some(_) => 0..1 + } + } + + fn get_number_of_nodes(&self) -> usize { + match self.param { + None => 0, + Some(_) => 1 + } + } + + fn get_node(&self, node_idx: usize) -> &crate::params::Params { + if node_idx == 0 { + self.param.as_ref().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { + if node_idx == 0 { + self.param.as_mut().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_network(&self, node: usize, current_state: &Vec) + -> usize { + if node == 0 { + match current_state[0] { + StateType::Discrete(x) => x + } + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &std::collections::BTreeSet, + ) -> usize { + unimplemented!("CtmpProcess has only one node") + } + + fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } + + fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } +} diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d435634..050daeb 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,7 +1,7 @@ //! Module containing methods for the sampling. use crate::{ - network::Network, + process::NetworkProcess, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -13,7 +13,7 @@ pub trait Sampler: Iterator { pub struct ForwardSampler<'a, T> where - T: Network, + T: NetworkProcess, { net: &'a T, rng: ChaCha8Rng, @@ -22,7 +22,7 @@ where next_transitions: Vec>, } -impl<'a, T: Network> ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. @@ -42,7 +42,7 @@ impl<'a, T: Network> ForwardSampler<'a, T> { } } -impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { type Item = (f64, Vec); fn next(&mut self) -> Option { @@ -100,7 +100,7 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { } } -impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; self.current_state = self diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index 57fed1e..b272e22 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{network, tools}; +use crate::{process, tools}; pub trait StructureLearningAlgorithm { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network; + T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 6474155..4ec3377 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::params::*; -use crate::{network, parameter_learning}; +use crate::{process, parameter_learning}; pub trait HypothesisTest { fn call( @@ -18,7 +18,7 @@ pub trait HypothesisTest { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning; } @@ -135,7 +135,7 @@ impl HypothesisTest for ChiSquare { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 9e329eb..16e9056 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{network, tools}; +use crate::{process, tools}; pub struct HillClimbing { score_function: S, @@ -23,7 +23,7 @@ impl HillClimbing { impl StructureLearningAlgorithm for HillClimbing { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network, + T: process::NetworkProcess, { //Check the coherence between dataset and network if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index cb6ad7b..8943478 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{network, parameter_learning, params, tools}; +use crate::{process, parameter_learning, params, tools}; pub trait ScoreFunction { fn call( @@ -16,7 +16,7 @@ pub trait ScoreFunction { dataset: &tools::Dataset, ) -> f64 where - T: network::Network; + T: process::NetworkProcess; } pub struct LogLikelihood { @@ -41,7 +41,7 @@ impl LogLikelihood { dataset: &tools::Dataset, ) -> (f64, Array3) where - T: network::Network, + T: process::NetworkProcess, { //Identify the type of node used match &net.get_node(node) { @@ -100,7 +100,7 @@ impl ScoreFunction for LogLikelihood { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { self.compute_score(net, node, parent_set, dataset).0 } @@ -127,7 +127,7 @@ impl ScoreFunction for BIC { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index aa48883..6f2f648 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{network, params}; +use crate::{process, params}; pub struct Trajectory { time: Array1, @@ -51,7 +51,7 @@ impl Dataset { } } -pub fn trajectory_generator( +pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 63c9621..0ad0fc4 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,8 +1,8 @@ mod utils; use std::collections::BTreeSet; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs new file mode 100644 index 0000000..31bc6df --- /dev/null +++ b/reCTBN/tests/ctmp.rs @@ -0,0 +1,127 @@ +mod utils; + +use std::collections::BTreeSet; + +use reCTBN::{ + params, + params::ParamsTrait, + process::{ctmp::*, NetworkProcess}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn define_simple_ctmp() { + let _ = CtmpProcess::new(); + assert!(true); +} + +#[test] +fn add_node_to_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); +} + +#[test] +fn add_two_nodes_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + match n2 { + Ok(_) => assert!(false), + Err(_) => assert!(true), + }; +} + +#[test] +#[should_panic] +fn add_edge_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + net.add_edge(0, 1) +} + +#[test] +fn childen_and_parents() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + assert_eq!(0, net.get_parent_set(0).len()); + assert_eq!(0, net.get_children_set(0).len()); +} + +#[test] +#[should_panic] +fn get_childen_panic() { + let mut net = CtmpProcess::new(); + net.get_children_set(0); +} + +#[test] +#[should_panic] +fn get_childen_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_children_set(1); +} + +#[test] +#[should_panic] +fn get_parent_panic() { + let mut net = CtmpProcess::new(); + net.get_parent_set(0); +} + +#[test] +#[should_panic] +fn get_parent_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_parent_set(1); +} + +#[test] +fn compute_index_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); + assert_eq!(6, idx); +} + +#[test] +#[should_panic] +fn compute_index_from_custom_parent_set_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let _idx = net.get_param_index_from_custom_parent_set( + &vec![params::StateType::Discrete(6)], + &BTreeSet::from([0]) + ); +} diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 7d09b07..2cbc185 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -2,8 +2,8 @@ mod utils; use ndarray::arr3; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; use reCTBN::tools::*; diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a1667c2..2ec64b2 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -4,8 +4,8 @@ mod utils; use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 589b04e..806faef 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,6 +1,6 @@ use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*; From 7c3cba50d4afb08c1087711ef8fba12a2351ad54 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 11:14:41 +0100 Subject: [PATCH 2/6] Implemented amalgamation --- reCTBN/src/process/ctbn.rs | 78 +++++++++++++++++++++++++++++++++++++- reCTBN/tests/ctbn.rs | 36 +++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index c59d99d..3852c50 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,8 +4,11 @@ use std::collections::BTreeSet; use ndarray::prelude::*; +use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; use crate::process; -use crate::params::{Params, ParamsTrait, StateType}; + +use super::ctmp::CtmpProcess; +use super::NetworkProcess; /// It represents both the structure and the parameters of a CTBN. /// @@ -67,6 +70,79 @@ impl CtbnNetwork { nodes: Vec::new(), } } + + pub fn amalgamation(&self) -> CtmpProcess { + for v in self.nodes.iter() { + match v { + Params::DiscreteStatesContinousTime(_) => {} + _ => panic!("Unsupported node"), + } + } + + let variables_domain = + Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); + + let state_space = variables_domain.product(); + let variables_set = BTreeSet::from_iter(self.get_node_indices()); + let mut amalgamated_cim: Array3 = Array::zeros((1, state_space, state_space)); + + for idx_current_state in 0..state_space { + let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); + let current_state_statetype: Vec = current_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + for idx_node in 0..self.nodes.len() { + let p = match self.get_node(idx_node) { + Params::DiscreteStatesContinousTime(p) => p, + }; + for next_node_state in 0..variables_domain[idx_node] { + let mut next_state = current_state.clone(); + next_state[idx_node] = next_node_state; + + let next_state_statetype: Vec = next_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + let idx_next_state = self.get_param_index_from_custom_parent_set( + &next_state_statetype, + &variables_set, + ); + amalgamated_cim[[0, idx_current_state, idx_next_state]] += + p.get_cim().as_ref().unwrap()[[ + self.get_param_index_network(idx_node, ¤t_state_statetype), + current_state[idx_node], + next_node_state, + ]]; + } + } + } + + let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( + "ctmp".to_string(), + BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), + ); + + println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); + + amalgamated_param.set_cim(amalgamated_cim).unwrap(); + + let mut ctmp = CtmpProcess::new(); + + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap(); + return ctmp; + } + + pub fn idx_to_state(variables_domain: &Array1, state: usize) -> Array1 { + let mut state = state; + let mut array_state = Array1::zeros(variables_domain.shape()[0]); + for (idx, var) in variables_domain.indexed_iter() { + array_state[idx] = state % var; + state = state / var; + } + + return array_state; + } } impl process::NetworkProcess for CtbnNetwork { diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 0ad0fc4..fc17a94 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,7 +1,10 @@ mod utils; use std::collections::BTreeSet; +use std::f64::EPSILON; -use reCTBN::process::ctbn::*; +use approx::AbsDiffEq; +use ndarray::arr3; +use reCTBN::process::{ctbn::*, ctmp::*}; use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; @@ -129,3 +132,34 @@ fn compute_index_from_custom_parent_set() { ); assert_eq!(2, idx); } + +#[test] +fn simple_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + net.initialize_adj_matrix(); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); + } + } + + let ctmp = net.amalgamation(); + let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){ + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + + assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); +} From 4a7a8c5fbab5c0addcd2785a2f585c6c32eb4637 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 14:06:30 +0100 Subject: [PATCH 3/6] Added more tests --- reCTBN/src/params.rs | 4 +- reCTBN/src/process/ctbn.rs | 10 +- reCTBN/tests/ctbn.rs | 228 ++++++++++++++++++++++++++++++++++++- 3 files changed, 234 insertions(+), 8 deletions(-) diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs index 070c997..9f63860 100644 --- a/reCTBN/src/params.rs +++ b/reCTBN/src/params.rs @@ -267,11 +267,13 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { ))); } + let domain_size = domain_size as f64; + // Check if each row sum up to 0 if cim .sum_axis(Axis(2)) .iter() - .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) + .any(|x| f64::abs(x.clone()) > f64::EPSILON * domain_size) { return Err(ParamsError::InvalidCIM(String::from( "The sum of each row must be 0", diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 3852c50..a6be923 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -70,7 +70,12 @@ impl CtbnNetwork { nodes: Vec::new(), } } - + + ///Transform the **CTBN** into a **CTMP** + /// + /// # Return + /// + /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { for v in self.nodes.iter() { match v { @@ -123,8 +128,7 @@ impl CtbnNetwork { BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - println!("state space: {} - #nodes: {}\n{:?}", &state_space, self.nodes.len(), &amalgamated_cim); - + println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index fc17a94..a7752f2 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -4,9 +4,9 @@ use std::f64::EPSILON; use approx::AbsDiffEq; use ndarray::arr3; -use reCTBN::process::{ctbn::*, ctmp::*}; -use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; +use reCTBN::process::NetworkProcess; +use reCTBN::process::{ctbn::*, ctmp::*}; use utils::generate_discrete_time_continous_node; #[test] @@ -149,7 +149,7 @@ fn simple_amalgamation() { } let ctmp = net.amalgamation(); - let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0){ + let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0) { p.get_cim().as_ref().unwrap() } else { unreachable!(); @@ -160,6 +160,226 @@ fn simple_amalgamation() { unreachable!(); }; - assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); } + +#[test] +fn chain_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + let ctmp = net.amalgamation(); + + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + let p_ctmp_handmade = arr3(&[[ + [ + -1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} + +#[test] +fn chainfork_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + let n4 = net + .add_node(generate_discrete_time_continous_node(String::from("n4"), 2)) + .unwrap(); + + net.add_edge(n1, n3); + net.add_edge(n2, n3); + net.add_edge(n3, n4); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n4) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + + let ctmp = net.amalgamation(); + + let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { + p.get_cim().as_ref().unwrap() + } else { + unreachable!(); + }; + + let p_ctmp_handmade = arr3(&[[ + [ + -2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} From 28ed1a40b32bb1a6629e5a67b379ed0b56c89861 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 14:52:39 +0100 Subject: [PATCH 4/6] Fix for clippy --- reCTBN/src/lib.rs | 2 +- reCTBN/src/process/ctbn.rs | 13 ++-- reCTBN/src/process/ctmp.rs | 60 +++++++++++-------- reCTBN/src/sampling.rs | 2 +- .../src/structure_learning/hypothesis_test.rs | 2 +- .../src/structure_learning/score_function.rs | 2 +- reCTBN/src/tools.rs | 2 +- reCTBN/tests/ctbn.rs | 4 +- reCTBN/tests/ctmp.rs | 6 +- 9 files changed, 52 insertions(+), 41 deletions(-) diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index 6ab59cb..c62c42e 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -5,7 +5,7 @@ extern crate approx; pub mod parameter_learning; pub mod params; +pub mod process; pub mod sampling; pub mod structure_learning; pub mod tools; -pub mod process; diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index a6be923..7cb327d 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -70,7 +70,7 @@ impl CtbnNetwork { nodes: Vec::new(), } } - + ///Transform the **CTBN** into a **CTMP** /// /// # Return @@ -105,10 +105,8 @@ impl CtbnNetwork { let mut next_state = current_state.clone(); next_state[idx_node] = next_node_state; - let next_state_statetype: Vec = next_state - .iter() - .map(|x| StateType::Discrete(*x)) - .collect(); + let next_state_statetype: Vec = + next_state.iter().map(|x| StateType::Discrete(*x)).collect(); let idx_next_state = self.get_param_index_from_custom_parent_set( &next_state_statetype, &variables_set, @@ -127,13 +125,14 @@ impl CtbnNetwork { "ctmp".to_string(), BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), ); - + println!("{:?}", amalgamated_cim); amalgamated_param.set_cim(amalgamated_cim).unwrap(); let mut ctmp = CtmpProcess::new(); - ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)).unwrap(); + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) + .unwrap(); return ctmp; } diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs index b0b042a..81509fa 100644 --- a/reCTBN/src/process/ctmp.rs +++ b/reCTBN/src/process/ctmp.rs @@ -1,11 +1,14 @@ use std::collections::BTreeSet; -use crate::{process, params::{Params, StateType}}; +use crate::{ + params::{Params, StateType}, + process, +}; use super::NetworkProcess; pub struct CtmpProcess { - param: Option + param: Option, } impl CtmpProcess { @@ -24,26 +27,28 @@ impl NetworkProcess for CtmpProcess { None => { self.param = Some(n); Ok(0) - }, - Some(_) => Err(process::NetworkError::NodeInsertionError("CtmpProcess has only one node".to_string())) + } + Some(_) => Err(process::NetworkError::NodeInsertionError( + "CtmpProcess has only one node".to_string(), + )), } } - fn add_edge(&mut self, parent: usize, child: usize) { + fn add_edge(&mut self, _parent: usize, _child: usize) { unimplemented!("CtmpProcess has only one node") } fn get_node_indices(&self) -> std::ops::Range { match self.param { None => 0..0, - Some(_) => 0..1 + Some(_) => 0..1, } } fn get_number_of_nodes(&self) -> usize { match self.param { None => 0, - Some(_) => 1 + Some(_) => 1, } } @@ -63,11 +68,14 @@ impl NetworkProcess for CtmpProcess { } } - fn get_param_index_network(&self, node: usize, current_state: &Vec) - -> usize { + fn get_param_index_network( + &self, + node: usize, + current_state: &Vec, + ) -> usize { if node == 0 { match current_state[0] { - StateType::Discrete(x) => x + StateType::Discrete(x) => x, } } else { unimplemented!("CtmpProcess has only one node") @@ -76,31 +84,35 @@ impl NetworkProcess for CtmpProcess { fn get_param_index_from_custom_parent_set( &self, - current_state: &Vec, - parent_set: &std::collections::BTreeSet, + _current_state: &Vec, + _parent_set: &std::collections::BTreeSet, ) -> usize { unimplemented!("CtmpProcess has only one node") } fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { match self.param { - Some(_) => if node == 0 { - BTreeSet::new() - } else { - unimplemented!("CtmpProcess has only one node") - }, - None => panic!("Uninitialized CtmpProcess") + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), } } fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { match self.param { - Some(_) => if node == 0 { - BTreeSet::new() - } else { - unimplemented!("CtmpProcess has only one node") - }, - None => panic!("Uninitialized CtmpProcess") + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), } } } diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index 050daeb..0662994 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,8 +1,8 @@ //! Module containing methods for the sampling. use crate::{ - process::NetworkProcess, params::{self, ParamsTrait}, + process::NetworkProcess, }; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 4ec3377..344c995 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::params::*; -use crate::{process, parameter_learning}; +use crate::{parameter_learning, process}; pub trait HypothesisTest { fn call( diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index 8943478..f8b38b5 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{process, parameter_learning, params, tools}; +use crate::{parameter_learning, params, process, tools}; pub trait ScoreFunction { fn call( diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 6f2f648..2e727e8 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{process, params}; +use crate::{params, process}; pub struct Trajectory { time: Array1, diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index a7752f2..7db2bae 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,12 +1,12 @@ mod utils; use std::collections::BTreeSet; -use std::f64::EPSILON; + use approx::AbsDiffEq; use ndarray::arr3; use reCTBN::params::{self, ParamsTrait}; use reCTBN::process::NetworkProcess; -use reCTBN::process::{ctbn::*, ctmp::*}; +use reCTBN::process::{ctbn::*}; use utils::generate_discrete_time_continous_node; #[test] diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs index 31bc6df..830bfe0 100644 --- a/reCTBN/tests/ctmp.rs +++ b/reCTBN/tests/ctmp.rs @@ -45,7 +45,7 @@ fn add_edge_to_ctmp() { let _n1 = net .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); net.add_edge(0, 1) } @@ -64,7 +64,7 @@ fn childen_and_parents() { #[test] #[should_panic] fn get_childen_panic() { - let mut net = CtmpProcess::new(); + let net = CtmpProcess::new(); net.get_children_set(0); } @@ -81,7 +81,7 @@ fn get_childen_panic2() { #[test] #[should_panic] fn get_parent_panic() { - let mut net = CtmpProcess::new(); + let net = CtmpProcess::new(); net.get_parent_set(0); } From 44eaf8713fb9ce8f7bb05cc63af4bc625438e983 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 15:16:33 +0100 Subject: [PATCH 5/6] Fix for clippy --- reCTBN/src/process/ctbn.rs | 6 ------ reCTBN/tests/ctbn.rs | 31 +++++++++++-------------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 7cb327d..7473d4c 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -77,12 +77,6 @@ impl CtbnNetwork { /// /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { - for v in self.nodes.iter() { - match v { - Params::DiscreteStatesContinousTime(_) => {} - _ => panic!("Unsupported node"), - } - } let variables_domain = Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 7db2bae..3eb40d7 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -149,16 +149,10 @@ fn simple_amalgamation() { } let ctmp = net.amalgamation(); - let p_ctbn = if let params::Params::DiscreteStatesContinousTime(p) = &net.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0); + let p_ctbn = p_ctbn.get_cim().as_ref().unwrap(); + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); } @@ -211,11 +205,10 @@ fn chain_amalgamation() { let ctmp = net.amalgamation(); - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + + + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); let p_ctmp_handmade = arr3(&[[ [ @@ -308,11 +301,9 @@ fn chainfork_amalgamation() { let ctmp = net.amalgamation(); - let p_ctmp = if let params::Params::DiscreteStatesContinousTime(p) = &ctmp.get_node(0) { - p.get_cim().as_ref().unwrap() - } else { - unreachable!(); - }; + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); let p_ctmp_handmade = arr3(&[[ [ From 38e744e034e52e277bab2c3c7052f9e796862d81 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 16 Nov 2022 15:25:27 +0100 Subject: [PATCH 6/6] Fix fmt --- reCTBN/src/process/ctbn.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs index 7473d4c..c949afe 100644 --- a/reCTBN/src/process/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -77,7 +77,6 @@ impl CtbnNetwork { /// /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork pub fn amalgamation(&self) -> CtmpProcess { - let variables_domain = Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));