Merge branch 'dev' into 'meta-license'

Syncing with dev branch
pull/22/head
Meliurwen 2 years ago
commit 57ae851470
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 26
      .github/ISSUE_TEMPLATE/meta_request.md
  2. 26
      .github/ISSUE_TEMPLATE/refactor_request.md
  3. 36
      .github/labels.yml
  4. 2
      .github/pull_request_template.md
  5. 53
      .github/workflows/build.yml
  6. 23
      .github/workflows/labels.yml
  7. 1
      .gitignore
  8. 20
      Cargo.toml
  9. 28
      README.md
  10. 20
      reCTBN/Cargo.toml
  11. 12
      reCTBN/src/lib.rs
  12. 123
      reCTBN/src/parameter_learning.rs
  13. 287
      reCTBN/src/params.rs
  14. 120
      reCTBN/src/process.rs
  15. 247
      reCTBN/src/process/ctbn.rs
  16. 114
      reCTBN/src/process/ctmp.rs
  17. 59
      reCTBN/src/reward.rs
  18. 205
      reCTBN/src/reward/reward_evaluation.rs
  19. 106
      reCTBN/src/reward/reward_function.rs
  20. 133
      reCTBN/src/sampling.rs
  21. 13
      reCTBN/src/structure_learning.rs
  22. 348
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  23. 261
      reCTBN/src/structure_learning/hypothesis_test.rs
  24. 93
      reCTBN/src/structure_learning/score_based_algorithm.rs
  25. 146
      reCTBN/src/structure_learning/score_function.rs
  26. 355
      reCTBN/src/tools.rs
  27. 376
      reCTBN/tests/ctbn.rs
  28. 127
      reCTBN/tests/ctmp.rs
  29. 648
      reCTBN/tests/parameter_learning.rs
  30. 148
      reCTBN/tests/params.rs
  31. 122
      reCTBN/tests/reward_evaluation.rs
  32. 117
      reCTBN/tests/reward_function.rs
  33. 692
      reCTBN/tests/structure_learning.rs
  34. 251
      reCTBN/tests/tools.rs
  35. 19
      reCTBN/tests/utils.rs
  36. 7
      rust-toolchain.toml
  37. 39
      rustfmt.toml
  38. 164
      src/ctbn.rs
  39. 11
      src/lib.rs
  40. 39
      src/network.rs
  41. 25
      src/node.rs
  42. 161
      src/params.rs
  43. 119
      src/tools.rs
  44. 99
      tests/ctbn.rs
  45. 263
      tests/parameter_learning.rs
  46. 64
      tests/params.rs
  47. 45
      tests/tools.rs
  48. 16
      tests/utils.rs

@ -0,0 +1,26 @@
---
name: 📑 Meta request
about: Suggest an idea or a change for this same repository
title: '[Meta] '
labels: 'meta'
assignees: ''
---
## Description
As a X, I want to Y, so Z.
## Acceptance Criteria
* Criteria 1
* Criteria 2
## Checklist
* [ ] Element 1
* [ ] Element 2
## (Optional) Extra info
None

@ -0,0 +1,26 @@
---
name: ⚙ Refactor request
about: Suggest a refactor for this project
title: '[Refactor] '
labels: enhancement, refactor
assignees: ''
---
## Description
As a X, I want to Y, so Z.
## Acceptance Criteria
* Criteria 1
* Criteria 2
## Checklist
* [ ] Element 1
* [ ] Element 2
## (Optional) Extra info
None

@ -0,0 +1,36 @@
- name: "bug"
color: "d73a4a"
description: "Something isn't working"
- name: "enhancement"
color: "a2eeef"
description: "New feature or request"
- name: "refactor"
color: "B06E16"
description: "Change in the structure"
- name: "documentation"
color: "0075ca"
description: "Improvements or additions to documentation"
- name: "meta"
color: "1D76DB"
description: "Something related to the project itself"
- name: "duplicate"
color: "cfd3d7"
description: "This issue or pull request already exists"
- name: "help wanted"
color: "008672"
description: "Extra help is needed"
- name: "urgent"
color: "D93F0B"
description: ""
- name: "wontfix"
color: "ffffff"
description: "This will not be worked on"
- name: "invalid"
color: "e4e669"
description: "This doesn't seem right"
- name: "question"
color: "d876e3"
description: "Further information is requested"

@ -1,4 +1,4 @@
# Pull/Merge Request into master dev <!-- # Pull/Merge Request into dev -->
## Description ## Description

@ -0,0 +1,53 @@
name: build
on:
push:
branches: [ main, dev ]
pull_request:
branches: [ dev ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Rust stable (default)
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
default: true
components: clippy, rustfmt, rust-docs
- name: Setup Rust nightly
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
default: false
components: rustfmt
- name: Docs (doc)
uses: actions-rs/cargo@v1
with:
command: rustdoc
args: --package reCTBN -- --default-theme=ayu
- name: Linting (clippy)
uses: actions-rs/clippy-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --all-targets -- -D warnings -A clippy::all -W clippy::correctness
- name: Formatting (rustfmt)
uses: actions-rs/cargo@v1
with:
toolchain: nightly
command: fmt
args: --all -- --check --verbose
- name: Tests (test)
uses: actions-rs/cargo@v1
with:
command: test
args: --tests

@ -0,0 +1,23 @@
name: meta-github
on:
push:
branches:
- dev
jobs:
labeler:
runs-on: ubuntu-latest
steps:
-
name: Checkout
uses: actions/checkout@v2
-
name: Run Labeler
if: success()
uses: crazy-max/ghaction-github-labeler@v3
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
yaml-file: .github/labels.yml
skip-delete: false
dry-run: false

1
.gitignore vendored

@ -1,2 +1,3 @@
/target /target
Cargo.lock Cargo.lock
.vscode

@ -1,17 +1,5 @@
[package] [workspace]
name = "rustyCTBN"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html members = [
"reCTBN",
[dependencies] ]
ndarray = {version="*", features=["approx"]}
thiserror = "*"
rand = "*"
bimap = "*"
enum_dispatch = "*"
[dev-dependencies]
approx = "*"

@ -1,6 +1,6 @@
<div align="center"> <div align="center">
# rustyCTBN # reCTBN
</div> </div>
@ -37,8 +37,30 @@ To launch **tests**:
cargo test cargo test
``` ```
To **lint**: To **lint** with `cargo check`:
```sh ```sh
cargo check cargo check --all-targets
```
Or with `clippy`:
```sh
cargo clippy --all-targets -- -A clippy::all -W clippy::correctness
```
To check the **formatting**:
> **NOTE:** remove `--check` to apply the changes to the file(s).
```sh
cargo fmt --all -- --check
```
## Documentation
To generate the **documentation**:
```sh
cargo rustdoc --package reCTBN --open -- --default-theme=ayu
``` ```

@ -0,0 +1,20 @@
[package]
name = "reCTBN"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ndarray = {version="~0.15", features=["approx-0_5"]}
thiserror = "1.0.37"
rand = "~0.8"
bimap = "~0.6"
enum_dispatch = "~0.3"
statrs = "~0.16"
rand_chacha = "~0.3"
itertools = "~0.10"
rayon = "~1.6"
[dev-dependencies]
approx = { package = "approx", version = "~0.5" }

@ -0,0 +1,12 @@
#![doc = include_str!("../../README.md")]
#![allow(non_snake_case)]
#[cfg(test)]
extern crate approx;
pub mod parameter_learning;
pub mod params;
pub mod process;
pub mod reward;
pub mod sampling;
pub mod structure_learning;
pub mod tools;

@ -1,40 +1,35 @@
use crate::network; //! Module containing methods used to learn the parameters.
use crate::params::*;
use crate::tools;
use ndarray::prelude::*;
use ndarray::{concatenate, Slice};
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub trait ParameterLearning{ use ndarray::prelude::*;
fn fit<T:network::Network>(
use crate::params::*;
use crate::{process, tools::Dataset};
pub trait ParameterLearning: Sync {
fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>); ) -> Params;
} }
pub fn sufficient_statistics<T:network::Network>( pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: &BTreeSet<usize> parent_set: &BTreeSet<usize>,
) -> (Array3<usize>, Array2<f64>) { ) -> (Array3<usize>, Array2<f64>) {
//Get the number of values assumable by the node //Get the number of values assumable by the node
let node_domain = net let node_domain = net.get_node(node.clone()).get_reserved_space_as_parent();
.get_node(node.clone())
.params
.get_reserved_space_as_parent();
//Get the number of values assumable by each parent of the node //Get the number of values assumable by each parent of the node
let parentset_domain: Vec<usize> = parent_set let parentset_domain: Vec<usize> = parent_set
.iter() .iter()
.map(|x| { .map(|x| net.get_node(x.clone()).get_reserved_space_as_parent())
net.get_node(x.clone())
.params
.get_reserved_space_as_parent()
})
.collect(); .collect();
//Vector used to convert a specific configuration of the parent_set to the corresponding index //Vector used to convert a specific configuration of the parent_set to the corresponding index
@ -48,7 +43,7 @@ pub fn sufficient_statistics<T:network::Network>(
vector_to_idx[*idx] = acc; vector_to_idx[*idx] = acc;
acc * x acc * x
}); });
//Number of transition given a specific configuration of the parent set //Number of transition given a specific configuration of the parent set
let mut M: Array3<usize> = let mut M: Array3<usize> =
Array::zeros((parentset_domain.iter().product(), node_domain, node_domain)); Array::zeros((parentset_domain.iter().product(), node_domain, node_domain));
@ -57,12 +52,12 @@ pub fn sufficient_statistics<T:network::Network>(
let mut T: Array2<f64> = Array::zeros((parentset_domain.iter().product(), node_domain)); let mut T: Array2<f64> = Array::zeros((parentset_domain.iter().product(), node_domain));
//Compute the sufficient statistics //Compute the sufficient statistics
for trj in dataset.trajectories.iter() { for trj in dataset.get_trajectories().iter() {
for idx in 0..(trj.time.len() - 1) { for idx in 0..(trj.get_time().len() - 1) {
let t1 = trj.time[idx]; let t1 = trj.get_time()[idx];
let t2 = trj.time[idx + 1]; let t2 = trj.get_time()[idx + 1];
let ev1 = trj.events.row(idx); let ev1 = trj.get_events().row(idx);
let ev2 = trj.events.row(idx + 1); let ev2 = trj.get_events().row(idx + 1);
let idx1 = vector_to_idx.dot(&ev1); let idx1 = vector_to_idx.dot(&ev1);
T[[idx1, ev1[node]]] += t2 - t1; T[[idx1, ev1[node]]] += t2 - t1;
@ -73,34 +68,30 @@ pub fn sufficient_statistics<T:network::Network>(
} }
return (M, T); return (M, T);
} }
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
fn fit<T: process::NetworkProcess>(
fn fit<T: network::Network>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> Params {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network. //Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set { let parent_set = match parent_set {
Some(p) => p, Some(p) => p,
None => net.get_parent_set(node), None => net.get_parent_set(node),
}; };
let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
//Compute the CIM as M[i,x,y]/T[i,x] //Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T))); .for_each(|(mut C, m)| C.assign(&(&m / &T)));
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
@ -109,39 +100,53 @@ impl ParameterLearning for MLE {
.for_each(|(mut C, diag)| { .for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
return (CIM, M, T);
let mut n: Params = net.get_node(node).clone();
match n {
Params::DiscreteStatesContinousTime(ref mut dsct) => {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
} }
} }
pub struct BayesianApproach { pub struct BayesianApproach {
pub default_alpha: usize, pub alpha: usize,
pub default_tau: f64 pub tau: f64,
} }
impl ParameterLearning for BayesianApproach { impl ParameterLearning for BayesianApproach {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> (Array3<f64>, Array3<usize>, Array2<f64>) { ) -> Params {
//TODO: make this function general. Now it works only on ContinousTimeDiscreteState nodes
//Use parent_set from parameter if present. Otherwise use parent_set from network. //Use parent_set from parameter if present. Otherwise use parent_set from network.
let parent_set = match parent_set { let parent_set = match parent_set {
Some(p) => p, Some(p) => p,
None => net.get_parent_set(node), None => net.get_parent_set(node),
}; };
let (mut M, mut T) = sufficient_statistics(net, dataset, node.clone(), &parent_set); let (M, T) = sufficient_statistics(net, dataset, node.clone(), &parent_set);
M.mapv_inplace(|x|{x + self.default_alpha});
T.mapv_inplace(|x|{x + self.default_tau}); let alpha: f64 = self.alpha as f64 / M.shape()[0] as f64;
//Compute the CIM as M[i,x,y]/T[i,x] let tau: f64 = self.tau as f64 / M.shape()[0] as f64;
//Compute the CIM as M[i,x,y]/T[i,x]
let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2])); let mut CIM: Array3<f64> = Array::zeros((M.shape()[0], M.shape()[1], M.shape()[2]));
CIM.axis_iter_mut(Axis(2)) CIM.axis_iter_mut(Axis(2))
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m/&T))); .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau))));
CIM.outer_iter_mut().for_each(|mut C| {
C.diag_mut().fill(0.0);
});
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
@ -150,6 +155,16 @@ impl ParameterLearning for BayesianApproach {
.for_each(|(mut C, diag)| { .for_each(|(mut C, diag)| {
C.diag_mut().assign(&diag); C.diag_mut().assign(&diag);
}); });
return (CIM, M, T);
let mut n: Params = net.get_node(node).clone();
match n {
Params::DiscreteStatesContinousTime(ref mut dsct) => {
dsct.set_cim_unchecked(CIM);
dsct.set_transitions(M);
dsct.set_residence_time(T);
}
};
return n;
} }
} }

