diff --git a/.github/ISSUE_TEMPLATE/meta_request.md b/.github/ISSUE_TEMPLATE/meta_request.md new file mode 100644 index 0000000..d80ccde --- /dev/null +++ b/.github/ISSUE_TEMPLATE/meta_request.md @@ -0,0 +1,26 @@ +--- +name: 📑 Meta request +about: Suggest an idea or a change for this same repository +title: '[Meta] ' +labels: 'meta' +assignees: '' + +--- + +## Description + +As a X, I want to Y, so Z. + +## Acceptance Criteria + +* Criteria 1 +* Criteria 2 + +## Checklist + +* [ ] Element 1 +* [ ] Element 2 + +## (Optional) Extra info + +None diff --git a/.github/ISSUE_TEMPLATE/refactor_request.md b/.github/ISSUE_TEMPLATE/refactor_request.md new file mode 100644 index 0000000..9a4d090 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/refactor_request.md @@ -0,0 +1,26 @@ +--- +name: ⚙️ Refactor request +about: Suggest a refactor for this project +title: '[Refactor] ' +labels: enhancement, refactor +assignees: '' + +--- + +## Description + +As a X, I want to Y, so Z. + +## Acceptance Criteria + +* Criteria 1 +* Criteria 2 + +## Checklist + +* [ ] Element 1 +* [ ] Element 2 + +## (Optional) Extra info + +None diff --git a/.github/labels.yml b/.github/labels.yml new file mode 100644 index 0000000..0129d00 --- /dev/null +++ b/.github/labels.yml @@ -0,0 +1,36 @@ +- name: "bug" + color: "d73a4a" + description: "Something isn't working" +- name: "enhancement" + color: "a2eeef" + description: "New feature or request" +- name: "refactor" + color: "B06E16" + description: "Change in the structure" +- name: "documentation" + color: "0075ca" + description: "Improvements or additions to documentation" +- name: "meta" + color: "1D76DB" + description: "Something related to the project itself" + +- name: "duplicate" + color: "cfd3d7" + description: "This issue or pull request already exists" + +- name: "help wanted" + color: "008672" + description: "Extra help is needed" +- name: "urgent" + color: "D93F0B" + description: "" +- name: "wontfix" + color: "ffffff" + description: "This will not be worked on" +- name: "invalid" + color: "e4e669" + description: "This doesn't seem right" + +- name: "question" + color: "d876e3" + description: "Further information is requested" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7e51286..063a3e0 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,4 +1,4 @@ -# Pull/Merge Request into master dev + ## Description diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..7cc300c --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,53 @@ +name: build + +on: + push: + branches: [ main, dev ] + pull_request: + branches: [ dev ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Setup Rust stable (default) + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + default: true + components: clippy, rustfmt, rust-docs + - name: Setup Rust nightly + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + default: false + components: rustfmt + - name: Docs (doc) + uses: actions-rs/cargo@v1 + with: + command: rustdoc + args: --package reCTBN -- --default-theme=ayu + - name: Linting (clippy) + uses: actions-rs/clippy-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + args: --all-targets -- -D warnings -A clippy::all -W clippy::correctness + - name: Formatting (rustfmt) + uses: actions-rs/cargo@v1 + with: + toolchain: nightly + command: fmt + args: --all -- --check --verbose + - name: Tests (test) + uses: actions-rs/cargo@v1 + with: + command: test + args: --tests diff --git a/.github/workflows/labels.yml b/.github/workflows/labels.yml new file mode 100644 index 0000000..2d5bc59 --- /dev/null +++ b/.github/workflows/labels.yml @@ -0,0 +1,23 @@ +name: meta-github + +on: + push: + branches: + - dev + +jobs: + labeler: + runs-on: ubuntu-latest + steps: + - + name: Checkout + uses: actions/checkout@v2 + - + name: Run Labeler + if: success() + uses: crazy-max/ghaction-github-labeler@v3 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + yaml-file: .github/labels.yml + skip-delete: false + dry-run: false diff --git a/.gitignore b/.gitignore index 96ef6c0..c640ca5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.vscode diff --git a/Cargo.toml b/Cargo.toml index 3aa7c53..53c74f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,17 +1,5 @@ -[package] -name = "rustyCTBN" -version = "0.1.0" -edition = "2021" +[workspace] -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] - -ndarray = {version="*", features=["approx"]} -thiserror = "*" -rand = "*" -bimap = "*" -enum_dispatch = "*" - -[dev-dependencies] -approx = "*" +members = [ + "reCTBN", +] diff --git a/README.md b/README.md index be62df2..6a60dff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@
-# rustyCTBN +# reCTBN
@@ -37,8 +37,30 @@ To launch **tests**: cargo test ``` -To **lint**: +To **lint** with `cargo check`: ```sh -cargo check +cargo check --all-targets +``` + +Or with `clippy`: + +```sh +cargo clippy --all-targets -- -A clippy::all -W clippy::correctness +``` + +To check the **formatting**: + +> **NOTE:** remove `--check` to apply the changes to the file(s). + +```sh +cargo fmt --all -- --check +``` + +## Documentation + +To generate the **documentation**: + +```sh +cargo rustdoc --package reCTBN --open -- --default-theme=ayu ``` diff --git a/reCTBN/Cargo.toml b/reCTBN/Cargo.toml new file mode 100644 index 0000000..4749b23 --- /dev/null +++ b/reCTBN/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "reCTBN" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ndarray = {version="~0.15", features=["approx-0_5"]} +thiserror = "1.0.37" +rand = "~0.8" +bimap = "~0.6" +enum_dispatch = "~0.3" +statrs = "~0.16" +rand_chacha = "~0.3" +itertools = "~0.10" +rayon = "~1.6" + +[dev-dependencies] +approx = { package = "approx", version = "~0.5" } diff --git a/reCTBN/src/lib.rs b/reCTBN/src/lib.rs new file mode 100644 index 0000000..1997fa6 --- /dev/null +++ b/reCTBN/src/lib.rs @@ -0,0 +1,12 @@ +#![doc = include_str!("../../README.md")] +#![allow(non_snake_case)] +#[cfg(test)] +extern crate approx; + +pub mod parameter_learning; +pub mod params; +pub mod process; +pub mod reward; +pub mod sampling; +pub mod structure_learning; +pub mod tools; diff --git a/src/parameter_learning.rs b/reCTBN/src/parameter_learning.rs similarity index 60% rename from src/parameter_learning.rs rename to reCTBN/src/parameter_learning.rs index 67ea07f..3c34d06 100644 --- a/src/parameter_learning.rs +++ b/reCTBN/src/parameter_learning.rs @@ -1,40 +1,35 @@ -use crate::network; -use crate::params::*; -use crate::tools; -use ndarray::prelude::*; -use ndarray::{concatenate, Slice}; +//! Module containing methods used to learn the parameters. + use std::collections::BTreeSet; -pub trait ParameterLearning{ - fn fit( +use ndarray::prelude::*; + +use crate::params::*; +use crate::{process, tools::Dataset}; + +pub trait ParameterLearning: Sync { + fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2); + ) -> Params; } -pub fn sufficient_statistics( +pub fn sufficient_statistics( net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, - parent_set: &BTreeSet - ) -> (Array3, Array2) { + parent_set: &BTreeSet, +) -> (Array3, Array2) { //Get the number of values assumable by the node - let node_domain = net - .get_node(node.clone()) - .params - .get_reserved_space_as_parent(); + let node_domain = net.get_node(node.clone()).get_reserved_space_as_parent(); //Get the number of values assumable by each parent of the node let parentset_domain: Vec = parent_set .iter() - .map(|x| { - net.get_node(x.clone()) - .params - .get_reserved_space_as_parent() - }) + .map(|x| net.get_node(x.clone()).get_reserved_space_as_parent()) .collect(); //Vector used to convert a specific configuration of the parent_set to the corresponding index @@ -48,7 +43,7 @@ pub fn sufficient_statistics( vector_to_idx[*idx] = acc; acc * x }); - + //Number of transition given a specific configuration of the parent set let mut M: Array3 = Array::zeros((parentset_domain.iter().product(), node_domain, node_domain)); @@ -57,12 +52,12 @@ pub fn sufficient_statistics( let mut T: Array2 = Array::zeros((parentset_domain.iter().product(), node_domain)); //Compute the sufficient statistics - for trj in dataset.trajectories.iter() { - for idx in 0..(trj.time.len() - 1) { - let t1 = trj.time[idx]; - let t2 = trj.time[idx + 1]; - let ev1 = trj.events.row(idx); - let ev2 = trj.events.row(idx + 1); + for trj in dataset.get_trajectories().iter() { + for idx in 0..(trj.get_time().len() - 1) { + let t1 = trj.get_time()[idx]; + let t2 = trj.get_time()[idx + 1]; + let ev1 = trj.get_events().row(idx); + let ev2 = trj.get_events().row(idx + 1); let idx1 = vector_to_idx.dot(&ev1); T[[idx1, ev1[node]]] += t2 - t1; @@ -73,34 +68,30 @@ pub fn sufficient_statistics( } return (M, T); - } pub struct MLE {} impl ParameterLearning for MLE { - - fn fit( + fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { - //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes - + ) -> Params { //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { Some(p) => p, None => net.get_parent_set(node), }; - + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - //Compute the CIM as M[i,x,y]/T[i,x] + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); + .for_each(|(mut C, m)| C.assign(&(&m / &T))); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); @@ -109,39 +100,53 @@ impl ParameterLearning for MLE { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - return (CIM, M, T); + + let mut n: Params = net.get_node(node).clone(); + + match n { + Params::DiscreteStatesContinousTime(ref mut dsct) => { + dsct.set_cim_unchecked(CIM); + dsct.set_transitions(M); + dsct.set_residence_time(T); + } + }; + return n; } } pub struct BayesianApproach { - pub default_alpha: usize, - pub default_tau: f64 + pub alpha: usize, + pub tau: f64, } impl ParameterLearning for BayesianApproach { - fn fit( + fn fit( &self, net: &T, - dataset: &tools::Dataset, + dataset: &Dataset, node: usize, parent_set: Option>, - ) -> (Array3, Array3, Array2) { - //TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes - + ) -> Params { //Use parent_set from parameter if present. Otherwise use parent_set from network. let parent_set = match parent_set { Some(p) => p, None => net.get_parent_set(node), }; - - let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); - M.mapv_inplace(|x|{x + self.default_alpha}); - T.mapv_inplace(|x|{x + self.default_tau}); - //Compute the CIM as M[i,x,y]/T[i,x] + + let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); + + let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64; + let tau: f64 = self.tau as f64 / M.shape()[0] as f64; + + //Compute the CIM as M[i,x,y]/T[i,x] let mut CIM: Array3 = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); CIM.axis_iter_mut(Axis(2)) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) - .for_each(|(mut C, m)| C.assign(&(&m/&T))); + .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); + + CIM.outer_iter_mut().for_each(|mut C| { + C.diag_mut().fill(0.0); + }); //Set the diagonal of the inner matrices to the the row sum multiplied by -1 let tmp_diag_sum: Array2 = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); @@ -150,6 +155,16 @@ impl ParameterLearning for BayesianApproach { .for_each(|(mut C, diag)| { C.diag_mut().assign(&diag); }); - return (CIM, M, T); + + let mut n: Params = net.get_node(node).clone(); + + match n { + Params::DiscreteStatesContinousTime(ref mut dsct) => { + dsct.set_cim_unchecked(CIM); + dsct.set_transitions(M); + dsct.set_residence_time(T); + } + }; + return n; } } diff --git a/reCTBN/src/params.rs b/reCTBN/src/params.rs new file mode 100644 index 0000000..ccbb750 --- /dev/null +++ b/reCTBN/src/params.rs @@ -0,0 +1,287 @@ +//! Module containing methods to define different types of nodes. + +use std::collections::BTreeSet; + +use enum_dispatch::enum_dispatch; +use ndarray::prelude::*; +use rand::Rng; +use rand_chacha::ChaCha8Rng; +use thiserror::Error; + +/// Error types for trait Params +#[derive(Error, Debug, PartialEq)] +pub enum ParamsError { + #[error("Unsupported method")] + UnsupportedMethod(String), + #[error("Paramiters not initialized")] + ParametersNotInitialized(String), + #[error("Invalid cim for parameter")] + InvalidCIM(String), +} + +/// Allowed type of states +#[derive(Clone, Hash, PartialEq, Eq, Debug)] +pub enum StateType { + Discrete(usize), +} + +/// This is a core element for building different types of nodes; the goal is to define the set of +/// methods required to describes a generic node. +#[enum_dispatch(Params)] +pub trait ParamsTrait { + fn reset_params(&mut self); + + /// Randomly generate a possible state of the node disregarding the state of the node and it's + /// parents. + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType; + + /// Randomly generate a residence time for the given node taking into account the node state + /// and its parent set. + fn get_random_residence_time( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result; + + /// Randomly generate a possible state for the given node taking into account the node state + /// and its parent set. + fn get_random_state( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result; + + /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. + fn get_reserved_space_as_parent(&self) -> usize; + + /// Index used by discrete node to represents their states as usize. + fn state_to_index(&self, state: &StateType) -> usize; + + /// Validate parameters against domain + fn validate_params(&self) -> Result<(), ParamsError>; + + /// Return a reference to the associated label + fn get_label(&self) -> &String; +} + +/// Is a core element for building different types of nodes; the goal is to define all the +/// supported type of Parameters +#[derive(Clone)] +#[enum_dispatch] +pub enum Params { + DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), +} + +/// This represents the parameters of a classical discrete node for ctbn and it's composed by the +/// following elements. +/// +/// # Arguments +/// +/// * `label` - node's variable name. +/// * `domain` - an ordered and exhaustive set of possible states. +/// * `cim` - Conditional Intensity Matrix. +/// * `transitions` - number of transitions from one state to another given a specific realization +/// of the parent set; is a sufficient statistics are mainly used during the parameter learning +/// task. +/// * `residence_time` - residence time in each possible state, given a specific realization of the +/// parent set; is a sufficient statistics are mainly used during the parameter learning task. +#[derive(Clone)] +pub struct DiscreteStatesContinousTimeParams { + label: String, + domain: BTreeSet, + cim: Option>, + transitions: Option>, + residence_time: Option>, +} + +impl DiscreteStatesContinousTimeParams { + pub fn new(label: String, domain: BTreeSet) -> DiscreteStatesContinousTimeParams { + DiscreteStatesContinousTimeParams { + label, + domain, + cim: Option::None, + transitions: Option::None, + residence_time: Option::None, + } + } + + /// Getter function for CIM + pub fn get_cim(&self) -> &Option> { + &self.cim + } + + /// Setter function for CIM. + /// + /// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method: + /// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`. + /// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns + /// `ParamsError`. + pub fn set_cim(&mut self, cim: Array3) -> Result<(), ParamsError> { + self.cim = Some(cim); + match self.validate_params() { + Ok(()) => Ok(()), + Err(e) => { + self.cim = None; + Err(e) + } + } + } + + /// Unchecked version of the setter function for CIM. + pub fn set_cim_unchecked(&mut self, cim: Array3) { + self.cim = Some(cim); + } + + /// Getter function for transitions. + pub fn get_transitions(&self) -> &Option> { + &self.transitions + } + + /// Setter function for transitions. + pub fn set_transitions(&mut self, transitions: Array3) { + self.transitions = Some(transitions); + } + + /// Getter function for residence_time. + pub fn get_residence_time(&self) -> &Option> { + &self.residence_time + } + + /// Setter function for residence_time. + pub fn set_residence_time(&mut self, residence_time: Array2) { + self.residence_time = Some(residence_time); + } +} + +impl ParamsTrait for DiscreteStatesContinousTimeParams { + fn reset_params(&mut self) { + self.cim = Option::None; + self.transitions = Option::None; + self.residence_time = Option::None; + } + + fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { + StateType::Discrete(rng.gen_range(0..(self.domain.len()))) + } + + fn get_random_residence_time( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result { + // Generate a random residence time given the current state of the node and its parent set. + // The method used is described in: + // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates + match &self.cim { + Option::Some(cim) => { + let lambda = cim[[u, state, state]] * -1.0; + let x: f64 = rng.gen_range(0.0..=1.0); + Ok(-x.ln() / lambda) + } + Option::None => Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + } + } + + fn get_random_state( + &self, + state: usize, + u: usize, + rng: &mut ChaCha8Rng, + ) -> Result { + // Generate a random transition given the current state of the node and its parent set. + // The method used is described in: + // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution + match &self.cim { + Option::Some(cim) => { + let lambda = cim[[u, state, state]] * -1.0; + let urand: f64 = rng.gen_range(0.0..=1.0); + + let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold( + (0, 0.0), + |mut acc, ele| { + if &acc.1 + ele < urand && ele > &0.0 { + acc.0 += 1; + } + if ele > &0.0 { + acc.1 += ele; + } + acc + }, + ); + + let next_state = if next_state.0 < state { + next_state.0 + } else { + next_state.0 + 1 + }; + + Ok(StateType::Discrete(next_state)) + } + Option::None => Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + } + } + + fn get_reserved_space_as_parent(&self) -> usize { + self.domain.len() + } + + fn state_to_index(&self, state: &StateType) -> usize { + match state { + StateType::Discrete(val) => val.clone() as usize, + } + } + + fn validate_params(&self) -> Result<(), ParamsError> { + let domain_size = self.domain.len(); + + // Check if the cim is initialized + if let None = self.cim { + return Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))); + } + let cim = self.cim.as_ref().unwrap(); + // Check if the inner dimensions of the cim are equal to the cardinality of the variable + if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { + return Err(ParamsError::InvalidCIM(format!( + "Incompatible shape {:?} with domain {:?}", + cim.shape(), + domain_size + ))); + } + + // Check if the diagonal of each cim is non-positive + if cim + .axis_iter(Axis(0)) + .any(|x| x.diag().iter().any(|x| x >= &0.0)) + { + return Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))); + } + + // Check if each row sum up to 0 + if cim + .sum_axis(Axis(2)) + .iter() + .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) + { + return Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0", + ))); + } + + return Ok(()); + } + + fn get_label(&self) -> &String { + &self.label + } +} diff --git a/reCTBN/src/process.rs b/reCTBN/src/process.rs new file mode 100644 index 0000000..45c5e0a --- /dev/null +++ b/reCTBN/src/process.rs @@ -0,0 +1,120 @@ +//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs + +pub mod ctbn; +pub mod ctmp; + +use std::collections::BTreeSet; + +use thiserror::Error; + +use crate::params; + +/// Error types for trait Network +#[derive(Error, Debug)] +pub enum NetworkError { + #[error("Error during node insertion")] + NodeInsertionError(String), +} + +/// This type is used to represent a specific realization of a generic NetworkProcess +pub type NetworkProcessState = Vec; + +/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such +/// as a CTBN). +pub trait NetworkProcess: Sync { + fn initialize_adj_matrix(&mut self); + fn add_node(&mut self, n: params::Params) -> Result; + /// Add an **directed edge** between a two nodes of the network. + /// + /// # Arguments + /// + /// * `parent` - parent node. + /// * `child` - child node. + fn add_edge(&mut self, parent: usize, child: usize); + + /// Get all the indices of the nodes contained inside the network. + fn get_node_indices(&self) -> std::ops::Range; + + /// Get the numbers of nodes contained in the network. + fn get_number_of_nodes(&self) -> usize; + + /// Get the **node param**. + /// + /// # Arguments + /// + /// * `node_idx` - node index value. + /// + /// # Return + /// + /// * The selected **node param**. + fn get_node(&self, node_idx: usize) -> ¶ms::Params; + + /// Get the **node param**. + /// + /// # Arguments + /// + /// * `node_idx` - node index value. + /// + /// # Return + /// + /// * The selected **node mutable param**. + fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; + + /// Compute the index that must be used to access the parameters of a `node`, given a specific + /// configuration of the network. + /// + /// Usually, the only values really used in `current_state` are the ones in the parent set of + /// the `node`. + /// + /// # Arguments + /// + /// * `node` - selected node. + /// * `current_state` - current configuration of the network. + /// + /// # Return + /// + /// * Index of the `node` relative to the network. + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; + + /// Compute the index that must be used to access the parameters of a `node`, given a specific + /// configuration of the network and a generic `parent_set`. + /// + /// Usually, the only values really used in `current_state` are the ones in the parent set of + /// the `node`. + /// + /// # Arguments + /// + /// * `current_state` - current configuration of the network. + /// * `parent_set` - parent set of the selected `node`. + /// + /// # Return + /// + /// * Index of the `node` relative to the network. + fn get_param_index_from_custom_parent_set( + &self, + current_state: &Vec, + parent_set: &BTreeSet, + ) -> usize; + + /// Get the **parent set** of a given **node**. + /// + /// # Arguments + /// + /// * `node` - node index value. + /// + /// # Return + /// + /// * The **parent set** of the selected node. + fn get_parent_set(&self, node: usize) -> BTreeSet; + + /// Get the **children set** of a given **node**. + /// + /// # Arguments + /// + /// * `node` - node index value. + /// + /// # Return + /// + /// * The **children set** of the selected node. + fn get_children_set(&self, node: usize) -> BTreeSet; +} diff --git a/reCTBN/src/process/ctbn.rs b/reCTBN/src/process/ctbn.rs new file mode 100644 index 0000000..d93400d --- /dev/null +++ b/reCTBN/src/process/ctbn.rs @@ -0,0 +1,247 @@ +//! Continuous Time Bayesian Network + +use std::collections::BTreeSet; + +use ndarray::prelude::*; + +use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; +use crate::process; + +use super::ctmp::CtmpProcess; +use super::{NetworkProcess, NetworkProcessState}; + +/// It represents both the structure and the parameters of a CTBN. +/// +/// # Arguments +/// +/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix +/// * `nodes` - A vector containing all the nodes and their parameters. +/// +/// The index of a node inside the vector is also used as index for the `adj_matrix`. +/// +/// # Example +/// +/// ```rust +/// use std::collections::BTreeSet; +/// use reCTBN::process::NetworkProcess; +/// use reCTBN::params; +/// use reCTBN::process::ctbn::*; +/// +/// //Create the domain for a discrete node +/// let mut domain = BTreeSet::new(); +/// domain.insert(String::from("A")); +/// domain.insert(String::from("B")); +/// +/// //Create the parameters for a discrete node using the domain +/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); +/// +/// //Create the node using the parameters +/// let X1 = params::Params::DiscreteStatesContinousTime(param); +/// +/// let mut domain = BTreeSet::new(); +/// domain.insert(String::from("A")); +/// domain.insert(String::from("B")); +/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); +/// let X2 = params::Params::DiscreteStatesContinousTime(param); +/// +/// //Initialize a ctbn +/// let mut net = CtbnNetwork::new(); +/// +/// //Add nodes +/// let X1 = net.add_node(X1).unwrap(); +/// let X2 = net.add_node(X2).unwrap(); +/// +/// //Add an edge +/// net.add_edge(X1, X2); +/// +/// //Get all the children of node X1 +/// let cs = net.get_children_set(X1); +/// assert_eq!(&X2, cs.iter().next().unwrap()); +/// ``` +pub struct CtbnNetwork { + adj_matrix: Option>, + nodes: Vec, +} + +impl CtbnNetwork { + pub fn new() -> CtbnNetwork { + CtbnNetwork { + adj_matrix: None, + nodes: Vec::new(), + } + } + + ///Transform the **CTBN** into a **CTMP** + /// + /// # Return + /// + /// * The equivalent *CtmpProcess* computed from the current CtbnNetwork + pub fn amalgamation(&self) -> CtmpProcess { + let variables_domain = + Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); + + let state_space = variables_domain.product(); + let variables_set = BTreeSet::from_iter(self.get_node_indices()); + let mut amalgamated_cim: Array3 = Array::zeros((1, state_space, state_space)); + + for idx_current_state in 0..state_space { + let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); + let current_state_statetype: NetworkProcessState = current_state + .iter() + .map(|x| StateType::Discrete(*x)) + .collect(); + for idx_node in 0..self.nodes.len() { + let p = match self.get_node(idx_node) { + Params::DiscreteStatesContinousTime(p) => p, + }; + for next_node_state in 0..variables_domain[idx_node] { + let mut next_state = current_state.clone(); + next_state[idx_node] = next_node_state; + + let next_state_statetype: NetworkProcessState = + next_state.iter().map(|x| StateType::Discrete(*x)).collect(); + let idx_next_state = self.get_param_index_from_custom_parent_set( + &next_state_statetype, + &variables_set, + ); + amalgamated_cim[[0, idx_current_state, idx_next_state]] += + p.get_cim().as_ref().unwrap()[[ + self.get_param_index_network(idx_node, ¤t_state_statetype), + current_state[idx_node], + next_node_state, + ]]; + } + } + } + + let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( + "ctmp".to_string(), + BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), + ); + + amalgamated_param.set_cim(amalgamated_cim).unwrap(); + + let mut ctmp = CtmpProcess::new(); + + ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) + .unwrap(); + return ctmp; + } + + pub fn idx_to_state(variables_domain: &Array1, state: usize) -> Array1 { + let mut state = state; + let mut array_state = Array1::zeros(variables_domain.shape()[0]); + for (idx, var) in variables_domain.indexed_iter() { + array_state[idx] = state % var; + state = state / var; + } + + return array_state; + } + /// Get the Adjacency Matrix. + pub fn get_adj_matrix(&self) -> Option<&Array2> { + self.adj_matrix.as_ref() + } +} + +impl process::NetworkProcess for CtbnNetwork { + /// Initialize an Adjacency matrix. + fn initialize_adj_matrix(&mut self) { + self.adj_matrix = Some(Array2::::zeros( + (self.nodes.len(), self.nodes.len()).f(), + )); + } + + /// Add a new node. + fn add_node(&mut self, mut n: Params) -> Result { + n.reset_params(); + self.adj_matrix = Option::None; + self.nodes.push(n); + Ok(self.nodes.len() - 1) + } + + /// Connect two nodes with a new edge. + fn add_edge(&mut self, parent: usize, child: usize) { + if let None = self.adj_matrix { + self.initialize_adj_matrix(); + } + + if let Some(network) = &mut self.adj_matrix { + network[[parent, child]] = 1; + self.nodes[child].reset_params(); + } + } + + fn get_node_indices(&self) -> std::ops::Range { + 0..self.nodes.len() + } + + /// Get the number of nodes of the network. + fn get_number_of_nodes(&self) -> usize { + self.nodes.len() + } + + fn get_node(&self, node_idx: usize) -> &Params { + &self.nodes[node_idx] + } + + fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { + &mut self.nodes[node_idx] + } + + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { + self.adj_matrix + .as_ref() + .unwrap() + .column(node) + .iter() + .enumerate() + .fold((0, 1), |mut acc, x| { + if x.1 > &0 { + acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; + acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); + } + acc + }) + .0 + } + + fn get_param_index_from_custom_parent_set( + &self, + current_state: &NetworkProcessState, + parent_set: &BTreeSet, + ) -> usize { + parent_set + .iter() + .fold((0, 1), |mut acc, x| { + acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; + acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); + acc + }) + .0 + } + + /// Get all the parents of the given node. + fn get_parent_set(&self, node: usize) -> BTreeSet { + self.adj_matrix + .as_ref() + .unwrap() + .column(node) + .iter() + .enumerate() + .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) + .collect() + } + + /// Get all the children of the given node. + fn get_children_set(&self, node: usize) -> BTreeSet { + self.adj_matrix + .as_ref() + .unwrap() + .row(node) + .iter() + .enumerate() + .filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) + .collect() + } +} diff --git a/reCTBN/src/process/ctmp.rs b/reCTBN/src/process/ctmp.rs new file mode 100644 index 0000000..41b8db6 --- /dev/null +++ b/reCTBN/src/process/ctmp.rs @@ -0,0 +1,114 @@ +use std::collections::BTreeSet; + +use crate::{ + params::{Params, StateType}, + process, +}; + +use super::{NetworkProcess, NetworkProcessState}; + +pub struct CtmpProcess { + param: Option, +} + +impl CtmpProcess { + pub fn new() -> CtmpProcess { + CtmpProcess { param: None } + } +} + +impl NetworkProcess for CtmpProcess { + fn initialize_adj_matrix(&mut self) { + unimplemented!("CtmpProcess has only one node") + } + + fn add_node(&mut self, n: crate::params::Params) -> Result { + match self.param { + None => { + self.param = Some(n); + Ok(0) + } + Some(_) => Err(process::NetworkError::NodeInsertionError( + "CtmpProcess has only one node".to_string(), + )), + } + } + + fn add_edge(&mut self, _parent: usize, _child: usize) { + unimplemented!("CtmpProcess has only one node") + } + + fn get_node_indices(&self) -> std::ops::Range { + match self.param { + None => 0..0, + Some(_) => 0..1, + } + } + + fn get_number_of_nodes(&self) -> usize { + match self.param { + None => 0, + Some(_) => 1, + } + } + + fn get_node(&self, node_idx: usize) -> &crate::params::Params { + if node_idx == 0 { + self.param.as_ref().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { + if node_idx == 0 { + self.param.as_mut().unwrap() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { + if node == 0 { + match current_state[0] { + StateType::Discrete(x) => x, + } + } else { + unimplemented!("CtmpProcess has only one node") + } + } + + fn get_param_index_from_custom_parent_set( + &self, + _current_state: &NetworkProcessState, + _parent_set: &BTreeSet, + ) -> usize { + unimplemented!("CtmpProcess has only one node") + } + + fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), + } + } + + fn get_children_set(&self, node: usize) -> std::collections::BTreeSet { + match self.param { + Some(_) => { + if node == 0 { + BTreeSet::new() + } else { + unimplemented!("CtmpProcess has only one node") + } + } + None => panic!("Uninitialized CtmpProcess"), + } + } +} diff --git a/reCTBN/src/reward.rs b/reCTBN/src/reward.rs new file mode 100644 index 0000000..910954c --- /dev/null +++ b/reCTBN/src/reward.rs @@ -0,0 +1,59 @@ +pub mod reward_evaluation; +pub mod reward_function; + +use std::collections::HashMap; + +use crate::process; + +/// Instantiation of reward function and instantaneous reward +/// +/// +/// # Arguments +/// +/// * `transition_reward`: reward obtained transitioning from one state to another +/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state + +#[derive(Debug, PartialEq)] +pub struct Reward { + pub transition_reward: f64, + pub instantaneous_reward: f64, +} + +/// The trait RewardFunction describe the methods that all the reward functions must satisfy + +pub trait RewardFunction: Sync { + /// Given the current state and the previous state, it compute the reward. + /// + /// # Arguments + /// + /// * `current_state`: the current state of the network represented as a `process::NetworkProcessState` + /// * `previous_state`: an optional argument representing the previous state of the network + + fn call( + &self, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, + ) -> Reward; + + /// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess + /// + /// # Arguments + /// + /// * `p`: any structure that implements the trait `process::NetworkProcess` + fn initialize_from_network_process(p: &T) -> Self; +} + +pub trait RewardEvaluation { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap; + + fn evaluate_state( + &self, + network_process: &N, + reward_function: &R, + state: &process::NetworkProcessState, + ) -> f64; +} diff --git a/reCTBN/src/reward/reward_evaluation.rs b/reCTBN/src/reward/reward_evaluation.rs new file mode 100644 index 0000000..3802489 --- /dev/null +++ b/reCTBN/src/reward/reward_evaluation.rs @@ -0,0 +1,205 @@ +use std::collections::HashMap; + +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use statrs::distribution::ContinuousCDF; + +use crate::params::{self, ParamsTrait}; +use crate::process; + +use crate::{ + process::NetworkProcessState, + reward::RewardEvaluation, + sampling::{ForwardSampler, Sampler}, +}; + +pub enum RewardCriteria { + FiniteHorizon, + InfiniteHorizon { discount_factor: f64 }, +} + +pub struct MonteCarloReward { + max_iterations: usize, + max_err_stop: f64, + alpha_stop: f64, + end_time: f64, + reward_criteria: RewardCriteria, + seed: Option, +} + +impl MonteCarloReward { + pub fn new( + max_iterations: usize, + max_err_stop: f64, + alpha_stop: f64, + end_time: f64, + reward_criteria: RewardCriteria, + seed: Option, + ) -> MonteCarloReward { + MonteCarloReward { + max_iterations, + max_err_stop, + alpha_stop, + end_time, + reward_criteria, + seed, + } + } +} + +impl RewardEvaluation for MonteCarloReward { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap { + let variables_domain: Vec> = network_process + .get_node_indices() + .map(|x| match network_process.get_node(x) { + params::Params::DiscreteStatesContinousTime(x) => (0..x + .get_reserved_space_as_parent()) + .map(|s| params::StateType::Discrete(s)) + .collect(), + }) + .collect(); + + let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); + + (0..n_states) + .into_par_iter() + .map(|s| { + let state: process::NetworkProcessState = variables_domain + .iter() + .fold((s, vec![]), |acc, x| { + let mut acc = acc; + let idx_s = acc.0 % x.len(); + acc.1.push(x[idx_s].clone()); + acc.0 = acc.0 / x.len(); + acc + }) + .1; + + let r = self.evaluate_state(network_process, reward_function, &state); + (state, r) + }) + .collect() + } + + fn evaluate_state( + &self, + network_process: &N, + reward_function: &R, + state: &NetworkProcessState, + ) -> f64 { + let mut sampler = + ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); + let mut expected_value = 0.0; + let mut squared_expected_value = 0.0; + let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); + + for i in 0..self.max_iterations { + sampler.reset(); + let mut ret = 0.0; + let mut previous = sampler.next().unwrap(); + while previous.t < self.end_time { + let current = sampler.next().unwrap(); + if current.t > self.end_time { + let r = reward_function.call(&previous.state, None); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => self.end_time - previous.t, + RewardCriteria::InfiniteHorizon { discount_factor } => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * self.end_time) + } + }; + ret += discount * r.instantaneous_reward; + } else { + let r = reward_function.call(&previous.state, Some(¤t.state)); + let discount = match self.reward_criteria { + RewardCriteria::FiniteHorizon => current.t - previous.t, + RewardCriteria::InfiniteHorizon { discount_factor } => { + std::f64::consts::E.powf(-discount_factor * previous.t) + - std::f64::consts::E.powf(-discount_factor * current.t) + } + }; + ret += discount * r.instantaneous_reward; + ret += match self.reward_criteria { + RewardCriteria::FiniteHorizon => 1.0, + RewardCriteria::InfiniteHorizon { discount_factor } => { + std::f64::consts::E.powf(-discount_factor * current.t) + } + } * r.transition_reward; + } + previous = current; + } + + let float_i = i as f64; + expected_value = + expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); + squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) + + ret.powi(2) / (float_i + 1.0); + + if i > 2 { + let var = + (float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); + if self.alpha_stop + - 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) + > 0.0 + { + return expected_value; + } + } + } + + expected_value + } +} + +pub struct NeighborhoodRelativeReward { + inner_reward: RE, +} + +impl NeighborhoodRelativeReward { + pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward { + NeighborhoodRelativeReward { inner_reward } + } +} + +impl RewardEvaluation for NeighborhoodRelativeReward { + fn evaluate_state_space( + &self, + network_process: &N, + reward_function: &R, + ) -> HashMap { + let absolute_reward = self + .inner_reward + .evaluate_state_space(network_process, reward_function); + + //This approach optimize memory. Maybe optimizing execution time can be better. + absolute_reward + .iter() + .map(|(k1, v1)| { + let mut max_val: f64 = 1.0; + absolute_reward.iter().for_each(|(k2, v2)| { + let count_diff: usize = k1 + .iter() + .zip(k2.iter()) + .map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) + .sum(); + if count_diff < 2 { + max_val = max_val.max(v1 / v2); + } + }); + (k1.clone(), max_val) + }) + .collect() + } + + fn evaluate_state( + &self, + _network_process: &N, + _reward_function: &R, + _state: &process::NetworkProcessState, + ) -> f64 { + unimplemented!(); + } +} diff --git a/reCTBN/src/reward/reward_function.rs b/reCTBN/src/reward/reward_function.rs new file mode 100644 index 0000000..216df6a --- /dev/null +++ b/reCTBN/src/reward/reward_function.rs @@ -0,0 +1,106 @@ +//! Module for dealing with reward functions + +use crate::{ + params::{self, ParamsTrait}, + process, + reward::{Reward, RewardFunction}, +}; + +use ndarray; + +/// Reward function over a factored state space +/// +/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node +/// of the underling `NetworkProcess` +/// +/// # Arguments +/// +/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition +/// reward of a node + +pub struct FactoredRewardFunction { + transition_reward: Vec>, + instantaneous_reward: Vec>, +} + +impl FactoredRewardFunction { + pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2 { + &self.transition_reward[node_idx] + } + + pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2 { + &mut self.transition_reward[node_idx] + } + + pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1 { + &self.instantaneous_reward[node_idx] + } + + pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1 { + &mut self.instantaneous_reward[node_idx] + } +} + +impl RewardFunction for FactoredRewardFunction { + fn call( + &self, + current_state: &process::NetworkProcessState, + previous_state: Option<&process::NetworkProcessState>, + ) -> Reward { + let instantaneous_reward: f64 = current_state + .iter() + .enumerate() + .map(|(idx, x)| { + let x = match x { + params::StateType::Discrete(x) => x, + }; + self.instantaneous_reward[idx][*x] + }) + .sum(); + if let Some(previous_state) = previous_state { + let transition_reward = previous_state + .iter() + .zip(current_state.iter()) + .enumerate() + .find_map(|(idx, (p, c))| -> Option { + let p = match p { + params::StateType::Discrete(p) => p, + }; + let c = match c { + params::StateType::Discrete(c) => c, + }; + if p != c { + Some(self.transition_reward[idx][[*p, *c]]) + } else { + None + } + }) + .unwrap_or(0.0); + Reward { + transition_reward, + instantaneous_reward, + } + } else { + Reward { + transition_reward: 0.0, + instantaneous_reward, + } + } + } + + fn initialize_from_network_process(p: &T) -> Self { + let mut transition_reward: Vec> = vec![]; + let mut instantaneous_reward: Vec> = vec![]; + for i in p.get_node_indices() { + //This works only for discrete nodes! + let size: usize = p.get_node(i).get_reserved_space_as_parent(); + instantaneous_reward.push(ndarray::Array1::zeros(size)); + transition_reward.push(ndarray::Array2::zeros((size, size))); + } + + FactoredRewardFunction { + transition_reward, + instantaneous_reward, + } + } +} diff --git a/reCTBN/src/sampling.rs b/reCTBN/src/sampling.rs new file mode 100644 index 0000000..73c6d78 --- /dev/null +++ b/reCTBN/src/sampling.rs @@ -0,0 +1,133 @@ +//! Module containing methods for the sampling. + +use crate::{ + params::ParamsTrait, + process::{NetworkProcess, NetworkProcessState}, +}; +use rand::SeedableRng; +use rand_chacha::ChaCha8Rng; + +#[derive(Clone)] +pub struct Sample { + pub t: f64, + pub state: NetworkProcessState, +} + +pub trait Sampler: Iterator { + fn reset(&mut self); +} + +pub struct ForwardSampler<'a, T> +where + T: NetworkProcess, +{ + net: &'a T, + rng: ChaCha8Rng, + current_time: f64, + current_state: NetworkProcessState, + next_transitions: Vec>, + initial_state: Option, +} + +impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { + pub fn new( + net: &'a T, + seed: Option, + initial_state: Option, + ) -> ForwardSampler<'a, T> { + let rng: ChaCha8Rng = match seed { + //If a seed is present use it to initialize the random generator. + Some(seed) => SeedableRng::seed_from_u64(seed), + //Otherwise create a new random generator using the method `from_entropy` + None => SeedableRng::from_entropy(), + }; + let mut fs = ForwardSampler { + net, + rng, + current_time: 0.0, + current_state: vec![], + next_transitions: vec![], + initial_state, + }; + fs.reset(); + return fs; + } +} + +impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { + type Item = Sample; + + fn next(&mut self) -> Option { + let ret_time = self.current_time.clone(); + let ret_state = self.current_state.clone(); + + for (idx, val) in self.next_transitions.iter_mut().enumerate() { + if let None = val { + *val = Some( + self.net + .get_node(idx) + .get_random_residence_time( + self.net + .get_node(idx) + .state_to_index(&self.current_state[idx]), + self.net.get_param_index_network(idx, &self.current_state), + &mut self.rng, + ) + .unwrap() + + self.current_time, + ); + } + } + + let next_node_transition = self + .next_transitions + .iter() + .enumerate() + .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) + .unwrap() + .0; + + self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); + + self.current_state[next_node_transition] = self + .net + .get_node(next_node_transition) + .get_random_state( + self.net + .get_node(next_node_transition) + .state_to_index(&self.current_state[next_node_transition]), + self.net + .get_param_index_network(next_node_transition, &self.current_state), + &mut self.rng, + ) + .unwrap(); + + self.next_transitions[next_node_transition] = None; + + for child in self.net.get_children_set(next_node_transition) { + self.next_transitions[child] = None; + } + + Some(Sample { + t: ret_time, + state: ret_state, + }) + } +} + +impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { + fn reset(&mut self) { + self.current_time = 0.0; + match &self.initial_state { + None => { + self.current_state = self + .net + .get_node_indices() + .map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) + .collect() + } + Some(is) => self.current_state = is.clone(), + }; + self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); + } +} diff --git a/reCTBN/src/structure_learning.rs b/reCTBN/src/structure_learning.rs new file mode 100644 index 0000000..a4c6ea1 --- /dev/null +++ b/reCTBN/src/structure_learning.rs @@ -0,0 +1,13 @@ +//! Learn the structure of the network. + +pub mod constraint_based_algorithm; +pub mod hypothesis_test; +pub mod score_based_algorithm; +pub mod score_function; +use crate::{process, tools::Dataset}; + +pub trait StructureLearningAlgorithm { + fn fit_transform(&self, net: T, dataset: &Dataset) -> T + where + T: process::NetworkProcess; +} diff --git a/reCTBN/src/structure_learning/constraint_based_algorithm.rs b/reCTBN/src/structure_learning/constraint_based_algorithm.rs new file mode 100644 index 0000000..f9cd820 --- /dev/null +++ b/reCTBN/src/structure_learning/constraint_based_algorithm.rs @@ -0,0 +1,348 @@ +//! Module containing constraint based algorithms like CTPC and Hiton. + +use crate::params::Params; +use itertools::Itertools; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::ParallelExtend; +use std::collections::{BTreeSet, HashMap}; +use std::mem; +use std::usize; + +use super::hypothesis_test::*; +use crate::parameter_learning::ParameterLearning; +use crate::process; +use crate::structure_learning::StructureLearningAlgorithm; +use crate::tools::Dataset; + +pub struct Cache<'a, P: ParameterLearning> { + parameter_learning: &'a P, + cache_persistent_small: HashMap>, Params>, + cache_persistent_big: HashMap>, Params>, + parent_set_size_small: usize, +} + +impl<'a, P: ParameterLearning> Cache<'a, P> { + pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { + Cache { + parameter_learning, + cache_persistent_small: HashMap::new(), + cache_persistent_big: HashMap::new(), + parent_set_size_small: 0, + } + } + pub fn fit( + &mut self, + net: &T, + dataset: &Dataset, + node: usize, + parent_set: Option>, + ) -> Params { + let parent_set_len = parent_set.as_ref().unwrap().len(); + if parent_set_len > self.parent_set_size_small + 1 { + //self.cache_persistent_small = self.cache_persistent_big; + mem::swap( + &mut self.cache_persistent_small, + &mut self.cache_persistent_big, + ); + self.cache_persistent_big = HashMap::new(); + self.parent_set_size_small += 1; + } + + if parent_set_len > self.parent_set_size_small { + match self.cache_persistent_big.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_big.insert(parent_set, params.clone()); + params + } + } + } else { + match self.cache_persistent_small.get(&parent_set) { + // TODO: Better not clone `params`, useless clock cycles, RAM use and I/O + // not cloning requires a minor and reasoned refactoring across the library + Some(params) => params.clone(), + None => { + let params = + self.parameter_learning + .fit(net, dataset, node, parent_set.clone()); + self.cache_persistent_small + .insert(parent_set, params.clone()); + params + } + } + } + } +} + +/// Continuous-Time Peter Clark algorithm. +/// +/// A method to learn the structure of the network. +/// +/// # Arguments +/// +/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters. +/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test. +/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test. +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::parameter_learning::BayesianApproach; +/// use reCTBN::structure_learning::StructureLearningAlgorithm; +/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare}; +/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC; +/// # +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("A")); +/// # domain.insert(String::from("B")); +/// # domain.insert(String::from("C")); +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain); +/// # //Create the node n1 using the parameters +/// # let n1 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("D")); +/// # domain.insert(String::from("E")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain); +/// # let n2 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # let mut domain = BTreeSet::new(); +/// # domain.insert(String::from("G")); +/// # domain.insert(String::from("H")); +/// # domain.insert(String::from("I")); +/// # domain.insert(String::from("F")); +/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain); +/// # let n3 = params::Params::DiscreteStatesContinousTime(param); +/// # +/// # // Initialize a ctbn +/// # let mut net = CtbnNetwork::new(); +/// # +/// # // Add the nodes and their edges +/// # let n1 = net.add_node(n1).unwrap(); +/// # let n2 = net.add_node(n2).unwrap(); +/// # let n3 = net.add_node(n3).unwrap(); +/// # net.add_edge(n1, n2); +/// # net.add_edge(n1, n3); +/// # net.add_edge(n2, n3); +/// # +/// # match &mut net.get_node_mut(n1) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-3.0, 2.0, 1.0], +/// # [1.5, -2.0, 0.5], +/// # [0.4, 0.6, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n2) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.5], +/// # [3.0, -4.0, 1.0], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 4.0], +/// # [1.5, -2.0, 0.5], +/// # [3.0, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.0, 0.1, 0.9], +/// # [2.0, -2.5, 0.5], +/// # [0.9, 0.1, -1.0] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # match &mut net.get_node_mut(n3) { +/// # params::Params::DiscreteStatesContinousTime(param) => { +/// # assert_eq!( +/// # Ok(()), +/// # param.set_cim(arr3(&[ +/// # [ +/// # [-1.0, 0.5, 0.3, 0.2], +/// # [0.5, -4.0, 2.5, 1.0], +/// # [2.5, 0.5, -4.0, 1.0], +/// # [0.7, 0.2, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 2.0, 3.0, 1.0], +/// # [1.5, -3.0, 0.5, 1.0], +/// # [2.0, 1.3, -5.0, 1.7], +/// # [2.5, 0.5, 1.0, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.3, 0.1, 0.9], +/// # [1.4, -4.0, 0.5, 2.1], +/// # [1.0, 1.5, -3.0, 0.5], +/// # [0.4, 0.3, 0.1, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.7, 0.3], +/// # [1.3, -5.9, 2.7, 1.9], +/// # [2.0, 1.5, -4.0, 0.5], +/// # [0.2, 0.7, 0.1, -1.0] +/// # ], +/// # [ +/// # [-6.0, 1.0, 2.0, 3.0], +/// # [0.5, -3.0, 1.0, 1.5], +/// # [1.4, 2.1, -4.3, 0.8], +/// # [0.5, 1.0, 2.5, -4.0] +/// # ], +/// # [ +/// # [-1.3, 0.9, 0.3, 0.1], +/// # [0.1, -1.3, 0.2, 1.0], +/// # [0.5, 1.0, -3.0, 1.5], +/// # [0.1, 0.4, 0.3, -0.8] +/// # ], +/// # [ +/// # [-2.0, 1.0, 0.6, 0.4], +/// # [2.6, -7.1, 1.4, 3.1], +/// # [5.0, 1.0, -8.0, 2.0], +/// # [1.4, 0.4, 0.2, -2.0] +/// # ], +/// # [ +/// # [-3.0, 1.0, 1.5, 0.5], +/// # [3.0, -6.0, 1.0, 2.0], +/// # [0.3, 0.5, -1.9, 1.1], +/// # [5.0, 1.0, 2.0, -8.0] +/// # ], +/// # [ +/// # [-2.6, 0.6, 0.2, 1.8], +/// # [2.0, -6.0, 3.0, 1.0], +/// # [0.1, 0.5, -1.3, 0.7], +/// # [0.8, 0.6, 0.2, -1.6] +/// # ], +/// # ])) +/// # ); +/// # } +/// # } +/// # +/// # // Generate the trajectory +/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873)); +/// +/// // Initialize the hypothesis tests to pass to the CTPC with their +/// // respective significance level `alpha` +/// let f = F::new(1e-6); +/// let chi_sq = ChiSquare::new(1e-4); +/// // Use the bayesian approach to learn the parameters +/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; +/// +/// //Initialize CTPC +/// let ctpc = CTPC::new(parameter_learning, f, chi_sq); +/// +/// // Learn the structure of the network from the generated trajectory +/// let net = ctpc.fit_transform(net, &data); +/// # +/// # // Compare the generated network with the original one +/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); +/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +/// ``` +pub struct CTPC { + parameter_learning: P, + Ftest: F, + Chi2test: ChiSquare, +} + +impl CTPC

{ + pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC

{ + CTPC { + parameter_learning, + Ftest, + Chi2test, + } + } +} + +impl StructureLearningAlgorithm for CTPC

{ + fn fit_transform(&self, net: T, dataset: &Dataset) -> T + where + T: process::NetworkProcess, + { + //Check the coherence between dataset and network + if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { + panic!("Dataset and Network must have the same number of variables.") + } + + //Make the network mutable. + let mut net = net; + + net.initialize_adj_matrix(); + + let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![]; + learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { + let mut cache = Cache::new(&self.parameter_learning); + let mut candidate_parent_set: BTreeSet = net + .get_node_indices() + .into_iter() + .filter(|x| x != &child_node) + .collect(); + let mut separation_set_size = 0; + while separation_set_size < candidate_parent_set.len() { + let mut candidate_parent_set_TMP = candidate_parent_set.clone(); + for parent_node in candidate_parent_set.iter() { + for separation_set in candidate_parent_set + .iter() + .filter(|x| x != &parent_node) + .map(|x| *x) + .combinations(separation_set_size) + { + let separation_set = separation_set.into_iter().collect(); + if self.Ftest.call( + &net, + child_node, + *parent_node, + &separation_set, + dataset, + &mut cache, + ) && self.Chi2test.call( + &net, + child_node, + *parent_node, + &separation_set, + dataset, + &mut cache, + ) { + candidate_parent_set_TMP.remove(parent_node); + break; + } + } + } + candidate_parent_set = candidate_parent_set_TMP; + separation_set_size += 1; + } + (child_node, candidate_parent_set) + })); + for (child_node, candidate_parent_set) in learned_parent_sets { + for parent_node in candidate_parent_set.iter() { + net.add_edge(*parent_node, child_node); + } + } + net + } +} diff --git a/reCTBN/src/structure_learning/hypothesis_test.rs b/reCTBN/src/structure_learning/hypothesis_test.rs new file mode 100644 index 0000000..311ec47 --- /dev/null +++ b/reCTBN/src/structure_learning/hypothesis_test.rs @@ -0,0 +1,261 @@ +//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc... + +use std::collections::BTreeSet; + +use ndarray::{Array3, Axis}; +use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; + +use crate::params::*; +use crate::structure_learning::constraint_based_algorithm::Cache; +use crate::{parameter_learning, process, tools::Dataset}; + +pub trait HypothesisTest { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + dataset: &Dataset, + cache: &mut Cache

, + ) -> bool + where + T: process::NetworkProcess, + P: parameter_learning::ParameterLearning; +} + +/// Does the chi-squared test (χ2 test). +/// +/// Used to determine if a difference between two sets of data is due to chance, or if it is due to +/// a relationship (dependence) between the variables. +/// +/// # Arguments +/// +/// * `alpha` - is the significance level, the probability to reject a true null hypothesis; +/// in other words is the risk of concluding that an association between the variables exists +/// when there is no actual association. + +pub struct ChiSquare { + alpha: f64, +} + +/// Does the F-test. +/// +/// Used to determine if a difference between two sets of data is due to chance, or if it is due to +/// a relationship (dependence) between the variables. +/// +/// # Arguments +/// +/// * `alpha` - is the significance level, the probability to reject a true null hypothesis; +/// in other words is the risk of concluding that an association between the variables exists +/// when there is no actual association. + +pub struct F { + alpha: f64, +} + +impl F { + pub fn new(alpha: f64) -> F { + F { alpha } + } + + /// Compare two matrices extracted from two 3rd-orer tensors. + /// + /// # Arguments + /// + /// * `i` - Position of the matrix of `M1` to compare with `M2`. + /// * `M1` - 3rd-order tensor 1. + /// * `j` - Position of the matrix of `M2` to compare with `M1`. + /// * `M2` - 3rd-order tensor 2. + /// + /// # Returns + /// + /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**. + /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**. + + pub fn compare_matrices( + &self, + i: usize, + M1: &Array3, + cim_1: &Array3, + j: usize, + M2: &Array3, + cim_2: &Array3, + ) -> bool { + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + let cim_1 = cim_1.index_axis(Axis(0), i); + let cim_2 = cim_2.index_axis(Axis(0), j); + let r1 = M1.sum_axis(Axis(1)); + let r2 = M2.sum_axis(Axis(1)); + let q1 = cim_1.diag(); + let q2 = cim_2.diag(); + for idx in 0..r1.shape()[0] { + let s = q2[idx] / q1[idx]; + let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap(); + let s = F.cdf(s); + let lim_sx = self.alpha / 2.0; + let lim_dx = 1.0 - (self.alpha / 2.0); + if s < lim_sx || s > lim_dx { + return false; + } + } + true + } +} + +impl HypothesisTest for F { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + dataset: &Dataset, + cache: &mut Cache

