commit
57ae851470
@ -0,0 +1,26 @@ |
||||
--- |
||||
name: 📑 Meta request |
||||
about: Suggest an idea or a change for this same repository |
||||
title: '[Meta] ' |
||||
labels: 'meta' |
||||
assignees: '' |
||||
|
||||
--- |
||||
|
||||
## Description |
||||
|
||||
As a X, I want to Y, so Z. |
||||
|
||||
## Acceptance Criteria |
||||
|
||||
* Criteria 1 |
||||
* Criteria 2 |
||||
|
||||
## Checklist |
||||
|
||||
* [ ] Element 1 |
||||
* [ ] Element 2 |
||||
|
||||
## (Optional) Extra info |
||||
|
||||
None |
@ -0,0 +1,36 @@ |
||||
- name: "bug" |
||||
color: "d73a4a" |
||||
description: "Something isn't working" |
||||
- name: "enhancement" |
||||
color: "a2eeef" |
||||
description: "New feature or request" |
||||
- name: "refactor" |
||||
color: "B06E16" |
||||
description: "Change in the structure" |
||||
- name: "documentation" |
||||
color: "0075ca" |
||||
description: "Improvements or additions to documentation" |
||||
- name: "meta" |
||||
color: "1D76DB" |
||||
description: "Something related to the project itself" |
||||
|
||||
- name: "duplicate" |
||||
color: "cfd3d7" |
||||
description: "This issue or pull request already exists" |
||||
|
||||
- name: "help wanted" |
||||
color: "008672" |
||||
description: "Extra help is needed" |
||||
- name: "urgent" |
||||
color: "D93F0B" |
||||
description: "" |
||||
- name: "wontfix" |
||||
color: "ffffff" |
||||
description: "This will not be worked on" |
||||
- name: "invalid" |
||||
color: "e4e669" |
||||
description: "This doesn't seem right" |
||||
|
||||
- name: "question" |
||||
color: "d876e3" |
||||
description: "Further information is requested" |
@ -0,0 +1,53 @@ |
||||
name: build |
||||
|
||||
on: |
||||
push: |
||||
branches: [ main, dev ] |
||||
pull_request: |
||||
branches: [ dev ] |
||||
|
||||
env: |
||||
CARGO_TERM_COLOR: always |
||||
|
||||
jobs: |
||||
build: |
||||
|
||||
runs-on: ubuntu-latest |
||||
|
||||
steps: |
||||
- uses: actions/checkout@v3 |
||||
- name: Setup Rust stable (default) |
||||
uses: actions-rs/toolchain@v1 |
||||
with: |
||||
profile: minimal |
||||
toolchain: stable |
||||
default: true |
||||
components: clippy, rustfmt, rust-docs |
||||
- name: Setup Rust nightly |
||||
uses: actions-rs/toolchain@v1 |
||||
with: |
||||
profile: minimal |
||||
toolchain: nightly |
||||
default: false |
||||
components: rustfmt |
||||
- name: Docs (doc) |
||||
uses: actions-rs/cargo@v1 |
||||
with: |
||||
command: rustdoc |
||||
args: --package reCTBN -- --default-theme=ayu |
||||
- name: Linting (clippy) |
||||
uses: actions-rs/clippy-check@v1 |
||||
with: |
||||
token: ${{ secrets.GITHUB_TOKEN }} |
||||
args: --all-targets -- -D warnings -A clippy::all -W clippy::correctness |
||||
- name: Formatting (rustfmt) |
||||
uses: actions-rs/cargo@v1 |
||||
with: |
||||
toolchain: nightly |
||||
command: fmt |
||||
args: --all -- --check --verbose |
||||
- name: Tests (test) |
||||
uses: actions-rs/cargo@v1 |
||||
with: |
||||
command: test |
||||
args: --tests |
@ -0,0 +1,23 @@ |
||||
name: meta-github |
||||
|
||||
on: |
||||
push: |
||||
branches: |
||||
- dev |
||||
|
||||
jobs: |
||||
labeler: |
||||
runs-on: ubuntu-latest |
||||
steps: |
||||
- |
||||
name: Checkout |
||||
uses: actions/checkout@v2 |
||||
- |
||||
name: Run Labeler |
||||
if: success() |
||||
uses: crazy-max/ghaction-github-labeler@v3 |
||||
with: |
||||
github-token: ${{ secrets.GITHUB_TOKEN }} |
||||
yaml-file: .github/labels.yml |
||||
skip-delete: false |
||||
dry-run: false |
@ -1,2 +1,3 @@ |
||||
/target |
||||
Cargo.lock |
||||
.vscode |
||||
|
@ -1,17 +1,5 @@ |
||||
[package] |
||||
name = "rustyCTBN" |
||||
version = "0.1.0" |
||||
edition = "2021" |
||||
[workspace] |
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html |
||||
|
||||
[dependencies] |
||||
|
||||
ndarray = {version="*", features=["approx"]} |
||||
thiserror = "*" |
||||
rand = "*" |
||||
bimap = "*" |
||||
enum_dispatch = "*" |
||||
|
||||
[dev-dependencies] |
||||
approx = "*" |
||||
members = [ |
||||
"reCTBN", |
||||
] |
||||
|
@ -0,0 +1,20 @@ |
||||
[package] |
||||
name = "reCTBN" |
||||
version = "0.1.0" |
||||
edition = "2021" |
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html |
||||
|
||||
[dependencies] |
||||
ndarray = {version="~0.15", features=["approx-0_5"]} |
||||
thiserror = "1.0.37" |
||||
rand = "~0.8" |
||||
bimap = "~0.6" |
||||
enum_dispatch = "~0.3" |
||||
statrs = "~0.16" |
||||
rand_chacha = "~0.3" |
||||
itertools = "~0.10" |
||||
rayon = "~1.6" |
||||
|
||||
[dev-dependencies] |
||||
approx = { package = "approx", version = "~0.5" } |
@ -0,0 +1,12 @@ |
||||
#![doc = include_str!("../../README.md")] |
||||
#![allow(non_snake_case)] |
||||
#[cfg(test)] |
||||
extern crate approx; |
||||
|
||||
pub mod parameter_learning; |
||||
pub mod params; |
||||
pub mod process; |
||||
pub mod reward; |
||||
pub mod sampling; |
||||
pub mod structure_learning; |
||||
pub mod tools; |
@ -0,0 +1,287 @@ |
||||
//! Module containing methods to define different types of nodes.
|
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use enum_dispatch::enum_dispatch; |
||||
use ndarray::prelude::*; |
||||
use rand::Rng; |
||||
use rand_chacha::ChaCha8Rng; |
||||
use thiserror::Error; |
||||
|
||||
/// Error types for trait Params
|
||||
#[derive(Error, Debug, PartialEq)] |
||||
pub enum ParamsError { |
||||
#[error("Unsupported method")] |
||||
UnsupportedMethod(String), |
||||
#[error("Paramiters not initialized")] |
||||
ParametersNotInitialized(String), |
||||
#[error("Invalid cim for parameter")] |
||||
InvalidCIM(String), |
||||
} |
||||
|
||||
/// Allowed type of states
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug)] |
||||
pub enum StateType { |
||||
Discrete(usize), |
||||
} |
||||
|
||||
/// This is a core element for building different types of nodes; the goal is to define the set of
|
||||
/// methods required to describes a generic node.
|
||||
#[enum_dispatch(Params)] |
||||
pub trait ParamsTrait { |
||||
fn reset_params(&mut self); |
||||
|
||||
/// Randomly generate a possible state of the node disregarding the state of the node and it's
|
||||
/// parents.
|
||||
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType; |
||||
|
||||
/// Randomly generate a residence time for the given node taking into account the node state
|
||||
/// and its parent set.
|
||||
fn get_random_residence_time( |
||||
&self, |
||||
state: usize, |
||||
u: usize, |
||||
rng: &mut ChaCha8Rng, |
||||
) -> Result<f64, ParamsError>; |
||||
|
||||
/// Randomly generate a possible state for the given node taking into account the node state
|
||||
/// and its parent set.
|
||||
fn get_random_state( |
||||
&self, |
||||
state: usize, |
||||
u: usize, |
||||
rng: &mut ChaCha8Rng, |
||||
) -> Result<StateType, ParamsError>; |
||||
|
||||
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
|
||||
fn get_reserved_space_as_parent(&self) -> usize; |
||||
|
||||
/// Index used by discrete node to represents their states as usize.
|
||||
fn state_to_index(&self, state: &StateType) -> usize; |
||||
|
||||
/// Validate parameters against domain
|
||||
fn validate_params(&self) -> Result<(), ParamsError>; |
||||
|
||||
/// Return a reference to the associated label
|
||||
fn get_label(&self) -> &String; |
||||
} |
||||
|
||||
/// Is a core element for building different types of nodes; the goal is to define all the
|
||||
/// supported type of Parameters
|
||||
#[derive(Clone)] |
||||
#[enum_dispatch] |
||||
pub enum Params { |
||||
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), |
||||
} |
||||
|
||||
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
|
||||
/// following elements.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `label` - node's variable name.
|
||||
/// * `domain` - an ordered and exhaustive set of possible states.
|
||||
/// * `cim` - Conditional Intensity Matrix.
|
||||
/// * `transitions` - number of transitions from one state to another given a specific realization
|
||||
/// of the parent set; is a sufficient statistics are mainly used during the parameter learning
|
||||
/// task.
|
||||
/// * `residence_time` - residence time in each possible state, given a specific realization of the
|
||||
/// parent set; is a sufficient statistics are mainly used during the parameter learning task.
|
||||
#[derive(Clone)] |
||||
pub struct DiscreteStatesContinousTimeParams { |
||||
label: String, |
||||
domain: BTreeSet<String>, |
||||
cim: Option<Array3<f64>>, |
||||
transitions: Option<Array3<usize>>, |
||||
residence_time: Option<Array2<f64>>, |
||||
} |
||||
|
||||
impl DiscreteStatesContinousTimeParams { |
||||
pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams { |
||||
DiscreteStatesContinousTimeParams { |
||||
label, |
||||
domain, |
||||
cim: Option::None, |
||||
transitions: Option::None, |
||||
residence_time: Option::None, |
||||
} |
||||
} |
||||
|
||||
/// Getter function for CIM
|
||||
pub fn get_cim(&self) -> &Option<Array3<f64>> { |
||||
&self.cim |
||||
} |
||||
|
||||
/// Setter function for CIM.
|
||||
///
|
||||
/// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method:
|
||||
/// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`.
|
||||
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
|
||||
/// `ParamsError`.
|
||||
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> { |
||||
self.cim = Some(cim); |
||||
match self.validate_params() { |
||||
Ok(()) => Ok(()), |
||||
Err(e) => { |
||||
self.cim = None; |
||||
Err(e) |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Unchecked version of the setter function for CIM.
|
||||
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) { |
||||
self.cim = Some(cim); |
||||
} |
||||
|
||||
/// Getter function for transitions.
|
||||
pub fn get_transitions(&self) -> &Option<Array3<usize>> { |
||||
&self.transitions |
||||
} |
||||
|
||||
/// Setter function for transitions.
|
||||
pub fn set_transitions(&mut self, transitions: Array3<usize>) { |
||||
self.transitions = Some(transitions); |
||||
} |
||||
|
||||
/// Getter function for residence_time.
|
||||
pub fn get_residence_time(&self) -> &Option<Array2<f64>> { |
||||
&self.residence_time |
||||
} |
||||
|
||||
/// Setter function for residence_time.
|
||||
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) { |
||||
self.residence_time = Some(residence_time); |
||||
} |
||||
} |
||||
|
||||
impl ParamsTrait for DiscreteStatesContinousTimeParams { |
||||
fn reset_params(&mut self) { |
||||
self.cim = Option::None; |
||||
self.transitions = Option::None; |
||||
self.residence_time = Option::None; |
||||
} |
||||
|
||||
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType { |
||||
StateType::Discrete(rng.gen_range(0..(self.domain.len()))) |
||||
} |
||||
|
||||
fn get_random_residence_time( |
||||
&self, |
||||
state: usize, |
||||
u: usize, |
||||
rng: &mut ChaCha8Rng, |
||||
) -> Result<f64, ParamsError> { |
||||
// Generate a random residence time given the current state of the node and its parent set.
|
||||
// The method used is described in:
|
||||
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
|
||||
match &self.cim { |
||||
Option::Some(cim) => { |
||||
let lambda = cim[[u, state, state]] * -1.0; |
||||
let x: f64 = rng.gen_range(0.0..=1.0); |
||||
Ok(-x.ln() / lambda) |
||||
} |
||||
Option::None => Err(ParamsError::ParametersNotInitialized(String::from( |
||||
"CIM not initialized", |
||||
))), |
||||
} |
||||
} |
||||
|
||||
fn get_random_state( |
||||
&self, |
||||
state: usize, |
||||
u: usize, |
||||
rng: &mut ChaCha8Rng, |
||||
) -> Result<StateType, ParamsError> { |
||||
// Generate a random transition given the current state of the node and its parent set.
|
||||
// The method used is described in:
|
||||
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
|
||||
match &self.cim { |
||||
Option::Some(cim) => { |
||||
let lambda = cim[[u, state, state]] * -1.0; |
||||
let urand: f64 = rng.gen_range(0.0..=1.0); |
||||
|
||||
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold( |
||||
(0, 0.0), |
||||
|mut acc, ele| { |
||||
if &acc.1 + ele < urand && ele > &0.0 { |
||||
acc.0 += 1; |
||||
} |
||||
if ele > &0.0 { |
||||
acc.1 += ele; |
||||
} |
||||
acc |
||||
}, |
||||
); |
||||
|
||||
let next_state = if next_state.0 < state { |
||||
next_state.0 |
||||
} else { |
||||
next_state.0 + 1 |
||||
}; |
||||
|
||||
Ok(StateType::Discrete(next_state)) |
||||
} |
||||
Option::None => Err(ParamsError::ParametersNotInitialized(String::from( |
||||
"CIM not initialized", |
||||
))), |
||||
} |
||||
} |
||||
|
||||
fn get_reserved_space_as_parent(&self) -> usize { |
||||
self.domain.len() |
||||
} |
||||
|
||||
fn state_to_index(&self, state: &StateType) -> usize { |
||||
match state { |
||||
StateType::Discrete(val) => val.clone() as usize, |
||||
} |
||||
} |
||||
|
||||
fn validate_params(&self) -> Result<(), ParamsError> { |
||||
let domain_size = self.domain.len(); |
||||
|
||||
// Check if the cim is initialized
|
||||
if let None = self.cim { |
||||
return Err(ParamsError::ParametersNotInitialized(String::from( |
||||
"CIM not initialized", |
||||
))); |
||||
} |
||||
let cim = self.cim.as_ref().unwrap(); |
||||
// Check if the inner dimensions of the cim are equal to the cardinality of the variable
|
||||
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size { |
||||
return Err(ParamsError::InvalidCIM(format!( |
||||
"Incompatible shape {:?} with domain {:?}", |
||||
cim.shape(), |
||||
domain_size |
||||
))); |
||||
} |
||||
|
||||
// Check if the diagonal of each cim is non-positive
|
||||
if cim |
||||
.axis_iter(Axis(0)) |
||||
.any(|x| x.diag().iter().any(|x| x >= &0.0)) |
||||
{ |
||||
return Err(ParamsError::InvalidCIM(String::from( |
||||
"The diagonal of each cim must be non-positive", |
||||
))); |
||||
} |
||||
|
||||
// Check if each row sum up to 0
|
||||
if cim |
||||
.sum_axis(Axis(2)) |
||||
.iter() |
||||
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt()) |
||||
{ |
||||
return Err(ParamsError::InvalidCIM(String::from( |
||||
"The sum of each row must be 0", |
||||
))); |
||||
} |
||||
|
||||
return Ok(()); |
||||
} |
||||
|
||||
fn get_label(&self) -> &String { |
||||
&self.label |
||||
} |
||||
} |
@ -0,0 +1,120 @@ |
||||
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
|
||||
|
||||
pub mod ctbn; |
||||
pub mod ctmp; |
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use thiserror::Error; |
||||
|
||||
use crate::params; |
||||
|
||||
/// Error types for trait Network
|
||||
#[derive(Error, Debug)] |
||||
pub enum NetworkError { |
||||
#[error("Error during node insertion")] |
||||
NodeInsertionError(String), |
||||
} |
||||
|
||||
/// This type is used to represent a specific realization of a generic NetworkProcess
|
||||
pub type NetworkProcessState = Vec<params::StateType>; |
||||
|
||||
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
|
||||
/// as a CTBN).
|
||||
pub trait NetworkProcess: Sync { |
||||
fn initialize_adj_matrix(&mut self); |
||||
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; |
||||
/// Add an **directed edge** between a two nodes of the network.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `parent` - parent node.
|
||||
/// * `child` - child node.
|
||||
fn add_edge(&mut self, parent: usize, child: usize); |
||||
|
||||
/// Get all the indices of the nodes contained inside the network.
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize>; |
||||
|
||||
/// Get the numbers of nodes contained in the network.
|
||||
fn get_number_of_nodes(&self) -> usize; |
||||
|
||||
/// Get the **node param**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node_idx` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The selected **node param**.
|
||||
fn get_node(&self, node_idx: usize) -> ¶ms::Params; |
||||
|
||||
/// Get the **node param**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node_idx` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The selected **node mutable param**.
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; |
||||
|
||||
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
||||
/// configuration of the network.
|
||||
///
|
||||
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
||||
/// the `node`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node` - selected node.
|
||||
/// * `current_state` - current configuration of the network.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * Index of the `node` relative to the network.
|
||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; |
||||
|
||||
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
||||
/// configuration of the network and a generic `parent_set`.
|
||||
///
|
||||
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
||||
/// the `node`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `current_state` - current configuration of the network.
|
||||
/// * `parent_set` - parent set of the selected `node`.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * Index of the `node` relative to the network.
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
current_state: &Vec<params::StateType>, |
||||
parent_set: &BTreeSet<usize>, |
||||
) -> usize; |
||||
|
||||
/// Get the **parent set** of a given **node**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The **parent set** of the selected node.
|
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
||||
|
||||
/// Get the **children set** of a given **node**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The **children set** of the selected node.
|
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
||||
} |
@ -0,0 +1,247 @@ |
||||
//! Continuous Time Bayesian Network
|
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use ndarray::prelude::*; |
||||
|
||||
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; |
||||
use crate::process; |
||||
|
||||
use super::ctmp::CtmpProcess; |
||||
use super::{NetworkProcess, NetworkProcessState}; |
||||
|
||||
/// It represents both the structure and the parameters of a CTBN.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
|
||||
/// * `nodes` - A vector containing all the nodes and their parameters.
|
||||
///
|
||||
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::collections::BTreeSet;
|
||||
/// use reCTBN::process::NetworkProcess;
|
||||
/// use reCTBN::params;
|
||||
/// use reCTBN::process::ctbn::*;
|
||||
///
|
||||
/// //Create the domain for a discrete node
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
///
|
||||
/// //Create the parameters for a discrete node using the domain
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
||||
///
|
||||
/// //Create the node using the parameters
|
||||
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
||||
///
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
||||
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
||||
///
|
||||
/// //Initialize a ctbn
|
||||
/// let mut net = CtbnNetwork::new();
|
||||
///
|
||||
/// //Add nodes
|
||||
/// let X1 = net.add_node(X1).unwrap();
|
||||
/// let X2 = net.add_node(X2).unwrap();
|
||||
///
|
||||
/// //Add an edge
|
||||
/// net.add_edge(X1, X2);
|
||||
///
|
||||
/// //Get all the children of node X1
|
||||
/// let cs = net.get_children_set(X1);
|
||||
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
||||
/// ```
|
||||
pub struct CtbnNetwork { |
||||
adj_matrix: Option<Array2<u16>>, |
||||
nodes: Vec<Params>, |
||||
} |
||||
|
||||
impl CtbnNetwork { |
||||
pub fn new() -> CtbnNetwork { |
||||
CtbnNetwork { |
||||
adj_matrix: None, |
||||
nodes: Vec::new(), |
||||
} |
||||
} |
||||
|
||||
///Transform the **CTBN** into a **CTMP**
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
|
||||
pub fn amalgamation(&self) -> CtmpProcess { |
||||
let variables_domain = |
||||
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); |
||||
|
||||
let state_space = variables_domain.product(); |
||||
let variables_set = BTreeSet::from_iter(self.get_node_indices()); |
||||
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space)); |
||||
|
||||
for idx_current_state in 0..state_space { |
||||
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); |
||||
let current_state_statetype: NetworkProcessState = current_state |
||||
.iter() |
||||
.map(|x| StateType::Discrete(*x)) |
||||
.collect(); |
||||
for idx_node in 0..self.nodes.len() { |
||||
let p = match self.get_node(idx_node) { |
||||
Params::DiscreteStatesContinousTime(p) => p, |
||||
}; |
||||
for next_node_state in 0..variables_domain[idx_node] { |
||||
let mut next_state = current_state.clone(); |
||||
next_state[idx_node] = next_node_state; |
||||
|
||||
let next_state_statetype: NetworkProcessState = |
||||
next_state.iter().map(|x| StateType::Discrete(*x)).collect(); |
||||
let idx_next_state = self.get_param_index_from_custom_parent_set( |
||||
&next_state_statetype, |
||||
&variables_set, |
||||
); |
||||
amalgamated_cim[[0, idx_current_state, idx_next_state]] += |
||||
p.get_cim().as_ref().unwrap()[[ |
||||
self.get_param_index_network(idx_node, ¤t_state_statetype), |
||||
current_state[idx_node], |
||||
next_node_state, |
||||
]]; |
||||
} |
||||
} |
||||
} |
||||
|
||||
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( |
||||
"ctmp".to_string(), |
||||
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), |
||||
); |
||||
|
||||
amalgamated_param.set_cim(amalgamated_cim).unwrap(); |
||||
|
||||
let mut ctmp = CtmpProcess::new(); |
||||
|
||||
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) |
||||
.unwrap(); |
||||
return ctmp; |
||||
} |
||||
|
||||
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> { |
||||
let mut state = state; |
||||
let mut array_state = Array1::zeros(variables_domain.shape()[0]); |
||||
for (idx, var) in variables_domain.indexed_iter() { |
||||
array_state[idx] = state % var; |
||||
state = state / var; |
||||
} |
||||
|
||||
return array_state; |
||||
} |
||||
/// Get the Adjacency Matrix.
|
||||
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> { |
||||
self.adj_matrix.as_ref() |
||||
} |
||||
} |
||||
|
||||
impl process::NetworkProcess for CtbnNetwork { |
||||
/// Initialize an Adjacency matrix.
|
||||
fn initialize_adj_matrix(&mut self) { |
||||
self.adj_matrix = Some(Array2::<u16>::zeros( |
||||
(self.nodes.len(), self.nodes.len()).f(), |
||||
)); |
||||
} |
||||
|
||||
/// Add a new node.
|
||||
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> { |
||||
n.reset_params(); |
||||
self.adj_matrix = Option::None; |
||||
self.nodes.push(n); |
||||
Ok(self.nodes.len() - 1) |
||||
} |
||||
|
||||
/// Connect two nodes with a new edge.
|
||||
fn add_edge(&mut self, parent: usize, child: usize) { |
||||
if let None = self.adj_matrix { |
||||
self.initialize_adj_matrix(); |
||||
} |
||||
|
||||
if let Some(network) = &mut self.adj_matrix { |
||||
network[[parent, child]] = 1; |
||||
self.nodes[child].reset_params(); |
||||
} |
||||
} |
||||
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||
0..self.nodes.len() |
||||
} |
||||
|
||||
/// Get the number of nodes of the network.
|
||||
fn get_number_of_nodes(&self) -> usize { |
||||
self.nodes.len() |
||||
} |
||||
|
||||
fn get_node(&self, node_idx: usize) -> &Params { |
||||
&self.nodes[node_idx] |
||||
} |
||||
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { |
||||
&mut self.nodes[node_idx] |
||||
} |
||||
|
||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.column(node) |
||||
.iter() |
||||
.enumerate() |
||||
.fold((0, 1), |mut acc, x| { |
||||
if x.1 > &0 { |
||||
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
||||
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
||||
} |
||||
acc |
||||
}) |
||||
.0 |
||||
} |
||||
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
current_state: &NetworkProcessState, |
||||
parent_set: &BTreeSet<usize>, |
||||
) -> usize { |
||||
parent_set |
||||
.iter() |
||||
.fold((0, 1), |mut acc, x| { |
||||
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; |
||||
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); |
||||
acc |
||||
}) |
||||
.0 |
||||
} |
||||
|
||||
/// Get all the parents of the given node.
|
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.column(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||
.collect() |
||||
} |
||||
|
||||
/// Get all the children of the given node.
|
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.row(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||
.collect() |
||||
} |
||||
} |
@ -0,0 +1,114 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use crate::{ |
||||
params::{Params, StateType}, |
||||
process, |
||||
}; |
||||
|
||||
use super::{NetworkProcess, NetworkProcessState}; |
||||
|
||||
pub struct CtmpProcess { |
||||
param: Option<Params>, |
||||
} |
||||
|
||||
impl CtmpProcess { |
||||
pub fn new() -> CtmpProcess { |
||||
CtmpProcess { param: None } |
||||
} |
||||
} |
||||
|
||||
impl NetworkProcess for CtmpProcess { |
||||
fn initialize_adj_matrix(&mut self) { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> { |
||||
match self.param { |
||||
None => { |
||||
self.param = Some(n); |
||||
Ok(0) |
||||
} |
||||
Some(_) => Err(process::NetworkError::NodeInsertionError( |
||||
"CtmpProcess has only one node".to_string(), |
||||
)), |
||||
} |
||||
} |
||||
|
||||
fn add_edge(&mut self, _parent: usize, _child: usize) { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||
match self.param { |
||||
None => 0..0, |
||||
Some(_) => 0..1, |
||||
} |
||||
} |
||||
|
||||
fn get_number_of_nodes(&self) -> usize { |
||||
match self.param { |
||||
None => 0, |
||||
Some(_) => 1, |
||||
} |
||||
} |
||||
|
||||
fn get_node(&self, node_idx: usize) -> &crate::params::Params { |
||||
if node_idx == 0 { |
||||
self.param.as_ref().unwrap() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { |
||||
if node_idx == 0 { |
||||
self.param.as_mut().unwrap() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
||||
if node == 0 { |
||||
match current_state[0] { |
||||
StateType::Discrete(x) => x, |
||||
} |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
_current_state: &NetworkProcessState, |
||||
_parent_set: &BTreeSet<usize>, |
||||
) -> usize { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||
match self.param { |
||||
Some(_) => { |
||||
if node == 0 { |
||||
BTreeSet::new() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
None => panic!("Uninitialized CtmpProcess"), |
||||
} |
||||
} |
||||
|
||||
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||
match self.param { |
||||
Some(_) => { |
||||
if node == 0 { |
||||
BTreeSet::new() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
None => panic!("Uninitialized CtmpProcess"), |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,59 @@ |
||||
pub mod reward_evaluation; |
||||
pub mod reward_function; |
||||
|
||||
use std::collections::HashMap; |
||||
|
||||
use crate::process; |
||||
|
||||
/// Instantiation of reward function and instantaneous reward
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: reward obtained transitioning from one state to another
|
||||
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
||||
|
||||
#[derive(Debug, PartialEq)] |
||||
pub struct Reward { |
||||
pub transition_reward: f64, |
||||
pub instantaneous_reward: f64, |
||||
} |
||||
|
||||
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
||||
|
||||
pub trait RewardFunction: Sync { |
||||
/// Given the current state and the previous state, it compute the reward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
||||
/// * `previous_state`: an optional argument representing the previous state of the network
|
||||
|
||||
fn call( |
||||
&self, |
||||
current_state: &process::NetworkProcessState, |
||||
previous_state: Option<&process::NetworkProcessState>, |
||||
) -> Reward; |
||||
|
||||
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||
} |
||||
|
||||
pub trait RewardEvaluation { |
||||
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
) -> HashMap<process::NetworkProcessState, f64>; |
||||
|
||||
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
state: &process::NetworkProcessState, |
||||
) -> f64; |
||||
} |
@ -0,0 +1,205 @@ |
||||
use std::collections::HashMap; |
||||
|
||||
use rayon::prelude::{IntoParallelIterator, ParallelIterator}; |
||||
use statrs::distribution::ContinuousCDF; |
||||
|
||||
use crate::params::{self, ParamsTrait}; |
||||
use crate::process; |
||||
|
||||
use crate::{ |
||||
process::NetworkProcessState, |
||||
reward::RewardEvaluation, |
||||
sampling::{ForwardSampler, Sampler}, |
||||
}; |
||||
|
||||
pub enum RewardCriteria { |
||||
FiniteHorizon, |
||||
InfiniteHorizon { discount_factor: f64 }, |
||||
} |
||||
|
||||
pub struct MonteCarloReward { |
||||
max_iterations: usize, |
||||
max_err_stop: f64, |
||||
alpha_stop: f64, |
||||
end_time: f64, |
||||
reward_criteria: RewardCriteria, |
||||
seed: Option<u64>, |
||||
} |
||||
|
||||
impl MonteCarloReward { |
||||
pub fn new( |
||||
max_iterations: usize, |
||||
max_err_stop: f64, |
||||
alpha_stop: f64, |
||||
end_time: f64, |
||||
reward_criteria: RewardCriteria, |
||||
seed: Option<u64>, |
||||
) -> MonteCarloReward { |
||||
MonteCarloReward { |
||||
max_iterations, |
||||
max_err_stop, |
||||
alpha_stop, |
||||
end_time, |
||||
reward_criteria, |
||||
seed, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl RewardEvaluation for MonteCarloReward { |
||||
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
) -> HashMap<process::NetworkProcessState, f64> { |
||||
let variables_domain: Vec<Vec<params::StateType>> = network_process |
||||
.get_node_indices() |
||||
.map(|x| match network_process.get_node(x) { |
||||
params::Params::DiscreteStatesContinousTime(x) => (0..x |
||||
.get_reserved_space_as_parent()) |
||||
.map(|s| params::StateType::Discrete(s)) |
||||
.collect(), |
||||
}) |
||||
.collect(); |
||||
|
||||
let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); |
||||
|
||||
(0..n_states) |
||||
.into_par_iter() |
||||
.map(|s| { |
||||
let state: process::NetworkProcessState = variables_domain |
||||
.iter() |
||||
.fold((s, vec![]), |acc, x| { |
||||
let mut acc = acc; |
||||
let idx_s = acc.0 % x.len(); |
||||
acc.1.push(x[idx_s].clone()); |
||||
acc.0 = acc.0 / x.len(); |
||||
acc |
||||
}) |
||||
.1; |
||||
|
||||
let r = self.evaluate_state(network_process, reward_function, &state); |
||||
(state, r) |
||||
}) |
||||
.collect() |
||||
} |
||||
|
||||
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
state: &NetworkProcessState, |
||||
) -> f64 { |
||||
let mut sampler = |
||||
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); |
||||
let mut expected_value = 0.0; |
||||
let mut squared_expected_value = 0.0; |
||||
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); |
||||
|
||||
for i in 0..self.max_iterations { |
||||
sampler.reset(); |
||||
let mut ret = 0.0; |
||||
let mut previous = sampler.next().unwrap(); |
||||
while previous.t < self.end_time { |
||||
let current = sampler.next().unwrap(); |
||||
if current.t > self.end_time { |
||||
let r = reward_function.call(&previous.state, None); |
||||
let discount = match self.reward_criteria { |
||||
RewardCriteria::FiniteHorizon => self.end_time - previous.t, |
||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||
- std::f64::consts::E.powf(-discount_factor * self.end_time) |
||||
} |
||||
}; |
||||
ret += discount * r.instantaneous_reward; |
||||
} else { |
||||
let r = reward_function.call(&previous.state, Some(¤t.state)); |
||||
let discount = match self.reward_criteria { |
||||
RewardCriteria::FiniteHorizon => current.t - previous.t, |
||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||
- std::f64::consts::E.powf(-discount_factor * current.t) |
||||
} |
||||
}; |
||||
ret += discount * r.instantaneous_reward; |
||||
ret += match self.reward_criteria { |
||||
RewardCriteria::FiniteHorizon => 1.0, |
||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||
std::f64::consts::E.powf(-discount_factor * current.t) |
||||
} |
||||
} * r.transition_reward; |
||||
} |
||||
previous = current; |
||||
} |
||||
|
||||
let float_i = i as f64; |
||||
expected_value = |
||||
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); |
||||
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) |
||||
+ ret.powi(2) / (float_i + 1.0); |
||||
|
||||
if i > 2 { |
||||
let var = |
||||
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); |
||||
if self.alpha_stop |
||||
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) |
||||
> 0.0 |
||||
{ |
||||
return expected_value; |
||||
} |
||||
} |
||||
} |
||||
|
||||
expected_value |
||||
} |
||||
} |
||||
|
||||
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> { |
||||
inner_reward: RE, |
||||
} |
||||
|
||||
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> { |
||||
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> { |
||||
NeighborhoodRelativeReward { inner_reward } |
||||
} |
||||
} |
||||
|
||||
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> { |
||||
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
) -> HashMap<process::NetworkProcessState, f64> { |
||||
let absolute_reward = self |
||||
.inner_reward |
||||
.evaluate_state_space(network_process, reward_function); |
||||
|
||||
//This approach optimize memory. Maybe optimizing execution time can be better.
|
||||
absolute_reward |
||||
.iter() |
||||
.map(|(k1, v1)| { |
||||
let mut max_val: f64 = 1.0; |
||||
absolute_reward.iter().for_each(|(k2, v2)| { |
||||
let count_diff: usize = k1 |
||||
.iter() |
||||
.zip(k2.iter()) |
||||
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) |
||||
.sum(); |
||||
if count_diff < 2 { |
||||
max_val = max_val.max(v1 / v2); |
||||
} |
||||
}); |
||||
(k1.clone(), max_val) |
||||
}) |
||||
.collect() |
||||
} |
||||
|
||||
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
_network_process: &N, |
||||
_reward_function: &R, |
||||
_state: &process::NetworkProcessState, |
||||
) -> f64 { |
||||
unimplemented!(); |
||||
} |
||||
} |
@ -0,0 +1,106 @@ |
||||
//! Module for dealing with reward functions
|
||||
|
||||
use crate::{ |
||||
params::{self, ParamsTrait}, |
||||
process, |
||||
reward::{Reward, RewardFunction}, |
||||
}; |
||||
|
||||
use ndarray; |
||||
|
||||
/// Reward function over a factored state space
|
||||
///
|
||||
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
|
||||
/// of the underling `NetworkProcess`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
|
||||
/// reward of a node
|
||||
|
||||
pub struct FactoredRewardFunction { |
||||
transition_reward: Vec<ndarray::Array2<f64>>, |
||||
instantaneous_reward: Vec<ndarray::Array1<f64>>, |
||||
} |
||||
|
||||
impl FactoredRewardFunction { |
||||
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
||||
&self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
||||
&mut self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
||||
&self.instantaneous_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
||||
&mut self.instantaneous_reward[node_idx] |
||||
} |
||||
} |
||||
|
||||
impl RewardFunction for FactoredRewardFunction { |
||||
fn call( |
||||
&self, |
||||
current_state: &process::NetworkProcessState, |
||||
previous_state: Option<&process::NetworkProcessState>, |
||||
) -> Reward { |
||||
let instantaneous_reward: f64 = current_state |
||||
.iter() |
||||
.enumerate() |
||||
.map(|(idx, x)| { |
||||
let x = match x { |
||||
params::StateType::Discrete(x) => x, |
||||
}; |
||||
self.instantaneous_reward[idx][*x] |
||||
}) |
||||
.sum(); |
||||
if let Some(previous_state) = previous_state { |
||||
let transition_reward = previous_state |
||||
.iter() |
||||
.zip(current_state.iter()) |
||||
.enumerate() |
||||
.find_map(|(idx, (p, c))| -> Option<f64> { |
||||
let p = match p { |
||||
params::StateType::Discrete(p) => p, |
||||
}; |
||||
let c = match c { |
||||
params::StateType::Discrete(c) => c, |
||||
}; |
||||
if p != c { |
||||
Some(self.transition_reward[idx][[*p, *c]]) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
.unwrap_or(0.0); |
||||
Reward { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} else { |
||||
Reward { |
||||
transition_reward: 0.0, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
||||
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
||||
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
||||
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; |
||||
for i in p.get_node_indices() { |
||||
//This works only for discrete nodes!
|
||||
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
||||
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
||||
transition_reward.push(ndarray::Array2::zeros((size, size))); |
||||
} |
||||
|
||||
FactoredRewardFunction { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,133 @@ |
||||
//! Module containing methods for the sampling.
|
||||
|
||||
use crate::{ |
||||
params::ParamsTrait, |
||||
process::{NetworkProcess, NetworkProcessState}, |
||||
}; |
||||
use rand::SeedableRng; |
||||
use rand_chacha::ChaCha8Rng; |
||||
|
||||
#[derive(Clone)] |
||||
pub struct Sample { |
||||
pub t: f64, |
||||
pub state: NetworkProcessState, |
||||
} |
||||
|
||||
pub trait Sampler: Iterator<Item = Sample> { |
||||
fn reset(&mut self); |
||||
} |
||||
|
||||
pub struct ForwardSampler<'a, T> |
||||
where |
||||
T: NetworkProcess, |
||||
{ |
||||
net: &'a T, |
||||
rng: ChaCha8Rng, |
||||
current_time: f64, |
||||
current_state: NetworkProcessState, |
||||
next_transitions: Vec<Option<f64>>, |
||||
initial_state: Option<NetworkProcessState>, |
||||
} |
||||
|
||||
impl<'a, T: NetworkProcess> ForwardSampler<'a, T> { |
||||
pub fn new( |
||||
net: &'a T, |
||||
seed: Option<u64>, |
||||
initial_state: Option<NetworkProcessState>, |
||||
) -> ForwardSampler<'a, T> { |
||||
let rng: ChaCha8Rng = match seed { |
||||
//If a seed is present use it to initialize the random generator.
|
||||
Some(seed) => SeedableRng::seed_from_u64(seed), |
||||
//Otherwise create a new random generator using the method `from_entropy`
|
||||
None => SeedableRng::from_entropy(), |
||||
}; |
||||
let mut fs = ForwardSampler { |
||||
net, |
||||
rng, |
||||
current_time: 0.0, |
||||
current_state: vec![], |
||||
next_transitions: vec![], |
||||
initial_state, |
||||
}; |
||||
fs.reset(); |
||||
return fs; |
||||
} |
||||
} |
||||
|
||||
impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> { |
||||
type Item = Sample; |
||||
|
||||
fn next(&mut self) -> Option<Self::Item> { |
||||
let ret_time = self.current_time.clone(); |
||||
let ret_state = self.current_state.clone(); |
||||
|
||||
for (idx, val) in self.next_transitions.iter_mut().enumerate() { |
||||
if let None = val { |
||||
*val = Some( |
||||
self.net |
||||
.get_node(idx) |
||||
.get_random_residence_time( |
||||
self.net |
||||
.get_node(idx) |
||||
.state_to_index(&self.current_state[idx]), |
||||
self.net.get_param_index_network(idx, &self.current_state), |
||||
&mut self.rng, |
||||
) |
||||
.unwrap() |
||||
+ self.current_time, |
||||
); |
||||
} |
||||
} |
||||
|
||||
let next_node_transition = self |
||||
.next_transitions |
||||
.iter() |
||||
.enumerate() |
||||
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) |
||||
.unwrap() |
||||
.0; |
||||
|
||||
self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); |
||||
|
||||
self.current_state[next_node_transition] = self |
||||
.net |
||||
.get_node(next_node_transition) |
||||
.get_random_state( |
||||
self.net |
||||
.get_node(next_node_transition) |
||||
.state_to_index(&self.current_state[next_node_transition]), |
||||
self.net |
||||
.get_param_index_network(next_node_transition, &self.current_state), |
||||
&mut self.rng, |
||||
) |
||||
.unwrap(); |
||||
|
||||
self.next_transitions[next_node_transition] = None; |
||||
|
||||
for child in self.net.get_children_set(next_node_transition) { |
||||
self.next_transitions[child] = None; |
||||
} |
||||
|
||||
Some(Sample { |
||||
t: ret_time, |
||||
state: ret_state, |
||||
}) |
||||
} |
||||
} |
||||
|
||||
impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> { |
||||
fn reset(&mut self) { |
||||
self.current_time = 0.0; |
||||
match &self.initial_state { |
||||
None => { |
||||
self.current_state = self |
||||
.net |
||||
.get_node_indices() |
||||
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) |
||||
.collect() |
||||
} |
||||
Some(is) => self.current_state = is.clone(), |
||||
}; |
||||
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); |
||||
} |
||||
} |
@ -0,0 +1,13 @@ |
||||
//! Learn the structure of the network.
|
||||
|
||||
pub mod constraint_based_algorithm; |
||||
pub mod hypothesis_test; |
||||
pub mod score_based_algorithm; |
||||
pub mod score_function; |
||||
use crate::{process, tools::Dataset}; |
||||
|
||||
pub trait StructureLearningAlgorithm { |
||||
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
||||
where |
||||
T: process::NetworkProcess; |
||||
} |
@ -0,0 +1,348 @@ |
||||
//! Module containing constraint based algorithms like CTPC and Hiton.
|
||||
|
||||
use crate::params::Params; |
||||
use itertools::Itertools; |
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator}; |
||||
use rayon::prelude::ParallelExtend; |
||||
use std::collections::{BTreeSet, HashMap}; |
||||
use std::mem; |
||||
use std::usize; |
||||
|
||||
use super::hypothesis_test::*; |
||||
use crate::parameter_learning::ParameterLearning; |
||||
use crate::process; |
||||
use crate::structure_learning::StructureLearningAlgorithm; |
||||
use crate::tools::Dataset; |
||||
|
||||
pub struct Cache<'a, P: ParameterLearning> { |
||||
parameter_learning: &'a P, |
||||
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>, |
||||
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>, |
||||
parent_set_size_small: usize, |
||||
} |
||||
|
||||
impl<'a, P: ParameterLearning> Cache<'a, P> { |
||||
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
||||
Cache { |
||||
parameter_learning, |
||||
cache_persistent_small: HashMap::new(), |
||||
cache_persistent_big: HashMap::new(), |
||||
parent_set_size_small: 0, |
||||
} |
||||
} |
||||
pub fn fit<T: process::NetworkProcess>( |
||||
&mut self, |
||||
net: &T, |
||||
dataset: &Dataset, |
||||
node: usize, |
||||
parent_set: Option<BTreeSet<usize>>, |
||||
) -> Params { |
||||
let parent_set_len = parent_set.as_ref().unwrap().len(); |
||||
if parent_set_len > self.parent_set_size_small + 1 { |
||||
//self.cache_persistent_small = self.cache_persistent_big;
|
||||
mem::swap( |
||||
&mut self.cache_persistent_small, |
||||
&mut self.cache_persistent_big, |
||||
); |
||||
self.cache_persistent_big = HashMap::new(); |
||||
self.parent_set_size_small += 1; |
||||
} |
||||
|
||||
if parent_set_len > self.parent_set_size_small { |
||||
match self.cache_persistent_big.get(&parent_set) { |
||||
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
||||
// not cloning requires a minor and reasoned refactoring across the library
|
||||
Some(params) => params.clone(), |
||||
None => { |
||||
let params = |
||||
self.parameter_learning |
||||
.fit(net, dataset, node, parent_set.clone()); |
||||
self.cache_persistent_big.insert(parent_set, params.clone()); |
||||
params |
||||
} |
||||
} |
||||
} else { |
||||
match self.cache_persistent_small.get(&parent_set) { |
||||
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
||||
// not cloning requires a minor and reasoned refactoring across the library
|
||||
Some(params) => params.clone(), |
||||
None => { |
||||
let params = |
||||
self.parameter_learning |
||||
.fit(net, dataset, node, parent_set.clone()); |
||||
self.cache_persistent_small |
||||
.insert(parent_set, params.clone()); |
||||
params |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Continuous-Time Peter Clark algorithm.
|
||||
///
|
||||
/// A method to learn the structure of the network.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
|
||||
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
|
||||
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use std::collections::BTreeSet;
|
||||
/// # use ndarray::{arr1, arr2, arr3};
|
||||
/// # use reCTBN::params;
|
||||
/// # use reCTBN::tools::trajectory_generator;
|
||||
/// # use reCTBN::process::NetworkProcess;
|
||||
/// # use reCTBN::process::ctbn::CtbnNetwork;
|
||||
/// use reCTBN::parameter_learning::BayesianApproach;
|
||||
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
|
||||
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
|
||||
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
|
||||
/// #
|
||||
/// # // Create the domain for a discrete node
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # domain.insert(String::from("A"));
|
||||
/// # domain.insert(String::from("B"));
|
||||
/// # domain.insert(String::from("C"));
|
||||
/// # // Create the parameters for a discrete node using the domain
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
|
||||
/// # //Create the node n1 using the parameters
|
||||
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
|
||||
/// #
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # domain.insert(String::from("D"));
|
||||
/// # domain.insert(String::from("E"));
|
||||
/// # domain.insert(String::from("F"));
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
|
||||
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
|
||||
/// #
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # domain.insert(String::from("G"));
|
||||
/// # domain.insert(String::from("H"));
|
||||
/// # domain.insert(String::from("I"));
|
||||
/// # domain.insert(String::from("F"));
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
|
||||
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
|
||||
/// #
|
||||
/// # // Initialize a ctbn
|
||||
/// # let mut net = CtbnNetwork::new();
|
||||
/// #
|
||||
/// # // Add the nodes and their edges
|
||||
/// # let n1 = net.add_node(n1).unwrap();
|
||||
/// # let n2 = net.add_node(n2).unwrap();
|
||||
/// # let n3 = net.add_node(n3).unwrap();
|
||||
/// # net.add_edge(n1, n2);
|
||||
/// # net.add_edge(n1, n3);
|
||||
/// # net.add_edge(n2, n3);
|
||||
/// #
|
||||
/// # match &mut net.get_node_mut(n1) {
|
||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||
/// # assert_eq!(
|
||||
/// # Ok(()),
|
||||
/// # param.set_cim(arr3(&[
|
||||
/// # [
|
||||
/// # [-3.0, 2.0, 1.0],
|
||||
/// # [1.5, -2.0, 0.5],
|
||||
/// # [0.4, 0.6, -1.0]
|
||||
/// # ],
|
||||
/// # ]))
|
||||
/// # );
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # match &mut net.get_node_mut(n2) {
|
||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||
/// # assert_eq!(
|
||||
/// # Ok(()),
|
||||
/// # param.set_cim(arr3(&[
|
||||
/// # [
|
||||
/// # [-1.0, 0.5, 0.5],
|
||||
/// # [3.0, -4.0, 1.0],
|
||||
/// # [0.9, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-6.0, 2.0, 4.0],
|
||||
/// # [1.5, -2.0, 0.5],
|
||||
/// # [3.0, 1.0, -4.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-1.0, 0.1, 0.9],
|
||||
/// # [2.0, -2.5, 0.5],
|
||||
/// # [0.9, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # ]))
|
||||
/// # );
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # match &mut net.get_node_mut(n3) {
|
||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||
/// # assert_eq!(
|
||||
/// # Ok(()),
|
||||
/// # param.set_cim(arr3(&[
|
||||
/// # [
|
||||
/// # [-1.0, 0.5, 0.3, 0.2],
|
||||
/// # [0.5, -4.0, 2.5, 1.0],
|
||||
/// # [2.5, 0.5, -4.0, 1.0],
|
||||
/// # [0.7, 0.2, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-6.0, 2.0, 3.0, 1.0],
|
||||
/// # [1.5, -3.0, 0.5, 1.0],
|
||||
/// # [2.0, 1.3, -5.0, 1.7],
|
||||
/// # [2.5, 0.5, 1.0, -4.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-1.3, 0.3, 0.1, 0.9],
|
||||
/// # [1.4, -4.0, 0.5, 2.1],
|
||||
/// # [1.0, 1.5, -3.0, 0.5],
|
||||
/// # [0.4, 0.3, 0.1, -0.8]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-2.0, 1.0, 0.7, 0.3],
|
||||
/// # [1.3, -5.9, 2.7, 1.9],
|
||||
/// # [2.0, 1.5, -4.0, 0.5],
|
||||
/// # [0.2, 0.7, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-6.0, 1.0, 2.0, 3.0],
|
||||
/// # [0.5, -3.0, 1.0, 1.5],
|
||||
/// # [1.4, 2.1, -4.3, 0.8],
|
||||
/// # [0.5, 1.0, 2.5, -4.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-1.3, 0.9, 0.3, 0.1],
|
||||
/// # [0.1, -1.3, 0.2, 1.0],
|
||||
/// # [0.5, 1.0, -3.0, 1.5],
|
||||
/// # [0.1, 0.4, 0.3, -0.8]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-2.0, 1.0, 0.6, 0.4],
|
||||
/// # [2.6, -7.1, 1.4, 3.1],
|
||||
/// # [5.0, 1.0, -8.0, 2.0],
|
||||
/// # [1.4, 0.4, 0.2, -2.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-3.0, 1.0, 1.5, 0.5],
|
||||
/// # [3.0, -6.0, 1.0, 2.0],
|
||||
/// # [0.3, 0.5, -1.9, 1.1],
|
||||
/// # [5.0, 1.0, 2.0, -8.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-2.6, 0.6, 0.2, 1.8],
|
||||
/// # [2.0, -6.0, 3.0, 1.0],
|
||||
/// # [0.1, 0.5, -1.3, 0.7],
|
||||
/// # [0.8, 0.6, 0.2, -1.6]
|
||||
/// # ],
|
||||
/// # ]))
|
||||
/// # );
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # // Generate the trajectory
|
||||
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
|
||||
///
|
||||
/// // Initialize the hypothesis tests to pass to the CTPC with their
|
||||
/// // respective significance level `alpha`
|
||||
/// let f = F::new(1e-6);
|
||||
/// let chi_sq = ChiSquare::new(1e-4);
|
||||
/// // Use the bayesian approach to learn the parameters
|
||||
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
|
||||
///
|
||||
/// //Initialize CTPC
|
||||
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
|
||||
///
|
||||
/// // Learn the structure of the network from the generated trajectory
|
||||
/// let net = ctpc.fit_transform(net, &data);
|
||||
/// #
|
||||
/// # // Compare the generated network with the original one
|
||||
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
|
||||
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
|
||||
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
|
||||
/// ```
|
||||
pub struct CTPC<P: ParameterLearning> { |
||||
parameter_learning: P, |
||||
Ftest: F, |
||||
Chi2test: ChiSquare, |
||||
} |
||||
|
||||
impl<P: ParameterLearning> CTPC<P> { |
||||
pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> { |
||||
CTPC { |
||||
parameter_learning, |
||||
Ftest, |
||||
Chi2test, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> { |
||||
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
||||
where |
||||
T: process::NetworkProcess, |
||||
{ |
||||
//Check the coherence between dataset and network
|
||||
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { |
||||
panic!("Dataset and Network must have the same number of variables.") |
||||
} |
||||
|
||||
//Make the network mutable.
|
||||
let mut net = net; |
||||
|
||||
net.initialize_adj_matrix(); |
||||
|
||||
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![]; |
||||
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { |
||||
let mut cache = Cache::new(&self.parameter_learning); |
||||
let mut candidate_parent_set: BTreeSet<usize> = net |
||||
.get_node_indices() |
||||
.into_iter() |
||||
.filter(|x| x != &child_node) |
||||
.collect(); |
||||
let mut separation_set_size = 0; |
||||
while separation_set_size < candidate_parent_set.len() { |
||||
let mut candidate_parent_set_TMP = candidate_parent_set.clone(); |
||||
for parent_node in candidate_parent_set.iter() { |
||||
for separation_set in candidate_parent_set |
||||
.iter() |
||||
.filter(|x| x != &parent_node) |
||||
.map(|x| *x) |
||||
.combinations(separation_set_size) |
||||
{ |
||||
let separation_set = separation_set.into_iter().collect(); |
||||
if self.Ftest.call( |
||||
&net, |
||||
child_node, |
||||
*parent_node, |
||||
&separation_set, |
||||
dataset, |
||||
&mut cache, |
||||
) && self.Chi2test.call( |
||||
&net, |
||||
child_node, |
||||
*parent_node, |
||||
&separation_set, |
||||
dataset, |
||||
&mut cache, |
||||
) { |
||||
candidate_parent_set_TMP.remove(parent_node); |
||||
break; |
||||
} |
||||
} |
||||
} |
||||
candidate_parent_set = candidate_parent_set_TMP; |
||||
separation_set_size += 1; |
||||
} |
||||
(child_node, candidate_parent_set) |
||||
})); |
||||
for (child_node, candidate_parent_set) in learned_parent_sets { |
||||
for parent_node in candidate_parent_set.iter() { |
||||
net.add_edge(*parent_node, child_node); |
||||
} |
||||
} |
||||
net |
||||
} |
||||
} |
@ -0,0 +1,261 @@ |
||||
//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc...
|
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use ndarray::{Array3, Axis}; |
||||
use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor}; |
||||
|
||||
use crate::params::*; |
||||
use crate::structure_learning::constraint_based_algorithm::Cache; |
||||
use crate::{parameter_learning, process, tools::Dataset}; |
||||
|
||||
pub trait HypothesisTest { |
||||
fn call<T, P>( |
||||
&self, |
||||
net: &T, |
||||
child_node: usize, |
||||
parent_node: usize, |
||||
separation_set: &BTreeSet<usize>, |
||||
dataset: &Dataset, |
||||
cache: &mut Cache<P>, |
||||
) -> bool |
||||
where |
||||
T: process::NetworkProcess, |
||||
P: parameter_learning::ParameterLearning; |
||||
} |
||||
|
||||
/// Does the chi-squared test (χ2 test).
|
||||
///
|
||||
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
|
||||
/// a relationship (dependence) between the variables.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
|
||||
/// in other words is the risk of concluding that an association between the variables exists
|
||||
/// when there is no actual association.
|
||||
|
||||
pub struct ChiSquare { |
||||
alpha: f64, |
||||
} |
||||
|
||||
/// Does the F-test.
|
||||
///
|
||||
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
|
||||
/// a relationship (dependence) between the variables.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
|
||||
/// in other words is the risk of concluding that an association between the variables exists
|
||||
/// when there is no actual association.
|
||||
|
||||
pub struct F { |
||||
alpha: f64, |
||||
} |
||||
|
||||
impl F { |
||||
pub fn new(alpha: f64) -> F { |
||||
F { alpha } |
||||
} |
||||
|
||||
/// Compare two matrices extracted from two 3rd-orer tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
|
||||
/// * `M1` - 3rd-order tensor 1.
|
||||
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
|
||||
/// * `M2` - 3rd-order tensor 2.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
|
||||
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
|
||||
|
||||
pub fn compare_matrices( |
||||
&self, |
||||
i: usize, |
||||
M1: &Array3<usize>, |
||||
cim_1: &Array3<f64>, |
||||
j: usize, |
||||
M2: &Array3<usize>, |
||||
cim_2: &Array3<f64>, |
||||
) -> bool { |
||||
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); |
||||
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); |
||||
let cim_1 = cim_1.index_axis(Axis(0), i); |
||||
let cim_2 = cim_2.index_axis(Axis(0), j); |
||||
let r1 = M1.sum_axis(Axis(1)); |
||||
let r2 = M2.sum_axis(Axis(1)); |
||||
let q1 = cim_1.diag(); |
||||
let q2 = cim_2.diag(); |
||||
for idx in 0..r1.shape()[0] { |
||||
let s = q2[idx] / q1[idx]; |
||||
let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap(); |
||||
let s = F.cdf(s); |
||||
let lim_sx = self.alpha / 2.0; |
||||
let lim_dx = 1.0 - (self.alpha / 2.0); |
||||
if s < lim_sx || s > lim_dx { |
||||
return false; |
||||
} |
||||
} |
||||
true |
||||
} |
||||
} |
||||
|
||||
impl HypothesisTest for F { |
||||
fn call<T, P>( |
||||
&self, |
||||
net: &T, |
||||
child_node: usize, |
||||
parent_node: usize, |
||||
separation_set: &BTreeSet<usize>, |
||||
dataset: &Dataset, |
||||
cache: &mut Cache<P>, |
||||
) -> bool |
||||
where |
||||
T: process::NetworkProcess, |
||||
P: parameter_learning::ParameterLearning, |
||||
{ |
||||
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { |
||||
Params::DiscreteStatesContinousTime(node) => node, |
||||
}; |
||||
let mut extended_separation_set = separation_set.clone(); |
||||
extended_separation_set.insert(parent_node); |
||||
|
||||
let P_big = match cache.fit( |
||||
net, |
||||
&dataset, |
||||
child_node, |
||||
Some(extended_separation_set.clone()), |
||||
) { |
||||
Params::DiscreteStatesContinousTime(node) => node, |
||||
}; |
||||
let partial_cardinality_product: usize = extended_separation_set |
||||
.iter() |
||||
.take_while(|x| **x != parent_node) |
||||
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) |
||||
.product(); |
||||
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { |
||||
let idx_M_small: usize = idx_M_big % partial_cardinality_product |
||||
+ (idx_M_big |
||||
/ (partial_cardinality_product |
||||
* net.get_node(parent_node).get_reserved_space_as_parent())) |
||||
* partial_cardinality_product; |
||||
if !self.compare_matrices( |
||||
idx_M_small, |
||||
P_small.get_transitions().as_ref().unwrap(), |
||||
P_small.get_cim().as_ref().unwrap(), |
||||
idx_M_big, |
||||
P_big.get_transitions().as_ref().unwrap(), |
||||
P_big.get_cim().as_ref().unwrap(), |
||||
) { |
||||
return false; |
||||
} |
||||
} |
||||
return true; |
||||
} |
||||
} |
||||
|
||||
impl ChiSquare { |
||||
pub fn new(alpha: f64) -> ChiSquare { |
||||
ChiSquare { alpha } |
||||
} |
||||
|
||||
/// Compare two matrices extracted from two 3rd-orer tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
|
||||
/// * `M1` - 3rd-order tensor 1.
|
||||
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
|
||||
/// * `M2` - 3rd-order tensor 2.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
|
||||
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
|
||||
|
||||
pub fn compare_matrices( |
||||
&self, |
||||
i: usize, |
||||
M1: &Array3<usize>, |
||||
j: usize, |
||||
M2: &Array3<usize>, |
||||
) -> bool { |
||||
// Bregoli, A., Scutari, M. and Stella, F., 2021.
|
||||
// A constraint-based algorithm for the structural learning of
|
||||
// continuous-time Bayesian networks.
|
||||
// International Journal of Approximate Reasoning, 138, pp.105-122.
|
||||
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
|
||||
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); |
||||
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); |
||||
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); |
||||
let K = K.mapv(f64::sqrt); |
||||
// Reshape to column vector.
|
||||
let K = { |
||||
let n = K.len(); |
||||
K.into_shape((n, 1)).unwrap() |
||||
}; |
||||
let L = 1.0 / &K; |
||||
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); |
||||
X_2.diag_mut().fill(0.0); |
||||
let X_2 = X_2.sum_axis(Axis(1)); |
||||
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); |
||||
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha)); |
||||
ret |
||||
} |
||||
} |
||||
|
||||
impl HypothesisTest for ChiSquare { |
||||
fn call<T, P>( |
||||
&self, |
||||
net: &T, |
||||
child_node: usize, |
||||
parent_node: usize, |
||||
separation_set: &BTreeSet<usize>, |
||||
dataset: &Dataset, |
||||
cache: &mut Cache<P>, |
||||
) -> bool |
||||
where |
||||
T: process::NetworkProcess, |
||||
P: parameter_learning::ParameterLearning, |
||||
{ |
||||
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) { |
||||
Params::DiscreteStatesContinousTime(node) => node, |
||||
}; |
||||
let mut extended_separation_set = separation_set.clone(); |
||||
extended_separation_set.insert(parent_node); |
||||
|
||||
let P_big = match cache.fit( |
||||
net, |
||||
&dataset, |
||||
child_node, |
||||
Some(extended_separation_set.clone()), |
||||
) { |
||||
Params::DiscreteStatesContinousTime(node) => node, |
||||
}; |
||||
let partial_cardinality_product: usize = extended_separation_set |
||||
.iter() |
||||
.take_while(|x| **x != parent_node) |
||||
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) |
||||
.product(); |
||||
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] { |
||||
let idx_M_small: usize = idx_M_big % partial_cardinality_product |
||||
+ (idx_M_big |
||||
/ (partial_cardinality_product |
||||
* net.get_node(parent_node).get_reserved_space_as_parent())) |
||||
* partial_cardinality_product; |
||||
if !self.compare_matrices( |
||||
idx_M_small, |
||||
P_small.get_transitions().as_ref().unwrap(), |
||||
idx_M_big, |
||||
P_big.get_transitions().as_ref().unwrap(), |
||||
) { |
||||
return false; |
||||
} |
||||
} |
||||
return true; |
||||
} |
||||
} |
@ -0,0 +1,93 @@ |
||||
//! Module containing score based algorithms like Hill Climbing and Tabu Search.
|
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use crate::structure_learning::score_function::ScoreFunction; |
||||
use crate::structure_learning::StructureLearningAlgorithm; |
||||
use crate::{process, tools::Dataset}; |
||||
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator}; |
||||
use rayon::prelude::ParallelExtend; |
||||
|
||||
pub struct HillClimbing<S: ScoreFunction> { |
||||
score_function: S, |
||||
max_parent_set: Option<usize>, |
||||
} |
||||
|
||||
impl<S: ScoreFunction> HillClimbing<S> { |
||||
pub fn new(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> { |
||||
HillClimbing { |
||||
score_function, |
||||
max_parent_set, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { |
||||
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
||||
where |
||||
T: process::NetworkProcess, |
||||
{ |
||||
//Check the coherence between dataset and network
|
||||
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { |
||||
panic!("Dataset and Network must have the same number of variables.") |
||||
} |
||||
|
||||
//Make the network mutable.
|
||||
let mut net = net; |
||||
//Check if the max_parent_set constraint is present.
|
||||
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); |
||||
//Reset the adj matrix
|
||||
net.initialize_adj_matrix(); |
||||
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![]; |
||||
//Iterate over each node to learn their parent set.
|
||||
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| { |
||||
//Initialize an empty parent set.
|
||||
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); |
||||
//Compute the score for the empty parent set
|
||||
let mut current_score = self.score_function.call(&net, node, &parent_set, dataset); |
||||
//Set the old score to -\infty.
|
||||
let mut old_score = f64::NEG_INFINITY; |
||||
//Iterate until convergence
|
||||
while current_score > old_score { |
||||
//Save the current_score.
|
||||
old_score = current_score; |
||||
//Iterate over each node.
|
||||
for parent in net.get_node_indices() { |
||||
//Continue if the parent and the node are the same.
|
||||
if parent == node { |
||||
continue; |
||||
} |
||||
//Try to remove parent from the parent_set.
|
||||
let is_removed = parent_set.remove(&parent); |
||||
//If parent was not in the parent_set add it.
|
||||
if !is_removed && parent_set.len() < max_parent_set { |
||||
parent_set.insert(parent); |
||||
} |
||||
//Compute the score with the modified parent_set.
|
||||
let tmp_score = self.score_function.call(&net, node, &parent_set, dataset); |
||||
//If tmp_score is worst than current_score revert the change to the parent set
|
||||
if tmp_score < current_score { |
||||
if is_removed { |
||||
parent_set.insert(parent); |
||||
} else { |
||||
parent_set.remove(&parent); |
||||
} |
||||
} |
||||
//Otherwise save the computed score as current_score
|
||||
else { |
||||
current_score = tmp_score; |
||||
} |
||||
} |
||||
} |
||||
(node, parent_set) |
||||
})); |
||||
|
||||
for (child_node, candidate_parent_set) in learned_parent_sets { |
||||
for parent_node in candidate_parent_set.iter() { |
||||
net.add_edge(*parent_node, child_node); |
||||
} |
||||
} |
||||
return net; |
||||
} |
||||
} |
@ -0,0 +1,146 @@ |
||||
//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
|
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use ndarray::prelude::*; |
||||
use statrs::function::gamma; |
||||
|
||||
use crate::{parameter_learning, params, process, tools}; |
||||
|
||||
pub trait ScoreFunction: Sync { |
||||
fn call<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: process::NetworkProcess; |
||||
} |
||||
|
||||
pub struct LogLikelihood { |
||||
alpha: usize, |
||||
tau: f64, |
||||
} |
||||
|
||||
impl LogLikelihood { |
||||
pub fn new(alpha: usize, tau: f64) -> LogLikelihood { |
||||
//Tau must be >=0.0
|
||||
if tau < 0.0 { |
||||
panic!("tau must be >=0.0"); |
||||
} |
||||
LogLikelihood { alpha, tau } |
||||
} |
||||
|
||||
fn compute_score<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> (f64, Array3<usize>) |
||||
where |
||||
T: process::NetworkProcess, |
||||
{ |
||||
//Identify the type of node used
|
||||
match &net.get_node(node) { |
||||
params::Params::DiscreteStatesContinousTime(_params) => { |
||||
//Compute the sufficient statistics M (number of transistions) and T (residence
|
||||
//time)
|
||||
let (M, T) = |
||||
parameter_learning::sufficient_statistics(net, dataset, node, parent_set); |
||||
|
||||
//Scale alpha accordingly to the size of the parent set
|
||||
let alpha = self.alpha as f64 / M.shape()[0] as f64; |
||||
//Scale tau accordingly to the size of the parent set
|
||||
let tau = self.tau / M.shape()[0] as f64; |
||||
|
||||
//Compute the log likelihood for q
|
||||
let log_ll_q: f64 = M |
||||
.sum_axis(Axis(2)) |
||||
.iter() |
||||
.zip(T.iter()) |
||||
.map(|(m, t)| { |
||||
gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau) |
||||
- gamma::ln_gamma(alpha + 1.0) |
||||
- (alpha + *m as f64 + 1.0) * f64::ln(tau + t) |
||||
}) |
||||
.sum(); |
||||
|
||||
//Compute the log likelihood for theta
|
||||
let log_ll_theta: f64 = M |
||||
.outer_iter() |
||||
.map(|x| { |
||||
x.outer_iter() |
||||
.map(|y| { |
||||
gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64) |
||||
+ y.iter() |
||||
.map(|z| { |
||||
gamma::ln_gamma(alpha + *z as f64) |
||||
- gamma::ln_gamma(alpha) |
||||
}) |
||||
.sum::<f64>() |
||||
}) |
||||
.sum::<f64>() |
||||
}) |
||||
.sum(); |
||||
(log_ll_theta + log_ll_q, M) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl ScoreFunction for LogLikelihood { |
||||
fn call<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: process::NetworkProcess, |
||||
{ |
||||
self.compute_score(net, node, parent_set, dataset).0 |
||||
} |
||||
} |
||||
|
||||
pub struct BIC { |
||||
ll: LogLikelihood, |
||||
} |
||||
|
||||
impl BIC { |
||||
pub fn new(alpha: usize, tau: f64) -> BIC { |
||||
BIC { |
||||
ll: LogLikelihood::new(alpha, tau), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl ScoreFunction for BIC { |
||||
fn call<T>( |
||||
&self, |
||||
net: &T, |
||||
node: usize, |
||||
parent_set: &BTreeSet<usize>, |
||||
dataset: &tools::Dataset, |
||||
) -> f64 |
||||
where |
||||
T: process::NetworkProcess, |
||||
{ |
||||
//Compute the log-likelihood
|
||||
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); |
||||
//Compute the number of parameters
|
||||
let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1); |
||||
//TODO: Optimize this
|
||||
//Compute the sample size
|
||||
let sample_size: usize = dataset |
||||
.get_trajectories() |
||||
.iter() |
||||
.map(|x| x.get_time().len() - 1) |
||||
.sum(); |
||||
//Compute BIC
|
||||
ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64 |
||||
} |
||||
} |
@ -0,0 +1,355 @@ |
||||
//! Contains commonly used methods used across the crate.
|
||||
|
||||
use std::ops::{DivAssign, MulAssign, Range}; |
||||
|
||||
use ndarray::{Array, Array1, Array2, Array3, Axis}; |
||||
use rand::{Rng, SeedableRng}; |
||||
use rand_chacha::ChaCha8Rng; |
||||
|
||||
use crate::params::ParamsTrait; |
||||
use crate::process::NetworkProcess; |
||||
use crate::sampling::{ForwardSampler, Sampler}; |
||||
use crate::{params, process}; |
||||
|
||||
#[derive(Clone)] |
||||
pub struct Trajectory { |
||||
time: Array1<f64>, |
||||
events: Array2<usize>, |
||||
} |
||||
|
||||
impl Trajectory { |
||||
pub fn new(time: Array1<f64>, events: Array2<usize>) -> Trajectory { |
||||
//Events and time are two part of the same trajectory. For this reason they must have the
|
||||
//same number of sample.
|
||||
if time.shape()[0] != events.shape()[0] { |
||||
panic!("time.shape[0] must be equal to events.shape[0]"); |
||||
} |
||||
Trajectory { time, events } |
||||
} |
||||
|
||||
pub fn get_time(&self) -> &Array1<f64> { |
||||
&self.time |
||||
} |
||||
|
||||
pub fn get_events(&self) -> &Array2<usize> { |
||||
&self.events |
||||
} |
||||
} |
||||
|
||||
#[derive(Clone)] |
||||
pub struct Dataset { |
||||
trajectories: Vec<Trajectory>, |
||||
} |
||||
|
||||
impl Dataset { |
||||
pub fn new(trajectories: Vec<Trajectory>) -> Dataset { |
||||
//All the trajectories in the same dataset must represent the same process. For this reason
|
||||
//each trajectory must represent the same number of variables.
|
||||
if trajectories |
||||
.iter() |
||||
.any(|x| trajectories[0].get_events().shape()[1] != x.get_events().shape()[1]) |
||||
{ |
||||
panic!("All the trajectories mus represents the same number of variables"); |
||||
} |
||||
Dataset { trajectories } |
||||
} |
||||
|
||||
pub fn get_trajectories(&self) -> &Vec<Trajectory> { |
||||
&self.trajectories |
||||
} |
||||
} |
||||
|
||||
pub fn trajectory_generator<T: process::NetworkProcess>( |
||||
net: &T, |
||||
n_trajectories: u64, |
||||
t_end: f64, |
||||
seed: Option<u64>, |
||||
) -> Dataset { |
||||
//Tmp growing vector containing generated trajectories.
|
||||
let mut trajectories: Vec<Trajectory> = Vec::new(); |
||||
|
||||
//Random Generator object
|
||||
let mut sampler = ForwardSampler::new(net, seed, None); |
||||
//Each iteration generate one trajectory
|
||||
for _ in 0..n_trajectories { |
||||
//History of all the moments in which something changed
|
||||
let mut time: Vec<f64> = Vec::new(); |
||||
//Configuration of the process variables at time t initialized with an uniform
|
||||
//distribution.
|
||||
let mut events: Vec<process::NetworkProcessState> = Vec::new(); |
||||
|
||||
//Current Time and Current State
|
||||
let mut sample = sampler.next().unwrap(); |
||||
//Generate new samples until ending time is reached.
|
||||
while sample.t < t_end { |
||||
time.push(sample.t); |
||||
events.push(sample.state); |
||||
sample = sampler.next().unwrap(); |
||||
} |
||||
|
||||
let current_state = events.last().unwrap().clone(); |
||||
events.push(current_state); |
||||
|
||||
//Add t_end as last time.
|
||||
time.push(t_end.clone()); |
||||
|
||||
//Add the sampled trajectory to trajectories.
|
||||
trajectories.push(Trajectory::new( |
||||
Array::from_vec(time), |
||||
Array2::from_shape_vec( |
||||
(events.len(), events.last().unwrap().len()), |
||||
events |
||||
.iter() |
||||
.flatten() |
||||
.map(|x| match x { |
||||
params::StateType::Discrete(x) => x.clone(), |
||||
}) |
||||
.collect(), |
||||
) |
||||
.unwrap(), |
||||
)); |
||||
sampler.reset(); |
||||
} |
||||
//Return a dataset object with the sampled trajectories.
|
||||
Dataset::new(trajectories) |
||||
} |
||||
|
||||
pub trait RandomGraphGenerator { |
||||
fn new(density: f64, seed: Option<u64>) -> Self; |
||||
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T); |
||||
} |
||||
|
||||
/// Graph Generator using an uniform distribution.
|
||||
///
|
||||
/// A method to generate a random graph with edges uniformly distributed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`.
|
||||
/// * `rng` - is the random numbers generator.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use std::collections::BTreeSet;
|
||||
/// # use ndarray::{arr1, arr2, arr3};
|
||||
/// # use reCTBN::params;
|
||||
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
|
||||
/// # use reCTBN::tools::trajectory_generator;
|
||||
/// # use reCTBN::process::NetworkProcess;
|
||||
/// # use reCTBN::process::ctbn::CtbnNetwork;
|
||||
/// use reCTBN::tools::UniformGraphGenerator;
|
||||
/// use reCTBN::tools::RandomGraphGenerator;
|
||||
/// # let mut net = CtbnNetwork::new();
|
||||
/// # let nodes_cardinality = 8;
|
||||
/// # let domain_cardinality = 4;
|
||||
/// # for node in 0..nodes_cardinality {
|
||||
/// # // Create the domain for a discrete node
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # for dvalue in 0..domain_cardinality {
|
||||
/// # domain.insert(dvalue.to_string());
|
||||
/// # }
|
||||
/// # // Create the parameters for a discrete node using the domain
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new(
|
||||
/// # node.to_string(),
|
||||
/// # domain
|
||||
/// # );
|
||||
/// # //Create the node using the parameters
|
||||
/// # let node = DiscreteStatesContinousTime(param);
|
||||
/// # // Add the node to the network
|
||||
/// # net.add_node(node).unwrap();
|
||||
/// # }
|
||||
///
|
||||
/// // Initialize the Graph Generator using the one with an
|
||||
/// // uniform distribution
|
||||
/// let density = 1.0/3.0;
|
||||
/// let seed = Some(7641630759785120);
|
||||
/// let mut structure_generator = UniformGraphGenerator::new(
|
||||
/// density,
|
||||
/// seed
|
||||
/// );
|
||||
///
|
||||
/// // Generate the graph directly on the network
|
||||
/// structure_generator.generate_graph(&mut net);
|
||||
/// # // Count all the edges generated in the network
|
||||
/// # let mut edges = 0;
|
||||
/// # for node in net.get_node_indices(){
|
||||
/// # edges += net.get_children_set(node).len()
|
||||
/// # }
|
||||
/// # // Number of all the nodes in the network
|
||||
/// # let nodes = net.get_node_indices().len() as f64;
|
||||
/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
|
||||
/// # // ±10% of tolerance
|
||||
/// # let tolerance = ((expected_edges as f64)*0.10) as usize;
|
||||
/// # // As the way `generate_graph()` is implemented we can only reasonably
|
||||
/// # // expect the number of edges to be somewhere around the expected value.
|
||||
/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
|
||||
/// ```
|
||||
pub struct UniformGraphGenerator { |
||||
density: f64, |
||||
rng: ChaCha8Rng, |
||||
} |
||||
|
||||
impl RandomGraphGenerator for UniformGraphGenerator { |
||||
fn new(density: f64, seed: Option<u64>) -> UniformGraphGenerator { |
||||
if density < 0.0 || density > 1.0 { |
||||
panic!( |
||||
"Density value must be between 1.0 and 0.0, got {}.", |
||||
density |
||||
); |
||||
} |
||||
let rng: ChaCha8Rng = match seed { |
||||
Some(seed) => SeedableRng::seed_from_u64(seed), |
||||
None => SeedableRng::from_entropy(), |
||||
}; |
||||
UniformGraphGenerator { density, rng } |
||||
} |
||||
|
||||
/// Generate an uniformly distributed graph.
|
||||
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T) { |
||||
net.initialize_adj_matrix(); |
||||
let last_node_idx = net.get_node_indices().len(); |
||||
for parent in 0..last_node_idx { |
||||
for child in 0..last_node_idx { |
||||
if parent != child { |
||||
if self.rng.gen_bool(self.density) { |
||||
net.add_edge(parent, child); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub trait RandomParametersGenerator { |
||||
fn new(interval: Range<f64>, seed: Option<u64>) -> Self; |
||||
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T); |
||||
} |
||||
|
||||
/// Parameters Generator using an uniform distribution.
|
||||
///
|
||||
/// A method to generate random parameters uniformly distributed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`.
|
||||
/// * `rng` - is the random numbers generator.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use std::collections::BTreeSet;
|
||||
/// # use ndarray::{arr1, arr2, arr3};
|
||||
/// # use reCTBN::params;
|
||||
/// # use reCTBN::params::ParamsTrait;
|
||||
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
|
||||
/// # use reCTBN::process::NetworkProcess;
|
||||
/// # use reCTBN::process::ctbn::CtbnNetwork;
|
||||
/// # use reCTBN::tools::trajectory_generator;
|
||||
/// # use reCTBN::tools::RandomGraphGenerator;
|
||||
/// # use reCTBN::tools::UniformGraphGenerator;
|
||||
/// use reCTBN::tools::RandomParametersGenerator;
|
||||
/// use reCTBN::tools::UniformParametersGenerator;
|
||||
/// # let mut net = CtbnNetwork::new();
|
||||
/// # let nodes_cardinality = 8;
|
||||
/// # let domain_cardinality = 4;
|
||||
/// # for node in 0..nodes_cardinality {
|
||||
/// # // Create the domain for a discrete node
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # for dvalue in 0..domain_cardinality {
|
||||
/// # domain.insert(dvalue.to_string());
|
||||
/// # }
|
||||
/// # // Create the parameters for a discrete node using the domain
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new(
|
||||
/// # node.to_string(),
|
||||
/// # domain
|
||||
/// # );
|
||||
/// # //Create the node using the parameters
|
||||
/// # let node = DiscreteStatesContinousTime(param);
|
||||
/// # // Add the node to the network
|
||||
/// # net.add_node(node).unwrap();
|
||||
/// # }
|
||||
/// #
|
||||
/// # // Initialize the Graph Generator using the one with an
|
||||
/// # // uniform distribution
|
||||
/// # let mut structure_generator = UniformGraphGenerator::new(
|
||||
/// # 1.0/3.0,
|
||||
/// # Some(7641630759785120)
|
||||
/// # );
|
||||
/// #
|
||||
/// # // Generate the graph directly on the network
|
||||
/// # structure_generator.generate_graph(&mut net);
|
||||
///
|
||||
/// // Initialize the parameters generator with uniform distributin
|
||||
/// let mut cim_generator = UniformParametersGenerator::new(
|
||||
/// 0.0..7.0,
|
||||
/// Some(7641630759785120)
|
||||
/// );
|
||||
///
|
||||
/// // Generate CIMs with uniformly distributed parameters.
|
||||
/// cim_generator.generate_parameters(&mut net);
|
||||
/// #
|
||||
/// # for node in net.get_node_indices() {
|
||||
/// # assert_eq!(
|
||||
/// # Ok(()),
|
||||
/// # net.get_node(node).validate_params()
|
||||
/// # );
|
||||
/// }
|
||||
/// ```
|
||||
pub struct UniformParametersGenerator { |
||||
interval: Range<f64>, |
||||
rng: ChaCha8Rng, |
||||
} |
||||
|
||||
impl RandomParametersGenerator for UniformParametersGenerator { |
||||
fn new(interval: Range<f64>, seed: Option<u64>) -> UniformParametersGenerator { |
||||
if interval.start < 0.0 || interval.end < 0.0 { |
||||
panic!( |
||||
"Interval must be entirely less or equal than 0, got {}..{}.", |
||||
interval.start, interval.end |
||||
); |
||||
} |
||||
let rng: ChaCha8Rng = match seed { |
||||
Some(seed) => SeedableRng::seed_from_u64(seed), |
||||
None => SeedableRng::from_entropy(), |
||||
}; |
||||
UniformParametersGenerator { interval, rng } |
||||
} |
||||
|
||||
/// Generate CIMs with uniformly distributed parameters.
|
||||
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) { |
||||
for node in net.get_node_indices() { |
||||
let parent_set_state_space_cardinality: usize = net |
||||
.get_parent_set(node) |
||||
.iter() |
||||
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) |
||||
.product(); |
||||
match &mut net.get_node_mut(node) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
let node_domain_cardinality = param.get_reserved_space_as_parent(); |
||||
let mut cim = Array3::<f64>::from_shape_fn( |
||||
( |
||||
parent_set_state_space_cardinality, |
||||
node_domain_cardinality, |
||||
node_domain_cardinality, |
||||
), |
||||
|_| self.rng.gen(), |
||||
); |
||||
cim.axis_iter_mut(Axis(0)).for_each(|mut x| { |
||||
x.diag_mut().fill(0.0); |
||||
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1))); |
||||
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| { |
||||
self.rng.gen_range(self.interval.clone()) |
||||
}); |
||||
x.mul_assign(&diag.clone().insert_axis(Axis(1))); |
||||
// Recomputing the diagonal in order to reduce the issues caused by the
|
||||
// loss of precision when validating the parameters.
|
||||
let diag_sum = -x.sum_axis(Axis(1)); |
||||
x.diag_mut().assign(&diag_sum) |
||||
}); |
||||
param.set_cim_unchecked(cim); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,376 @@ |
||||
mod utils; |
||||
use std::collections::BTreeSet; |
||||
|
||||
|
||||
use approx::AbsDiffEq; |
||||
use ndarray::arr3; |
||||
use reCTBN::params::{self, ParamsTrait}; |
||||
use reCTBN::process::NetworkProcess; |
||||
use reCTBN::process::{ctbn::*}; |
||||
use utils::generate_discrete_time_continous_node; |
||||
|
||||
#[test] |
||||
fn define_simpe_ctbn() { |
||||
let _ = CtbnNetwork::new(); |
||||
assert!(true); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_node_to_ctbn() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_edge_to_ctbn() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
let cs = net.get_children_set(n1); |
||||
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||
} |
||||
|
||||
#[test] |
||||
fn children_and_parents() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
let cs = net.get_children_set(n1); |
||||
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||
let ps = net.get_parent_set(n2); |
||||
assert_eq!(&n1, ps.iter().next().unwrap()); |
||||
} |
||||
|
||||
#[test] |
||||
fn compute_index_ctbn() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n3, n2); |
||||
let idx = net.get_param_index_network( |
||||
n2, |
||||
&vec![ |
||||
params::StateType::Discrete(1), |
||||
params::StateType::Discrete(1), |
||||
params::StateType::Discrete(1), |
||||
], |
||||
); |
||||
assert_eq!(3, idx); |
||||
|
||||
let idx = net.get_param_index_network( |
||||
n2, |
||||
&vec![ |
||||
params::StateType::Discrete(0), |
||||
params::StateType::Discrete(1), |
||||
params::StateType::Discrete(1), |
||||
], |
||||
); |
||||
assert_eq!(2, idx); |
||||
|
||||
let idx = net.get_param_index_network( |
||||
n2, |
||||
&vec![ |
||||
params::StateType::Discrete(1), |
||||
params::StateType::Discrete(1), |
||||
params::StateType::Discrete(0), |
||||
], |
||||
); |
||||
assert_eq!(1, idx); |
||||
} |
||||
|
||||
#[test] |
||||
fn compute_index_from_custom_parent_set() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let _n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
let _n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||
.unwrap(); |
||||
|
||||
let idx = net.get_param_index_from_custom_parent_set( |
||||
&vec![ |
||||
params::StateType::Discrete(0), |
||||
params::StateType::Discrete(0), |
||||
params::StateType::Discrete(1), |
||||
], |
||||
&BTreeSet::from([1]), |
||||
); |
||||
assert_eq!(0, idx); |
||||
|
||||
let idx = net.get_param_index_from_custom_parent_set( |
||||
&vec![ |
||||
params::StateType::Discrete(0), |
||||
params::StateType::Discrete(0), |
||||
params::StateType::Discrete(1), |
||||
], |
||||
&BTreeSet::from([1, 2]), |
||||
); |
||||
assert_eq!(2, idx); |
||||
} |
||||
|
||||
#[test] |
||||
fn simple_amalgamation() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
net.initialize_adj_matrix(); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); |
||||
} |
||||
} |
||||
|
||||
let ctmp = net.amalgamation(); |
||||
let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0); |
||||
let p_ctbn = p_ctbn.get_cim().as_ref().unwrap(); |
||||
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); |
||||
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); |
||||
|
||||
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON)); |
||||
} |
||||
|
||||
#[test] |
||||
fn chain_amalgamation() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||
.unwrap(); |
||||
|
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n2, n3); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]] |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n3) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]] |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let ctmp = net.amalgamation(); |
||||
|
||||
|
||||
|
||||
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0); |
||||
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); |
||||
|
||||
let p_ctmp_handmade = arr3(&[[ |
||||
[ |
||||
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
||||
], |
||||
[ |
||||
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00, |
||||
], |
||||
]]); |
||||
|
||||
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); |
||||
} |
||||
|
||||
#[test] |
||||
fn chainfork_amalgamation() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||
.unwrap(); |
||||
let n4 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2)) |
||||
.unwrap(); |
||||
|
||||
net.add_edge(n1, n3); |
||||
net.add_edge(n2, n3); |
||||
net.add_edge(n3, n4); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]]))); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n3) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]] |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n4) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]] |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
|
||||
let ctmp = net.amalgamation(); |
||||
|
||||
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
|
||||
|
||||
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap(); |
||||
|
||||
let p_ctmp_handmade = arr3(&[[ |
||||
[ |
||||
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
||||
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, |
||||
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
||||
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00, |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01, |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01, |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
||||
], |
||||
[ |
||||
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, |
||||
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, |
||||
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, |
||||
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, |
||||
], |
||||
[ |
||||
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, |
||||
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00, |
||||
], |
||||
]]); |
||||
|
||||
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8)); |
||||
} |
@ -0,0 +1,127 @@ |
||||
mod utils; |
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use reCTBN::{ |
||||
params, |
||||
params::ParamsTrait, |
||||
process::{ctmp::*, NetworkProcess}, |
||||
}; |
||||
use utils::generate_discrete_time_continous_node; |
||||
|
||||
#[test] |
||||
fn define_simple_ctmp() { |
||||
let _ = CtmpProcess::new(); |
||||
assert!(true); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_node_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_two_nodes_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||
|
||||
match n2 { |
||||
Ok(_) => assert!(false), |
||||
Err(_) => assert!(true), |
||||
}; |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn add_edge_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||
|
||||
net.add_edge(0, 1) |
||||
} |
||||
|
||||
#[test] |
||||
fn childen_and_parents() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
assert_eq!(0, net.get_parent_set(0).len()); |
||||
assert_eq!(0, net.get_children_set(0).len()); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_childen_panic() { |
||||
let net = CtmpProcess::new(); |
||||
net.get_children_set(0); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_childen_panic2() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
net.get_children_set(1); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_parent_panic() { |
||||
let net = CtmpProcess::new(); |
||||
net.get_parent_set(0); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_parent_panic2() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
net.get_parent_set(1); |
||||
} |
||||
|
||||
#[test] |
||||
fn compute_index_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node( |
||||
String::from("n1"), |
||||
10, |
||||
)) |
||||
.unwrap(); |
||||
|
||||
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); |
||||
assert_eq!(6, idx); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn compute_index_from_custom_parent_set_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node( |
||||
String::from("n1"), |
||||
10, |
||||
)) |
||||
.unwrap(); |
||||
|
||||
let _idx = net.get_param_index_from_custom_parent_set( |
||||
&vec![params::StateType::Discrete(6)], |
||||
&BTreeSet::from([0]) |
||||
); |
||||
} |
@ -0,0 +1,648 @@ |
||||
#![allow(non_snake_case)] |
||||
|
||||
mod utils; |
||||
use ndarray::arr3; |
||||
use reCTBN::process::ctbn::*; |
||||
use reCTBN::process::NetworkProcess; |
||||
use reCTBN::parameter_learning::*; |
||||
use reCTBN::params; |
||||
use reCTBN::params::Params::DiscreteStatesContinousTime; |
||||
use reCTBN::tools::*; |
||||
use utils::*; |
||||
|
||||
extern crate approx; |
||||
use crate::approx::AbsDiffEq; |
||||
|
||||
fn learn_binary_cim<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 1.0], |
||||
[4.0, -4.0] |
||||
], |
||||
[ |
||||
[-6.0, 6.0], |
||||
[2.0, -2.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); |
||||
let p = match pl.fit(&net, &data, 1, None) { |
||||
params::Params::DiscreteStatesContinousTime(p) => p, |
||||
}; |
||||
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); |
||||
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&arr3(&[ |
||||
[ |
||||
[-1.0, 1.0], |
||||
[4.0, -4.0] |
||||
], |
||||
[ |
||||
[-6.0, 6.0], |
||||
[2.0, -2.0] |
||||
], |
||||
]), |
||||
0.1 |
||||
)); |
||||
} |
||||
|
||||
fn generate_nodes( |
||||
net: &mut CtbnNetwork, |
||||
nodes_cardinality: usize, |
||||
nodes_domain_cardinality: usize |
||||
) { |
||||
for node_label in 0..nodes_cardinality { |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
node_label.to_string(), |
||||
nodes_domain_cardinality, |
||||
) |
||||
).unwrap(); |
||||
} |
||||
} |
||||
|
||||
fn learn_binary_cim_gen<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
generate_nodes(&mut net, 2, 2); |
||||
|
||||
net.add_edge(0, 1); |
||||
|
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
1.0..6.0, |
||||
Some(6813071588535822) |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
|
||||
let p_gen = match net.get_node(1) { |
||||
DiscreteStatesContinousTime(p_gen) => p_gen, |
||||
}; |
||||
|
||||
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); |
||||
let p_tj = match pl.fit(&net, &data, 1, None) { |
||||
DiscreteStatesContinousTime(p_tj) => p_tj, |
||||
}; |
||||
|
||||
assert_eq!( |
||||
p_tj.get_cim().as_ref().unwrap().shape(), |
||||
p_gen.get_cim().as_ref().unwrap().shape() |
||||
); |
||||
assert!( |
||||
p_tj.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&p_gen.get_cim().as_ref().unwrap(), |
||||
0.1 |
||||
) |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_binary_cim_MLE() { |
||||
let mle = MLE {}; |
||||
learn_binary_cim(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_binary_cim_MLE_gen() { |
||||
let mle = MLE {}; |
||||
learn_binary_cim_gen(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_binary_cim_BA() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_binary_cim(ba); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_binary_cim_BA_gen() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_binary_cim_gen(ba); |
||||
} |
||||
|
||||
fn learn_ternary_cim<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-3.0, 2.0, 1.0], |
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.5], |
||||
[3.0, -4.0, 1.0], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 4.0], |
||||
[1.5, -2.0, 0.5], |
||||
[3.0, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.0, 0.1, 0.9], |
||||
[2.0, -2.5, 0.5], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
||||
let p = match pl.fit(&net, &data, 1, None) { |
||||
params::Params::DiscreteStatesContinousTime(p) => p, |
||||
}; |
||||
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); |
||||
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.5], |
||||
[3.0, -4.0, 1.0], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 4.0], |
||||
[1.5, -2.0, 0.5], |
||||
[3.0, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.0, 0.1, 0.9], |
||||
[2.0, -2.5, 0.5], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
]), |
||||
0.1 |
||||
)); |
||||
} |
||||
|
||||
fn learn_ternary_cim_gen<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
generate_nodes(&mut net, 2, 3); |
||||
|
||||
net.add_edge(0, 1); |
||||
|
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
4.0..6.0, |
||||
Some(6813071588535822) |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
|
||||
let p_gen = match net.get_node(1) { |
||||
DiscreteStatesContinousTime(p_gen) => p_gen, |
||||
}; |
||||
|
||||
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
||||
let p_tj = match pl.fit(&net, &data, 1, None) { |
||||
DiscreteStatesContinousTime(p_tj) => p_tj, |
||||
}; |
||||
|
||||
assert_eq!( |
||||
p_tj.get_cim().as_ref().unwrap().shape(), |
||||
p_gen.get_cim().as_ref().unwrap().shape() |
||||
); |
||||
assert!( |
||||
p_tj.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&p_gen.get_cim().as_ref().unwrap(), |
||||
0.1 |
||||
) |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_MLE() { |
||||
let mle = MLE {}; |
||||
learn_ternary_cim(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_MLE_gen() { |
||||
let mle = MLE {}; |
||||
learn_ternary_cim_gen(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_BA() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_ternary_cim(ba); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_BA_gen() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_ternary_cim_gen(ba); |
||||
} |
||||
|
||||
fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-3.0, 2.0, 1.0], |
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0] |
||||
] |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.5], |
||||
[3.0, -4.0, 1.0], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 4.0], |
||||
[1.5, -2.0, 0.5], |
||||
[3.0, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.0, 0.1, 0.9], |
||||
[2.0, -2.5, 0.5], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
||||
let p = match pl.fit(&net, &data, 0, None) { |
||||
params::Params::DiscreteStatesContinousTime(p) => p, |
||||
}; |
||||
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); |
||||
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&arr3(&[ |
||||
[ |
||||
[-3.0, 2.0, 1.0], |
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0] |
||||
], |
||||
]), |
||||
0.1 |
||||
)); |
||||
} |
||||
|
||||
fn learn_ternary_cim_no_parents_gen<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
generate_nodes(&mut net, 2, 3); |
||||
|
||||
net.add_edge(0, 1); |
||||
|
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
1.0..6.0, |
||||
Some(6813071588535822) |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
|
||||
let p_gen = match net.get_node(0) { |
||||
DiscreteStatesContinousTime(p_gen) => p_gen, |
||||
}; |
||||
|
||||
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
||||
let p_tj = match pl.fit(&net, &data, 0, None) { |
||||
DiscreteStatesContinousTime(p_tj) => p_tj, |
||||
}; |
||||
|
||||
assert_eq!( |
||||
p_tj.get_cim().as_ref().unwrap().shape(), |
||||
p_gen.get_cim().as_ref().unwrap().shape() |
||||
); |
||||
assert!( |
||||
p_tj.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&p_gen.get_cim().as_ref().unwrap(), |
||||
0.1 |
||||
) |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_no_parents_MLE() { |
||||
let mle = MLE {}; |
||||
learn_ternary_cim_no_parents(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_no_parents_MLE_gen() { |
||||
let mle = MLE {}; |
||||
learn_ternary_cim_no_parents_gen(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_no_parents_BA() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_ternary_cim_no_parents(ba); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_no_parents_BA_gen() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_ternary_cim_no_parents_gen(ba); |
||||
} |
||||
|
||||
fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) |
||||
.unwrap(); |
||||
|
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n1, n3); |
||||
net.add_edge(n2, n3); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-3.0, 2.0, 1.0], |
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.5], |
||||
[3.0, -4.0, 1.0], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 4.0], |
||||
[1.5, -2.0, 0.5], |
||||
[3.0, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.0, 0.1, 0.9], |
||||
[2.0, -2.5, 0.5], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n3) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.3, 0.2], |
||||
[0.5, -4.0, 2.5, 1.0], |
||||
[2.5, 0.5, -4.0, 1.0], |
||||
[0.7, 0.2, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 3.0, 1.0], |
||||
[1.5, -3.0, 0.5, 1.0], |
||||
[2.0, 1.3, -5.0, 1.7], |
||||
[2.5, 0.5, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.3, 0.3, 0.1, 0.9], |
||||
[1.4, -4.0, 0.5, 2.1], |
||||
[1.0, 1.5, -3.0, 0.5], |
||||
[0.4, 0.3, 0.1, -0.8] |
||||
], |
||||
[ |
||||
[-2.0, 1.0, 0.7, 0.3], |
||||
[1.3, -5.9, 2.7, 1.9], |
||||
[2.0, 1.5, -4.0, 0.5], |
||||
[0.2, 0.7, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 1.0, 2.0, 3.0], |
||||
[0.5, -3.0, 1.0, 1.5], |
||||
[1.4, 2.1, -4.3, 0.8], |
||||
[0.5, 1.0, 2.5, -4.0] |
||||
], |
||||
[ |
||||
[-1.3, 0.9, 0.3, 0.1], |
||||
[0.1, -1.3, 0.2, 1.0], |
||||
[0.5, 1.0, -3.0, 1.5], |
||||
[0.1, 0.4, 0.3, -0.8] |
||||
], |
||||
[ |
||||
[-2.0, 1.0, 0.6, 0.4], |
||||
[2.6, -7.1, 1.4, 3.1], |
||||
[5.0, 1.0, -8.0, 2.0], |
||||
[1.4, 0.4, 0.2, -2.0] |
||||
], |
||||
[ |
||||
[-3.0, 1.0, 1.5, 0.5], |
||||
[3.0, -6.0, 1.0, 2.0], |
||||
[0.3, 0.5, -1.9, 1.1], |
||||
[5.0, 1.0, 2.0, -8.0] |
||||
], |
||||
[ |
||||
[-2.6, 0.6, 0.2, 1.8], |
||||
[2.0, -6.0, 3.0, 1.0], |
||||
[0.1, 0.5, -1.3, 0.7], |
||||
[0.8, 0.6, 0.2, -1.6] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); |
||||
let p = match pl.fit(&net, &data, 2, None) { |
||||
params::Params::DiscreteStatesContinousTime(p) => p, |
||||
}; |
||||
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); |
||||
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.3, 0.2], |
||||
[0.5, -4.0, 2.5, 1.0], |
||||
[2.5, 0.5, -4.0, 1.0], |
||||
[0.7, 0.2, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 3.0, 1.0], |
||||
[1.5, -3.0, 0.5, 1.0], |
||||
[2.0, 1.3, -5.0, 1.7], |
||||
[2.5, 0.5, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.3, 0.3, 0.1, 0.9], |
||||
[1.4, -4.0, 0.5, 2.1], |
||||
[1.0, 1.5, -3.0, 0.5], |
||||
[0.4, 0.3, 0.1, -0.8] |
||||
], |
||||
[ |
||||
[-2.0, 1.0, 0.7, 0.3], |
||||
[1.3, -5.9, 2.7, 1.9], |
||||
[2.0, 1.5, -4.0, 0.5], |
||||
[0.2, 0.7, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 1.0, 2.0, 3.0], |
||||
[0.5, -3.0, 1.0, 1.5], |
||||
[1.4, 2.1, -4.3, 0.8], |
||||
[0.5, 1.0, 2.5, -4.0] |
||||
], |
||||
[ |
||||
[-1.3, 0.9, 0.3, 0.1], |
||||
[0.1, -1.3, 0.2, 1.0], |
||||
[0.5, 1.0, -3.0, 1.5], |
||||
[0.1, 0.4, 0.3, -0.8] |
||||
], |
||||
[ |
||||
[-2.0, 1.0, 0.6, 0.4], |
||||
[2.6, -7.1, 1.4, 3.1], |
||||
[5.0, 1.0, -8.0, 2.0], |
||||
[1.4, 0.4, 0.2, -2.0] |
||||
], |
||||
[ |
||||
[-3.0, 1.0, 1.5, 0.5], |
||||
[3.0, -6.0, 1.0, 2.0], |
||||
[0.3, 0.5, -1.9, 1.1], |
||||
[5.0, 1.0, 2.0, -8.0] |
||||
], |
||||
[ |
||||
[-2.6, 0.6, 0.2, 1.8], |
||||
[2.0, -6.0, 3.0, 1.0], |
||||
[0.1, 0.5, -1.3, 0.7], |
||||
[0.8, 0.6, 0.2, -1.6] |
||||
], |
||||
]), |
||||
0.2 |
||||
)); |
||||
} |
||||
|
||||
fn learn_mixed_discrete_cim_gen<T: ParameterLearning>(pl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
generate_nodes(&mut net, 2, 3); |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
String::from("3"), |
||||
4 |
||||
) |
||||
).unwrap(); |
||||
net.add_edge(0, 1); |
||||
net.add_edge(0, 2); |
||||
net.add_edge(1, 2); |
||||
|
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
1.0..8.0, |
||||
Some(6813071588535822) |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
|
||||
let p_gen = match net.get_node(2) { |
||||
DiscreteStatesContinousTime(p_gen) => p_gen, |
||||
}; |
||||
|
||||
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); |
||||
let p_tj = match pl.fit(&net, &data, 2, None) { |
||||
DiscreteStatesContinousTime(p_tj) => p_tj, |
||||
}; |
||||
|
||||
assert_eq!( |
||||
p_tj.get_cim().as_ref().unwrap().shape(), |
||||
p_gen.get_cim().as_ref().unwrap().shape() |
||||
); |
||||
assert!( |
||||
p_tj.get_cim().as_ref().unwrap().abs_diff_eq( |
||||
&p_gen.get_cim().as_ref().unwrap(), |
||||
0.2 |
||||
) |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_cim_MLE() { |
||||
let mle = MLE {}; |
||||
learn_mixed_discrete_cim(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_cim_MLE_gen() { |
||||
let mle = MLE {}; |
||||
learn_mixed_discrete_cim_gen(mle); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_cim_BA() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_mixed_discrete_cim(ba); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_cim_BA_gen() { |
||||
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; |
||||
learn_mixed_discrete_cim_gen(ba); |
||||
} |
@ -0,0 +1,148 @@ |
||||
use ndarray::prelude::*; |
||||
use rand_chacha::rand_core::SeedableRng; |
||||
use rand_chacha::ChaCha8Rng; |
||||
use reCTBN::params::{ParamsTrait, *}; |
||||
|
||||
mod utils; |
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { |
||||
#![allow(unused_must_use)] |
||||
let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3); |
||||
|
||||
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; |
||||
|
||||
params.set_cim(cim); |
||||
params |
||||
} |
||||
|
||||
#[test] |
||||
fn test_get_label() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
assert_eq!(&String::from("A"), param.get_label()) |
||||
} |
||||
|
||||
#[test] |
||||
fn test_uniform_generation() { |
||||
#![allow(irrefutable_let_patterns)] |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<usize>::zeros(10000); |
||||
|
||||
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); |
||||
|
||||
states.mapv_inplace(|_| { |
||||
if let StateType::Discrete(val) = param.get_random_state_uniform(&mut rng) { |
||||
val |
||||
} else { |
||||
panic!() |
||||
} |
||||
}); |
||||
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||
|
||||
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_random_generation_state() { |
||||
#![allow(irrefutable_let_patterns)] |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<usize>::zeros(10000); |
||||
|
||||
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); |
||||
|
||||
states.mapv_inplace(|_| { |
||||
if let StateType::Discrete(val) = param.get_random_state(1, 0, &mut rng).unwrap() { |
||||
val |
||||
} else { |
||||
panic!() |
||||
} |
||||
}); |
||||
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; |
||||
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||
|
||||
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01); |
||||
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_random_generation_residence_time() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<f64>::zeros(10000); |
||||
|
||||
let mut rng = ChaCha8Rng::seed_from_u64(6347747169756259); |
||||
|
||||
states.mapv_inplace(|_| param.get_random_residence_time(1, 0, &mut rng).unwrap()); |
||||
|
||||
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_validate_params_valid_cim() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
|
||||
assert_eq!(Ok(()), param.validate_params()); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_validate_params_valid_cim_with_huge_values() { |
||||
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); |
||||
let cim = array![[ |
||||
[-2e10, 1e10, 1e10], |
||||
[1.5e10, -3e10, 1.5e10], |
||||
[1e10, 1e10, -2e10] |
||||
]]; |
||||
let result = param.set_cim(cim); |
||||
assert_eq!(Ok(()), result); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_validate_params_cim_not_initialized() { |
||||
let param = utils::generate_discrete_time_continous_params("A".to_string(), 3); |
||||
assert_eq!( |
||||
Err(ParamsError::ParametersNotInitialized(String::from( |
||||
"CIM not initialized", |
||||
))), |
||||
param.validate_params() |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_validate_params_wrong_shape() { |
||||
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4); |
||||
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; |
||||
let result = param.set_cim(cim); |
||||
assert_eq!( |
||||
Err(ParamsError::InvalidCIM(String::from( |
||||
"Incompatible shape [1, 3, 3] with domain 4" |
||||
))), |
||||
result |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_validate_params_positive_diag() { |
||||
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); |
||||
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; |
||||
let result = param.set_cim(cim); |
||||
assert_eq!( |
||||
Err(ParamsError::InvalidCIM(String::from( |
||||
"The diagonal of each cim must be non-positive", |
||||
))), |
||||
result |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_validate_params_row_not_sum_to_zero() { |
||||
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); |
||||
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; |
||||
let result = param.set_cim(cim); |
||||
assert_eq!( |
||||
Err(ParamsError::InvalidCIM(String::from( |
||||
"The sum of each row must be 0" |
||||
))), |
||||
result |
||||
); |
||||
} |
@ -0,0 +1,122 @@ |
||||
mod utils; |
||||
|
||||
use approx::assert_abs_diff_eq; |
||||
use ndarray::*; |
||||
use reCTBN::{ |
||||
params, |
||||
process::{ctbn::*, NetworkProcess, NetworkProcessState}, |
||||
reward::{reward_evaluation::*, reward_function::*, *}, |
||||
}; |
||||
use utils::generate_discrete_time_continous_node; |
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_binary_node_mc() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1) |
||||
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1) |
||||
.assign(&arr1(&[3.0, 3.0])); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); |
||||
} |
||||
} |
||||
|
||||
net.initialize_adj_matrix(); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
|
||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||
|
||||
let rst = mc.evaluate_state_space(&net, &rf); |
||||
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); |
||||
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); |
||||
|
||||
|
||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); |
||||
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||
|
||||
|
||||
} |
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_chain_mc() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
|
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||
.unwrap(); |
||||
|
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n2, n3); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param |
||||
.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]], |
||||
])) |
||||
.unwrap(); |
||||
} |
||||
} |
||||
|
||||
|
||||
match &mut net.get_node_mut(n3) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param |
||||
.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]], |
||||
])) |
||||
.unwrap(); |
||||
} |
||||
} |
||||
|
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1) |
||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||
|
||||
rf.get_transition_reward_mut(n2) |
||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||
|
||||
rf.get_transition_reward_mut(n3) |
||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||
|
||||
let s000: NetworkProcessState = vec![ |
||||
params::StateType::Discrete(1), |
||||
params::StateType::Discrete(0), |
||||
params::StateType::Discrete(0), |
||||
]; |
||||
|
||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); |
||||
|
||||
let rst = mc.evaluate_state_space(&net, &rf); |
||||
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); |
||||
|
||||
} |
@ -0,0 +1,117 @@ |
||||
mod utils; |
||||
|
||||
use ndarray::*; |
||||
use utils::generate_discrete_time_continous_node; |
||||
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params}; |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_binary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
|
||||
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_ternary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; |
||||
|
||||
|
||||
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); |
||||
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); |
||||
} |
||||
|
||||
#[test] |
||||
fn factored_reward_function_two_nodes() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
|
||||
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); |
||||
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; |
||||
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; |
||||
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; |
||||
|
||||
|
||||
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; |
||||
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; |
||||
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; |
||||
|
||||
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); |
||||
} |
@ -0,0 +1,692 @@ |
||||
#![allow(non_snake_case)] |
||||
|
||||
mod utils; |
||||
use std::collections::BTreeSet; |
||||
|
||||
use ndarray::{arr1, arr2, arr3}; |
||||
use reCTBN::process::ctbn::*; |
||||
use reCTBN::process::NetworkProcess; |
||||
use reCTBN::parameter_learning::BayesianApproach; |
||||
use reCTBN::params; |
||||
use reCTBN::structure_learning::hypothesis_test::*; |
||||
use reCTBN::structure_learning::constraint_based_algorithm::*; |
||||
use reCTBN::structure_learning::score_based_algorithm::*; |
||||
use reCTBN::structure_learning::score_function::*; |
||||
use reCTBN::structure_learning::StructureLearningAlgorithm; |
||||
use reCTBN::tools::*; |
||||
use utils::*; |
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
#[test] |
||||
fn simple_score_test() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); |
||||
|
||||
let dataset = Dataset::new(vec![trj]); |
||||
|
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
|
||||
assert_abs_diff_eq!( |
||||
0.04257, |
||||
ll.call(&net, n1, &BTreeSet::new(), &dataset), |
||||
epsilon = 1e-3 |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn simple_bic() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); |
||||
|
||||
let dataset = Dataset::new(vec![trj]); |
||||
let bic = BIC::new(1, 1.0); |
||||
|
||||
assert_abs_diff_eq!( |
||||
-0.65058, |
||||
bic.call(&net, n1, &BTreeSet::new(), &dataset), |
||||
epsilon = 1e-3 |
||||
); |
||||
} |
||||
|
||||
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-3.0, 2.0, 1.0], |
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.5], |
||||
[3.0, -4.0, 1.0], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 4.0], |
||||
[1.5, -2.0, 0.5], |
||||
[3.0, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.0, 0.1, 0.9], |
||||
[2.0, -2.5, 0.5], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); |
||||
|
||||
let mut net = CtbnNetwork::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let _net = sl.fit_transform(net, &data); |
||||
} |
||||
|
||||
fn generate_nodes( |
||||
net: &mut CtbnNetwork, |
||||
nodes_cardinality: usize, |
||||
nodes_domain_cardinality: usize |
||||
) { |
||||
for node_label in 0..nodes_cardinality { |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
node_label.to_string(), |
||||
nodes_domain_cardinality, |
||||
) |
||||
).unwrap(); |
||||
} |
||||
} |
||||
|
||||
fn check_compatibility_between_dataset_and_network_gen<T: StructureLearningAlgorithm>(sl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
generate_nodes(&mut net, 2, 3); |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
String::from("3"), |
||||
4 |
||||
) |
||||
).unwrap(); |
||||
|
||||
net.add_edge(0, 1); |
||||
|
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
0.0..7.0, |
||||
Some(6813071588535822) |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
|
||||
let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259)); |
||||
|
||||
let mut net = CtbnNetwork::new(); |
||||
let _n1 = net |
||||
.add_node( |
||||
generate_discrete_time_continous_node(String::from("0"), |
||||
3) |
||||
).unwrap(); |
||||
let _net = sl.fit_transform(net, &data); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
pub fn check_compatibility_between_dataset_and_network_hill_climbing() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, None); |
||||
check_compatibility_between_dataset_and_network(hl); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, None); |
||||
check_compatibility_between_dataset_and_network_gen(hl); |
||||
} |
||||
|
||||
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-3.0, 2.0, 1.0], |
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.5], |
||||
[3.0, -4.0, 1.0], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 4.0], |
||||
[1.5, -2.0, 0.5], |
||||
[3.0, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.0, 0.1, 0.9], |
||||
[2.0, -2.5, 0.5], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); |
||||
|
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); |
||||
} |
||||
|
||||
fn learn_ternary_net_2_nodes_gen<T: StructureLearningAlgorithm>(sl: T) { |
||||
let mut net = CtbnNetwork::new(); |
||||
generate_nodes(&mut net, 2, 3); |
||||
|
||||
net.add_edge(0, 1); |
||||
|
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
0.0..7.0, |
||||
Some(6813071588535822) |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
|
||||
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); |
||||
|
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, None); |
||||
learn_ternary_net_2_nodes(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, None); |
||||
learn_ternary_net_2_nodes_gen(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { |
||||
let bic = BIC::new(1, 1.0); |
||||
let hl = HillClimbing::new(bic, None); |
||||
learn_ternary_net_2_nodes(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() { |
||||
let bic = BIC::new(1, 1.0); |
||||
let hl = HillClimbing::new(bic, None); |
||||
learn_ternary_net_2_nodes_gen(hl); |
||||
} |
||||
|
||||
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) |
||||
.unwrap(); |
||||
|
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n1, n3); |
||||
net.add_edge(n2, n3); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-3.0, 2.0, 1.0], |
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.5], |
||||
[3.0, -4.0, 1.0], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 4.0], |
||||
[1.5, -2.0, 0.5], |
||||
[3.0, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.0, 0.1, 0.9], |
||||
[2.0, -2.5, 0.5], |
||||
[0.9, 0.1, -1.0] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n3) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
assert_eq!( |
||||
Ok(()), |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 0.5, 0.3, 0.2], |
||||
[0.5, -4.0, 2.5, 1.0], |
||||
[2.5, 0.5, -4.0, 1.0], |
||||
[0.7, 0.2, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 2.0, 3.0, 1.0], |
||||
[1.5, -3.0, 0.5, 1.0], |
||||
[2.0, 1.3, -5.0, 1.7], |
||||
[2.5, 0.5, 1.0, -4.0] |
||||
], |
||||
[ |
||||
[-1.3, 0.3, 0.1, 0.9], |
||||
[1.4, -4.0, 0.5, 2.1], |
||||
[1.0, 1.5, -3.0, 0.5], |
||||
[0.4, 0.3, 0.1, -0.8] |
||||
], |
||||
[ |
||||
[-2.0, 1.0, 0.7, 0.3], |
||||
[1.3, -5.9, 2.7, 1.9], |
||||
[2.0, 1.5, -4.0, 0.5], |
||||
[0.2, 0.7, 0.1, -1.0] |
||||
], |
||||
[ |
||||
[-6.0, 1.0, 2.0, 3.0], |
||||
[0.5, -3.0, 1.0, 1.5], |
||||
[1.4, 2.1, -4.3, 0.8], |
||||
[0.5, 1.0, 2.5, -4.0] |
||||
], |
||||
[ |
||||
[-1.3, 0.9, 0.3, 0.1], |
||||
[0.1, -1.3, 0.2, 1.0], |
||||
[0.5, 1.0, -3.0, 1.5], |
||||
[0.1, 0.4, 0.3, -0.8] |
||||
], |
||||
[ |
||||
[-2.0, 1.0, 0.6, 0.4], |
||||
[2.6, -7.1, 1.4, 3.1], |
||||
[5.0, 1.0, -8.0, 2.0], |
||||
[1.4, 0.4, 0.2, -2.0] |
||||
], |
||||
[ |
||||
[-3.0, 1.0, 1.5, 0.5], |
||||
[3.0, -6.0, 1.0, 2.0], |
||||
[0.3, 0.5, -1.9, 1.1], |
||||
[5.0, 1.0, 2.0, -8.0] |
||||
], |
||||
[ |
||||
[-2.6, 0.6, 0.2, 1.8], |
||||
[2.0, -6.0, 3.0, 1.0], |
||||
[0.1, 0.5, -1.3, 0.7], |
||||
[0.8, 0.6, 0.2, -1.6] |
||||
], |
||||
])) |
||||
); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); |
||||
return (net, data); |
||||
} |
||||
|
||||
fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) { |
||||
let mut net = CtbnNetwork::new(); |
||||
generate_nodes(&mut net, 2, 3); |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
String::from("3"), |
||||
4 |
||||
) |
||||
).unwrap(); |
||||
|
||||
net.add_edge(0, 1); |
||||
net.add_edge(0, 2); |
||||
net.add_edge(1, 2); |
||||
|
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
0.0..7.0, |
||||
Some(6813071588535822) |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
|
||||
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); |
||||
return (net, data); |
||||
} |
||||
|
||||
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) { |
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); |
||||
} |
||||
|
||||
fn learn_mixed_discrete_net_3_nodes_gen<T: StructureLearningAlgorithm>(sl: T) { |
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); |
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, None); |
||||
learn_mixed_discrete_net_3_nodes(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, None); |
||||
learn_mixed_discrete_net_3_nodes_gen(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { |
||||
let bic = BIC::new(1, 1.0); |
||||
let hl = HillClimbing::new(bic, None); |
||||
learn_mixed_discrete_net_3_nodes(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() { |
||||
let bic = BIC::new(1, 1.0); |
||||
let hl = HillClimbing::new(bic, None); |
||||
learn_mixed_discrete_net_3_nodes_gen(hl); |
||||
} |
||||
|
||||
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) { |
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); |
||||
} |
||||
|
||||
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen<T: StructureLearningAlgorithm>(sl: T) { |
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen(); |
||||
let net = sl.fit_transform(net, &data); |
||||
assert_eq!(BTreeSet::new(), net.get_parent_set(0)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1)); |
||||
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, Some(1)); |
||||
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() { |
||||
let ll = LogLikelihood::new(1, 1.0); |
||||
let hl = HillClimbing::new(ll, Some(1)); |
||||
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { |
||||
let bic = BIC::new(1, 1.0); |
||||
let hl = HillClimbing::new(bic, Some(1)); |
||||
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() { |
||||
let bic = BIC::new(1, 1.0); |
||||
let hl = HillClimbing::new(bic, Some(1)); |
||||
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn chi_square_compare_matrices() { |
||||
let i: usize = 1; |
||||
let M1 = arr3(&[ |
||||
[ |
||||
[ 0, 2, 3], |
||||
[ 4, 0, 6], |
||||
[ 7, 8, 0] |
||||
], |
||||
[ |
||||
[0, 12, 90], |
||||
[ 3, 0, 40], |
||||
[ 6, 40, 0] |
||||
], |
||||
[ |
||||
[ 0, 2, 3], |
||||
[ 4, 0, 6], |
||||
[ 44, 66, 0] |
||||
], |
||||
]); |
||||
let j: usize = 0; |
||||
let M2 = arr3(&[ |
||||
[ |
||||
[ 0, 200, 300], |
||||
[ 400, 0, 600], |
||||
[ 700, 800, 0] |
||||
], |
||||
]); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn chi_square_compare_matrices_2() { |
||||
let i: usize = 1; |
||||
let M1 = arr3(&[ |
||||
[ |
||||
[ 0, 2, 3], |
||||
[ 4, 0, 6], |
||||
[ 7, 8, 0] |
||||
], |
||||
[ |
||||
[0, 20, 30], |
||||
[ 40, 0, 60], |
||||
[ 70, 80, 0] |
||||
], |
||||
[ |
||||
[ 0, 2, 3], |
||||
[ 4, 0, 6], |
||||
[ 44, 66, 0] |
||||
], |
||||
]); |
||||
let j: usize = 0; |
||||
let M2 = arr3(&[ |
||||
[[ 0, 200, 300], |
||||
[ 400, 0, 600], |
||||
[ 700, 800, 0]] |
||||
]); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn chi_square_compare_matrices_3() { |
||||
let i: usize = 1; |
||||
let M1 = arr3(&[ |
||||
[ |
||||
[ 0, 2, 3], |
||||
[ 4, 0, 6], |
||||
[ 7, 8, 0] |
||||
], |
||||
[ |
||||
[0, 21, 31], |
||||
[ 41, 0, 59], |
||||
[ 71, 79, 0] |
||||
], |
||||
[ |
||||
[ 0, 2, 3], |
||||
[ 4, 0, 6], |
||||
[ 44, 66, 0] |
||||
], |
||||
]); |
||||
let j: usize = 0; |
||||
let M2 = arr3(&[ |
||||
[ |
||||
[ 0, 200, 300], |
||||
[ 400, 0, 600], |
||||
[ 700, 800, 0] |
||||
], |
||||
]); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
pub fn chi_square_call() { |
||||
|
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||
let N3: usize = 2; |
||||
let N2: usize = 1; |
||||
let N1: usize = 0; |
||||
let mut separation_set = BTreeSet::new(); |
||||
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; |
||||
let mut cache = Cache::new(¶meter_learning); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
|
||||
assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache)); |
||||
let mut cache = Cache::new(¶meter_learning); |
||||
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache)); |
||||
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache)); |
||||
separation_set.insert(N1); |
||||
let mut cache = Cache::new(¶meter_learning); |
||||
assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache)); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn f_call() { |
||||
|
||||
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); |
||||
let N3: usize = 2; |
||||
let N2: usize = 1; |
||||
let N1: usize = 0; |
||||
let mut separation_set = BTreeSet::new(); |
||||
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; |
||||
let mut cache = Cache::new(¶meter_learning); |
||||
let f = F::new(1e-6); |
||||
|
||||
|
||||
assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache)); |
||||
let mut cache = Cache::new(¶meter_learning); |
||||
assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache)); |
||||
assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache)); |
||||
separation_set.insert(N1); |
||||
let mut cache = Cache::new(¶meter_learning); |
||||
assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache)); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_ctpc() { |
||||
let f = F::new(1e-6); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; |
||||
let ctpc = CTPC::new(parameter_learning, f, chi_sq); |
||||
learn_ternary_net_2_nodes(ctpc); |
||||
} |
||||
|
||||
#[test] |
||||
pub fn learn_ternary_net_2_nodes_ctpc_gen() { |
||||
let f = F::new(1e-6); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; |
||||
let ctpc = CTPC::new(parameter_learning, f, chi_sq); |
||||
learn_ternary_net_2_nodes_gen(ctpc); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_net_3_nodes_ctpc() { |
||||
let f = F::new(1e-6); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; |
||||
let ctpc = CTPC::new(parameter_learning, f, chi_sq); |
||||
learn_mixed_discrete_net_3_nodes(ctpc); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_net_3_nodes_ctpc_gen() { |
||||
let f = F::new(1e-6); |
||||
let chi_sq = ChiSquare::new(1e-4); |
||||
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 }; |
||||
let ctpc = CTPC::new(parameter_learning, f, chi_sq); |
||||
learn_mixed_discrete_net_3_nodes_gen(ctpc); |
||||
} |
@ -0,0 +1,251 @@ |
||||
use std::ops::Range; |
||||
|
||||
use ndarray::{arr1, arr2, arr3}; |
||||
use reCTBN::params::ParamsTrait; |
||||
use reCTBN::process::ctbn::*; |
||||
use reCTBN::process::ctmp::*; |
||||
use reCTBN::process::NetworkProcess; |
||||
use reCTBN::params; |
||||
use reCTBN::tools::*; |
||||
|
||||
use utils::*; |
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
mod utils; |
||||
|
||||
#[test] |
||||
fn run_sampling() { |
||||
#![allow(unused_must_use)] |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(utils::generate_discrete_time_continous_node( |
||||
String::from("n1"), |
||||
2, |
||||
)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(utils::generate_discrete_time_continous_node( |
||||
String::from("n2"), |
||||
2, |
||||
)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-3.0, 3.0], |
||||
[2.0, -2.0] |
||||
], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.set_cim(arr3(&[ |
||||
[ |
||||
[-1.0, 1.0], |
||||
[4.0, -4.0] |
||||
], |
||||
[ |
||||
[-6.0, 6.0], |
||||
[2.0, -2.0] |
||||
], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259)); |
||||
|
||||
assert_eq!(4, data.get_trajectories().len()); |
||||
assert_relative_eq!( |
||||
1.0, |
||||
data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1] |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn trajectory_wrong_shape() { |
||||
let time = arr1(&[0.0, 0.2]); |
||||
let events = arr2(&[[0, 3]]); |
||||
Trajectory::new(time, events); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn dataset_wrong_shape() { |
||||
let time = arr1(&[0.0, 0.2]); |
||||
let events = arr2(&[[0, 3], [1, 2]]); |
||||
let t1 = Trajectory::new(time, events); |
||||
|
||||
let time = arr1(&[0.0, 0.2]); |
||||
let events = arr2(&[[0, 3, 3], [1, 2, 3]]); |
||||
let t2 = Trajectory::new(time, events); |
||||
Dataset::new(vec![t1, t2]); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn uniform_graph_generator_wrong_density_1() { |
||||
let density = 2.1; |
||||
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( |
||||
density, |
||||
None |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn uniform_graph_generator_wrong_density_2() { |
||||
let density = -0.5; |
||||
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( |
||||
density, |
||||
None |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn uniform_graph_generator_right_densities() { |
||||
for density in [1.0, 0.75, 0.5, 0.25, 0.0] { |
||||
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( |
||||
density, |
||||
None |
||||
); |
||||
} |
||||
} |
||||
|
||||
#[test] |
||||
fn uniform_graph_generator_generate_graph_ctbn() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let nodes_cardinality = 0..=100; |
||||
let nodes_domain_cardinality = 2; |
||||
for node_label in nodes_cardinality { |
||||
net.add_node( |
||||
utils::generate_discrete_time_continous_node( |
||||
node_label.to_string(), |
||||
nodes_domain_cardinality, |
||||
) |
||||
).unwrap(); |
||||
} |
||||
let density = 1.0/3.0; |
||||
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( |
||||
density, |
||||
Some(7641630759785120) |
||||
); |
||||
structure_generator.generate_graph(&mut net); |
||||
let mut edges = 0; |
||||
for node in net.get_node_indices(){ |
||||
edges += net.get_children_set(node).len() |
||||
} |
||||
let nodes = net.get_node_indices().len() as f64; |
||||
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize; |
||||
let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance
|
||||
// As the way `generate_graph()` is implemented we can only reasonably
|
||||
// expect the number of edges to be somewhere around the expected value.
|
||||
assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance)); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn uniform_graph_generator_generate_graph_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let node_label = String::from("0"); |
||||
let node_domain_cardinality = 4; |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
node_label, |
||||
node_domain_cardinality |
||||
) |
||||
).unwrap(); |
||||
let density = 1.0/3.0; |
||||
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( |
||||
density, |
||||
Some(7641630759785120) |
||||
); |
||||
structure_generator.generate_graph(&mut net); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn uniform_parameters_generator_wrong_density_1() { |
||||
let interval: Range<f64> = -2.0..-5.0; |
||||
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
interval, |
||||
None |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn uniform_parameters_generator_wrong_density_2() { |
||||
let interval: Range<f64> = -1.0..0.0; |
||||
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
interval, |
||||
None |
||||
); |
||||
} |
||||
|
||||
#[test] |
||||
fn uniform_parameters_generator_right_densities_ctbn() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let nodes_cardinality = 0..=3; |
||||
let nodes_domain_cardinality = 9; |
||||
for node_label in nodes_cardinality { |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
node_label.to_string(), |
||||
nodes_domain_cardinality, |
||||
) |
||||
).unwrap(); |
||||
} |
||||
let density = 1.0/3.0; |
||||
let seed = Some(7641630759785120); |
||||
let interval = 0.0..7.0; |
||||
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new( |
||||
density, |
||||
seed |
||||
); |
||||
structure_generator.generate_graph(&mut net); |
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
interval, |
||||
seed |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
for node in net.get_node_indices() { |
||||
assert_eq!( |
||||
Ok(()), |
||||
net.get_node(node).validate_params() |
||||
); |
||||
} |
||||
} |
||||
|
||||
#[test] |
||||
fn uniform_parameters_generator_right_densities_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let node_label = String::from("0"); |
||||
let node_domain_cardinality = 4; |
||||
net.add_node( |
||||
generate_discrete_time_continous_node( |
||||
node_label, |
||||
node_domain_cardinality |
||||
) |
||||
).unwrap(); |
||||
let seed = Some(7641630759785120); |
||||
let interval = 0.0..7.0; |
||||
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new( |
||||
interval, |
||||
seed |
||||
); |
||||
cim_generator.generate_parameters(&mut net); |
||||
for node in net.get_node_indices() { |
||||
assert_eq!( |
||||
Ok(()), |
||||
net.get_node(node).validate_params() |
||||
); |
||||
} |
||||
} |
@ -0,0 +1,19 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use reCTBN::params; |
||||
|
||||
#[allow(dead_code)] |
||||
pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { |
||||
params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params( |
||||
label, |
||||
cardinality, |
||||
)) |
||||
} |
||||
|
||||
pub fn generate_discrete_time_continous_params( |
||||
label: String, |
||||
cardinality: usize, |
||||
) -> params::DiscreteStatesContinousTimeParams { |
||||
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); |
||||
params::DiscreteStatesContinousTimeParams::new(label, domain) |
||||
} |
@ -0,0 +1,7 @@ |
||||
# This file defines the Rust toolchain to use when a command is executed. |
||||
# See also https://rust-lang.github.io/rustup/overrides.html |
||||
|
||||
[toolchain] |
||||
channel = "stable" |
||||
components = [ "clippy", "rustfmt" ] |
||||
profile = "minimal" |
@ -0,0 +1,39 @@ |
||||
# This file defines the Rust style for automatic reformatting. |
||||
# See also https://rust-lang.github.io/rustfmt |
||||
|
||||
# NOTE: the unstable options will be uncommented when stabilized. |
||||
|
||||
# Version of the formatting rules to use. |
||||
#version = "One" |
||||
|
||||
# Number of spaces per tab. |
||||
tab_spaces = 4 |
||||
|
||||
max_width = 100 |
||||
#comment_width = 80 |
||||
|
||||
# Prevent carriage returns, admitted only \n. |
||||
newline_style = "Unix" |
||||
|
||||
# The "Default" setting has a heuristic which can split lines too aggresively. |
||||
#use_small_heuristics = "Max" |
||||
|
||||
# How imports should be grouped into `use` statements. |
||||
#imports_granularity = "Module" |
||||
|
||||
# How consecutive imports are grouped together. |
||||
#group_imports = "StdExternalCrate" |
||||
|
||||
# Error if unable to get all lines within max_width, except for comments and |
||||
# string literals. |
||||
#error_on_line_overflow = true |
||||
|
||||
# Error if unable to get comments or string literals within max_width, or they |
||||
# are left with trailing whitespaces. |
||||
#error_on_unformatted = true |
||||
|
||||
# Files to ignore like third party code which is formatted upstream. |
||||
# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors |
||||
ignore = [ |
||||
"tests/" |
||||
] |
@ -1,164 +0,0 @@ |
||||
use ndarray::prelude::*; |
||||
use crate::node; |
||||
use crate::params::{StateType, ParamsTrait}; |
||||
use crate::network; |
||||
use std::collections::BTreeSet; |
||||
|
||||
|
||||
|
||||
|
||||
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
|
||||
///composed by the following elements:
|
||||
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
|
||||
///- **nodes**: a vector containing all the nodes and their parameters.
|
||||
///The index of a node inside the vector is also used as index for the adj_matrix.
|
||||
///
|
||||
///# Examples
|
||||
///
|
||||
///```
|
||||
///
|
||||
/// use std::collections::BTreeSet;
|
||||
/// use rustyCTBN::network::Network;
|
||||
/// use rustyCTBN::node;
|
||||
/// use rustyCTBN::params;
|
||||
/// use rustyCTBN::ctbn::*;
|
||||
///
|
||||
/// //Create the domain for a discrete node
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
///
|
||||
/// //Create the parameters for a discrete node using the domain
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
|
||||
///
|
||||
/// //Create the node using the parameters
|
||||
/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
|
||||
///
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
|
||||
/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
|
||||
///
|
||||
/// //Initialize a ctbn
|
||||
/// let mut net = CtbnNetwork::init();
|
||||
///
|
||||
/// //Add nodes
|
||||
/// let X1 = net.add_node(X1).unwrap();
|
||||
/// let X2 = net.add_node(X2).unwrap();
|
||||
///
|
||||
/// //Add an edge
|
||||
/// net.add_edge(X1, X2);
|
||||
///
|
||||
/// //Get all the children of node X1
|
||||
/// let cs = net.get_children_set(X1);
|
||||
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
||||
/// ```
|
||||
pub struct CtbnNetwork { |
||||
adj_matrix: Option<Array2<u16>>, |
||||
nodes: Vec<node::Node> |
||||
} |
||||
|
||||
|
||||
impl CtbnNetwork { |
||||
pub fn init() -> CtbnNetwork { |
||||
CtbnNetwork { |
||||
adj_matrix: None, |
||||
nodes: Vec::new() |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl network::Network for CtbnNetwork { |
||||
fn initialize_adj_matrix(&mut self) { |
||||
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f())); |
||||
|
||||
} |
||||
|
||||
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> { |
||||
n.params.reset_params(); |
||||
self.adj_matrix = Option::None; |
||||
self.nodes.push(n); |
||||
Ok(self.nodes.len() -1)
|
||||
} |
||||
|
||||
fn add_edge(&mut self, parent: usize, child: usize) { |
||||
if let None = self.adj_matrix { |
||||
self.initialize_adj_matrix(); |
||||
} |
||||
|
||||
if let Some(network) = &mut self.adj_matrix { |
||||
network[[parent, child]] = 1; |
||||
self.nodes[child].params.reset_params(); |
||||
} |
||||
} |
||||
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize>{ |
||||
0..self.nodes.len() |
||||
} |
||||
|
||||
fn get_number_of_nodes(&self) -> usize { |
||||
self.nodes.len() |
||||
} |
||||
|
||||
fn get_node(&self, node_idx: usize) -> &node::Node{ |
||||
&self.nodes[node_idx] |
||||
} |
||||
|
||||
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{ |
||||
&mut self.nodes[node_idx] |
||||
} |
||||
|
||||
|
||||
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{ |
||||
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { |
||||
if x.1 > &0 { |
||||
acc.0 += self.nodes[x.0].params.state_to_index(¤t_state[x.0]) * acc.1; |
||||
acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent(); |
||||
} |
||||
acc |
||||
}).0 |
||||
} |
||||
|
||||
|
||||
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize { |
||||
parent_set.iter().fold((0, 1), |mut acc, x| { |
||||
acc.0 += self.nodes[*x].params.state_to_index(¤t_state[*x]) * acc.1; |
||||
acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent(); |
||||
acc |
||||
}).0 |
||||
} |
||||
|
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.adj_matrix.as_ref() |
||||
.unwrap() |
||||
.column(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| { |
||||
if x > &0 { |
||||
Some(idx) |
||||
} else { |
||||
None |
||||
} |
||||
}).collect() |
||||
} |
||||
|
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize>{ |
||||
self.adj_matrix.as_ref() |
||||
.unwrap() |
||||
.row(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| { |
||||
if x > &0 { |
||||
Some(idx) |
||||
} else { |
||||
None |
||||
} |
||||
}).collect() |
||||
} |
||||
|
||||
} |
||||
|
@ -1,11 +0,0 @@ |
||||
#[cfg(test)] |
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
pub mod node; |
||||
pub mod params; |
||||
pub mod network; |
||||
pub mod ctbn; |
||||
pub mod tools; |
||||
pub mod parameter_learning; |
||||
|
@ -1,39 +0,0 @@ |
||||
use thiserror::Error; |
||||
use crate::params; |
||||
use crate::node; |
||||
use std::collections::BTreeSet; |
||||
|
||||
/// Error types for trait Network
|
||||
#[derive(Error, Debug)] |
||||
pub enum NetworkError { |
||||
#[error("Error during node insertion")] |
||||
NodeInsertionError(String) |
||||
} |
||||
|
||||
|
||||
///Network
|
||||
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
|
||||
pub trait Network { |
||||
fn initialize_adj_matrix(&mut self); |
||||
fn add_node(&mut self, n: node::Node) -> Result<usize, NetworkError>; |
||||
fn add_edge(&mut self, parent: usize, child: usize); |
||||
|
||||
///Get all the indices of the nodes contained inside the network
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize>; |
||||
fn get_number_of_nodes(&self) -> usize; |
||||
fn get_node(&self, node_idx: usize) -> &node::Node; |
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node; |
||||
|
||||
///Compute the index that must be used to access the parameters of a node given a specific
|
||||
///configuration of the network. Usually, the only values really used in *current_state* are
|
||||
///the ones in the parent set of the *node*.
|
||||
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize; |
||||
|
||||
|
||||
///Compute the index that must be used to access the parameters of a node given a specific
|
||||
///configuration of the network and a generic parent_set. Usually, the only values really used
|
||||
///in *current_state* are the ones in the parent set of the *node*.
|
||||
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<params::StateType>, parent_set: &BTreeSet<usize>) -> usize; |
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
||||
} |
@ -1,25 +0,0 @@ |
||||
use crate::params::*; |
||||
|
||||
|
||||
pub struct Node { |
||||
pub params: Params, |
||||
pub label: String |
||||
} |
||||
|
||||
impl Node { |
||||
pub fn init(params: Params, label: String) -> Node { |
||||
Node{ |
||||
params: params, |
||||
label:label |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
impl PartialEq for Node { |
||||
fn eq(&self, other: &Node) -> bool{ |
||||
self.label == other.label |
||||
} |
||||
} |
||||
|
||||
|
@ -1,161 +0,0 @@ |
||||
use ndarray::prelude::*; |
||||
use rand::Rng; |
||||
use std::collections::{BTreeSet, HashMap}; |
||||
use thiserror::Error; |
||||
use enum_dispatch::enum_dispatch; |
||||
|
||||
/// Error types for trait Params
|
||||
#[derive(Error, Debug)] |
||||
pub enum ParamsError { |
||||
#[error("Unsupported method")] |
||||
UnsupportedMethod(String), |
||||
#[error("Paramiters not initialized")] |
||||
ParametersNotInitialized(String), |
||||
} |
||||
|
||||
/// Allowed type of states
|
||||
#[derive(Clone)] |
||||
pub enum StateType { |
||||
Discrete(usize), |
||||
} |
||||
|
||||
/// Parameters
|
||||
/// The Params trait is the core element for building different types of nodes. The goal is to
|
||||
/// define the set of method required to describes a generic node.
|
||||
#[enum_dispatch(Params)] |
||||
pub trait ParamsTrait { |
||||
fn reset_params(&mut self); |
||||
|
||||
/// Randomly generate a possible state of the node disregarding the state of the node and it's
|
||||
/// parents.
|
||||
fn get_random_state_uniform(&self) -> StateType; |
||||
|
||||
/// Randomly generate a residence time for the given node taking into account the node state
|
||||
/// and its parent set.
|
||||
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>; |
||||
|
||||
/// Randomly generate a possible state for the given node taking into account the node state
|
||||
/// and its parent set.
|
||||
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError>; |
||||
|
||||
/// Used by childern of the node described by this parameters to reserve spaces in their CIMs.
|
||||
fn get_reserved_space_as_parent(&self) -> usize; |
||||
|
||||
/// Index used by discrete node to represents their states as usize.
|
||||
fn state_to_index(&self, state: &StateType) -> usize; |
||||
} |
||||
|
||||
/// The Params enum is the core element for building different types of nodes. The goal is to
|
||||
/// define all the supported type of parameters.
|
||||
#[enum_dispatch] |
||||
pub enum Params { |
||||
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), |
||||
} |
||||
|
||||
|
||||
/// DiscreteStatesContinousTime.
|
||||
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
|
||||
/// following elements:
|
||||
/// - **domain**: an ordered and exhaustive set of possible states
|
||||
/// - **cim**: Conditional Intensity Matrix
|
||||
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
|
||||
/// learning task and are composed by:
|
||||
/// - **transitions**: number of transitions from one state to another given a specific
|
||||
/// realization of the parent set
|
||||
/// - **residence_time**: permanence time in each possible states given a specific
|
||||
/// realization of the parent set
|
||||
pub struct DiscreteStatesContinousTimeParams { |
||||
pub domain: BTreeSet<String>, |
||||
pub cim: Option<Array3<f64>>, |
||||
pub transitions: Option<Array3<u64>>, |
||||
pub residence_time: Option<Array2<f64>>, |
||||
} |
||||
|
||||
impl DiscreteStatesContinousTimeParams { |
||||
pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams { |
||||
DiscreteStatesContinousTimeParams { |
||||
domain: domain, |
||||
cim: Option::None, |
||||
transitions: Option::None, |
||||
residence_time: Option::None, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl ParamsTrait for DiscreteStatesContinousTimeParams { |
||||
fn reset_params(&mut self) { |
||||
self.cim = Option::None; |
||||
self.transitions = Option::None; |
||||
self.residence_time = Option::None; |
||||
} |
||||
|
||||
fn get_random_state_uniform(&self) -> StateType { |
||||
let mut rng = rand::thread_rng(); |
||||
StateType::Discrete(rng.gen_range(0..(self.domain.len()))) |
||||
} |
||||
|
||||
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError> { |
||||
// Generate a random residence time given the current state of the node and its parent set.
|
||||
// The method used is described in:
|
||||
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
|
||||
match &self.cim { |
||||
Option::Some(cim) => { |
||||
let mut rng = rand::thread_rng(); |
||||
let lambda = cim[[u, state, state]] * -1.0; |
||||
let x: f64 = rng.gen_range(0.0..=1.0); |
||||
Ok(-x.ln() / lambda) |
||||
} |
||||
Option::None => Err(ParamsError::ParametersNotInitialized(String::from( |
||||
"CIM not initialized", |
||||
))), |
||||
} |
||||
} |
||||
|
||||
fn get_random_state(&self, state: usize, u: usize) -> Result<StateType, ParamsError> { |
||||
// Generate a random transition given the current state of the node and its parent set.
|
||||
// The method used is described in:
|
||||
// https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
|
||||
match &self.cim { |
||||
Option::Some(cim) => { |
||||
let mut rng = rand::thread_rng(); |
||||
let lambda = cim[[u, state, state]] * -1.0; |
||||
let urand: f64 = rng.gen_range(0.0..=1.0); |
||||
|
||||
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold( |
||||
(0, 0.0), |
||||
|mut acc, ele| { |
||||
if &acc.1 + ele < urand && ele > &0.0 { |
||||
acc.0 += 1; |
||||
} |
||||
if ele > &0.0 { |
||||
acc.1 += ele; |
||||
} |
||||
acc |
||||
}, |
||||
); |
||||
|
||||
let next_state = if next_state.0 < state { |
||||
next_state.0 |
||||
} else { |
||||
next_state.0 + 1 |
||||
}; |
||||
|
||||
Ok(StateType::Discrete(next_state)) |
||||
} |
||||
Option::None => Err(ParamsError::ParametersNotInitialized(String::from( |
||||
"CIM not initialized", |
||||
))), |
||||
} |
||||
} |
||||
|
||||
fn get_reserved_space_as_parent(&self) -> usize { |
||||
self.domain.len() |
||||
} |
||||
|
||||
fn state_to_index(&self, state: &StateType) -> usize { |
||||
match state { |
||||
StateType::Discrete(val) => val.clone() as usize, |
||||
} |
||||
} |
||||
} |
||||
|
@ -1,119 +0,0 @@ |
||||
use crate::network; |
||||
use crate::node; |
||||
use crate::params; |
||||
use crate::params::ParamsTrait; |
||||
use ndarray::prelude::*; |
||||
|
||||
pub struct Trajectory { |
||||
pub time: Array1<f64>, |
||||
pub events: Array2<usize>, |
||||
} |
||||
|
||||
pub struct Dataset { |
||||
pub trajectories: Vec<Trajectory>, |
||||
} |
||||
|
||||
pub fn trajectory_generator<T: network::Network>( |
||||
net: &T, |
||||
n_trajectories: u64, |
||||
t_end: f64, |
||||
) -> Dataset { |
||||
let mut dataset = Dataset { |
||||
trajectories: Vec::new(), |
||||
}; |
||||
|
||||
let node_idx: Vec<_> = net.get_node_indices().collect(); |
||||
for _ in 0..n_trajectories { |
||||
let mut t = 0.0; |
||||
let mut time: Vec<f64> = Vec::new(); |
||||
let mut events: Vec<Array1<usize>> = Vec::new(); |
||||
let mut current_state: Vec<params::StateType> = node_idx |
||||
.iter() |
||||
.map(|x| net.get_node(*x).params.get_random_state_uniform()) |
||||
.collect(); |
||||
let mut next_transitions: Vec<Option<f64>> = |
||||
(0..node_idx.len()).map(|_| Option::None).collect(); |
||||
events.push( |
||||
current_state |
||||
.iter() |
||||
.map(|x| match x { |
||||
params::StateType::Discrete(state) => state.clone(), |
||||
}) |
||||
.collect(), |
||||
); |
||||
time.push(t.clone()); |
||||
while t < t_end { |
||||
for (idx, val) in next_transitions.iter_mut().enumerate() { |
||||
if let None = val { |
||||
*val = Some( |
||||
net.get_node(idx) |
||||
.params |
||||
.get_random_residence_time( |
||||
net.get_node(idx).params.state_to_index(¤t_state[idx]), |
||||
net.get_param_index_network(idx, ¤t_state), |
||||
) |
||||
.unwrap() |
||||
+ t, |
||||
); |
||||
} |
||||
} |
||||
|
||||
let next_node_transition = next_transitions |
||||
.iter() |
||||
.enumerate() |
||||
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) |
||||
.unwrap() |
||||
.0; |
||||
if next_transitions[next_node_transition].unwrap() > t_end { |
||||
break; |
||||
} |
||||
t = next_transitions[next_node_transition].unwrap().clone(); |
||||
time.push(t.clone()); |
||||
|
||||
current_state[next_node_transition] = net |
||||
.get_node(next_node_transition) |
||||
.params |
||||
.get_random_state( |
||||
net.get_node(next_node_transition) |
||||
.params |
||||
.state_to_index(¤t_state[next_node_transition]), |
||||
net.get_param_index_network(next_node_transition, ¤t_state), |
||||
) |
||||
.unwrap(); |
||||
|
||||
events.push(Array::from_vec( |
||||
current_state |
||||
.iter() |
||||
.map(|x| match x { |
||||
params::StateType::Discrete(state) => state.clone(), |
||||
}) |
||||
.collect(), |
||||
)); |
||||
next_transitions[next_node_transition] = None; |
||||
|
||||
for child in net.get_children_set(next_node_transition) { |
||||
next_transitions[child] = None |
||||
} |
||||
} |
||||
|
||||
events.push( |
||||
current_state |
||||
.iter() |
||||
.map(|x| match x { |
||||
params::StateType::Discrete(state) => state.clone(), |
||||
}) |
||||
.collect(), |
||||
); |
||||
time.push(t_end.clone()); |
||||
|
||||
dataset.trajectories.push(Trajectory { |
||||
time: Array::from_vec(time), |
||||
events: Array2::from_shape_vec( |
||||
(events.len(), current_state.len()), |
||||
events.iter().flatten().cloned().collect(), |
||||
) |
||||
.unwrap(), |
||||
}); |
||||
} |
||||
dataset |
||||
} |
@ -1,99 +0,0 @@ |
||||
mod utils; |
||||
use utils::generate_discrete_time_continous_node; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::node; |
||||
use rustyCTBN::params; |
||||
use std::collections::BTreeSet; |
||||
use rustyCTBN::ctbn::*; |
||||
|
||||
#[test] |
||||
fn define_simpe_ctbn() { |
||||
let _ = CtbnNetwork::init(); |
||||
assert!(true); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_node_to_ctbn() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
assert_eq!(String::from("n1"), net.get_node(n1).label); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_edge_to_ctbn() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
let cs = net.get_children_set(n1); |
||||
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||
} |
||||
|
||||
#[test] |
||||
fn children_and_parents() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
let cs = net.get_children_set(n1); |
||||
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||
let ps = net.get_parent_set(n2); |
||||
assert_eq!(&n1, ps.iter().next().unwrap()); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn compute_index_ctbn() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n3, n2); |
||||
let idx = net.get_param_index_network(n2, &vec![ |
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1)]); |
||||
assert_eq!(3, idx); |
||||
|
||||
|
||||
let idx = net.get_param_index_network(n2, &vec![ |
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1)]); |
||||
assert_eq!(2, idx); |
||||
|
||||
|
||||
let idx = net.get_param_index_network(n2, &vec![ |
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(0)]); |
||||
assert_eq!(1, idx); |
||||
|
||||
} |
||||
|
||||
|
||||
|
||||
#[test] |
||||
fn compute_index_from_custom_parent_set() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); |
||||
|
||||
|
||||
let idx = net.get_param_index_from_custom_parent_set(&vec![ |
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(1)], |
||||
&BTreeSet::from([1])); |
||||
assert_eq!(0, idx); |
||||
|
||||
|
||||
let idx = net.get_param_index_from_custom_parent_set(&vec![ |
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(1)], |
||||
&BTreeSet::from([1,2])); |
||||
assert_eq!(2, idx); |
||||
} |
@ -1,263 +0,0 @@ |
||||
mod utils; |
||||
use utils::*; |
||||
|
||||
use rustyCTBN::parameter_learning::*; |
||||
use rustyCTBN::ctbn::*; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::node; |
||||
use rustyCTBN::params; |
||||
use rustyCTBN::tools::*; |
||||
use ndarray::arr3; |
||||
use std::collections::BTreeSet; |
||||
|
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
|
||||
fn learn_binary_cim<T: ParameterLearning> (pl: T) { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[ |
||||
[[-1.0, 1.0], [4.0, -4.0]], |
||||
[[-6.0, 6.0], [2.0, -2.0]], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 100.0); |
||||
let (CIM, M, T) = pl.fit(&net, &data, 1, None); |
||||
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||
assert_eq!(CIM.shape(), [2, 2, 2]); |
||||
assert!(CIM.abs_diff_eq(&arr3(&[ |
||||
[[-1.0, 1.0], [4.0, -4.0]], |
||||
[[-6.0, 6.0], [2.0, -2.0]], |
||||
]), 0.2)); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_binary_cim_MLE() { |
||||
let mle = MLE{}; |
||||
learn_binary_cim(mle); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn learn_binary_cim_BA() { |
||||
let ba = BayesianApproach{ |
||||
default_alpha: 1, |
||||
default_tau: 1.0}; |
||||
learn_binary_cim(ba); |
||||
} |
||||
|
||||
fn learn_ternary_cim<T: ParameterLearning> (pl: T) { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]])); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 200.0); |
||||
let (CIM, M, T) = pl.fit(&net, &data, 1, None); |
||||
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||
assert_eq!(CIM.shape(), [3, 3, 3]); |
||||
assert!(CIM.abs_diff_eq(&arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
]), 0.2)); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_MLE() { |
||||
let mle = MLE{}; |
||||
learn_ternary_cim(mle); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_BA() { |
||||
let ba = BayesianApproach{ |
||||
default_alpha: 1, |
||||
default_tau: 1.0}; |
||||
learn_ternary_cim(ba); |
||||
} |
||||
|
||||
fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]])); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 100, 200.0); |
||||
let (CIM, M, T) = pl.fit(&net, &data, 0, None); |
||||
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||
assert_eq!(CIM.shape(), [1, 3, 3]); |
||||
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]]), 0.2)); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_no_parents_MLE() { |
||||
let mle = MLE{}; |
||||
learn_ternary_cim_no_parents(mle); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_no_parents_BA() { |
||||
let ba = BayesianApproach{ |
||||
default_alpha: 1, |
||||
default_tau: 1.0}; |
||||
learn_ternary_cim_no_parents(ba); |
||||
} |
||||
|
||||
|
||||
fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||
.unwrap(); |
||||
|
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"),4)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n1, n3); |
||||
net.add_edge(n2, n3); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]])); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
|
||||
match &mut net.get_node_mut(n3).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[ |
||||
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], |
||||
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], |
||||
|
||||
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], |
||||
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], |
||||
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], |
||||
|
||||
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], |
||||
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], |
||||
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
|
||||
let data = trajectory_generator(&net, 300, 200.0); |
||||
let (CIM, M, T) = pl.fit(&net, &data, 2, None); |
||||
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||
assert_eq!(CIM.shape(), [9, 4, 4]); |
||||
assert!(CIM.abs_diff_eq(&arr3(&[ |
||||
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], |
||||
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], |
||||
|
||||
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], |
||||
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], |
||||
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], |
||||
|
||||
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], |
||||
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], |
||||
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], |
||||
]), 0.2)); |
||||
} |
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_cim_MLE() { |
||||
let mle = MLE{}; |
||||
learn_mixed_discrete_cim(mle); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn learn_mixed_discrete_cim_BA() { |
||||
let ba = BayesianApproach{ |
||||
default_alpha: 1, |
||||
default_tau: 1.0}; |
||||
learn_mixed_discrete_cim(ba); |
||||
} |
@ -1,64 +0,0 @@ |
||||
use rustyCTBN::params::*; |
||||
use ndarray::prelude::*; |
||||
use std::collections::BTreeSet; |
||||
|
||||
mod utils; |
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
|
||||
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { |
||||
let mut params = utils::generate_discrete_time_continous_param(3); |
||||
|
||||
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; |
||||
|
||||
params.cim = Some(cim); |
||||
params |
||||
} |
||||
|
||||
#[test] |
||||
fn test_uniform_generation() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<usize>::zeros(10000); |
||||
|
||||
states.mapv_inplace(|_| { |
||||
if let StateType::Discrete(val) = param.get_random_state_uniform() { |
||||
val |
||||
} else { |
||||
panic!() |
||||
} |
||||
}); |
||||
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||
|
||||
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_random_generation_state() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<usize>::zeros(10000); |
||||
|
||||
states.mapv_inplace(|_| { |
||||
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { |
||||
val |
||||
} else { |
||||
panic!() |
||||
} |
||||
}); |
||||
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; |
||||
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||
|
||||
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01); |
||||
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_random_generation_residence_time() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<f64>::zeros(10000); |
||||
|
||||
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); |
||||
|
||||
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); |
||||
} |
@ -1,45 +0,0 @@ |
||||
|
||||
use rustyCTBN::tools::*; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::ctbn::*; |
||||
use rustyCTBN::node; |
||||
use rustyCTBN::params; |
||||
use std::collections::BTreeSet; |
||||
use ndarray::arr3; |
||||
|
||||
|
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
mod utils; |
||||
|
||||
#[test] |
||||
fn run_sampling() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); |
||||
} |
||||
} |
||||
|
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some (arr3(&[ |
||||
[[-1.0,1.0],[4.0,-4.0]], |
||||
[[-6.0,6.0],[2.0,-2.0]]])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(&net, 4, 1.0); |
||||
|
||||
assert_eq!(4, data.trajectories.len()); |
||||
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); |
||||
} |
||||
|
||||
|
@ -1,16 +0,0 @@ |
||||
use rustyCTBN::params; |
||||
use rustyCTBN::node; |
||||
use std::collections::BTreeSet; |
||||
|
||||
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { |
||||
node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) |
||||
} |
||||
|
||||
|
||||
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ |
||||
let mut domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); |
||||
params::DiscreteStatesContinousTimeParams::init(domain) |
||||
} |
||||
|
||||
|
||||
|
Loading…
Reference in new issue