@ -0,0 +1,287 @@
//! Module containing methods to define different types of nodes.
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch;
use ndarray::prelude::*;
use rand::Rng;
use rand_chacha::ChaCha8Rng;
use thiserror::Error;
/// Error types for trait Params
#[derive(Error, Debug, PartialEq)]
pub enum ParamsError {
#[error("Unsupported method")]
UnsupportedMethod(String),
#[error("Paramiters not initialized")]
ParametersNotInitialized(String),
#[error("Invalid cim for parameter")]
InvalidCIM(String),
}
/// Allowed type of states
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub enum StateType {
Discrete(usize),
}
/// This is a core element for building different types of nodes; the goal is to define the set of
/// methods required to describes a generic node.
#[enum_dispatch(Params)]
pub trait ParamsTrait {
fn reset_params(&mut self);
/// Randomly generate a possible state of the node disregarding the state of the node and it's
/// parents.
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType;
/// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set.
fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set.
fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize;
/// Index used by discrete node to represents their states as usize.
fn state_to_index(&self, state: &StateType) -> usize;
/// Validate parameters against domain
fn validate_params(&self) -> Result<(), ParamsError>;
/// Return a reference to the associated label
fn get_label(&self) -> &String;
}
/// Is a core element for building different types of nodes; the goal is to define all the
/// supported type of Parameters
#[derive(Clone)]
#[enum_dispatch]
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements.
///
/// # Arguments
///
/// * `label` - node's variable name.
/// * `domain` - an ordered and exhaustive set of possible states.
/// * `cim` - Conditional Intensity Matrix.
/// * `transitions` - number of transitions from one state to another given a specific realization
/// of the parent set; is a sufficient statistics are mainly used during the parameter learning
/// task.
/// * `residence_time` - residence time in each possible state, given a specific realization of the
/// parent set; is a sufficient statistics are mainly used during the parameter learning task.
#[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams {
label: String,
domain: BTreeSet<String>,
cim: Option<Array3<f64>>,
transitions: Option<Array3<usize>>,
residence_time: Option<Array2<f64>>,
}
impl DiscreteStatesContinousTimeParams {
pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams {
label,
domain,
cim: Option::None,
transitions: Option::None,
residence_time: Option::None,
}
}
/// Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> {
&self.cim
}
/// Setter function for CIM.
///
/// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method:
/// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`.
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
/// `ParamsError`.
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
self.cim = Some(cim);
match self.validate_params() {
Ok(()) => Ok(()),
Err(e) => {
self.cim = None;
Err(e)
}
}
}
/// Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim);
}
/// Getter function for transitions.
pub fn get_transitions(&self) -> &Option<Array3<usize>> {
&self.transitions
}
/// Setter function for transitions.
pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions);
}
/// Getter function for residence_time.
pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
&self.residence_time
}
/// Setter function for residence_time.
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time);
}
}
impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn reset_params(&mut self) {
self.cim = Option::None;
self.transitions = Option::None;
self.residence_time = Option::None;
}
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType {
StateType::Discrete(rng.gen_range(0..(self.domain.len())))
}
fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
match &self.cim {
Option::Some(cim) => {
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda)
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
}
}
fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
match &self.cim {
Option::Some(cim) => {
let lambda = cim[[u, state, state]] * -1.0;
let urand: f64 = rng.gen_range(0.0..=1.0);
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
(0, 0.0),
|mut acc, ele| {
if &acc.1 + ele < urand && ele > &0.0 {
acc.0 += 1;
}
if ele > &0.0 {
acc.1 += ele;
}
acc
},
);
let next_state = if next_state.0 < state {
next_state.0
} else {
next_state.0 + 1
};
Ok(StateType::Discrete(next_state))
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
}
}
fn get_reserved_space_as_parent(&self) -> usize {
self.domain.len()
}
fn state_to_index(&self, state: &StateType) -> usize {
match state {
StateType::Discrete(val) => val.clone() as usize,
}
}
fn validate_params(&self) -> Result<(), ParamsError> {
let domain_size = self.domain.len();
// Check if the cim is initialized
if let None = self.cim {
return Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)));
}
let cim = self.cim.as_ref().unwrap();
// Check if the inner dimensions of the cim are equal to the cardinality of the variable
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size {
return Err(ParamsError::InvalidCIM(format!(
"Incompatible shape {:?} with domain {:?}",
cim.shape(),
domain_size
)));
}
// Check if the diagonal of each cim is non-positive
if cim
.axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0))
{
return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
)));
}
// Check if each row sum up to 0
if cim
.sum_axis(Axis(2))
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{
return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0",
)));
}
return Ok(());
}
fn get_label(&self) -> &String {
&self.label
}
}

@ -0,0 +1,120 @@
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
pub mod ctbn;
pub mod ctmp;
use std::collections::BTreeSet;
use thiserror::Error;
use crate::params;
/// Error types for trait Network
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Error during node insertion")]
NodeInsertionError(String),
}
/// This type is used to represent a specific realization of a generic NetworkProcess
pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN).
pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
///
/// # Arguments
///
/// * `parent` - parent node.
/// * `child` - child node.
fn add_edge(&mut self, parent: usize, child: usize);
/// Get all the indices of the nodes contained inside the network.
fn get_node_indices(&self) -> std::ops::Range<usize>;
/// Get the numbers of nodes contained in the network.
fn get_number_of_nodes(&self) -> usize;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node param**.
fn get_node(&self, node_idx: usize) -> &params::Params;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node mutable param**.
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `node` - selected node.
/// * `current_state` - current configuration of the network.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize;
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `current_state` - current configuration of the network.
/// * `parent_set` - parent set of the selected `node`.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
/// Get the **parent set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **parent set** of the selected node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
/// Get the **children set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **children set** of the selected node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -0,0 +1,247 @@
//! Continuous Time Bayesian Network
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
use crate::process;
use super::ctmp::CtmpProcess;
use super::{NetworkProcess, NetworkProcessState};
/// It represents both the structure and the parameters of a CTBN.
///
/// # Arguments
///
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
/// * `nodes` - A vector containing all the nodes and their parameters.
///
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
///
/// # Example
///
/// ```rust
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Get all the children of node X1
/// let cs = net.get_children_set(X1);
/// assert_eq!(&X2, cs.iter().next().unwrap());
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params>,
}
impl CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new(),
}
}
///Transform the **CTBN** into a **CTMP**
///
/// # Return
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess {
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
let state_space = variables_domain.product();
let variables_set = BTreeSet::from_iter(self.get_node_indices());
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space));
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: NetworkProcessState = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
for idx_node in 0..self.nodes.len() {
let p = match self.get_node(idx_node) {
Params::DiscreteStatesContinousTime(p) => p,
};
for next_node_state in 0..variables_domain[idx_node] {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
&variables_set,
);
amalgamated_cim[[0, idx_current_state, idx_next_state]] +=
p.get_cim().as_ref().unwrap()[[
self.get_param_index_network(idx_node, &current_state_statetype),
current_state[idx_node],
next_node_state,
]];
}
}
}
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new(
"ctmp".to_string(),
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param))
.unwrap();
return ctmp;
}
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> {
let mut state = state;
let mut array_state = Array1::zeros(variables_domain.shape()[0]);
for (idx, var) in variables_domain.indexed_iter() {
array_state[idx] = state % var;
state = state / var;
}
return array_state;
}
/// Get the Adjacency Matrix.
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> {
self.adj_matrix.as_ref()
}
}
impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() - 1)
}
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len()
}
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
}
fn get_param_index_from_custom_parent_set(
&self,
current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set
.iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
}
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
}

