diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs index db33ae4..6ab59cb 100644 --- a/reCTBN/src/lib.rs +++ b/reCTBN/src/lib.rs @@ -3,10 +3,9 @@ #[cfg(test)] extern crate approx; -pub mod ctbn; -pub mod network; pub mod parameter_learning; pub mod params; pub mod sampling; pub mod structure_learning; pub mod tools; +pub mod process; diff --git a/reCTBN/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs index 61d4dca..2aa518c 100644 --- a/reCTBN/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -5,10 +5,10 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use crate::params::*; -use crate::{network, tools}; +use crate::{process, tools}; pub trait ParameterLearning { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -17,7 +17,7 @@ pub trait ParameterLearning { ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, dataset: &tools::Dataset, node: usize, @@ -73,7 +73,7 @@ pub fn sufficient_statistics( pub struct MLE {} impl ParameterLearning for MLE { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -120,7 +120,7 @@ pub struct BayesianApproach { } impl ParameterLearning for BayesianApproach { - fn fit( + fn fit( &self, net: &T, dataset: &tools::Dataset, @@ -177,7 +177,7 @@ impl Cache

{ dataset, } } - pub fn fit( + pub fn fit( &mut self, net: &T, node: usize, diff --git a/reCTBN/src/network.rs b/reCTBN/src/process.rs similarity index 98% rename from reCTBN/src/network.rs rename to reCTBN/src/process.rs index fbdd2e6..2b70b59 100644 --- a/reCTBN/src/network.rs +++ b/reCTBN/src/process.rs @@ -1,5 +1,8 @@ //! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs +pub mod ctbn; +pub mod ctmp; + use std::collections::BTreeSet; use thiserror::Error; @@ -15,7 +18,7 @@ pub enum NetworkError { /// It defines the required methods for a structure used as a Probabilistic Graphical Models (such /// as a CTBN). -pub trait Network { +pub trait NetworkProcess { fn initialize_adj_matrix(&mut self); fn add_node(&mut self, n: params::Params) -> Result; /// Add an **directed edge** between a two nodes of the network. diff --git a/reCTBN/src/ctbn.rs b/reCTBN/src/process/ctbn.rs similarity index 95% rename from reCTBN/src/ctbn.rs rename to reCTBN/src/process/ctbn.rs index 2b01d14..c59d99d 100644 --- a/reCTBN/src/ctbn.rs +++ b/reCTBN/src/process/ctbn.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; -use crate::network; +use crate::process; use crate::params::{Params, ParamsTrait, StateType}; /// It represents both the structure and the parameters of a CTBN. @@ -20,9 +20,9 @@ use crate::params::{Params, ParamsTrait, StateType}; /// /// ```rust /// use std::collections::BTreeSet; -/// use reCTBN::network::Network; +/// use reCTBN::process::NetworkProcess; /// use reCTBN::params; -/// use reCTBN::ctbn::*; +/// use reCTBN::process::ctbn::*; /// /// //Create the domain for a discrete node /// let mut domain = BTreeSet::new(); @@ -69,7 +69,7 @@ impl CtbnNetwork { } } -impl network::Network for CtbnNetwork { +impl process::NetworkProcess for CtbnNetwork { /// Initialize an Adjacency matrix. fn initialize_adj_matrix(&mut self) { self.adj_matrix = Some(Array2::::zeros( @@ -78,7 +78,7 @@ impl network::Network for CtbnNetwork { } /// Add a new node. - fn add_node(&mut self, mut n: Params) -> Result { + fn add_node(&mut self, mut n: Params) -> Result { n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs new file mode 100644 index 0000000..b0b042a --- /dev/null +++ b/reCTBN/src/process/ctmp.rs @@ -0,0 +1,106 @@ +use std::collections::BTreeSet; + +use crate::{process, params::{Params, StateType}}; + +use super::NetworkProcess; + +pub struct CtmpProcess { + param: Option +} + +impl CtmpProcess { + pub fn new() -> CtmpProcess { + CtmpProcess { param: None } + } +} + +impl NetworkProcess for CtmpProcess { + fn initialize_adj_matrix(&mut self) { + unimplemented!("CtmpProcess has only one node") + } + + fn add_node(&mut self, n: crate::params::Params) -> Result { + match self.param { + None => { + self.param = Some(n); + Ok(0) + }, + Some(_) => Err(process::NetworkError::NodeInsertionError("CtmpProcess has only one node".to_string())) + } + } + + fn add_edge(&mut self, parent: usize, child: usize) { + unimplemented!("CtmpProcess has only one node") + } + + fn get_node_indices(&self) -> std::ops::Range { + match self.param { + None => 0..0, + Some(_) => 0..1 + } + } + + fn get_number_of_nodes(&self) -> usize { + match self.param { + None => 0, + Some(_) => 1 + } + } + + fn get_node(&self, node_idx: usize) -> &crate::params::Params { + if node_idx == 0 { + self.param.as_ref().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { + if node_idx == 0 { + self.param.as_mut().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_network(&self, node: usize, current_state: &Vec) + -> usize { + if node == 0 { + match current_state[0] { + StateType::Discrete(x) => x + } + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &std::collections::BTreeSet, + ) -> usize { + unimplemented!("CtmpProcess has only one node") + } + + fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } + + fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + }, + None => panic!("Uninitialized CtmpProcess") + } + } +} diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs index d435634..050daeb 100644 --- a/reCTBN/src/sampling.rs +++ b/reCTBN/src/sampling.rs @@ -1,7 +1,7 @@ //! Module containing methods for the sampling. use crate::{ - network::Network, + process::NetworkProcess, params::{self, ParamsTrait}, }; use rand::SeedableRng; @@ -13,7 +13,7 @@ pub trait Sampler: Iterator { pub struct ForwardSampler<'a, T> where - T: Network, + T: NetworkProcess, { net: &'a T, rng: ChaCha8Rng, @@ -22,7 +22,7 @@ where next_transitions: Vec>, } -impl<'a, T: Network> ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { pub fn new(net: &'a T, seed: Option) -> ForwardSampler<'a, T> { let rng: ChaCha8Rng = match seed { //If a seed is present use it to initialize the random generator. @@ -42,7 +42,7 @@ impl<'a, T: Network> ForwardSampler<'a, T> { } } -impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { type Item = (f64, Vec); fn next(&mut self) -> Option { @@ -100,7 +100,7 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { } } -impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { +impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { fn reset(&mut self) { self.current_time = 0.0; self.current_state = self diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs index 57fed1e..b272e22 100644 --- a/reCTBN/src/structure_learning.rs +++ b/reCTBN/src/structure_learning.rs @@ -4,10 +4,10 @@ pub mod constraint_based_algorithm; pub mod hypothesis_test; pub mod score_based_algorithm; pub mod score_function; -use crate::{network, tools}; +use crate::{process, tools}; pub trait StructureLearningAlgorithm { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network; + T: process::NetworkProcess; } diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs index 6474155..4ec3377 100644 --- a/reCTBN/src/structure_learning/hypothesis_test.rs +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -6,7 +6,7 @@ use ndarray::{Array3, Axis}; use statrs::distribution::{ChiSquared, ContinuousCDF}; use crate::params::*; -use crate::{network, parameter_learning}; +use crate::{process, parameter_learning}; pub trait HypothesisTest { fn call( @@ -18,7 +18,7 @@ pub trait HypothesisTest { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning; } @@ -135,7 +135,7 @@ impl HypothesisTest for ChiSquare { cache: &mut parameter_learning::Cache

, ) -> bool where - T: network::Network, + T: process::NetworkProcess, P: parameter_learning::ParameterLearning, { // Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs index 9e329eb..16e9056 100644 --- a/reCTBN/src/structure_learning/score_based_algorithm.rs +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -4,7 +4,7 @@ use std::collections::BTreeSet; use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::StructureLearningAlgorithm; -use crate::{network, tools}; +use crate::{process, tools}; pub struct HillClimbing { score_function: S, @@ -23,7 +23,7 @@ impl HillClimbing { impl StructureLearningAlgorithm for HillClimbing { fn fit_transform(&self, net: T, dataset: &tools::Dataset) -> T where - T: network::Network, + T: process::NetworkProcess, { //Check the coherence between dataset and network if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs index cb6ad7b..8943478 100644 --- a/reCTBN/src/structure_learning/score_function.rs +++ b/reCTBN/src/structure_learning/score_function.rs @@ -5,7 +5,7 @@ use std::collections::BTreeSet; use ndarray::prelude::*; use statrs::function::gamma; -use crate::{network, parameter_learning, params, tools}; +use crate::{process, parameter_learning, params, tools}; pub trait ScoreFunction { fn call( @@ -16,7 +16,7 @@ pub trait ScoreFunction { dataset: &tools::Dataset, ) -> f64 where - T: network::Network; + T: process::NetworkProcess; } pub struct LogLikelihood { @@ -41,7 +41,7 @@ impl LogLikelihood { dataset: &tools::Dataset, ) -> (f64, Array3) where - T: network::Network, + T: process::NetworkProcess, { //Identify the type of node used match &net.get_node(node) { @@ -100,7 +100,7 @@ impl ScoreFunction for LogLikelihood { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { self.compute_score(net, node, parent_set, dataset).0 } @@ -127,7 +127,7 @@ impl ScoreFunction for BIC { dataset: &tools::Dataset, ) -> f64 where - T: network::Network, + T: process::NetworkProcess, { //Compute the log-likelihood let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index aa48883..6f2f648 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,7 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; -use crate::{network, params}; +use crate::{process, params}; pub struct Trajectory { time: Array1, @@ -51,7 +51,7 @@ impl Dataset { } } -pub fn trajectory_generator( +pub fn trajectory_generator( net: &T, n_trajectories: u64, t_end: f64, diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs index 63c9621..0ad0fc4 100644 --- a/reCTBN/tests/ctbn.rs +++ b/reCTBN/tests/ctbn.rs @@ -1,8 +1,8 @@ mod utils; use std::collections::BTreeSet; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params::{self, ParamsTrait}; use utils::generate_discrete_time_continous_node; diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs new file mode 100644 index 0000000..31bc6df --- /dev/null +++ b/reCTBN/tests/ctmp.rs @@ -0,0 +1,127 @@ +mod utils; + +use std::collections::BTreeSet; + +use reCTBN::{ + params, + params::ParamsTrait, + process::{ctmp::*, NetworkProcess}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn define_simple_ctmp() { + let _ = CtmpProcess::new(); + assert!(true); +} + +#[test] +fn add_node_to_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); +} + +#[test] +fn add_two_nodes_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + match n2 { + Ok(_) => assert!(false), + Err(_) => assert!(true), + }; +} + +#[test] +#[should_panic] +fn add_edge_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + net.add_edge(0, 1) +} + +#[test] +fn childen_and_parents() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + assert_eq!(0, net.get_parent_set(0).len()); + assert_eq!(0, net.get_children_set(0).len()); +} + +#[test] +#[should_panic] +fn get_childen_panic() { + let mut net = CtmpProcess::new(); + net.get_children_set(0); +} + +#[test] +#[should_panic] +fn get_childen_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_children_set(1); +} + +#[test] +#[should_panic] +fn get_parent_panic() { + let mut net = CtmpProcess::new(); + net.get_parent_set(0); +} + +#[test] +#[should_panic] +fn get_parent_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_parent_set(1); +} + +#[test] +fn compute_index_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); + assert_eq!(6, idx); +} + +#[test] +#[should_panic] +fn compute_index_from_custom_parent_set_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let _idx = net.get_param_index_from_custom_parent_set( + &vec![params::StateType::Discrete(6)], + &BTreeSet::from([0]) + ); +} diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs index 7d09b07..2cbc185 100644 --- a/reCTBN/tests/parameter_learning.rs +++ b/reCTBN/tests/parameter_learning.rs @@ -2,8 +2,8 @@ mod utils; use ndarray::arr3; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::*; use reCTBN::params; use reCTBN::tools::*; diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs index a1667c2..2ec64b2 100644 --- a/reCTBN/tests/structure_learning.rs +++ b/reCTBN/tests/structure_learning.rs @@ -4,8 +4,8 @@ mod utils; use std::collections::BTreeSet; use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::parameter_learning::BayesianApproach; use reCTBN::parameter_learning::Cache; use reCTBN::params; diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs index 589b04e..806faef 100644 --- a/reCTBN/tests/tools.rs +++ b/reCTBN/tests/tools.rs @@ -1,6 +1,6 @@ use ndarray::{arr1, arr2, arr3}; -use reCTBN::ctbn::*; -use reCTBN::network::Network; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; use reCTBN::params; use reCTBN::tools::*;