, + ) -> bool + where + T: process::NetworkProcess, + P: parameter_learning::ParameterLearning, + { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let mut extended_separation_set = separation_set.clone(); + extended_separation_set.insert(parent_node); + + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let partial_cardinality_product: usize = extended_separation_set + .iter() + .take_while(|x| **x != parent_node) + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { + let idx_M_small: usize = idx_M_big % partial_cardinality_product + + (idx_M_big + / (partial_cardinality_product + * net.get_node(parent_node).get_reserved_space_as_parent())) + * partial_cardinality_product; + if !self.compare_matrices( + idx_M_small, + P_small.get_transitions().as_ref().unwrap(), + P_small.get_cim().as_ref().unwrap(), + idx_M_big, + P_big.get_transitions().as_ref().unwrap(), + P_big.get_cim().as_ref().unwrap(), + ) { + return false; + } + } + return true; + } +} + +impl ChiSquare { + pub fn new(alpha: f64) -> ChiSquare { + ChiSquare { alpha } + } + + /// Compare two matrices extracted from two 3rd-orer tensors. + /// + /// # Arguments + /// + /// * `i` - Position of the matrix of `M1` to compare with `M2`. + /// * `M1` - 3rd-order tensor 1. + /// * `j` - Position of the matrix of `M2` to compare with `M1`. + /// * `M2` - 3rd-order tensor 2. + /// + /// # Returns + /// + /// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**. + /// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**. + + pub fn compare_matrices( + &self, + i: usize, + M1: &Array3, + j: usize, + M2: &Array3, + ) -> bool { + // Bregoli, A., Scutari, M. and Stella, F., 2021. + // A constraint-based algorithm for the structural learning of + // continuous-time Bayesian networks. + // International Journal of Approximate Reasoning, 138, pp.105-122. + // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm + let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); + let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); + let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); + let K = K.mapv(f64::sqrt); + // Reshape to column vector. + let K = { + let n = K.len(); + K.into_shape((n, 1)).unwrap() + }; + let L = 1.0 / &K; + let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); + X_2.diag_mut().fill(0.0); + let X_2 = X_2.sum_axis(Axis(1)); + let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); + let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); + ret + } +} + +impl HypothesisTest for ChiSquare { + fn call( + &self, + net: &T, + child_node: usize, + parent_node: usize, + separation_set: &BTreeSet, + dataset: &Dataset, + cache: &mut Cache

, + ) -> bool + where + T: process::NetworkProcess, + P: parameter_learning::ParameterLearning, + { + let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let mut extended_separation_set = separation_set.clone(); + extended_separation_set.insert(parent_node); + + let P_big = match cache.fit( + net, + &dataset, + child_node, + Some(extended_separation_set.clone()), + ) { + Params::DiscreteStatesContinousTime(node) => node, + }; + let partial_cardinality_product: usize = extended_separation_set + .iter() + .take_while(|x| **x != parent_node) + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { + let idx_M_small: usize = idx_M_big % partial_cardinality_product + + (idx_M_big + / (partial_cardinality_product + * net.get_node(parent_node).get_reserved_space_as_parent())) + * partial_cardinality_product; + if !self.compare_matrices( + idx_M_small, + P_small.get_transitions().as_ref().unwrap(), + idx_M_big, + P_big.get_transitions().as_ref().unwrap(), + ) { + return false; + } + } + return true; + } +} diff --git a/reCTBN/src/structure_learning/score_based_algorithm.rs b/reCTBN/src/structure_learning/score_based_algorithm.rs new file mode 100644 index 0000000..9173b86 --- /dev/null +++ b/reCTBN/src/structure_learning/score_based_algorithm.rs @@ -0,0 +1,93 @@ +//! Module containing score based algorithms like Hill Climbing and Tabu Search. + +use std::collections::BTreeSet; + +use crate::structure_learning::score_function::ScoreFunction; +use crate::structure_learning::StructureLearningAlgorithm; +use crate::{process, tools::Dataset}; + +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::prelude::ParallelExtend; + +pub struct HillClimbing { + score_function: S, + max_parent_set: Option, +} + +impl HillClimbing { + pub fn new(score_function: S, max_parent_set: Option) -> HillClimbing { + HillClimbing { + score_function, + max_parent_set, + } + } +} + +impl StructureLearningAlgorithm for HillClimbing { + fn fit_transform(&self, net: T, dataset: &Dataset) -> T + where + T: process::NetworkProcess, + { + //Check the coherence between dataset and network + if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { + panic!("Dataset and Network must have the same number of variables.") + } + + //Make the network mutable. + let mut net = net; + //Check if the max_parent_set constraint is present. + let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); + //Reset the adj matrix + net.initialize_adj_matrix(); + let mut learned_parent_sets: Vec<(usize, BTreeSet)> = vec![]; + //Iterate over each node to learn their parent set. + learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| { + //Initialize an empty parent set. + let mut parent_set: BTreeSet = BTreeSet::new(); + //Compute the score for the empty parent set + let mut current_score = self.score_function.call(&net, node, &parent_set, dataset); + //Set the old score to -\infty. + let mut old_score = f64::NEG_INFINITY; + //Iterate until convergence + while current_score > old_score { + //Save the current_score. + old_score = current_score; + //Iterate over each node. + for parent in net.get_node_indices() { + //Continue if the parent and the node are the same. + if parent == node { + continue; + } + //Try to remove parent from the parent_set. + let is_removed = parent_set.remove(&parent); + //If parent was not in the parent_set add it. + if !is_removed && parent_set.len() < max_parent_set { + parent_set.insert(parent); + } + //Compute the score with the modified parent_set. + let tmp_score = self.score_function.call(&net, node, &parent_set, dataset); + //If tmp_score is worst than current_score revert the change to the parent set + if tmp_score < current_score { + if is_removed { + parent_set.insert(parent); + } else { + parent_set.remove(&parent); + } + } + //Otherwise save the computed score as current_score + else { + current_score = tmp_score; + } + } + } + (node, parent_set) + })); + + for (child_node, candidate_parent_set) in learned_parent_sets { + for parent_node in candidate_parent_set.iter() { + net.add_edge(*parent_node, child_node); + } + } + return net; + } +} diff --git a/reCTBN/src/structure_learning/score_function.rs b/reCTBN/src/structure_learning/score_function.rs new file mode 100644 index 0000000..5a56594 --- /dev/null +++ b/reCTBN/src/structure_learning/score_function.rs @@ -0,0 +1,146 @@ +//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc... + +use std::collections::BTreeSet; + +use ndarray::prelude::*; +use statrs::function::gamma; + +use crate::{parameter_learning, params, process, tools}; + +pub trait ScoreFunction: Sync { + fn call( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: process::NetworkProcess; +} + +pub struct LogLikelihood { + alpha: usize, + tau: f64, +} + +impl LogLikelihood { + pub fn new(alpha: usize, tau: f64) -> LogLikelihood { + //Tau must be >=0.0 + if tau < 0.0 { + panic!("tau must be >=0.0"); + } + LogLikelihood { alpha, tau } + } + + fn compute_score( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> (f64, Array3) + where + T: process::NetworkProcess, + { + //Identify the type of node used + match &net.get_node(node) { + params::Params::DiscreteStatesContinousTime(_params) => { + //Compute the sufficient statistics M (number of transistions) and T (residence + //time) + let (M, T) = + parameter_learning::sufficient_statistics(net, dataset, node, parent_set); + + //Scale alpha accordingly to the size of the parent set + let alpha = self.alpha as f64 / M.shape()[0] as f64; + //Scale tau accordingly to the size of the parent set + let tau = self.tau / M.shape()[0] as f64; + + //Compute the log likelihood for q + let log_ll_q: f64 = M + .sum_axis(Axis(2)) + .iter() + .zip(T.iter()) + .map(|(m, t)| { + gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau) + - gamma::ln_gamma(alpha + 1.0) + - (alpha + *m as f64 + 1.0) * f64::ln(tau + t) + }) + .sum(); + + //Compute the log likelihood for theta + let log_ll_theta: f64 = M + .outer_iter() + .map(|x| { + x.outer_iter() + .map(|y| { + gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64) + + y.iter() + .map(|z| { + gamma::ln_gamma(alpha + *z as f64) + - gamma::ln_gamma(alpha) + }) + .sum::() + }) + .sum::() + }) + .sum(); + (log_ll_theta + log_ll_q, M) + } + } + } +} + +impl ScoreFunction for LogLikelihood { + fn call( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: process::NetworkProcess, + { + self.compute_score(net, node, parent_set, dataset).0 + } +} + +pub struct BIC { + ll: LogLikelihood, +} + +impl BIC { + pub fn new(alpha: usize, tau: f64) -> BIC { + BIC { + ll: LogLikelihood::new(alpha, tau), + } + } +} + +impl ScoreFunction for BIC { + fn call( + &self, + net: &T, + node: usize, + parent_set: &BTreeSet, + dataset: &tools::Dataset, + ) -> f64 + where + T: process::NetworkProcess, + { + //Compute the log-likelihood + let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); + //Compute the number of parameters + let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); + //TODO: Optimize this + //Compute the sample size + let sample_size: usize = dataset + .get_trajectories() + .iter() + .map(|x| x.get_time().len() - 1) + .sum(); + //Compute BIC + ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 + } +} diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs new file mode 100644 index 0000000..5085c43 --- /dev/null +++ b/reCTBN/src/tools.rs @@ -0,0 +1,355 @@ +//! Contains commonly used methods used across the crate. + +use std::ops::{DivAssign, MulAssign, Range}; + +use ndarray::{Array, Array1, Array2, Array3, Axis}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +use crate::params::ParamsTrait; +use crate::process::NetworkProcess; +use crate::sampling::{ForwardSampler, Sampler}; +use crate::{params, process}; + +#[derive(Clone)] +pub struct Trajectory { + time: Array1, + events: Array2, +} + +impl Trajectory { + pub fn new(time: Array1, events: Array2) -> Trajectory { + //Events and time are two part of the same trajectory. For this reason they must have the + //same number of sample. + if time.shape()[0] != events.shape()[0] { + panic!("time.shape[0] must be equal to events.shape[0]"); + } + Trajectory { time, events } + } + + pub fn get_time(&self) -> &Array1 { + &self.time + } + + pub fn get_events(&self) -> &Array2 { + &self.events + } +} + +#[derive(Clone)] +pub struct Dataset { + trajectories: Vec, +} + +impl Dataset { + pub fn new(trajectories: Vec) -> Dataset { + //All the trajectories in the same dataset must represent the same process. For this reason + //each trajectory must represent the same number of variables. + if trajectories + .iter() + .any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1]) + { + panic!("All the trajectories mus represents the same number of variables"); + } + Dataset { trajectories } + } + + pub fn get_trajectories(&self) -> &Vec { + &self.trajectories + } +} + +pub fn trajectory_generator( + net: &T, + n_trajectories: u64, + t_end: f64, + seed: Option, +) -> Dataset { + //Tmp growing vector containing generated trajectories. + let mut trajectories: Vec = Vec::new(); + + //Random Generator object + let mut sampler = ForwardSampler::new(net, seed, None); + //Each iteration generate one trajectory + for _ in 0..n_trajectories { + //History of all the moments in which something changed + let mut time: Vec = Vec::new(); + //Configuration of the process variables at time t initialized with an uniform + //distribution. + let mut events: Vec = Vec::new(); + + //Current Time and Current State + let mut sample = sampler.next().unwrap(); + //Generate new samples until ending time is reached. + while sample.t < t_end { + time.push(sample.t); + events.push(sample.state); + sample = sampler.next().unwrap(); + } + + let current_state = events.last().unwrap().clone(); + events.push(current_state); + + //Add t_end as last time. + time.push(t_end.clone()); + + //Add the sampled trajectory to trajectories. + trajectories.push(Trajectory::new( + Array::from_vec(time), + Array2::from_shape_vec( + (events.len(), events.last().unwrap().len()), + events + .iter() + .flatten() + .map(|x| match x { + params::StateType::Discrete(x) => x.clone(), + }) + .collect(), + ) + .unwrap(), + )); + sampler.reset(); + } + //Return a dataset object with the sampled trajectories. + Dataset::new(trajectories) +} + +pub trait RandomGraphGenerator { + fn new(density: f64, seed: Option) -> Self; + fn generate_graph(&mut self, net: &mut T); +} + +/// Graph Generator using an uniform distribution. +/// +/// A method to generate a random graph with edges uniformly distributed. +/// +/// # Arguments +/// +/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomGraphGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// +/// // Initialize the Graph Generator using the one with an +/// // uniform distribution +/// let density = 1.0/3.0; +/// let seed = Some(7641630759785120); +/// let mut structure_generator = UniformGraphGenerator::new( +/// density, +/// seed +/// ); +/// +/// // Generate the graph directly on the network +/// structure_generator.generate_graph(&mut net); +/// # // Count all the edges generated in the network +/// # let mut edges = 0; +/// # for node in net.get_node_indices(){ +/// # edges += net.get_children_set(node).len() +/// # } +/// # // Number of all the nodes in the network +/// # let nodes = net.get_node_indices().len() as f64; +/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; +/// # // ±10% of tolerance +/// # let tolerance = ((expected_edges as f64)*0.10) as usize; +/// # // As the way `generate_graph()` is implemented we can only reasonably +/// # // expect the number of edges to be somewhere around the expected value. +/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); +/// ``` +pub struct UniformGraphGenerator { + density: f64, + rng: ChaCha8Rng, +} + +impl RandomGraphGenerator for UniformGraphGenerator { + fn new(density: f64, seed: Option) -> UniformGraphGenerator { + if density < 0.0 || density > 1.0 { + panic!( + "Density value must be between 1.0 and 0.0, got {}.", + density + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + UniformGraphGenerator { density, rng } + } + + /// Generate an uniformly distributed graph. + fn generate_graph(&mut self, net: &mut T) { + net.initialize_adj_matrix(); + let last_node_idx = net.get_node_indices().len(); + for parent in 0..last_node_idx { + for child in 0..last_node_idx { + if parent != child { + if self.rng.gen_bool(self.density) { + net.add_edge(parent, child); + } + } + } + } + } +} + +pub trait RandomParametersGenerator { + fn new(interval: Range, seed: Option) -> Self; + fn generate_parameters(&mut self, net: &mut T); +} + +/// Parameters Generator using an uniform distribution. +/// +/// A method to generate random parameters uniformly distributed. +/// +/// # Arguments +/// +/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`. +/// * `rng` - is the random numbers generator. +/// +/// # Example +/// +/// ```rust +/// # use std::collections::BTreeSet; +/// # use ndarray::{arr1, arr2, arr3}; +/// # use reCTBN::params; +/// # use reCTBN::params::ParamsTrait; +/// # use reCTBN::params::Params::DiscreteStatesContinousTime; +/// # use reCTBN::process::NetworkProcess; +/// # use reCTBN::process::ctbn::CtbnNetwork; +/// # use reCTBN::tools::trajectory_generator; +/// # use reCTBN::tools::RandomGraphGenerator; +/// # use reCTBN::tools::UniformGraphGenerator; +/// use reCTBN::tools::RandomParametersGenerator; +/// use reCTBN::tools::UniformParametersGenerator; +/// # let mut net = CtbnNetwork::new(); +/// # let nodes_cardinality = 8; +/// # let domain_cardinality = 4; +/// # for node in 0..nodes_cardinality { +/// # // Create the domain for a discrete node +/// # let mut domain = BTreeSet::new(); +/// # for dvalue in 0..domain_cardinality { +/// # domain.insert(dvalue.to_string()); +/// # } +/// # // Create the parameters for a discrete node using the domain +/// # let param = params::DiscreteStatesContinousTimeParams::new( +/// # node.to_string(), +/// # domain +/// # ); +/// # //Create the node using the parameters +/// # let node = DiscreteStatesContinousTime(param); +/// # // Add the node to the network +/// # net.add_node(node).unwrap(); +/// # } +/// # +/// # // Initialize the Graph Generator using the one with an +/// # // uniform distribution +/// # let mut structure_generator = UniformGraphGenerator::new( +/// # 1.0/3.0, +/// # Some(7641630759785120) +/// # ); +/// # +/// # // Generate the graph directly on the network +/// # structure_generator.generate_graph(&mut net); +/// +/// // Initialize the parameters generator with uniform distributin +/// let mut cim_generator = UniformParametersGenerator::new( +/// 0.0..7.0, +/// Some(7641630759785120) +/// ); +/// +/// // Generate CIMs with uniformly distributed parameters. +/// cim_generator.generate_parameters(&mut net); +/// # +/// # for node in net.get_node_indices() { +/// # assert_eq!( +/// # Ok(()), +/// # net.get_node(node).validate_params() +/// # ); +/// } +/// ``` +pub struct UniformParametersGenerator { + interval: Range, + rng: ChaCha8Rng, +} + +impl RandomParametersGenerator for UniformParametersGenerator { + fn new(interval: Range, seed: Option) -> UniformParametersGenerator { + if interval.start < 0.0 || interval.end < 0.0 { + panic!( + "Interval must be entirely less or equal than 0, got {}..{}.", + interval.start, interval.end + ); + } + let rng: ChaCha8Rng = match seed { + Some(seed) => SeedableRng::seed_from_u64(seed), + None => SeedableRng::from_entropy(), + }; + UniformParametersGenerator { interval, rng } + } + + /// Generate CIMs with uniformly distributed parameters. + fn generate_parameters(&mut self, net: &mut T) { + for node in net.get_node_indices() { + let parent_set_state_space_cardinality: usize = net + .get_parent_set(node) + .iter() + .map(|x| net.get_node(*x).get_reserved_space_as_parent()) + .product(); + match &mut net.get_node_mut(node) { + params::Params::DiscreteStatesContinousTime(param) => { + let node_domain_cardinality = param.get_reserved_space_as_parent(); + let mut cim = Array3::::from_shape_fn( + ( + parent_set_state_space_cardinality, + node_domain_cardinality, + node_domain_cardinality, + ), + |_| self.rng.gen(), + ); + cim.axis_iter_mut(Axis(0)).for_each(|mut x| { + x.diag_mut().fill(0.0); + x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); + let diag = Array1::::from_shape_fn(node_domain_cardinality, |_| { + self.rng.gen_range(self.interval.clone()) + }); + x.mul_assign(&diag.clone().insert_axis(Axis(1))); + // Recomputing the diagonal in order to reduce the issues caused by the + // loss of precision when validating the parameters. + let diag_sum = -x.sum_axis(Axis(1)); + x.diag_mut().assign(&diag_sum) + }); + param.set_cim_unchecked(cim); + } + } + } + } +} diff --git a/reCTBN/tests/ctbn.rs b/reCTBN/tests/ctbn.rs new file mode 100644 index 0000000..3eb40d7 --- /dev/null +++ b/reCTBN/tests/ctbn.rs @@ -0,0 +1,376 @@ +mod utils; +use std::collections::BTreeSet; + + +use approx::AbsDiffEq; +use ndarray::arr3; +use reCTBN::params::{self, ParamsTrait}; +use reCTBN::process::NetworkProcess; +use reCTBN::process::{ctbn::*}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn define_simpe_ctbn() { + let _ = CtbnNetwork::new(); + assert!(true); +} + +#[test] +fn add_node_to_ctbn() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); +} + +#[test] +fn add_edge_to_ctbn() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + let cs = net.get_children_set(n1); + assert_eq!(&n2, cs.iter().next().unwrap()); +} + +#[test] +fn children_and_parents() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + let cs = net.get_children_set(n1); + assert_eq!(&n2, cs.iter().next().unwrap()); + let ps = net.get_parent_set(n2); + assert_eq!(&n1, ps.iter().next().unwrap()); +} + +#[test] +fn compute_index_ctbn() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + net.add_edge(n1, n2); + net.add_edge(n3, n2); + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(1), + params::StateType::Discrete(1), + ], + ); + assert_eq!(3, idx); + + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(1), + params::StateType::Discrete(1), + ], + ); + assert_eq!(2, idx); + + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(1), + params::StateType::Discrete(0), + ], + ); + assert_eq!(1, idx); +} + +#[test] +fn compute_index_from_custom_parent_set() { + let mut net = CtbnNetwork::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let _n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let _n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + let idx = net.get_param_index_from_custom_parent_set( + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1), + ], + &BTreeSet::from([1]), + ); + assert_eq!(0, idx); + + let idx = net.get_param_index_from_custom_parent_set( + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1), + ], + &BTreeSet::from([1, 2]), + ); + assert_eq!(2, idx); +} + +#[test] +fn simple_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + net.initialize_adj_matrix(); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); + } + } + + let ctmp = net.amalgamation(); + let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0); + let p_ctbn = p_ctbn.get_cim().as_ref().unwrap(); + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); + + assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); +} + +#[test] +fn chain_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + let ctmp = net.amalgamation(); + + + + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); + + let p_ctmp_handmade = arr3(&[[ + [ + -1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} + +#[test] +fn chainfork_amalgamation() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + let n4 = net + .add_node(generate_discrete_time_continous_node(String::from("n4"), 2)) + .unwrap(); + + net.add_edge(n1, n3); + net.add_edge(n2, n3); + net.add_edge(n3, n4); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + match &mut net.get_node_mut(n4) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]] + ])) + ); + } + } + + + let ctmp = net.amalgamation(); + + let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); + + let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); + + let p_ctmp_handmade = arr3(&[[ + [ + -2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, + 0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, + 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, + 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, + 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, + ], + [ + 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, + 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00, + ], + ]]); + + assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); +} diff --git a/reCTBN/tests/ctmp.rs b/reCTBN/tests/ctmp.rs new file mode 100644 index 0000000..830bfe0 --- /dev/null +++ b/reCTBN/tests/ctmp.rs @@ -0,0 +1,127 @@ +mod utils; + +use std::collections::BTreeSet; + +use reCTBN::{ + params, + params::ParamsTrait, + process::{ctmp::*, NetworkProcess}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn define_simple_ctmp() { + let _ = CtmpProcess::new(); + assert!(true); +} + +#[test] +fn add_node_to_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); +} + +#[test] +fn add_two_nodes_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + match n2 { + Ok(_) => assert!(false), + Err(_) => assert!(true), + }; +} + +#[test] +#[should_panic] +fn add_edge_to_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); + + net.add_edge(0, 1) +} + +#[test] +fn childen_and_parents() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + assert_eq!(0, net.get_parent_set(0).len()); + assert_eq!(0, net.get_children_set(0).len()); +} + +#[test] +#[should_panic] +fn get_childen_panic() { + let net = CtmpProcess::new(); + net.get_children_set(0); +} + +#[test] +#[should_panic] +fn get_childen_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_children_set(1); +} + +#[test] +#[should_panic] +fn get_parent_panic() { + let net = CtmpProcess::new(); + net.get_parent_set(0); +} + +#[test] +#[should_panic] +fn get_parent_panic2() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + net.get_parent_set(1); +} + +#[test] +fn compute_index_ctmp() { + let mut net = CtmpProcess::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); + assert_eq!(6, idx); +} + +#[test] +#[should_panic] +fn compute_index_from_custom_parent_set_ctmp() { + let mut net = CtmpProcess::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node( + String::from("n1"), + 10, + )) + .unwrap(); + + let _idx = net.get_param_index_from_custom_parent_set( + &vec![params::StateType::Discrete(6)], + &BTreeSet::from([0]) + ); +} diff --git a/reCTBN/tests/parameter_learning.rs b/reCTBN/tests/parameter_learning.rs new file mode 100644 index 0000000..0a09a2a --- /dev/null +++ b/reCTBN/tests/parameter_learning.rs @@ -0,0 +1,648 @@ +#![allow(non_snake_case)] + +mod utils; +use ndarray::arr3; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; +use reCTBN::parameter_learning::*; +use reCTBN::params; +use reCTBN::params::Params::DiscreteStatesContinousTime; +use reCTBN::tools::*; +use utils::*; + +extern crate approx; +use crate::approx::AbsDiffEq; + +fn learn_binary_cim(pl: T) { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); + let p = match pl.fit(&net, &data, 1, None) { + params::Params::DiscreteStatesContinousTime(p) => p, + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( + &arr3(&[ + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], + ]), + 0.1 + )); +} + +fn generate_nodes( + net: &mut CtbnNetwork, + nodes_cardinality: usize, + nodes_domain_cardinality: usize +) { + for node_label in 0..nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } +} + +fn learn_binary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 2); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + +#[test] +fn learn_binary_cim_MLE() { + let mle = MLE {}; + learn_binary_cim(mle); +} + +#[test] +fn learn_binary_cim_MLE_gen() { + let mle = MLE {}; + learn_binary_cim_gen(mle); +} + +#[test] +fn learn_binary_cim_BA() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_binary_cim(ba); +} + +#[test] +fn learn_binary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_binary_cim_gen(ba); +} + +fn learn_ternary_cim(pl: T) { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p = match pl.fit(&net, &data, 1, None) { + params::Params::DiscreteStatesContinousTime(p) => p, + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( + &arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ]), + 0.1 + )); +} + +fn learn_ternary_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 4.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(1) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 1, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + +#[test] +fn learn_ternary_cim_MLE() { + let mle = MLE {}; + learn_ternary_cim(mle); +} + +#[test] +fn learn_ternary_cim_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_gen(mle); +} + +#[test] +fn learn_ternary_cim_BA() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim(ba); +} + +#[test] +fn learn_ternary_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_gen(ba); +} + +fn learn_ternary_cim_no_parents(pl: T) { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ] + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p = match pl.fit(&net, &data, 0, None) { + params::Params::DiscreteStatesContinousTime(p) => p, + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( + &arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ]), + 0.1 + )); +} + +fn learn_ternary_cim_no_parents_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..6.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(0) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 0, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.1 + ) + ); +} + +#[test] +fn learn_ternary_cim_no_parents_MLE() { + let mle = MLE {}; + learn_ternary_cim_no_parents(mle); +} + +#[test] +fn learn_ternary_cim_no_parents_MLE_gen() { + let mle = MLE {}; + learn_ternary_cim_no_parents_gen(mle); +} + +#[test] +fn learn_ternary_cim_no_parents_BA() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_no_parents(ba); +} + +#[test] +fn learn_ternary_cim_no_parents_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_ternary_cim_no_parents_gen(ba); +} + +fn learn_mixed_discrete_cim(pl: T) { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) + .unwrap(); + net.add_edge(n1, n2); + net.add_edge(n1, n3); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); + let p = match pl.fit(&net, &data, 2, None) { + params::Params::DiscreteStatesContinousTime(p) => p, + }; + assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); + assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( + &arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ]), + 0.2 + )); +} + +fn learn_mixed_discrete_cim_gen(pl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 1.0..8.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let p_gen = match net.get_node(2) { + DiscreteStatesContinousTime(p_gen) => p_gen, + }; + + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); + let p_tj = match pl.fit(&net, &data, 2, None) { + DiscreteStatesContinousTime(p_tj) => p_tj, + }; + + assert_eq!( + p_tj.get_cim().as_ref().unwrap().shape(), + p_gen.get_cim().as_ref().unwrap().shape() + ); + assert!( + p_tj.get_cim().as_ref().unwrap().abs_diff_eq( + &p_gen.get_cim().as_ref().unwrap(), + 0.2 + ) + ); +} + +#[test] +fn learn_mixed_discrete_cim_MLE() { + let mle = MLE {}; + learn_mixed_discrete_cim(mle); +} + +#[test] +fn learn_mixed_discrete_cim_MLE_gen() { + let mle = MLE {}; + learn_mixed_discrete_cim_gen(mle); +} + +#[test] +fn learn_mixed_discrete_cim_BA() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_mixed_discrete_cim(ba); +} + +#[test] +fn learn_mixed_discrete_cim_BA_gen() { + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; + learn_mixed_discrete_cim_gen(ba); +} diff --git a/reCTBN/tests/params.rs b/reCTBN/tests/params.rs new file mode 100644 index 0000000..7f16f12 --- /dev/null +++ b/reCTBN/tests/params.rs @@ -0,0 +1,148 @@ +use ndarray::prelude::*; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; +use reCTBN::params::{ParamsTrait, *}; + +mod utils; + +#[macro_use] +extern crate approx; + +fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { + #![allow(unused_must_use)] + let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3); + + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + + params.set_cim(cim); + params +} + +#[test] +fn test_get_label() { + let param = create_ternary_discrete_time_continous_param(); + assert_eq!(&String::from("A"), param.get_label()) +} + +#[test] +fn test_uniform_generation() { + #![allow(irrefutable_let_patterns)] + let param = create_ternary_discrete_time_continous_param(); + let mut states = Array1::::zeros(10000); + + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); + + states.mapv_inplace(|_| { + if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { + val + } else { + panic!() + } + }); + let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; + + assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01); +} + +#[test] +fn test_random_generation_state() { + #![allow(irrefutable_let_patterns)] + let param = create_ternary_discrete_time_continous_param(); + let mut states = Array1::::zeros(10000); + + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); + + states.mapv_inplace(|_| { + if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { + val + } else { + panic!() + } + }); + let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; + let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; + + assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01); + assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01); +} + +#[test] +fn test_random_generation_residence_time() { + let param = create_ternary_discrete_time_continous_param(); + let mut states = Array1::::zeros(10000); + + let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); + + states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); + + assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); +} + +#[test] +fn test_validate_params_valid_cim() { + let param = create_ternary_discrete_time_continous_param(); + + assert_eq!(Ok(()), param.validate_params()); +} + +#[test] +fn test_validate_params_valid_cim_with_huge_values() { + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); + let cim = array![[ + [-2e10, 1e10, 1e10], + [1.5e10, -3e10, 1.5e10], + [1e10, 1e10, -2e10] + ]]; + let result = param.set_cim(cim); + assert_eq!(Ok(()), result); +} + +#[test] +fn test_validate_params_cim_not_initialized() { + let param = utils::generate_discrete_time_continous_params("A".to_string(), 3); + assert_eq!( + Err(ParamsError::ParametersNotInitialized(String::from( + "CIM not initialized", + ))), + param.validate_params() + ); +} + +#[test] +fn test_validate_params_wrong_shape() { + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + let result = param.set_cim(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "Incompatible shape [1, 3, 3] with domain 4" + ))), + result + ); +} + +#[test] +fn test_validate_params_positive_diag() { + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); + let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; + let result = param.set_cim(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The diagonal of each cim must be non-positive", + ))), + result + ); +} + +#[test] +fn test_validate_params_row_not_sum_to_zero() { + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); + let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; + let result = param.set_cim(cim); + assert_eq!( + Err(ParamsError::InvalidCIM(String::from( + "The sum of each row must be 0" + ))), + result + ); +} diff --git a/reCTBN/tests/reward_evaluation.rs b/reCTBN/tests/reward_evaluation.rs new file mode 100644 index 0000000..355341c --- /dev/null +++ b/reCTBN/tests/reward_evaluation.rs @@ -0,0 +1,122 @@ +mod utils; + +use approx::assert_abs_diff_eq; +use ndarray::*; +use reCTBN::{ + params, + process::{ctbn::*, NetworkProcess, NetworkProcessState}, + reward::{reward_evaluation::*, reward_function::*, *}, +}; +use utils::generate_discrete_time_continous_node; + +#[test] +fn simple_factored_reward_function_binary_node_mc() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1) + .assign(&arr1(&[3.0, 3.0])); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); + } + } + + net.initialize_adj_matrix(); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); + assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); + + + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); + assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); + + +} + +#[test] +fn simple_factored_reward_function_chain_mc() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + net.add_edge(n1, n2); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + param + .set_cim(arr3(&[ + [[-0.01, 0.01], [5.0, -5.0]], + [[-5.0, 5.0], [0.01, -0.01]], + ])) + .unwrap(); + } + } + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n2) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + rf.get_transition_reward_mut(n3) + .assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); + + let s000: NetworkProcessState = vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(0), + params::StateType::Discrete(0), + ]; + + let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); + assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); + + let rst = mc.evaluate_state_space(&net, &rf); + assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); + +} diff --git a/reCTBN/tests/reward_function.rs b/reCTBN/tests/reward_function.rs new file mode 100644 index 0000000..853efc9 --- /dev/null +++ b/reCTBN/tests/reward_function.rs @@ -0,0 +1,117 @@ +mod utils; + +use ndarray::*; +use utils::generate_discrete_time_continous_node; +use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params}; + + +#[test] +fn simple_factored_reward_function_binary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + + assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); +} + + +#[test] +fn simple_factored_reward_function_ternary_node() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; + let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; + let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; + + + assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); + assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); + + + assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); + assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); + + + assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); + assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); +} + +#[test] +fn factored_reward_function_two_nodes() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + net.add_edge(n1, n2); + + + let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); + rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); + rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); + + + rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); + rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); + let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; + let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; + let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; + + + let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; + let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; + let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; + + assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); + assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); + + + assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); + assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); + + + assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); + assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); + + + assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); + assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); + + + assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); + assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); +} diff --git a/reCTBN/tests/structure_learning.rs b/reCTBN/tests/structure_learning.rs new file mode 100644 index 0000000..3d7e230 --- /dev/null +++ b/reCTBN/tests/structure_learning.rs @@ -0,0 +1,692 @@ +#![allow(non_snake_case)] + +mod utils; +use std::collections::BTreeSet; + +use ndarray::{arr1, arr2, arr3}; +use reCTBN::process::ctbn::*; +use reCTBN::process::NetworkProcess; +use reCTBN::parameter_learning::BayesianApproach; +use reCTBN::params; +use reCTBN::structure_learning::hypothesis_test::*; +use reCTBN::structure_learning::constraint_based_algorithm::*; +use reCTBN::structure_learning::score_based_algorithm::*; +use reCTBN::structure_learning::score_function::*; +use reCTBN::structure_learning::StructureLearningAlgorithm; +use reCTBN::tools::*; +use utils::*; + +#[macro_use] +extern crate approx; + +#[test] +fn simple_score_test() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); + + let dataset = Dataset::new(vec![trj]); + + let ll = LogLikelihood::new(1, 1.0); + + assert_abs_diff_eq!( + 0.04257, + ll.call(&net, n1, &BTreeSet::new(), &dataset), + epsilon = 1e-3 + ); +} + +#[test] +fn simple_bic() { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + + let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); + + let dataset = Dataset::new(vec![trj]); + let bic = BIC::new(1, 1.0); + + assert_abs_diff_eq!( + -0.65058, + bic.call(&net, n1, &BTreeSet::new(), &dataset), + epsilon = 1e-3 + ); +} + +fn check_compatibility_between_dataset_and_network(sl: T) { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); + + let mut net = CtbnNetwork::new(); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let _net = sl.fit_transform(net, &data); +} + +fn generate_nodes( + net: &mut CtbnNetwork, + nodes_cardinality: usize, + nodes_domain_cardinality: usize +) { + for node_label in 0..nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } +} + +fn check_compatibility_between_dataset_and_network_gen(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); + + let mut net = CtbnNetwork::new(); + let _n1 = net + .add_node( + generate_discrete_time_continous_node(String::from("0"), + 3) + ).unwrap(); + let _net = sl.fit_transform(net, &data); +} + +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + check_compatibility_between_dataset_and_network(hl); +} + +#[test] +#[should_panic] +pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + check_compatibility_between_dataset_and_network_gen(hl); +} + +fn learn_ternary_net_2_nodes(sl: T) { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); + assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); +} + +fn learn_ternary_net_2_nodes_gen(sl: T) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + + net.add_edge(0, 1); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); + + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); +} + +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_ternary_net_2_nodes(hl); +} + +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_ternary_net_2_nodes_gen(hl); +} + +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_ternary_net_2_nodes(hl); +} + +#[test] +pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_ternary_net_2_nodes_gen(hl); +} + +fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) + .unwrap(); + + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) + .unwrap(); + net.add_edge(n1, n2); + net.add_edge(n1, n3); + net.add_edge(n2, n3); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.5], + [3.0, -4.0, 1.0], + [0.9, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 4.0], + [1.5, -2.0, 0.5], + [3.0, 1.0, -4.0] + ], + [ + [-1.0, 0.1, 0.9], + [2.0, -2.5, 0.5], + [0.9, 0.1, -1.0] + ], + ])) + ); + } + } + + match &mut net.get_node_mut(n3) { + params::Params::DiscreteStatesContinousTime(param) => { + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ])) + ); + } + } + + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); + return (net, data); +} + +fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + generate_nodes(&mut net, 2, 3); + net.add_node( + generate_discrete_time_continous_node( + String::from("3"), + 4 + ) + ).unwrap(); + + net.add_edge(0, 1); + net.add_edge(0, 2); + net.add_edge(1, 2); + + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + 0.0..7.0, + Some(6813071588535822) + ); + cim_generator.generate_parameters(&mut net); + + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); + return (net, data); +} + +fn learn_mixed_discrete_net_3_nodes(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +} + +fn learn_mixed_discrete_net_3_nodes_gen(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_mixed_discrete_net_3_nodes(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_mixed_discrete_net_3_nodes(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, None); + learn_mixed_discrete_net_3_nodes_gen(hl); +} + +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); +} + +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(sl: T) { + let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); + let net = sl.fit_transform(net, &data); + assert_eq!(BTreeSet::new(), net.get_parent_set(0)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); + assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() { + let ll = LogLikelihood::new(1, 1.0); + let hl = HillClimbing::new(ll, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); +} + +#[test] +pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() { + let bic = BIC::new(1, 1.0); + let hl = HillClimbing::new(bic, Some(1)); + learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); +} + +#[test] +pub fn chi_square_compare_matrices() { + let i: usize = 1; + let M1 = arr3(&[ + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 12, 90], + [ 3, 0, 40], + [ 6, 40, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], + ]); + let j: usize = 0; + let M2 = arr3(&[ + [ + [ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0] + ], + ]); + let chi_sq = ChiSquare::new(1e-4); + assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); +} + +#[test] +pub fn chi_square_compare_matrices_2() { + let i: usize = 1; + let M1 = arr3(&[ + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 20, 30], + [ 40, 0, 60], + [ 70, 80, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], + ]); + let j: usize = 0; + let M2 = arr3(&[ + [[ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0]] + ]); + let chi_sq = ChiSquare::new(1e-4); + assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); +} + +#[test] +pub fn chi_square_compare_matrices_3() { + let i: usize = 1; + let M1 = arr3(&[ + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 7, 8, 0] + ], + [ + [0, 21, 31], + [ 41, 0, 59], + [ 71, 79, 0] + ], + [ + [ 0, 2, 3], + [ 4, 0, 6], + [ 44, 66, 0] + ], + ]); + let j: usize = 0; + let M2 = arr3(&[ + [ + [ 0, 200, 300], + [ 400, 0, 600], + [ 700, 800, 0] + ], + ]); + let chi_sq = ChiSquare::new(1e-4); + assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); +} + + +#[test] +pub fn chi_square_call() { + + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let N3: usize = 2; + let N2: usize = 1; + let N1: usize = 0; + let mut separation_set = BTreeSet::new(); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let mut cache = Cache::new(¶meter_learning); + let chi_sq = ChiSquare::new(1e-4); + + assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); + assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); + separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); + assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); +} + +#[test] +pub fn f_call() { + + let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); + let N3: usize = 2; + let N2: usize = 1; + let N1: usize = 0; + let mut separation_set = BTreeSet::new(); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let mut cache = Cache::new(¶meter_learning); + let f = F::new(1e-6); + + + assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); + let mut cache = Cache::new(¶meter_learning); + assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); + assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); + separation_set.insert(N1); + let mut cache = Cache::new(¶meter_learning); + assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); +} + +#[test] +pub fn learn_ternary_net_2_nodes_ctpc() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes(ctpc); +} + +#[test] +pub fn learn_ternary_net_2_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_ternary_net_2_nodes_gen(ctpc); +} + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_mixed_discrete_net_3_nodes(ctpc); +} + +#[test] +fn learn_mixed_discrete_net_3_nodes_ctpc_gen() { + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + learn_mixed_discrete_net_3_nodes_gen(ctpc); +} diff --git a/reCTBN/tests/tools.rs b/reCTBN/tests/tools.rs new file mode 100644 index 0000000..59d8f27 --- /dev/null +++ b/reCTBN/tests/tools.rs @@ -0,0 +1,251 @@ +use std::ops::Range; + +use ndarray::{arr1, arr2, arr3}; +use reCTBN::params::ParamsTrait; +use reCTBN::process::ctbn::*; +use reCTBN::process::ctmp::*; +use reCTBN::process::NetworkProcess; +use reCTBN::params; +use reCTBN::tools::*; + +use utils::*; + +#[macro_use] +extern crate approx; + +mod utils; + +#[test] +fn run_sampling() { + #![allow(unused_must_use)] + let mut net = CtbnNetwork::new(); + let n1 = net + .add_node(utils::generate_discrete_time_continous_node( + String::from("n1"), + 2, + )) + .unwrap(); + let n2 = net + .add_node(utils::generate_discrete_time_continous_node( + String::from("n2"), + 2, + )) + .unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[ + [ + [-3.0, 3.0], + [2.0, -2.0] + ], + ])); + } + } + + match &mut net.get_node_mut(n2) { + params::Params::DiscreteStatesContinousTime(param) => { + param.set_cim(arr3(&[ + [ + [-1.0, 1.0], + [4.0, -4.0] + ], + [ + [-6.0, 6.0], + [2.0, -2.0] + ], + ])); + } + } + + let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259)); + + assert_eq!(4, data.get_trajectories().len()); + assert_relative_eq!( + 1.0, + data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1] + ); +} + +#[test] +#[should_panic] +fn trajectory_wrong_shape() { + let time = arr1(&[0.0, 0.2]); + let events = arr2(&[[0, 3]]); + Trajectory::new(time, events); +} + +#[test] +#[should_panic] +fn dataset_wrong_shape() { + let time = arr1(&[0.0, 0.2]); + let events = arr2(&[[0, 3], [1, 2]]); + let t1 = Trajectory::new(time, events); + + let time = arr1(&[0.0, 0.2]); + let events = arr2(&[[0, 3, 3], [1, 2, 3]]); + let t2 = Trajectory::new(time, events); + Dataset::new(vec![t1, t2]); +} + +#[test] +#[should_panic] +fn uniform_graph_generator_wrong_density_1() { + let density = 2.1; + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); +} + +#[test] +#[should_panic] +fn uniform_graph_generator_wrong_density_2() { + let density = -0.5; + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); +} + +#[test] +fn uniform_graph_generator_right_densities() { + for density in [1.0, 0.75, 0.5, 0.25, 0.0] { + let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + None + ); + } +} + +#[test] +fn uniform_graph_generator_generate_graph_ctbn() { + let mut net = CtbnNetwork::new(); + let nodes_cardinality = 0..=100; + let nodes_domain_cardinality = 2; + for node_label in nodes_cardinality { + net.add_node( + utils::generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); + structure_generator.generate_graph(&mut net); + let mut edges = 0; + for node in net.get_node_indices(){ + edges += net.get_children_set(node).len() + } + let nodes = net.get_node_indices().len() as f64; + let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; + let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance + // As the way `generate_graph()` is implemented we can only reasonably + // expect the number of edges to be somewhere around the expected value. + assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); +} + +#[test] +#[should_panic] +fn uniform_graph_generator_generate_graph_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let density = 1.0/3.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + Some(7641630759785120) + ); + structure_generator.generate_graph(&mut net); +} + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_1() { + let interval: Range = -2.0..-5.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); +} + +#[test] +#[should_panic] +fn uniform_parameters_generator_wrong_density_2() { + let interval: Range = -1.0..0.0; + let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + None + ); +} + +#[test] +fn uniform_parameters_generator_right_densities_ctbn() { + let mut net = CtbnNetwork::new(); + let nodes_cardinality = 0..=3; + let nodes_domain_cardinality = 9; + for node_label in nodes_cardinality { + net.add_node( + generate_discrete_time_continous_node( + node_label.to_string(), + nodes_domain_cardinality, + ) + ).unwrap(); + } + let density = 1.0/3.0; + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( + density, + seed + ); + structure_generator.generate_graph(&mut net); + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); + cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); + } +} + +#[test] +fn uniform_parameters_generator_right_densities_ctmp() { + let mut net = CtmpProcess::new(); + let node_label = String::from("0"); + let node_domain_cardinality = 4; + net.add_node( + generate_discrete_time_continous_node( + node_label, + node_domain_cardinality + ) + ).unwrap(); + let seed = Some(7641630759785120); + let interval = 0.0..7.0; + let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( + interval, + seed + ); + cim_generator.generate_parameters(&mut net); + for node in net.get_node_indices() { + assert_eq!( + Ok(()), + net.get_node(node).validate_params() + ); + } +} diff --git a/reCTBN/tests/utils.rs b/reCTBN/tests/utils.rs new file mode 100644 index 0000000..ed43215 --- /dev/null +++ b/reCTBN/tests/utils.rs @@ -0,0 +1,19 @@ +use std::collections::BTreeSet; + +use reCTBN::params; + +#[allow(dead_code)] +pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { + params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params( + label, + cardinality, + )) +} + +pub fn generate_discrete_time_continous_params( + label: String, + cardinality: usize, +) -> params::DiscreteStatesContinousTimeParams { + let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); + params::DiscreteStatesContinousTimeParams::new(label, domain) +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..367bc0b --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,7 @@ +# This file defines the Rust toolchain to use when a command is executed. +# See also https://rust-lang.github.io/rustup/overrides.html + +[toolchain] +channel = "stable" +components = [ "clippy", "rustfmt" ] +profile = "minimal" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..b6f1257 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,39 @@ +# This file defines the Rust style for automatic reformatting. +# See also https://rust-lang.github.io/rustfmt + +# NOTE: the unstable options will be uncommented when stabilized. + +# Version of the formatting rules to use. +#version = "One" + +# Number of spaces per tab. +tab_spaces = 4 + +max_width = 100 +#comment_width = 80 + +# Prevent carriage returns, admitted only \n. +newline_style = "Unix" + +# The "Default" setting has a heuristic which can split lines too aggresively. +#use_small_heuristics = "Max" + +# How imports should be grouped into `use` statements. +#imports_granularity = "Module" + +# How consecutive imports are grouped together. +#group_imports = "StdExternalCrate" + +# Error if unable to get all lines within max_width, except for comments and +# string literals. +#error_on_line_overflow = true + +# Error if unable to get comments or string literals within max_width, or they +# are left with trailing whitespaces. +#error_on_unformatted = true + +# Files to ignore like third party code which is formatted upstream. +# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors +ignore = [ + "tests/" +] diff --git a/src/ctbn.rs b/src/ctbn.rs deleted file mode 100644 index 9cabe20..0000000 --- a/src/ctbn.rs +++ /dev/null @@ -1,164 +0,0 @@ -use ndarray::prelude::*; -use crate::node; -use crate::params::{StateType, ParamsTrait}; -use crate::network; -use std::collections::BTreeSet; - - - - -///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is -///composed by the following elements: -///- **adj_metrix**: a 2d ndarray representing the adjacency matrix -///- **nodes**: a vector containing all the nodes and their parameters. -///The index of a node inside the vector is also used as index for the adj_matrix. -/// -///# Examples -/// -///``` -/// -/// use std::collections::BTreeSet; -/// use rustyCTBN::network::Network; -/// use rustyCTBN::node; -/// use rustyCTBN::params; -/// use rustyCTBN::ctbn::*; -/// -/// //Create the domain for a discrete node -/// let mut domain = BTreeSet::new(); -/// domain.insert(String::from("A")); -/// domain.insert(String::from("B")); -/// -/// //Create the parameters for a discrete node using the domain -/// let param = params::DiscreteStatesContinousTimeParams::init(domain); -/// -/// //Create the node using the parameters -/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1")); -/// -/// let mut domain = BTreeSet::new(); -/// domain.insert(String::from("A")); -/// domain.insert(String::from("B")); -/// let param = params::DiscreteStatesContinousTimeParams::init(domain); -/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2")); -/// -/// //Initialize a ctbn -/// let mut net = CtbnNetwork::init(); -/// -/// //Add nodes -/// let X1 = net.add_node(X1).unwrap(); -/// let X2 = net.add_node(X2).unwrap(); -/// -/// //Add an edge -/// net.add_edge(X1, X2); -/// -/// //Get all the children of node X1 -/// let cs = net.get_children_set(X1); -/// assert_eq!(&X2, cs.iter().next().unwrap()); -/// ``` -pub struct CtbnNetwork { - adj_matrix: Option>, - nodes: Vec -} - - -impl CtbnNetwork { - pub fn init() -> CtbnNetwork { - CtbnNetwork { - adj_matrix: None, - nodes: Vec::new() - } - } -} - -impl network::Network for CtbnNetwork { - fn initialize_adj_matrix(&mut self) { - self.adj_matrix = Some(Array2::::zeros((self.nodes.len(), self.nodes.len()).f())); - - } - - fn add_node(&mut self, mut n: node::Node) -> Result { - n.params.reset_params(); - self.adj_matrix = Option::None; - self.nodes.push(n); - Ok(self.nodes.len() -1) - } - - fn add_edge(&mut self, parent: usize, child: usize) { - if let None = self.adj_matrix { - self.initialize_adj_matrix(); - } - - if let Some(network) = &mut self.adj_matrix { - network[[parent, child]] = 1; - self.nodes[child].params.reset_params(); - } - } - - fn get_node_indices(&self) -> std::ops::Range{ - 0..self.nodes.len() - } - - fn get_number_of_nodes(&self) -> usize { - self.nodes.len() - } - - fn get_node(&self, node_idx: usize) -> &node::Node{ - &self.nodes[node_idx] - } - - - fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{ - &mut self.nodes[node_idx] - } - - - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ - self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { - if x.1 > &0 { - acc.0 += self.nodes[x.0].params.state_to_index(¤t_state[x.0]) * acc.1; - acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent(); - } - acc - }).0 - } - - - fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize { - parent_set.iter().fold((0, 1), |mut acc, x| { - acc.0 += self.nodes[*x].params.state_to_index(¤t_state[*x]) * acc.1; - acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent(); - acc - }).0 - } - - fn get_parent_set(&self, node: usize) -> BTreeSet { - self.adj_matrix.as_ref() - .unwrap() - .column(node) - .iter() - .enumerate() - .filter_map(|(idx, x)| { - if x > &0 { - Some(idx) - } else { - None - } - }).collect() - } - - fn get_children_set(&self, node: usize) -> BTreeSet{ - self.adj_matrix.as_ref() - .unwrap() - .row(node) - .iter() - .enumerate() - .filter_map(|(idx, x)| { - if x > &0 { - Some(idx) - } else { - None - } - }).collect() - } - -} - diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 65e4b11..0000000 --- a/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -#[cfg(test)] -#[macro_use] -extern crate approx; - -pub mod node; -pub mod params; -pub mod network; -pub mod ctbn; -pub mod tools; -pub mod parameter_learning; - diff --git a/src/network.rs b/src/network.rs deleted file mode 100644 index 3b6ce06..0000000 --- a/src/network.rs +++ /dev/null @@ -1,39 +0,0 @@ -use thiserror::Error; -use crate::params; -use crate::node; -use std::collections::BTreeSet; - -/// Error types for trait Network -#[derive(Error, Debug)] -pub enum NetworkError { - #[error("Error during node insertion")] - NodeInsertionError(String) -} - - -///Network -///The Network trait define the required methods for a structure used as pgm (such as ctbn). -pub trait Network { - fn initialize_adj_matrix(&mut self); - fn add_node(&mut self, n: node::Node) -> Result; - fn add_edge(&mut self, parent: usize, child: usize); - - ///Get all the indices of the nodes contained inside the network - fn get_node_indices(&self) -> std::ops::Range; - fn get_number_of_nodes(&self) -> usize; - fn get_node(&self, node_idx: usize) -> &node::Node; - fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node; - - ///Compute the index that must be used to access the parameters of a node given a specific - ///configuration of the network. Usually, the only values really used in *current_state* are - ///the ones in the parent set of the *node*. - fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; - - - ///Compute the index that must be used to access the parameters of a node given a specific - ///configuration of the network and a generic parent_set. Usually, the only values really used - ///in *current_state* are the ones in the parent set of the *node*. - fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize; - fn get_parent_set(&self, node: usize) -> BTreeSet; - fn get_children_set(&self, node: usize) -> BTreeSet; -} diff --git a/src/node.rs b/src/node.rs deleted file mode 100644 index 7ed21ba..0000000 --- a/src/node.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::params::*; - - -pub struct Node { - pub params: Params, - pub label: String -} - -impl Node { - pub fn init(params: Params, label: String) -> Node { - Node{ - params: params, - label:label - } - } - -} - -impl PartialEq for Node { - fn eq(&self, other: &Node) -> bool{ - self.label == other.label - } -} - - diff --git a/src/params.rs b/src/params.rs deleted file mode 100644 index c5a9acf..0000000 --- a/src/params.rs +++ /dev/null @@ -1,161 +0,0 @@ -use ndarray::prelude::*; -use rand::Rng; -use std::collections::{BTreeSet, HashMap}; -use thiserror::Error; -use enum_dispatch::enum_dispatch; - -/// Error types for trait Params -#[derive(Error, Debug)] -pub enum ParamsError { - #[error("Unsupported method")] - UnsupportedMethod(String), - #[error("Paramiters not initialized")] - ParametersNotInitialized(String), -} - -/// Allowed type of states -#[derive(Clone)] -pub enum StateType { - Discrete(usize), -} - -/// Parameters -/// The Params trait is the core element for building different types of nodes. The goal is to -/// define the set of method required to describes a generic node. -#[enum_dispatch(Params)] -pub trait ParamsTrait { - fn reset_params(&mut self); - - /// Randomly generate a possible state of the node disregarding the state of the node and it's - /// parents. - fn get_random_state_uniform(&self) -> StateType; - - /// Randomly generate a residence time for the given node taking into account the node state - /// and its parent set. - fn get_random_residence_time(&self, state: usize, u: usize) -> Result; - - /// Randomly generate a possible state for the given node taking into account the node state - /// and its parent set. - fn get_random_state(&self, state: usize, u: usize) -> Result; - - /// Used by childern of the node described by this parameters to reserve spaces in their CIMs. - fn get_reserved_space_as_parent(&self) -> usize; - - /// Index used by discrete node to represents their states as usize. - fn state_to_index(&self, state: &StateType) -> usize; -} - -/// The Params enum is the core element for building different types of nodes. The goal is to -/// define all the supported type of parameters. -#[enum_dispatch] -pub enum Params { - DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), -} - - -/// DiscreteStatesContinousTime. -/// This represents the parameters of a classical discrete node for ctbn and it's composed by the -/// following elements: -/// - **domain**: an ordered and exhaustive set of possible states -/// - **cim**: Conditional Intensity Matrix -/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter -/// learning task and are composed by: -/// - **transitions**: number of transitions from one state to another given a specific -/// realization of the parent set -/// - **residence_time**: permanence time in each possible states given a specific -/// realization of the parent set -pub struct DiscreteStatesContinousTimeParams { - pub domain: BTreeSet, - pub cim: Option>, - pub transitions: Option>, - pub residence_time: Option>, -} - -impl DiscreteStatesContinousTimeParams { - pub fn init(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { - DiscreteStatesContinousTimeParams { - domain: domain, - cim: Option::None, - transitions: Option::None, - residence_time: Option::None, - } - } -} - -impl ParamsTrait for DiscreteStatesContinousTimeParams { - fn reset_params(&mut self) { - self.cim = Option::None; - self.transitions = Option::None; - self.residence_time = Option::None; - } - - fn get_random_state_uniform(&self) -> StateType { - let mut rng = rand::thread_rng(); - StateType::Discrete(rng.gen_range(0..(self.domain.len()))) - } - - fn get_random_residence_time(&self, state: usize, u: usize) -> Result { - // Generate a random residence time given the current state of the node and its parent set. - // The method used is described in: - // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates - match &self.cim { - Option::Some(cim) => { - let mut rng = rand::thread_rng(); - let lambda = cim[[u, state, state]] * -1.0; - let x: f64 = rng.gen_range(0.0..=1.0); - Ok(-x.ln() / lambda) - } - Option::None => Err(ParamsError::ParametersNotInitialized(String::from( - "CIM not initialized", - ))), - } - } - - fn get_random_state(&self, state: usize, u: usize) -> Result { - // Generate a random transition given the current state of the node and its parent set. - // The method used is described in: - // https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution - match &self.cim { - Option::Some(cim) => { - let mut rng = rand::thread_rng(); - let lambda = cim[[u, state, state]] * -1.0; - let urand: f64 = rng.gen_range(0.0..=1.0); - - let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold( - (0, 0.0), - |mut acc, ele| { - if &acc.1 + ele < urand && ele > &0.0 { - acc.0 += 1; - } - if ele > &0.0 { - acc.1 += ele; - } - acc - }, - ); - - let next_state = if next_state.0 < state { - next_state.0 - } else { - next_state.0 + 1 - }; - - Ok(StateType::Discrete(next_state)) - } - Option::None => Err(ParamsError::ParametersNotInitialized(String::from( - "CIM not initialized", - ))), - } - } - - fn get_reserved_space_as_parent(&self) -> usize { - self.domain.len() - } - - fn state_to_index(&self, state: &StateType) -> usize { - match state { - StateType::Discrete(val) => val.clone() as usize, - } - } -} - diff --git a/src/tools.rs b/src/tools.rs deleted file mode 100644 index 27438f9..0000000 --- a/src/tools.rs +++ /dev/null @@ -1,119 +0,0 @@ -use crate::network; -use crate::node; -use crate::params; -use crate::params::ParamsTrait; -use ndarray::prelude::*; - -pub struct Trajectory { - pub time: Array1, - pub events: Array2, -} - -pub struct Dataset { - pub trajectories: Vec, -} - -pub fn trajectory_generator( - net: &T, - n_trajectories: u64, - t_end: f64, -) -> Dataset { - let mut dataset = Dataset { - trajectories: Vec::new(), - }; - - let node_idx: Vec<_> = net.get_node_indices().collect(); - for _ in 0..n_trajectories { - let mut t = 0.0; - let mut time: Vec = Vec::new(); - let mut events: Vec> = Vec::new(); - let mut current_state: Vec = node_idx - .iter() - .map(|x| net.get_node(*x).params.get_random_state_uniform()) - .collect(); - let mut next_transitions: Vec> = - (0..node_idx.len()).map(|_| Option::None).collect(); - events.push( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - ); - time.push(t.clone()); - while t < t_end { - for (idx, val) in next_transitions.iter_mut().enumerate() { - if let None = val { - *val = Some( - net.get_node(idx) - .params - .get_random_residence_time( - net.get_node(idx).params.state_to_index(¤t_state[idx]), - net.get_param_index_network(idx, ¤t_state), - ) - .unwrap() - + t, - ); - } - } - - let next_node_transition = next_transitions - .iter() - .enumerate() - .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) - .unwrap() - .0; - if next_transitions[next_node_transition].unwrap() > t_end { - break; - } - t = next_transitions[next_node_transition].unwrap().clone(); - time.push(t.clone()); - - current_state[next_node_transition] = net - .get_node(next_node_transition) - .params - .get_random_state( - net.get_node(next_node_transition) - .params - .state_to_index(¤t_state[next_node_transition]), - net.get_param_index_network(next_node_transition, ¤t_state), - ) - .unwrap(); - - events.push(Array::from_vec( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - )); - next_transitions[next_node_transition] = None; - - for child in net.get_children_set(next_node_transition) { - next_transitions[child] = None - } - } - - events.push( - current_state - .iter() - .map(|x| match x { - params::StateType::Discrete(state) => state.clone(), - }) - .collect(), - ); - time.push(t_end.clone()); - - dataset.trajectories.push(Trajectory { - time: Array::from_vec(time), - events: Array2::from_shape_vec( - (events.len(), current_state.len()), - events.iter().flatten().cloned().collect(), - ) - .unwrap(), - }); - } - dataset -} diff --git a/tests/ctbn.rs b/tests/ctbn.rs deleted file mode 100644 index 2d54f5f..0000000 --- a/tests/ctbn.rs +++ /dev/null @@ -1,99 +0,0 @@ -mod utils; -use utils::generate_discrete_time_continous_node; -use rustyCTBN::network::Network; -use rustyCTBN::node; -use rustyCTBN::params; -use std::collections::BTreeSet; -use rustyCTBN::ctbn::*; - -#[test] -fn define_simpe_ctbn() { - let _ = CtbnNetwork::init(); - assert!(true); -} - -#[test] -fn add_node_to_ctbn() { - let mut net = CtbnNetwork::init(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - assert_eq!(String::from("n1"), net.get_node(n1).label); -} - -#[test] -fn add_edge_to_ctbn() { - let mut net = CtbnNetwork::init(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - net.add_edge(n1, n2); - let cs = net.get_children_set(n1); - assert_eq!(&n2, cs.iter().next().unwrap()); -} - -#[test] -fn children_and_parents() { - let mut net = CtbnNetwork::init(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - net.add_edge(n1, n2); - let cs = net.get_children_set(n1); - assert_eq!(&n2, cs.iter().next().unwrap()); - let ps = net.get_parent_set(n2); - assert_eq!(&n1, ps.iter().next().unwrap()); -} - - -#[test] -fn compute_index_ctbn() { - let mut net = CtbnNetwork::init(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); - net.add_edge(n1, n2); - net.add_edge(n3, n2); - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(1), - params::StateType::Discrete(1), - params::StateType::Discrete(1)]); - assert_eq!(3, idx); - - - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(1), - params::StateType::Discrete(1)]); - assert_eq!(2, idx); - - - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(1), - params::StateType::Discrete(1), - params::StateType::Discrete(0)]); - assert_eq!(1, idx); - -} - - - -#[test] -fn compute_index_from_custom_parent_set() { - let mut net = CtbnNetwork::init(); - let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); - - - let idx = net.get_param_index_from_custom_parent_set(&vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(0), - params::StateType::Discrete(1)], - &BTreeSet::from([1])); - assert_eq!(0, idx); - - - let idx = net.get_param_index_from_custom_parent_set(&vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(0), - params::StateType::Discrete(1)], - &BTreeSet::from([1,2])); - assert_eq!(2, idx); -} diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs deleted file mode 100644 index a5cca51..0000000 --- a/tests/parameter_learning.rs +++ /dev/null @@ -1,263 +0,0 @@ -mod utils; -use utils::*; - -use rustyCTBN::parameter_learning::*; -use rustyCTBN::ctbn::*; -use rustyCTBN::network::Network; -use rustyCTBN::node; -use rustyCTBN::params; -use rustyCTBN::tools::*; -use ndarray::arr3; -use std::collections::BTreeSet; - - -#[macro_use] -extern crate approx; - - -fn learn_binary_cim (pl: T) { - let mut net = CtbnNetwork::init(); - let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) - .unwrap(); - let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),2)) - .unwrap(); - net.add_edge(n1, n2); - - match &mut net.get_node_mut(n1).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); - } - } - - match &mut net.get_node_mut(n2).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], - ])); - } - } - - let data = trajectory_generator(&net, 100, 100.0); - let (CIM, M, T) = pl.fit(&net, &data, 1, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [2, 2, 2]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.2)); -} - -#[test] -fn learn_binary_cim_MLE() { - let mle = MLE{}; - learn_binary_cim(mle); -} - - -#[test] -fn learn_binary_cim_BA() { - let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; - learn_binary_cim(ba); -} - -fn learn_ternary_cim (pl: T) { - let mut net = CtbnNetwork::init(); - let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) - .unwrap(); - let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) - .unwrap(); - net.add_edge(n1, n2); - - match &mut net.get_node_mut(n1).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); - } - } - - match &mut net.get_node_mut(n2).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); - } - } - - let data = trajectory_generator(&net, 100, 200.0); - let (CIM, M, T) = pl.fit(&net, &data, 1, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [3, 3, 3]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.2)); -} - - -#[test] -fn learn_ternary_cim_MLE() { - let mle = MLE{}; - learn_ternary_cim(mle); -} - - -#[test] -fn learn_ternary_cim_BA() { - let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; - learn_ternary_cim(ba); -} - -fn learn_ternary_cim_no_parents (pl: T) { - let mut net = CtbnNetwork::init(); - let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) - .unwrap(); - let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) - .unwrap(); - net.add_edge(n1, n2); - - match &mut net.get_node_mut(n1).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); - } - } - - match &mut net.get_node_mut(n2).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); - } - } - - let data = trajectory_generator(&net, 100, 200.0); - let (CIM, M, T) = pl.fit(&net, &data, 0, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [1, 3, 3]); - assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.2)); -} - - -#[test] -fn learn_ternary_cim_no_parents_MLE() { - let mle = MLE{}; - learn_ternary_cim_no_parents(mle); -} - - -#[test] -fn learn_ternary_cim_no_parents_BA() { - let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; - learn_ternary_cim_no_parents(ba); -} - - -fn learn_mixed_discrete_cim (pl: T) { - let mut net = CtbnNetwork::init(); - let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) - .unwrap(); - let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) - .unwrap(); - - let n3 = net - .add_node(generate_discrete_time_continous_node(String::from("n3"),4)) - .unwrap(); - net.add_edge(n1, n2); - net.add_edge(n1, n3); - net.add_edge(n2, n3); - - match &mut net.get_node_mut(n1).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]])); - } - } - - match &mut net.get_node_mut(n2).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ])); - } - } - - - match &mut net.get_node_mut(n3).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some(arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ])); - } - } - - - let data = trajectory_generator(&net, 300, 200.0); - let (CIM, M, T) = pl.fit(&net, &data, 2, None); - print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); - assert_eq!(CIM.shape(), [9, 4, 4]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.2)); -} - -#[test] -fn learn_mixed_discrete_cim_MLE() { - let mle = MLE{}; - learn_mixed_discrete_cim(mle); -} - - -#[test] -fn learn_mixed_discrete_cim_BA() { - let ba = BayesianApproach{ - default_alpha: 1, - default_tau: 1.0}; - learn_mixed_discrete_cim(ba); -} diff --git a/tests/params.rs b/tests/params.rs deleted file mode 100644 index ed601b2..0000000 --- a/tests/params.rs +++ /dev/null @@ -1,64 +0,0 @@ -use rustyCTBN::params::*; -use ndarray::prelude::*; -use std::collections::BTreeSet; - -mod utils; - -#[macro_use] -extern crate approx; - - -fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { - let mut params = utils::generate_discrete_time_continous_param(3); - - let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; - - params.cim = Some(cim); - params -} - -#[test] -fn test_uniform_generation() { - let param = create_ternary_discrete_time_continous_param(); - let mut states = Array1::::zeros(10000); - - states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state_uniform() { - val - } else { - panic!() - } - }); - let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; - - assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01); -} - -#[test] -fn test_random_generation_state() { - let param = create_ternary_discrete_time_continous_param(); - let mut states = Array1::::zeros(10000); - - states.mapv_inplace(|_| { - if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { - val - } else { - panic!() - } - }); - let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; - let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; - - assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01); - assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01); -} - -#[test] -fn test_random_generation_residence_time() { - let param = create_ternary_discrete_time_continous_param(); - let mut states = Array1::::zeros(10000); - - states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); - - assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); -} diff --git a/tests/tools.rs b/tests/tools.rs deleted file mode 100644 index 802e2fe..0000000 --- a/tests/tools.rs +++ /dev/null @@ -1,45 +0,0 @@ - -use rustyCTBN::tools::*; -use rustyCTBN::network::Network; -use rustyCTBN::ctbn::*; -use rustyCTBN::node; -use rustyCTBN::params; -use std::collections::BTreeSet; -use ndarray::arr3; - - - -#[macro_use] -extern crate approx; - -mod utils; - -#[test] -fn run_sampling() { - let mut net = CtbnNetwork::init(); - let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - net.add_edge(n1, n2); - - match &mut net.get_node_mut(n1).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); - } - } - - - match &mut net.get_node_mut(n2).params { - params::Params::DiscreteStatesContinousTime(param) => { - param.cim = Some (arr3(&[ - [[-1.0,1.0],[4.0,-4.0]], - [[-6.0,6.0],[2.0,-2.0]]])); - } - } - - let data = trajectory_generator(&net, 4, 1.0); - - assert_eq!(4, data.trajectories.len()); - assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); -} - - diff --git a/tests/utils.rs b/tests/utils.rs deleted file mode 100644 index a973926..0000000 --- a/tests/utils.rs +++ /dev/null @@ -1,16 +0,0 @@ -use rustyCTBN::params; -use rustyCTBN::node; -use std::collections::BTreeSet; - -pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { - node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) -} - - -pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ - let mut domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); - params::DiscreteStatesContinousTimeParams::init(domain) -} - - -