@ -0,0 +1,114 @@
use std::collections::BTreeSet;
use crate::{
params::{Params, StateType},
process,
};
use super::{NetworkProcess, NetworkProcessState};
pub struct CtmpProcess {
param: Option<Params>,
}
impl CtmpProcess {
pub fn new() -> CtmpProcess {
CtmpProcess { param: None }
}
}
impl NetworkProcess for CtmpProcess {
fn initialize_adj_matrix(&mut self) {
unimplemented!("CtmpProcess has only one node")
}
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> {
match self.param {
None => {
self.param = Some(n);
Ok(0)
}
Some(_) => Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(),
)),
}
}
fn add_edge(&mut self, _parent: usize, _child: usize) {
unimplemented!("CtmpProcess has only one node")
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
match self.param {
None => 0..0,
Some(_) => 0..1,
}
}
fn get_number_of_nodes(&self) -> usize {
match self.param {
None => 0,
Some(_) => 1,
}
}
fn get_node(&self, node_idx: usize) -> &crate::params::Params {
if node_idx == 0 {
self.param.as_ref().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params {
if node_idx == 0 {
self.param.as_mut().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
if node == 0 {
match current_state[0] {
StateType::Discrete(x) => x,
}
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_from_custom_parent_set(
&self,
_current_state: &NetworkProcessState,
_parent_set: &BTreeSet<usize>,
) -> usize {
unimplemented!("CtmpProcess has only one node")
}
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
}

@ -0,0 +1,59 @@
pub mod reward_evaluation;
pub mod reward_function;
use std::collections::HashMap;
use crate::process;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction: Sync {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
pub trait RewardEvaluation {
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64>;
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &process::NetworkProcessState,
) -> f64;
}

@ -0,0 +1,205 @@
use std::collections::HashMap;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use statrs::distribution::ContinuousCDF;
use crate::params::{self, ParamsTrait};
use crate::process;
use crate::{
process::NetworkProcessState,
reward::RewardEvaluation,
sampling::{ForwardSampler, Sampler},
};
pub enum RewardCriteria {
FiniteHorizon,
InfiniteHorizon { discount_factor: f64 },
}
pub struct MonteCarloReward {
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
}
impl MonteCarloReward {
pub fn new(
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
) -> MonteCarloReward {
MonteCarloReward {
max_iterations,
max_err_stop,
alpha_stop,
end_time,
reward_criteria,
seed,
}
}
}
impl RewardEvaluation for MonteCarloReward {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let variables_domain: Vec<Vec<params::StateType>> = network_process
.get_node_indices()
.map(|x| match network_process.get_node(x) {
params::Params::DiscreteStatesContinousTime(x) => (0..x
.get_reserved_space_as_parent())
.map(|s| params::StateType::Discrete(s))
.collect(),
})
.collect();
let n_states: usize = variables_domain.iter().map(|x| x.len()).product();
(0..n_states)
.into_par_iter()
.map(|s| {
let state: process::NetworkProcessState = variables_domain
.iter()
.fold((s, vec![]), |acc, x| {
let mut acc = acc;
let idx_s = acc.0 % x.len();
acc.1.push(x[idx_s].clone());
acc.0 = acc.0 / x.len();
acc
})
.1;
let r = self.evaluate_state(network_process, reward_function, &state);
(state, r)
})
.collect()
}
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &NetworkProcessState,
) -> f64 {
let mut sampler =
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone()));
let mut expected_value = 0.0;
let mut squared_expected_value = 0.0;
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap();
for i in 0..self.max_iterations {
sampler.reset();
let mut ret = 0.0;
let mut previous = sampler.next().unwrap();
while previous.t < self.end_time {
let current = sampler.next().unwrap();
if current.t > self.end_time {
let r = reward_function.call(&previous.state, None);
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => self.end_time - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * self.end_time)
}
};
ret += discount * r.instantaneous_reward;
} else {
let r = reward_function.call(&previous.state, Some(&current.state));
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => current.t - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * current.t)
}
};
ret += discount * r.instantaneous_reward;
ret += match self.reward_criteria {
RewardCriteria::FiniteHorizon => 1.0,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * current.t)
}
} * r.transition_reward;
}
previous = current;
}
let float_i = i as f64;
expected_value =
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0);
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0)
+ ret.powi(2) / (float_i + 1.0);
if i > 2 {
let var =
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2));
if self.alpha_stop
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt())
> 0.0
{
return expected_value;
}
}
}
expected_value
}
}
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> {
inner_reward: RE,
}
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> {
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> {
NeighborhoodRelativeReward { inner_reward }
}
}
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let absolute_reward = self
.inner_reward
.evaluate_state_space(network_process, reward_function);
//This approach optimize memory. Maybe optimizing execution time can be better.
absolute_reward
.iter()
.map(|(k1, v1)| {
let mut max_val: f64 = 1.0;
absolute_reward.iter().for_each(|(k2, v2)| {
let count_diff: usize = k1
.iter()
.zip(k2.iter())
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 })
.sum();
if count_diff < 2 {
max_val = max_val.max(v1 / v2);
}
});
(k1.clone(), max_val)
})
.collect()
}
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
_network_process: &N,
_reward_function: &R,
_state: &process::NetworkProcessState,
) -> f64 {
unimplemented!();
}
}

@ -0,0 +1,106 @@
//! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process,
reward::{Reward, RewardFunction},
};
use ndarray;
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>,
}
impl FactoredRewardFunction {
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> {
&self.transition_reward[node_idx]
}
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> {
&mut self.transition_reward[node_idx]
}
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> {
&self.instantaneous_reward[node_idx]
}
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(
&self,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state
.iter()
.zip(current_state.iter())
.enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p {
params::StateType::Discrete(p) => p,
};
let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else {
Reward {
transition_reward: 0.0,
instantaneous_reward,
}
}
}
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() {
//This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent();
instantaneous_reward.push(ndarray::Array1::zeros(size));
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
}
}

@ -0,0 +1,133 @@
//! Module containing methods for the sampling.
use crate::{
params::ParamsTrait,
process::{NetworkProcess, NetworkProcessState},
};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: NetworkProcessState,
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self);
}
pub struct ForwardSampler<'a, T>
where
T: NetworkProcess,
{
net: &'a T,
rng: ChaCha8Rng,
current_time: f64,
current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>,
initial_state: Option<NetworkProcessState>,
}
impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
pub fn new(
net: &'a T,
seed: Option<u64>,
initial_state: Option<NetworkProcessState>,
) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed),
//Otherwise create a new random generator using the method `from_entropy`
None => SeedableRng::from_entropy(),
};
let mut fs = ForwardSampler {
net,
rng,
current_time: 0.0,
current_state: vec![],
next_transitions: vec![],
initial_state,
};
fs.reset();
return fs;
}
}
impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone();
let ret_state = self.current_state.clone();
for (idx, val) in self.next_transitions.iter_mut().enumerate() {
if let None = val {
*val = Some(
self.net
.get_node(idx)
.get_random_residence_time(
self.net
.get_node(idx)
.state_to_index(&self.current_state[idx]),
self.net.get_param_index_network(idx, &self.current_state),
&mut self.rng,
)
.unwrap()
+ self.current_time,
);
}
}
let next_node_transition = self
.next_transitions
.iter()
.enumerate()
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap()
.0;
self.current_time = self.next_transitions[next_node_transition].unwrap().clone();
self.current_state[next_node_transition] = self
.net
.get_node(next_node_transition)
.get_random_state(
self.net
.get_node(next_node_transition)
.state_to_index(&self.current_state[next_node_transition]),
self.net
.get_param_index_network(next_node_transition, &self.current_state),
&mut self.rng,
)
.unwrap();
self.next_transitions[next_node_transition] = None;
for child in self.net.get_children_set(next_node_transition) {
self.next_transitions[child] = None;
}
Some(Sample {
t: ret_time,
state: ret_state,
})
}
}
impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) {
self.current_time = 0.0;
match &self.initial_state {
None => {
self.current_state = self
.net
.get_node_indices()
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng))
.collect()
}
Some(is) => self.current_state = is.clone(),
};
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect();
}
}

@ -0,0 +1,13 @@
//! Learn the structure of the network.
pub mod constraint_based_algorithm;
pub mod hypothesis_test;
pub mod score_based_algorithm;
pub mod score_function;
use crate::{process, tools::Dataset};
pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where
T: process::NetworkProcess;
}

@ -0,0 +1,348 @@
//! Module containing constraint based algorithms like CTPC and Hiton.
use crate::params::Params;
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
use std::collections::{BTreeSet, HashMap};
use std::mem;
use std::usize;
use super::hypothesis_test::*;
use crate::parameter_learning::ParameterLearning;
use crate::process;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools::Dataset;
pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P,
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>,
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>,
parent_set_size_small: usize,
}
impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache {
parameter_learning,
cache_persistent_small: HashMap::new(),
cache_persistent_big: HashMap::new(),
parent_set_size_small: 0,
}
}
pub fn fit<T: process::NetworkProcess>(
&mut self,
net: &T,
dataset: &Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
let parent_set_len = parent_set.as_ref().unwrap().len();
if parent_set_len > self.parent_set_size_small + 1 {
//self.cache_persistent_small = self.cache_persistent_big;
mem::swap(
&mut self.cache_persistent_small,
&mut self.cache_persistent_big,
);
self.cache_persistent_big = HashMap::new();
self.parent_set_size_small += 1;
}
if parent_set_len > self.parent_set_size_small {
match self.cache_persistent_big.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_big.insert(parent_set, params.clone());
params
}
}
} else {
match self.cache_persistent_small.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_small
.insert(parent_set, params.clone());
params
}
}
}
}
}
/// Continuous-Time Peter Clark algorithm.
///
/// A method to learn the structure of the network.
///
/// # Arguments
///
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// use reCTBN::parameter_learning::BayesianApproach;
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
/// #
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("A"));
/// # domain.insert(String::from("B"));
/// # domain.insert(String::from("C"));
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
/// # //Create the node n1 using the parameters
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("D"));
/// # domain.insert(String::from("E"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("G"));
/// # domain.insert(String::from("H"));
/// # domain.insert(String::from("I"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # // Initialize a ctbn
/// # let mut net = CtbnNetwork::new();
/// #
/// # // Add the nodes and their edges
/// # let n1 = net.add_node(n1).unwrap();
/// # let n2 = net.add_node(n2).unwrap();
/// # let n3 = net.add_node(n3).unwrap();
/// # net.add_edge(n1, n2);
/// # net.add_edge(n1, n3);
/// # net.add_edge(n2, n3);
/// #
/// # match &mut net.get_node_mut(n1) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-3.0, 2.0, 1.0],
/// # [1.5, -2.0, 0.5],
/// # [0.4, 0.6, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n2) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.5],
/// # [3.0, -4.0, 1.0],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 4.0],
/// # [1.5, -2.0, 0.5],
/// # [3.0, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.0, 0.1, 0.9],
/// # [2.0, -2.5, 0.5],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n3) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.3, 0.2],
/// # [0.5, -4.0, 2.5, 1.0],
/// # [2.5, 0.5, -4.0, 1.0],
/// # [0.7, 0.2, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 3.0, 1.0],
/// # [1.5, -3.0, 0.5, 1.0],
/// # [2.0, 1.3, -5.0, 1.7],
/// # [2.5, 0.5, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.3, 0.1, 0.9],
/// # [1.4, -4.0, 0.5, 2.1],
/// # [1.0, 1.5, -3.0, 0.5],
/// # [0.4, 0.3, 0.1, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.7, 0.3],
/// # [1.3, -5.9, 2.7, 1.9],
/// # [2.0, 1.5, -4.0, 0.5],
/// # [0.2, 0.7, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 1.0, 2.0, 3.0],
/// # [0.5, -3.0, 1.0, 1.5],
/// # [1.4, 2.1, -4.3, 0.8],
/// # [0.5, 1.0, 2.5, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.9, 0.3, 0.1],
/// # [0.1, -1.3, 0.2, 1.0],
/// # [0.5, 1.0, -3.0, 1.5],
/// # [0.1, 0.4, 0.3, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.6, 0.4],
/// # [2.6, -7.1, 1.4, 3.1],
/// # [5.0, 1.0, -8.0, 2.0],
/// # [1.4, 0.4, 0.2, -2.0]
/// # ],
/// # [
/// # [-3.0, 1.0, 1.5, 0.5],
/// # [3.0, -6.0, 1.0, 2.0],
/// # [0.3, 0.5, -1.9, 1.1],
/// # [5.0, 1.0, 2.0, -8.0]
/// # ],
/// # [
/// # [-2.6, 0.6, 0.2, 1.8],
/// # [2.0, -6.0, 3.0, 1.0],
/// # [0.1, 0.5, -1.3, 0.7],
/// # [0.8, 0.6, 0.2, -1.6]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # // Generate the trajectory
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
///
/// // Initialize the hypothesis tests to pass to the CTPC with their
/// // respective significance level `alpha`
/// let f = F::new(1e-6);
/// let chi_sq = ChiSquare::new(1e-4);
/// // Use the bayesian approach to learn the parameters
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
///
/// //Initialize CTPC
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
///
/// // Learn the structure of the network from the generated trajectory
/// let net = ctpc.fit_transform(net, &data);
/// #
/// # // Compare the generated network with the original one
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
/// ```
pub struct CTPC<P: ParameterLearning> {
parameter_learning: P,
Ftest: F,
Chi2test: ChiSquare,
}
impl<P: ParameterLearning> CTPC<P> {
pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> {
CTPC {
parameter_learning,
Ftest,
Chi2test,
}
}
}
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where
T: process::NetworkProcess,
{
//Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
panic!("Dataset and Network must have the same number of variables.")
}
//Make the network mutable.
let mut net = net;
net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| {
let mut cache = Cache::new(&self.parameter_learning);
let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices()
.into_iter()
.filter(|x| x != &child_node)
.collect();
let mut separation_set_size = 0;
while separation_set_size < candidate_parent_set.len() {
let mut candidate_parent_set_TMP = candidate_parent_set.clone();
for parent_node in candidate_parent_set.iter() {
for separation_set in candidate_parent_set
.iter()
.filter(|x| x != &parent_node)
.map(|x| *x)
.combinations(separation_set_size)
{
let separation_set = separation_set.into_iter().collect();
if self.Ftest.call(
&net,
child_node,
*parent_node,
&separation_set,
dataset,
&mut cache,
) && self.Chi2test.call(
&net,
child_node,
*parent_node,
&separation_set,
dataset,
&mut cache,
) {
candidate_parent_set_TMP.remove(parent_node);
break;
}
}
}
candidate_parent_set = candidate_parent_set_TMP;
separation_set_size += 1;
}
(child_node, candidate_parent_set)
}));
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
}
net
}
}

@ -0,0 +1,261 @@
//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc...
use std::collections::BTreeSet;
use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use crate::params::*;
use crate::structure_learning::constraint_based_algorithm::Cache;
use crate::{parameter_learning, process, tools::Dataset};
pub trait HypothesisTest {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,
P: parameter_learning::ParameterLearning;
}
/// Does the chi-squared test (χ2 test).
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct ChiSquare {
alpha: f64,
}
/// Does the F-test.
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct F {
alpha: f64,
}
impl F {
pub fn new(alpha: f64) -> F {
F { alpha }
}
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices(
&self,
i: usize,
M1: &Array3<usize>,
cim_1: &Array3<f64>,
j: usize,
M2: &Array3<usize>,
cim_2: &Array3<f64>,
) -> bool {
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
let cim_1 = cim_1.index_axis(Axis(0), i);
let cim_2 = cim_2.index_axis(Axis(0), j);
let r1 = M1.sum_axis(Axis(1));
let r2 = M2.sum_axis(Axis(1));
let q1 = cim_1.diag();
let q2 = cim_2.diag();
for idx in 0..r1.shape()[0] {
let s = q2[idx] / q1[idx];
let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap();
let s = F.cdf(s);
let lim_sx = self.alpha / 2.0;
let lim_dx = 1.0 - (self.alpha / 2.0);
if s < lim_sx || s > lim_dx {
return false;
}
}
true
}
}
impl HypothesisTest for F {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,
P: parameter_learning::ParameterLearning,
{
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node,
};
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
Params::DiscreteStatesContinousTime(node) => node,
};
let partial_cardinality_product: usize = extended_separation_set
.iter()
.take_while(|x| **x != parent_node)
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ (idx_M_big
/ (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product;
if !self.compare_matrices(
idx_M_small,
P_small.get_transitions().as_ref().unwrap(),
P_small.get_cim().as_ref().unwrap(),
idx_M_big,
P_big.get_transitions().as_ref().unwrap(),
P_big.get_cim().as_ref().unwrap(),
) {
return false;
}
}
return true;
}
}
impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha }
}
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices(
&self,
i: usize,
M1: &Array3<usize>,
j: usize,
M2: &Array3<usize>,
) -> bool {
// Bregoli, A., Scutari, M. and Stella, F., 2021.
// A constraint-based algorithm for the structural learning of
// continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
let K = K.mapv(f64::sqrt);
// Reshape to column vector.
let K = {
let n = K.len();
K.into_shape((n, 1)).unwrap()
};
let L = 1.0 / &K;
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
ret
}
}
impl HypothesisTest for ChiSquare {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,
P: parameter_learning::ParameterLearning,
{
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node,
};
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
Params::DiscreteStatesContinousTime(node) => node,
};
let partial_cardinality_product: usize = extended_separation_set
.iter()
.take_while(|x| **x != parent_node)
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ (idx_M_big
/ (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product;
if !self.compare_matrices(
idx_M_small,
P_small.get_transitions().as_ref().unwrap(),
idx_M_big,
P_big.get_transitions().as_ref().unwrap(),
) {
return false;
}
}
return true;
}
}

@ -0,0 +1,93 @@
//! Module containing score based algorithms like Hill Climbing and Tabu Search.
use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools::Dataset};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
pub struct HillClimbing<S: ScoreFunction> {
score_function: S,
max_parent_set: Option<usize>,
}
impl<S: ScoreFunction> HillClimbing<S> {
pub fn new(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> {
HillClimbing {
score_function,
max_parent_set,
}
}
}
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where
T: process::NetworkProcess,
{
//Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
panic!("Dataset and Network must have the same number of variables.")
}
//Make the network mutable.
let mut net = net;
//Check if the max_parent_set constraint is present.
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
//Reset the adj matrix
net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
//Iterate over each node to learn their parent set.
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| {
//Initialize an empty parent set.
let mut parent_set: BTreeSet<usize> = BTreeSet::new();
//Compute the score for the empty parent set
let mut current_score = self.score_function.call(&net, node, &parent_set, dataset);
//Set the old score to -\infty.
let mut old_score = f64::NEG_INFINITY;
//Iterate until convergence
while current_score > old_score {
//Save the current_score.
old_score = current_score;
//Iterate over each node.
for parent in net.get_node_indices() {
//Continue if the parent and the node are the same.
if parent == node {
continue;
}
//Try to remove parent from the parent_set.
let is_removed = parent_set.remove(&parent);
//If parent was not in the parent_set add it.
if !is_removed && parent_set.len() < max_parent_set {
parent_set.insert(parent);
}
//Compute the score with the modified parent_set.
let tmp_score = self.score_function.call(&net, node, &parent_set, dataset);
//If tmp_score is worst than current_score revert the change to the parent set
if tmp_score < current_score {
if is_removed {
parent_set.insert(parent);
} else {
parent_set.remove(&parent);
}
}
//Otherwise save the computed score as current_score
else {
current_score = tmp_score;
}
}
}
(node, parent_set)
}));
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
}
return net;
}
}

@ -0,0 +1,146 @@
//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
use std::collections::BTreeSet;
use ndarray::prelude::*;
use statrs::function::gamma;
use crate::{parameter_learning, params, process, tools};
pub trait ScoreFunction: Sync {
fn call<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> f64
where
T: process::NetworkProcess;
}
pub struct LogLikelihood {
alpha: usize,
tau: f64,
}
impl LogLikelihood {
pub fn new(alpha: usize, tau: f64) -> LogLikelihood {
//Tau must be >=0.0
if tau < 0.0 {
panic!("tau must be >=0.0");
}
LogLikelihood { alpha, tau }
}
fn compute_score<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> (f64, Array3<usize>)
where
T: process::NetworkProcess,
{
//Identify the type of node used
match &net.get_node(node) {
params::Params::DiscreteStatesContinousTime(_params) => {
//Compute the sufficient statistics M (number of transistions) and T (residence
//time)
let (M, T) =
parameter_learning::sufficient_statistics(net, dataset, node, parent_set);
//Scale alpha accordingly to the size of the parent set
let alpha = self.alpha as f64 / M.shape()[0] as f64;
//Scale tau accordingly to the size of the parent set
let tau = self.tau / M.shape()[0] as f64;
//Compute the log likelihood for q
let log_ll_q: f64 = M
.sum_axis(Axis(2))
.iter()
.zip(T.iter())
.map(|(m, t)| {
gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau)
- gamma::ln_gamma(alpha + 1.0)
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t)
})
.sum();
//Compute the log likelihood for theta
let log_ll_theta: f64 = M
.outer_iter()
.map(|x| {
x.outer_iter()
.map(|y| {
gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64)
+ y.iter()
.map(|z| {
gamma::ln_gamma(alpha + *z as f64)
- gamma::ln_gamma(alpha)
})
.sum::<f64>()
})
.sum::<f64>()
})
.sum();
(log_ll_theta + log_ll_q, M)
}
}
}
}
impl ScoreFunction for LogLikelihood {
fn call<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> f64
where
T: process::NetworkProcess,
{
self.compute_score(net, node, parent_set, dataset).0
}
}
pub struct BIC {
ll: LogLikelihood,
}
impl BIC {
pub fn new(alpha: usize, tau: f64) -> BIC {
BIC {
ll: LogLikelihood::new(alpha, tau),
}
}
}
impl ScoreFunction for BIC {
fn call<T>(
&self,
net: &T,
node: usize,
parent_set: &BTreeSet<usize>,
dataset: &tools::Dataset,
) -> f64
where
T: process::NetworkProcess,
{
//Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
//Compute the number of parameters
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
//TODO: Optimize this
//Compute the sample size
let sample_size: usize = dataset
.get_trajectories()
.iter()
.map(|x| x.get_time().len() - 1)
.sum();
//Compute BIC
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64
}
}

@ -0,0 +1,355 @@
//! Contains commonly used methods used across the crate.
use std::ops::{DivAssign, MulAssign, Range};
use ndarray::{Array, Array1, Array2, Array3, Axis};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait;
use crate::process::NetworkProcess;
use crate::sampling::{ForwardSampler, Sampler};
use crate::{params, process};
#[derive(Clone)]
pub struct Trajectory {
time: Array1<f64>,
events: Array2<usize>,
}
impl Trajectory {
pub fn new(time: Array1<f64>, events: Array2<usize>) -> Trajectory {
//Events and time are two part of the same trajectory. For this reason they must have the
//same number of sample.
if time.shape()[0] != events.shape()[0] {
panic!("time.shape[0] must be equal to events.shape[0]");
}
Trajectory { time, events }
}
pub fn get_time(&self) -> &Array1<f64> {
&self.time
}
pub fn get_events(&self) -> &Array2<usize> {
&self.events
}
}
#[derive(Clone)]
pub struct Dataset {
trajectories: Vec<Trajectory>,
}
impl Dataset {
pub fn new(trajectories: Vec<Trajectory>) -> Dataset {
//All the trajectories in the same dataset must represent the same process. For this reason
//each trajectory must represent the same number of variables.
if trajectories
.iter()
.any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1])
{
panic!("All the trajectories mus represents the same number of variables");
}
Dataset { trajectories }
}
pub fn get_trajectories(&self) -> &Vec<Trajectory> {
&self.trajectories
}
}
pub fn trajectory_generator<T: process::NetworkProcess>(
net: &T,
n_trajectories: u64,
t_end: f64,
seed: Option<u64>,
) -> Dataset {
//Tmp growing vector containing generated trajectories.
let mut trajectories: Vec<Trajectory> = Vec::new();
//Random Generator object
let mut sampler = ForwardSampler::new(net, seed, None);
//Each iteration generate one trajectory
for _ in 0..n_trajectories {
//History of all the moments in which something changed
let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform
//distribution.
let mut events: Vec<process::NetworkProcessState> = Vec::new();
//Current Time and Current State
let mut sample = sampler.next().unwrap();
//Generate new samples until ending time is reached.
while sample.t < t_end {
time.push(sample.t);
events.push(sample.state);
sample = sampler.next().unwrap();
}
let current_state = events.last().unwrap().clone();
events.push(current_state);
//Add t_end as last time.
time.push(t_end.clone());
//Add the sampled trajectory to trajectories.
trajectories.push(Trajectory::new(
Array::from_vec(time),
Array2::from_shape_vec(
(events.len(), events.last().unwrap().len()),
events
.iter()
.flatten()
.map(|x| match x {
params::StateType::Discrete(x) => x.clone(),
})
.collect(),
)
.unwrap(),
));
sampler.reset();
}
//Return a dataset object with the sampled trajectories.
Dataset::new(trajectories)
}
pub trait RandomGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> Self;
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T);
}
/// Graph Generator using an uniform distribution.
///
/// A method to generate a random graph with edges uniformly distributed.
///
/// # Arguments
///
/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`.
/// * `rng` - is the random numbers generator.
///
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// use reCTBN::tools::UniformGraphGenerator;
/// use reCTBN::tools::RandomGraphGenerator;
/// # let mut net = CtbnNetwork::new();
/// # let nodes_cardinality = 8;
/// # let domain_cardinality = 4;
/// # for node in 0..nodes_cardinality {
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # for dvalue in 0..domain_cardinality {
/// # domain.insert(dvalue.to_string());
/// # }
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new(
/// # node.to_string(),
/// # domain
/// # );
/// # //Create the node using the parameters
/// # let node = DiscreteStatesContinousTime(param);
/// # // Add the node to the network
/// # net.add_node(node).unwrap();
/// # }
///
/// // Initialize the Graph Generator using the one with an
/// // uniform distribution
/// let density = 1.0/3.0;
/// let seed = Some(7641630759785120);
/// let mut structure_generator = UniformGraphGenerator::new(
/// density,
/// seed
/// );
///
/// // Generate the graph directly on the network
/// structure_generator.generate_graph(&mut net);
/// # // Count all the edges generated in the network
/// # let mut edges = 0;
/// # for node in net.get_node_indices(){
/// # edges += net.get_children_set(node).len()
/// # }
/// # // Number of all the nodes in the network
/// # let nodes = net.get_node_indices().len() as f64;
/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
/// # // ±10% of tolerance
/// # let tolerance = ((expected_edges as f64)*0.10) as usize;
/// # // As the way `generate_graph()` is implemented we can only reasonably
/// # // expect the number of edges to be somewhere around the expected value.
/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
/// ```
pub struct UniformGraphGenerator {
density: f64,
rng: ChaCha8Rng,
}
impl RandomGraphGenerator for UniformGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> UniformGraphGenerator {
if density < 0.0 || density > 1.0 {
panic!(
"Density value must be between 1.0 and 0.0, got {}.",
density
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformGraphGenerator { density, rng }
}
/// Generate an uniformly distributed graph.
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T) {
net.initialize_adj_matrix();
let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx {
for child in 0..last_node_idx {
if parent != child {
if self.rng.gen_bool(self.density) {
net.add_edge(parent, child);
}
}
}
}
}
}
pub trait RandomParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> Self;
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T);
}
/// Parameters Generator using an uniform distribution.
///
/// A method to generate random parameters uniformly distributed.
///
/// # Arguments
///
/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`.
/// * `rng` - is the random numbers generator.
///
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::params::ParamsTrait;
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::tools::RandomGraphGenerator;
/// # use reCTBN::tools::UniformGraphGenerator;
/// use reCTBN::tools::RandomParametersGenerator;
/// use reCTBN::tools::UniformParametersGenerator;
/// # let mut net = CtbnNetwork::new();
/// # let nodes_cardinality = 8;
/// # let domain_cardinality = 4;
/// # for node in 0..nodes_cardinality {
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # for dvalue in 0..domain_cardinality {
/// # domain.insert(dvalue.to_string());
/// # }
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new(
/// # node.to_string(),
/// # domain
/// # );
/// # //Create the node using the parameters
/// # let node = DiscreteStatesContinousTime(param);
/// # // Add the node to the network
/// # net.add_node(node).unwrap();
/// # }
/// #
/// # // Initialize the Graph Generator using the one with an
/// # // uniform distribution
/// # let mut structure_generator = UniformGraphGenerator::new(
/// # 1.0/3.0,
/// # Some(7641630759785120)
/// # );
/// #
/// # // Generate the graph directly on the network
/// # structure_generator.generate_graph(&mut net);
///
/// // Initialize the parameters generator with uniform distributin
/// let mut cim_generator = UniformParametersGenerator::new(
/// 0.0..7.0,
/// Some(7641630759785120)
/// );
///
/// // Generate CIMs with uniformly distributed parameters.
/// cim_generator.generate_parameters(&mut net);
/// #
/// # for node in net.get_node_indices() {
/// # assert_eq!(
/// # Ok(()),
/// # net.get_node(node).validate_params()
/// # );
/// }
/// ```
pub struct UniformParametersGenerator {
interval: Range<f64>,
rng: ChaCha8Rng,
}
impl RandomParametersGenerator for UniformParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> UniformParametersGenerator {
if interval.start < 0.0 || interval.end < 0.0 {
panic!(
"Interval must be entirely less or equal than 0, got {}..{}.",
interval.start, interval.end
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformParametersGenerator { interval, rng }
}
/// Generate CIMs with uniformly distributed parameters.
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) {
for node in net.get_node_indices() {
let parent_set_state_space_cardinality: usize = net
.get_parent_set(node)
.iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
let node_domain_cardinality = param.get_reserved_space_as_parent();
let mut cim = Array3::<f64>::from_shape_fn(
(
parent_set_state_space_cardinality,
node_domain_cardinality,
node_domain_cardinality,
),
|_| self.rng.gen(),
);
cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0);
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone())
});
x.mul_assign(&diag.clone().insert_axis(Axis(1)));
// Recomputing the diagonal in order to reduce the issues caused by the
// loss of precision when validating the parameters.
let diag_sum = -x.sum_axis(Axis(1));
x.diag_mut().assign(&diag_sum)
});
param.set_cim_unchecked(cim);
}
}
}
}
}

@ -0,0 +1,376 @@
mod utils;
use std::collections::BTreeSet;
use approx::AbsDiffEq;
use ndarray::arr3;
use reCTBN::params::{self, ParamsTrait};
use reCTBN::process::NetworkProcess;
use reCTBN::process::{ctbn::*};
use utils::generate_discrete_time_continous_node;
#[test]
fn define_simpe_ctbn() {
let _ = CtbnNetwork::new();
assert!(true);
}
#[test]
fn add_node_to_ctbn() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
}
#[test]
fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
}
#[test]
fn children_and_parents() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
let ps = net.get_parent_set(n2);
assert_eq!(&n1, ps.iter().next().unwrap());
}
#[test]
fn compute_index_ctbn() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n3, n2);
let idx = net.get_param_index_network(
n2,
&vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(1),
],
);
assert_eq!(3, idx);
let idx = net.get_param_index_network(
n2,
&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(1),
params::StateType::Discrete(1),
],
);
assert_eq!(2, idx);
let idx = net.get_param_index_network(
n2,
&vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(0),
],
);
assert_eq!(1, idx);
}
#[test]
fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let _n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let _n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let idx = net.get_param_index_from_custom_parent_set(
&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1),
],
&BTreeSet::from([1]),
);
assert_eq!(0, idx);
let idx = net.get_param_index_from_custom_parent_set(
&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1),
],
&BTreeSet::from([1, 2]),
);
assert_eq!(2, idx);
}
#[test]
fn simple_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.initialize_adj_matrix();
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0);
let p_ctbn = p_ctbn.get_cim().as_ref().unwrap();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
}
#[test]
fn chain_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}
#[test]
fn chainfork_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let n4 = net
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2))
.unwrap();
net.add_edge(n1, n3);
net.add_edge(n2, n3);
net.add_edge(n3, n4);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n4) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}

@ -0,0 +1,127 @@
mod utils;
use std::collections::BTreeSet;
use reCTBN::{
params,
params::ParamsTrait,
process::{ctmp::*, NetworkProcess},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn define_simple_ctmp() {
let _ = CtmpProcess::new();
assert!(true);
}
#[test]
fn add_node_to_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
}
#[test]
fn add_two_nodes_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
match n2 {
Ok(_) => assert!(false),
Err(_) => assert!(true),
};
}
#[test]
#[should_panic]
fn add_edge_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
net.add_edge(0, 1)
}
#[test]
fn childen_and_parents() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(0, net.get_parent_set(0).len());
assert_eq!(0, net.get_children_set(0).len());
}
#[test]
#[should_panic]
fn get_childen_panic() {
let net = CtmpProcess::new();
net.get_children_set(0);
}
#[test]
#[should_panic]
fn get_childen_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_children_set(1);
}
#[test]
#[should_panic]
fn get_parent_panic() {
let net = CtmpProcess::new();
net.get_parent_set(0);
}
#[test]
#[should_panic]
fn get_parent_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_parent_set(1);
}
#[test]
fn compute_index_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]);
assert_eq!(6, idx);
}
#[test]
#[should_panic]
fn compute_index_from_custom_parent_set_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let _idx = net.get_param_index_from_custom_parent_set(
&vec![params::StateType::Discrete(6)],
&BTreeSet::from([0])
);
}

@ -0,0 +1,648 @@
#![allow(non_snake_case)]
mod utils;
use ndarray::arr3;
use reCTBN::process::ctbn::*;
use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::*;
use reCTBN::params;
use reCTBN::params::Params::DiscreteStatesContinousTime;
use reCTBN::tools::*;
use utils::*;
extern crate approx;
use crate::approx::AbsDiffEq;
fn learn_binary_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]))
);
}
}
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]),
0.1
));
}
fn generate_nodes(
net: &mut CtbnNetwork,
nodes_cardinality: usize,
nodes_domain_cardinality: usize
) {
for node_label in 0..nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
}
fn learn_binary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 2);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test]
fn learn_binary_cim_MLE() {
let mle = MLE {};
learn_binary_cim(mle);
}
#[test]
fn learn_binary_cim_MLE_gen() {
let mle = MLE {};
learn_binary_cim_gen(mle);
}
#[test]
fn learn_binary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim(ba);
}
#[test]
fn learn_binary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim_gen(ba);
}
fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
}
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 1, None) {
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]),
0.1
));
}
fn learn_ternary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
4.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test]
fn learn_ternary_cim_MLE() {
let mle = MLE {};
learn_ternary_cim(mle);
}
#[test]
fn learn_ternary_cim_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_gen(mle);
}
#[test]
fn learn_ternary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim(ba);
}
#[test]
fn learn_ternary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_gen(ba);
}
fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
]
]))
);
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
}
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 0, None) {
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]),
0.1
));
}
fn learn_ternary_cim_no_parents_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(0) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 0, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test]
fn learn_ternary_cim_no_parents_MLE() {
let mle = MLE {};
learn_ternary_cim_no_parents(mle);
}
#[test]
fn learn_ternary_cim_no_parents_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_no_parents_gen(mle);
}
#[test]
fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents(ba);
}
#[test]
fn learn_ternary_cim_no_parents_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents_gen(ba);
}
fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 4))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n1, n3);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.3, 0.2],
[0.5, -4.0, 2.5, 1.0],
[2.5, 0.5, -4.0, 1.0],
[0.7, 0.2, 0.1, -1.0]
],
[
[-6.0, 2.0, 3.0, 1.0],
[1.5, -3.0, 0.5, 1.0],
[2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]))
);
}
}
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p = match pl.fit(&net, &data, 2, None) {
params::Params::DiscreteStatesContinousTime(p) => p,
};
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]);
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq(
&arr3(&[
[
[-1.0, 0.5, 0.3, 0.2],
[0.5, -4.0, 2.5, 1.0],
[2.5, 0.5, -4.0, 1.0],
[0.7, 0.2, 0.1, -1.0]
],
[
[-6.0, 2.0, 3.0, 1.0],
[1.5, -3.0, 0.5, 1.0],
[2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]),
0.2
));
}
fn learn_mixed_discrete_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
net.add_edge(0, 2);
net.add_edge(1, 2);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..8.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(2) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 2, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.2
)
);
}
#[test]
fn learn_mixed_discrete_cim_MLE() {
let mle = MLE {};
learn_mixed_discrete_cim(mle);
}
#[test]
fn learn_mixed_discrete_cim_MLE_gen() {
let mle = MLE {};
learn_mixed_discrete_cim_gen(mle);
}
#[test]
fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim(ba);
}
#[test]
fn learn_mixed_discrete_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim_gen(ba);
}

@ -0,0 +1,148 @@
use ndarray::prelude::*;
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
use reCTBN::params::{ParamsTrait, *};
mod utils;
#[macro_use]
extern crate approx;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
#![allow(unused_must_use)]
let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
params.set_cim(cim);
params
}
#[test]
fn test_get_label() {
let param = create_ternary_discrete_time_continous_param();
assert_eq!(&String::from("A"), param.get_label())
}
#[test]
fn test_uniform_generation() {
#![allow(irrefutable_let_patterns)]
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) {
val
} else {
panic!()
}
});
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_state() {
#![allow(irrefutable_let_patterns)]
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() {
val
} else {
panic!()
}
});
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01);
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000);
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap());
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
}
#[test]
fn test_validate_params_valid_cim() {
let param = create_ternary_discrete_time_continous_param();
assert_eq!(Ok(()), param.validate_params());
}
#[test]
fn test_validate_params_valid_cim_with_huge_values() {
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[
[-2e10, 1e10, 1e10],
[1.5e10, -3e10, 1.5e10],
[1e10, 1e10, -2e10]
]];
let result = param.set_cim(cim);
assert_eq!(Ok(()), result);
}
#[test]
fn test_validate_params_cim_not_initialized() {
let param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
assert_eq!(
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
param.validate_params()
);
}
#[test]
fn test_validate_params_wrong_shape() {
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"Incompatible shape [1, 3, 3] with domain 4"
))),
result
);
}
#[test]
fn test_validate_params_positive_diag() {
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
))),
result
);
}
#[test]
fn test_validate_params_row_not_sum_to_zero() {
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0"
))),
result
);
}

@ -0,0 +1,122 @@
mod utils;
use approx::assert_abs_diff_eq;
use ndarray::*;
use reCTBN::{
params,
process::{ctbn::*, NetworkProcess, NetworkProcessState},
reward::{reward_evaluation::*, reward_function::*, *},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn simple_factored_reward_function_binary_node_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1)
.assign(&arr1(&[3.0, 3.0]));
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap();
}
}
net.initialize_adj_matrix();
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2);
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2);
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215));
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
}
#[test]
fn simple_factored_reward_function_chain_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap();
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n2)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n3)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
let s000: NetworkProcessState = vec![
params::StateType::Discrete(1),
params::StateType::Discrete(0),
params::StateType::Discrete(0),
];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1);
}

@ -0,0 +1,117 @@
mod utils;
use ndarray::*;
use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params};
#[test]
fn simple_factored_reward_function_binary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
#[test]
fn simple_factored_reward_function_ternary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
}
#[test]
fn factored_reward_function_two_nodes() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
}

@ -0,0 +1,692 @@
#![allow(non_snake_case)]
mod utils;
use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3};
use reCTBN::process::ctbn::*;
use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::params;
use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::constraint_based_algorithm::*;
use reCTBN::structure_learning::score_based_algorithm::*;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::StructureLearningAlgorithm;
use reCTBN::tools::*;
use utils::*;
#[macro_use]
extern crate approx;
#[test]
fn simple_score_test() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
let dataset = Dataset::new(vec![trj]);
let ll = LogLikelihood::new(1, 1.0);
assert_abs_diff_eq!(
0.04257,
ll.call(&net, n1, &BTreeSet::new(), &dataset),
epsilon = 1e-3
);
}
#[test]
fn simple_bic() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
let dataset = Dataset::new(vec![trj]);
let bic = BIC::new(1, 1.0);
assert_abs_diff_eq!(
-0.65058,
bic.call(&net, n1, &BTreeSet::new(), &dataset),
epsilon = 1e-3
);
}
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
}
let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
let mut net = CtbnNetwork::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let _net = sl.fit_transform(net, &data);
}
fn generate_nodes(
net: &mut CtbnNetwork,
nodes_cardinality: usize,
nodes_domain_cardinality: usize
) {
for node_label in 0..nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
}
fn check_compatibility_between_dataset_and_network_gen<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
let mut net = CtbnNetwork::new();
let _n1 = net
.add_node(
generate_discrete_time_continous_node(String::from("0"),
3)
).unwrap();
let _net = sl.fit_transform(net, &data);
}
#[test]
#[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
check_compatibility_between_dataset_and_network(hl);
}
#[test]
#[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
check_compatibility_between_dataset_and_network_gen(hl);
}
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
}
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
}
fn learn_ternary_net_2_nodes_gen<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_ternary_net_2_nodes(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_ternary_net_2_nodes_gen(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_ternary_net_2_nodes(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_ternary_net_2_nodes_gen(hl);
}
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 4))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n1, n3);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.5],
[3.0, -4.0, 1.0],
[0.9, 0.1, -1.0]
],
[
[-6.0, 2.0, 4.0],
[1.5, -2.0, 0.5],
[3.0, 1.0, -4.0]
],
[
[-1.0, 0.1, 0.9],
[2.0, -2.5, 0.5],
[0.9, 0.1, -1.0]
],
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.3, 0.2],
[0.5, -4.0, 2.5, 1.0],
[2.5, 0.5, -4.0, 1.0],
[0.7, 0.2, 0.1, -1.0]
],
[
[-6.0, 2.0, 3.0, 1.0],
[1.5, -3.0, 0.5, 1.0],
[2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]))
);
}
}
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
return (net, data);
}
fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
net.add_edge(0, 2);
net.add_edge(1, 2);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
return (net, data);
}
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
}
fn learn_mixed_discrete_net_3_nodes_gen<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_mixed_discrete_net_3_nodes(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_mixed_discrete_net_3_nodes_gen(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_mixed_discrete_net_3_nodes(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_mixed_discrete_net_3_nodes_gen(hl);
}
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
}
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
}
#[test]
pub fn chi_square_compare_matrices() {
let i: usize = 1;
let M1 = arr3(&[
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]
],
[
[0, 12, 90],
[ 3, 0, 40],
[ 6, 40, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]);
let j: usize = 0;
let M2 = arr3(&[
[
[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(1e-4);
assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_2() {
let i: usize = 1;
let M1 = arr3(&[
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]
],
[
[0, 20, 30],
[ 40, 0, 60],
[ 70, 80, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]);
let j: usize = 0;
let M2 = arr3(&[
[[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
}
#[test]
pub fn chi_square_compare_matrices_3() {
let i: usize = 1;
let M1 = arr3(&[
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 7, 8, 0]
],
[
[0, 21, 31],
[ 41, 0, 59],
[ 71, 79, 0]
],
[
[ 0, 2, 3],
[ 4, 0, 6],
[ 44, 66, 0]
],
]);
let j: usize = 0;
let M2 = arr3(&[
[
[ 0, 200, 300],
[ 400, 0, 600],
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
}
#[test]
pub fn chi_square_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(&parameter_learning);
let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache));
}
#[test]
pub fn f_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(&parameter_learning);
let f = F::new(1e-6);
assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache));
}
#[test]
pub fn learn_ternary_net_2_nodes_ctpc() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_ternary_net_2_nodes(ctpc);
}
#[test]
pub fn learn_ternary_net_2_nodes_ctpc_gen() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_ternary_net_2_nodes_gen(ctpc);
}
#[test]
fn learn_mixed_discrete_net_3_nodes_ctpc() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_mixed_discrete_net_3_nodes(ctpc);
}
#[test]
fn learn_mixed_discrete_net_3_nodes_ctpc_gen() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_mixed_discrete_net_3_nodes_gen(ctpc);
}

@ -0,0 +1,251 @@
use std::ops::Range;
use ndarray::{arr1, arr2, arr3};
use reCTBN::params::ParamsTrait;
use reCTBN::process::ctbn::*;
use reCTBN::process::ctmp::*;
use reCTBN::process::NetworkProcess;
use reCTBN::params;
use reCTBN::tools::*;
use utils::*;
#[macro_use]
extern crate approx;
mod utils;
#[test]
fn run_sampling() {
#![allow(unused_must_use)]
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(utils::generate_discrete_time_continous_node(
String::from("n1"),
2,
))
.unwrap();
let n2 = net
.add_node(utils::generate_discrete_time_continous_node(
String::from("n2"),
2,
))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[
[
[-3.0, 3.0],
[2.0, -2.0]
],
]));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[
[
[-1.0, 1.0],
[4.0, -4.0]
],
[
[-6.0, 6.0],
[2.0, -2.0]
],
]));
}
}
let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259));
assert_eq!(4, data.get_trajectories().len());
assert_relative_eq!(
1.0,
data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1]
);
}
#[test]
#[should_panic]
fn trajectory_wrong_shape() {
let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0, 3]]);
Trajectory::new(time, events);
}
#[test]
#[should_panic]
fn dataset_wrong_shape() {
let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0, 3], [1, 2]]);
let t1 = Trajectory::new(time, events);
let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0, 3, 3], [1, 2, 3]]);
let t2 = Trajectory::new(time, events);
Dataset::new(vec![t1, t2]);
}
#[test]
#[should_panic]
fn uniform_graph_generator_wrong_density_1() {
let density = 2.1;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
#[test]
#[should_panic]
fn uniform_graph_generator_wrong_density_2() {
let density = -0.5;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
#[test]
fn uniform_graph_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
}
#[test]
fn uniform_graph_generator_generate_graph_ctbn() {
let mut net = CtbnNetwork::new();
let nodes_cardinality = 0..=100;
let nodes_domain_cardinality = 2;
for node_label in nodes_cardinality {
net.add_node(
utils::generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
Some(7641630759785120)
);
structure_generator.generate_graph(&mut net);
let mut edges = 0;
for node in net.get_node_indices(){
edges += net.get_children_set(node).len()
}
let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance
// As the way `generate_graph()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
}
#[test]
#[should_panic]
fn uniform_graph_generator_generate_graph_ctmp() {
let mut net = CtmpProcess::new();
let node_label = String::from("0");
let node_domain_cardinality = 4;
net.add_node(
generate_discrete_time_continous_node(
node_label,
node_domain_cardinality
)
).unwrap();
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
Some(7641630759785120)
);
structure_generator.generate_graph(&mut net);
}
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_1() {
let interval: Range<f64> = -2.0..-5.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
None
);
}
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_2() {
let interval: Range<f64> = -1.0..0.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
None
);
}
#[test]
fn uniform_parameters_generator_right_densities_ctbn() {
let mut net = CtbnNetwork::new();
let nodes_cardinality = 0..=3;
let nodes_domain_cardinality = 9;
for node_label in nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
let density = 1.0/3.0;
let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
seed
);
structure_generator.generate_graph(&mut net);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
seed
);
cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
assert_eq!(
Ok(()),
net.get_node(node).validate_params()
);
}
}
#[test]
fn uniform_parameters_generator_right_densities_ctmp() {
let mut net = CtmpProcess::new();
let node_label = String::from("0");
let node_domain_cardinality = 4;
net.add_node(
generate_discrete_time_continous_node(
node_label,
node_domain_cardinality
)
).unwrap();
let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
seed
);
cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
assert_eq!(
Ok(()),
net.get_node(node).validate_params()
);
}
}

@ -0,0 +1,19 @@
use std::collections::BTreeSet;
use reCTBN::params;
#[allow(dead_code)]
pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params {
params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(
label,
cardinality,
))
}
pub fn generate_discrete_time_continous_params(
label: String,
cardinality: usize,
) -> params::DiscreteStatesContinousTimeParams {
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::new(label, domain)
}

@ -0,0 +1,7 @@
# This file defines the Rust toolchain to use when a command is executed.
# See also https://rust-lang.github.io/rustup/overrides.html
[toolchain]
channel = "stable"
components = [ "clippy", "rustfmt" ]
profile = "minimal"

@ -0,0 +1,39 @@
# This file defines the Rust style for automatic reformatting.
# See also https://rust-lang.github.io/rustfmt
# NOTE: the unstable options will be uncommented when stabilized.
# Version of the formatting rules to use.
#version = "One"
# Number of spaces per tab.
tab_spaces = 4
max_width = 100
#comment_width = 80
# Prevent carriage returns, admitted only \n.
newline_style = "Unix"
# The "Default" setting has a heuristic which can split lines too aggresively.
#use_small_heuristics = "Max"
# How imports should be grouped into `use` statements.
#imports_granularity = "Module"
# How consecutive imports are grouped together.
#group_imports = "StdExternalCrate"
# Error if unable to get all lines within max_width, except for comments and
# string literals.
#error_on_line_overflow = true
# Error if unable to get comments or string literals within max_width, or they
# are left with trailing whitespaces.
#error_on_unformatted = true
# Files to ignore like third party code which is formatted upstream.
# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors
ignore = [
"tests/"
]

@ -1,164 +0,0 @@
use ndarray::prelude::*;
use crate::node;
use crate::params::{StateType, ParamsTrait};
use crate::network;
use std::collections::BTreeSet;
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
///composed by the following elements:
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
///- **nodes**: a vector containing all the nodes and their parameters.
///The index of a node inside the vector is also used as index for the adj_matrix.
///
///# Examples
///
///```
///
/// use std::collections::BTreeSet;
/// use rustyCTBN::network::Network;
/// use rustyCTBN::node;
/// use rustyCTBN::params;
/// use rustyCTBN::ctbn::*;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
///
/// //Create the node using the parameters
/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::init();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Get all the children of node X1
/// let cs = net.get_children_set(X1);
/// assert_eq!(&X2, cs.iter().next().unwrap());
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<node::Node>
}
impl CtbnNetwork {
pub fn init() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new()
}
}
}
impl network::Network for CtbnNetwork {
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f()));
}
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> {
n.params.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() -1)
}
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].params.reset_params();
}
}
fn get_node_indices(&self) -> std::ops::Range<usize>{
0..self.nodes.len()
}
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &node::Node{
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].params.state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent();
}
acc
}).0
}
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize {
parent_set.iter().fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].params.state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent();
acc
}).0
}
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| {
if x > &0 {
Some(idx)
} else {
None
}
}).collect()
}
fn get_children_set(&self, node: usize) -> BTreeSet<usize>{
self.adj_matrix.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| {
if x > &0 {
Some(idx)
} else {
None
}
}).collect()
}
}

@ -1,11 +0,0 @@
#[cfg(test)]
#[macro_use]
extern crate approx;
pub mod node;
pub mod params;
pub mod network;
pub mod ctbn;
pub mod tools;
pub mod parameter_learning;

@ -1,39 +0,0 @@
use thiserror::Error;
use crate::params;
use crate::node;
use std::collections::BTreeSet;
/// Error types for trait Network
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Error during node insertion")]
NodeInsertionError(String)
}
///Network
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: node::Node) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network
fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_number_of_nodes(&self) -> usize;
fn get_node(&self, node_idx: usize) -> &node::Node;
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*.
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<params::StateType>, parent_set: &BTreeSet<usize>) -> usize;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -1,25 +0,0 @@
use crate::params::*;
pub struct Node {
pub params: Params,
pub label: String
}
impl Node {
pub fn init(params: Params, label: String) -> Node {
Node{
params: params,
label:label
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Node) -> bool{
self.label == other.label
}
}

@ -1,161 +0,0 @@
use ndarray::prelude::*;
use rand::Rng;
use std::collections::{BTreeSet, HashMap};
use thiserror::Error;
use enum_dispatch::enum_dispatch;
/// Error types for trait Params
#[derive(Error, Debug)]
pub enum ParamsError {
#[error("Unsupported method")]
UnsupportedMethod(String),
#[error("Paramiters not initialized")]
ParametersNotInitialized(String),
}
/// Allowed type of states
#[derive(Clone)]
pub enum StateType {
Discrete(usize),
}
/// Parameters
/// The Params trait is the core element for building different types of nodes. The goal is to
/// define the set of method required to describes a generic node.
#[enum_dispatch(Params)]
pub trait ParamsTrait {
fn reset_params(&mut self);
/// Randomly generate a possible state of the node disregarding the state of the node and it's
/// parents.
fn get_random_state_uniform(&self) -> StateType;
/// Randomly generate a residence time for the given node taking into account the node state
/// and its parent set.
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>;
/// Randomly generate a possible state for the given node taking into account the node state
/// and its parent set.
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError>;
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
fn get_reserved_space_as_parent(&self) -> usize;
/// Index used by discrete node to represents their states as usize.
fn state_to_index(&self, state: &StateType) -> usize;
}
/// The Params enum is the core element for building different types of nodes. The goal is to
/// define all the supported type of parameters.
#[enum_dispatch]
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements:
/// - **domain**: an ordered and exhaustive set of possible states
/// - **cim**: Conditional Intensity Matrix
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
/// learning task and are composed by:
/// - **transitions**: number of transitions from one state to another given a specific
/// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
pub struct DiscreteStatesContinousTimeParams {
pub domain: BTreeSet<String>,
pub cim: Option<Array3<f64>>,
pub transitions: Option<Array3<u64>>,
pub residence_time: Option<Array2<f64>>,
}
impl DiscreteStatesContinousTimeParams {
pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams {
domain: domain,
cim: Option::None,
transitions: Option::None,
residence_time: Option::None,
}
}
}
impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn reset_params(&mut self) {
self.cim = Option::None;
self.transitions = Option::None;
self.residence_time = Option::None;
}
fn get_random_state_uniform(&self) -> StateType {
let mut rng = rand::thread_rng();
StateType::Discrete(rng.gen_range(0..(self.domain.len())))
}
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> {
// Generate a random residence time given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
match &self.cim {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..=1.0);
Ok(-x.ln() / lambda)
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
}
}
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> {
// Generate a random transition given the current state of the node and its parent set.
// The method used is described in:
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
match &self.cim {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let urand: f64 = rng.gen_range(0.0..=1.0);
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
(0, 0.0),
|mut acc, ele| {
if &acc.1 + ele < urand && ele > &0.0 {
acc.0 += 1;
}
if ele > &0.0 {
acc.1 += ele;
}
acc
},
);
let next_state = if next_state.0 < state {
next_state.0
} else {
next_state.0 + 1
};
Ok(StateType::Discrete(next_state))
}
Option::None => Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
))),
}
}
fn get_reserved_space_as_parent(&self) -> usize {
self.domain.len()
}
fn state_to_index(&self, state: &StateType) -> usize {
match state {
StateType::Discrete(val) => val.clone() as usize,
}
}
}

@ -1,119 +0,0 @@
use crate::network;
use crate::node;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*;
pub struct Trajectory {
pub time: Array1<f64>,
pub events: Array2<usize>,
}
pub struct Dataset {
pub trajectories: Vec<Trajectory>,
}
pub fn trajectory_generator<T: network::Network>(
net: &T,
n_trajectories: u64,
t_end: f64,
) -> Dataset {
let mut dataset = Dataset {
trajectories: Vec::new(),
};
let node_idx: Vec<_> = net.get_node_indices().collect();
for _ in 0..n_trajectories {
let mut t = 0.0;
let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Array1<usize>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx
.iter()
.map(|x| net.get_node(*x).params.get_random_state_uniform())
.collect();
let mut next_transitions: Vec<Option<f64>> =
(0..node_idx.len()).map(|_| Option::None).collect();
events.push(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
time.push(t.clone());
while t < t_end {
for (idx, val) in next_transitions.iter_mut().enumerate() {
if let None = val {
*val = Some(
net.get_node(idx)
.params
.get_random_residence_time(
net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state),
)
.unwrap()
+ t,
);
}
}
let next_node_transition = next_transitions
.iter()
.enumerate()
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap()
.0;
if next_transitions[next_node_transition].unwrap() > t_end {
break;
}
t = next_transitions[next_node_transition].unwrap().clone();
time.push(t.clone());
current_state[next_node_transition] = net
.get_node(next_node_transition)
.params
.get_random_state(
net.get_node(next_node_transition)
.params
.state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state),
)
.unwrap();
events.push(Array::from_vec(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
));
next_transitions[next_node_transition] = None;
for child in net.get_children_set(next_node_transition) {
next_transitions[child] = None
}
}
events.push(
current_state
.iter()
.map(|x| match x {
params::StateType::Discrete(state) => state.clone(),
})
.collect(),
);
time.push(t_end.clone());
dataset.trajectories.push(Trajectory {
time: Array::from_vec(time),
events: Array2::from_shape_vec(
(events.len(), current_state.len()),
events.iter().flatten().cloned().collect(),
)
.unwrap(),
});
}
dataset
}

@ -1,99 +0,0 @@
mod utils;
use utils::generate_discrete_time_continous_node;
use rustyCTBN::network::Network;
use rustyCTBN::node;
use rustyCTBN::params;
use std::collections::BTreeSet;
use rustyCTBN::ctbn::*;
#[test]
fn define_simpe_ctbn() {
let _ = CtbnNetwork::init();
assert!(true);
}
#[test]
fn add_node_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
assert_eq!(String::from("n1"), net.get_node(n1).label);
}
#[test]
fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
}
#[test]
fn children_and_parents() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
let ps = net.get_parent_set(n2);
assert_eq!(&n1, ps.iter().next().unwrap());
}
#[test]
fn compute_index_ctbn() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();
net.add_edge(n1, n2);
net.add_edge(n3, n2);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(3, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(0),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
assert_eq!(2, idx);
let idx = net.get_param_index_network(n2, &vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(0)]);
assert_eq!(1, idx);
}
#[test]
fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::init();
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1]));
assert_eq!(0, idx);
let idx = net.get_param_index_from_custom_parent_set(&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1,2]));
assert_eq!(2, idx);
}

@ -1,263 +0,0 @@
mod utils;
use utils::*;
use rustyCTBN::parameter_learning::*;
use rustyCTBN::ctbn::*;
use rustyCTBN::network::Network;
use rustyCTBN::node;
use rustyCTBN::params;
use rustyCTBN::tools::*;
use ndarray::arr3;
use std::collections::BTreeSet;
#[macro_use]
extern crate approx;
fn learn_binary_cim<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),2))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]));
}
}
let data = trajectory_generator(&net, 100, 100.0);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]), 0.2));
}
#[test]
fn learn_binary_cim_MLE() {
let mle = MLE{};
learn_binary_cim(mle);
}
#[test]
fn learn_binary_cim_BA() {
let ba = BayesianApproach{
default_alpha: 1,
default_tau: 1.0};
learn_binary_cim(ba);
}
fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
let data = trajectory_generator(&net, 100, 200.0);
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.2));
}
#[test]
fn learn_ternary_cim_MLE() {
let mle = MLE{};
learn_ternary_cim(mle);
}
#[test]
fn learn_ternary_cim_BA() {
let ba = BayesianApproach{
default_alpha: 1,
default_tau: 1.0};
learn_ternary_cim(ba);
}
fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),3))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
let data = trajectory_generator(&net, 100, 200.0);
let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]), 0.2));
}
#[test]
fn learn_ternary_cim_no_parents_MLE() {
let mle = MLE{};
learn_ternary_cim_no_parents(mle);
}
#[test]
fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach{
default_alpha: 1,
default_tau: 1.0};
learn_ternary_cim_no_parents(ba);
}
fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"),3))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"),4))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n1, n3);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]));
}
}
match &mut net.get_node_mut(n3).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]],
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]],
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]],
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]],
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
]));
}
}
let data = trajectory_generator(&net, 300, 200.0);
let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]],
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]],
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]],
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]],
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
]), 0.2));
}
#[test]
fn learn_mixed_discrete_cim_MLE() {
let mle = MLE{};
learn_mixed_discrete_cim(mle);
}
#[test]
fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach{
default_alpha: 1,
default_tau: 1.0};
learn_mixed_discrete_cim(ba);
}

@ -1,64 +0,0 @@
use rustyCTBN::params::*;
use ndarray::prelude::*;
use std::collections::BTreeSet;
mod utils;
#[macro_use]
extern crate approx;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut params = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]];
params.cim = Some(cim);
params
}
#[test]
fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state_uniform() {
val
} else {
panic!()
}
});
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_state() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<usize>::zeros(10000);
states.mapv_inplace(|_| {
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() {
val
} else {
panic!()
}
});
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0;
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0;
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01);
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01);
}
#[test]
fn test_random_generation_residence_time() {
let param = create_ternary_discrete_time_continous_param();
let mut states = Array1::<f64>::zeros(10000);
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap());
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01);
}

@ -1,45 +0,0 @@
use rustyCTBN::tools::*;
use rustyCTBN::network::Network;
use rustyCTBN::ctbn::*;
use rustyCTBN::node;
use rustyCTBN::params;
use std::collections::BTreeSet;
use ndarray::arr3;
#[macro_use]
extern crate approx;
mod utils;
#[test]
fn run_sampling() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[
[[-1.0,1.0],[4.0,-4.0]],
[[-6.0,6.0],[2.0,-2.0]]]));
}
}
let data = trajectory_generator(&net, 4, 1.0);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);
}

@ -1,16 +0,0 @@
use rustyCTBN::params;
use rustyCTBN::node;
use std::collections::BTreeSet;
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node {
node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name)
}
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
let mut domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::init(domain)
}
Loading…
Cancel
Save