Compare commits

..

3 Commits

Author SHA1 Message Date
AlessandroBregoli e1af14b620 folder fix 2 years ago
AlessandroBregoli aa630bf9e9 Generate trajectories from python 2 years ago
AlessandroBregoli 3dae67a80c Init 2 years ago
  1. 7
      .github/workflows/build.yml
  2. 1
      .gitignore
  3. 176
      LICENSE-APACHE
  4. 23
      LICENSE-MIT
  5. 16
      README.md
  6. 16
      pyproject.toml
  7. 18
      reCTBN/Cargo.toml
  8. 162
      reCTBN/src/ctbn.rs
  9. 5
      reCTBN/src/lib.rs
  10. 43
      reCTBN/src/network.rs
  11. 43
      reCTBN/src/parameter_learning.rs
  12. 59
      reCTBN/src/params.rs
  13. 120
      reCTBN/src/process.rs
  14. 247
      reCTBN/src/process/ctbn.rs
  15. 114
      reCTBN/src/process/ctmp.rs
  16. 59
      reCTBN/src/reward.rs
  17. 205
      reCTBN/src/reward/reward_evaluation.rs
  18. 106
      reCTBN/src/reward/reward_function.rs
  19. 50
      reCTBN/src/sampling.rs
  20. 8
      reCTBN/src/structure_learning.rs
  21. 351
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  22. 212
      reCTBN/src/structure_learning/hypothesis_test.rs
  23. 24
      reCTBN/src/structure_learning/score_based_algorithm.rs
  24. 14
      reCTBN/src/structure_learning/score_function.rs
  25. 272
      reCTBN/src/tools.rs
  26. 249
      reCTBN/tests/ctbn.rs
  27. 127
      reCTBN/tests/ctmp.rs
  28. 207
      reCTBN/tests/parameter_learning.rs
  29. 122
      reCTBN/tests/reward_evaluation.rs
  30. 117
      reCTBN/tests/reward_function.rs
  31. 247
      reCTBN/tests/structure_learning.rs
  32. 171
      reCTBN/tests/tools.rs
  33. 69
      reCTBNpy/.github/workflows/CI.yml
  34. 72
      reCTBNpy/.gitignore
  35. 15
      reCTBNpy/Cargo.toml
  36. 14
      reCTBNpy/pyproject.toml
  37. 21
      reCTBNpy/src/lib.rs
  38. 68
      reCTBNpy/src/pyctbn.rs
  39. 102
      reCTBNpy/src/pyparams.rs
  40. 50
      reCTBNpy/src/pytools.rs

@ -22,7 +22,7 @@ jobs:
profile: minimal
toolchain: stable
default: true
components: clippy, rustfmt, rust-docs
components: clippy, rustfmt
- name: Setup Rust nightly
uses: actions-rs/toolchain@v1
with:
@ -30,11 +30,6 @@ jobs:
toolchain: nightly
default: false
components: rustfmt
- name: Docs (doc)
uses: actions-rs/cargo@v1
with:
command: rustdoc
args: --package reCTBN -- --default-theme=ayu
- name: Linting (clippy)
uses: actions-rs/clippy-check@v1
with:

1
.gitignore vendored

@ -1,3 +1,4 @@
/target
Cargo.lock
.vscode
poetry.lock

@ -1,176 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS

@ -1,23 +0,0 @@
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

@ -56,19 +56,3 @@ To check the **formatting**:
```sh
cargo fmt --all -- --check
```
## Documentation
To generate the **documentation**:
```sh
cargo rustdoc --package reCTBN --open -- --default-theme=ayu
```
## License
This software is distributed under the terms of both the Apache License
(Version 2.0) and the MIT license.
See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for
details.

@ -0,0 +1,16 @@
[tool.poetry]
name = "rectbnpy"
version = "0.1.0"
description = ""
authors = ["AlessandroBregoli <alessandroxciv@gmail.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
maturin = "^0.13.3"
numpy = "^1.23.3"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

@ -6,15 +6,13 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ndarray = {version="~0.15", features=["approx-0_5"]}
thiserror = "1.0.37"
rand = "~0.8"
bimap = "~0.6"
enum_dispatch = "~0.3"
statrs = "~0.16"
rand_chacha = "~0.3"
itertools = "~0.10"
rayon = "~1.6"
ndarray = {version="*", features=["approx-0_5"]}
thiserror = "*"
rand = "*"
bimap = "*"
enum_dispatch = "*"
statrs = "*"
rand_chacha = "*"
[dev-dependencies]
approx = { package = "approx", version = "~0.5" }
approx = { package = "approx", version = "0.5" }

@ -0,0 +1,162 @@
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::network;
use crate::params::{Params, ParamsTrait, StateType};
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
///composed by the following elements:
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
///- **nodes**: a vector containing all the nodes and their parameters.
///The index of a node inside the vector is also used as index for the adj_matrix.
///
///# Examples
///
///```
///
/// use std::collections::BTreeSet;
/// use reCTBN::network::Network;
/// use reCTBN::params;
/// use reCTBN::ctbn::*;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Get all the children of node X1
/// let cs = net.get_children_set(X1);
/// assert_eq!(&X2, cs.iter().next().unwrap());
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params>,
}
impl CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new(),
}
}
}
impl network::Network for CtbnNetwork {
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() - 1)
}
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len()
}
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
}
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<StateType>,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set
.iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
}
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
}

@ -1,12 +1,11 @@
#![doc = include_str!("../../README.md")]
#![allow(non_snake_case)]
#[cfg(test)]
extern crate approx;
pub mod ctbn;
pub mod network;
pub mod parameter_learning;
pub mod params;
pub mod process;
pub mod reward;
pub mod sampling;
pub mod structure_learning;
pub mod tools;

@ -0,0 +1,43 @@
use std::collections::BTreeSet;
use thiserror::Error;
use crate::params;
/// Error types for trait Network
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Error during node insertion")]
NodeInsertionError(String),
}
///Network
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network
fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_number_of_nodes(&self) -> usize;
fn get_node(&self, node_idx: usize) -> &params::Params;
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*.
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -1,25 +1,23 @@
//! Module containing methods used to learn the parameters.
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::params::*;
use crate::{process, tools::Dataset};
use crate::{network, tools};
pub trait ParameterLearning: Sync {
fn fit<T: process::NetworkProcess>(
pub trait ParameterLearning {
fn fit<T: network::Network>(
&self,
net: &T,
dataset: &Dataset,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params;
}
pub fn sufficient_statistics<T: process::NetworkProcess>(
pub fn sufficient_statistics<T: network::Network>(
net: &T,
dataset: &Dataset,
dataset: &tools::Dataset,
node: usize,
parent_set: &BTreeSet<usize>,
) -> (Array3<usize>, Array2<f64>) {
@ -73,10 +71,10 @@ pub fn sufficient_statistics<T: process::NetworkProcess>(
pub struct MLE {}
impl ParameterLearning for MLE {
fn fit<T: process::NetworkProcess>(
fn fit<T: network::Network>(
&self,
net: &T,
dataset: &Dataset,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
@ -120,10 +118,10 @@ pub struct BayesianApproach {
}
impl ParameterLearning for BayesianApproach {
fn fit<T: process::NetworkProcess>(
fn fit<T: network::Network>(
&self,
net: &T,
dataset: &Dataset,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
@ -144,10 +142,6 @@ impl ParameterLearning for BayesianApproach {
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau))));
CIM.outer_iter_mut().for_each(|mut C| {
C.diag_mut().fill(0.0);
});
//Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut()
@ -168,3 +162,20 @@ impl ParameterLearning for BayesianApproach {
return n;
}
}
pub struct Cache<P: ParameterLearning> {
parameter_learning: P,
dataset: tools::Dataset,
}
impl<P: ParameterLearning> Cache<P> {
pub fn fit<T: network::Network>(
&mut self,
net: &T,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
self.parameter_learning
.fit(net, &self.dataset, node, parent_set)
}
}

@ -1,5 +1,3 @@
//! Module containing methods to define different types of nodes.
use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch;
@ -20,13 +18,14 @@ pub enum ParamsError {
}
/// Allowed type of states
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
#[derive(Clone)]
pub enum StateType {
Discrete(usize),
}
/// This is a core element for building different types of nodes; the goal is to define the set of
/// methods required to describes a generic node.
/// Parameters
/// The Params trait is the core element for building different types of nodes. The goal is to
/// define the set of method required to describes a generic node.
#[enum_dispatch(Params)]
pub trait ParamsTrait {
fn reset_params(&mut self);
@ -66,27 +65,25 @@ pub trait ParamsTrait {
fn get_label(&self) -> &String;
}
/// Is a core element for building different types of nodes; the goal is to define all the
/// supported type of Parameters
/// The Params enum is the core element for building different types of nodes. The goal is to
/// define all the supported type of Parameters
#[derive(Clone)]
#[enum_dispatch]
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements.
///
/// # Arguments
///
/// * `label` - node's variable name.
/// * `domain` - an ordered and exhaustive set of possible states.
/// * `cim` - Conditional Intensity Matrix.
/// * `transitions` - number of transitions from one state to another given a specific realization
/// of the parent set; is a sufficient statistics are mainly used during the parameter learning
/// task.
/// * `residence_time` - residence time in each possible state, given a specific realization of the
/// parent set; is a sufficient statistics are mainly used during the parameter learning task.
/// following elements:
/// - **domain**: an ordered and exhaustive set of possible states
/// - **cim**: Conditional Intensity Matrix
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
/// learning task and are composed by:
/// - **transitions**: number of transitions from one state to another given a specific
/// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
#[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams {
label: String,
@ -107,17 +104,15 @@ impl DiscreteStatesContinousTimeParams {
}
}
/// Getter function for CIM
///Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> {
&self.cim
}
/// Setter function for CIM.
///
/// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method:
/// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`.
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
/// `ParamsError`.
///Setter function for CIM.\\
///This function check if the cim is valid using the validate_params method.
///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(())
///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
self.cim = Some(cim);
match self.validate_params() {
@ -129,27 +124,27 @@ impl DiscreteStatesContinousTimeParams {
}
}
/// Unchecked version of the setter function for CIM.
///Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim);
}
/// Getter function for transitions.
///Getter function for transitions
pub fn get_transitions(&self) -> &Option<Array3<usize>> {
&self.transitions
}
/// Setter function for transitions.
///Setter function for transitions
pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions);
}
/// Getter function for residence_time.
///Getter function for residence_time
pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
&self.residence_time
}
/// Setter function for residence_time.
///Setter function for residence_time
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time);
}
@ -271,7 +266,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
if cim
.sum_axis(Axis(2))
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0)
{
return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0",

@ -1,120 +0,0 @@
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
pub mod ctbn;
pub mod ctmp;
use std::collections::BTreeSet;
use thiserror::Error;
use crate::params;
/// Error types for trait Network
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Error during node insertion")]
NodeInsertionError(String),
}
/// This type is used to represent a specific realization of a generic NetworkProcess
pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN).
pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
///
/// # Arguments
///
/// * `parent` - parent node.
/// * `child` - child node.
fn add_edge(&mut self, parent: usize, child: usize);
/// Get all the indices of the nodes contained inside the network.
fn get_node_indices(&self) -> std::ops::Range<usize>;
/// Get the numbers of nodes contained in the network.
fn get_number_of_nodes(&self) -> usize;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node param**.
fn get_node(&self, node_idx: usize) -> &params::Params;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node mutable param**.
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `node` - selected node.
/// * `current_state` - current configuration of the network.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize;
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `current_state` - current configuration of the network.
/// * `parent_set` - parent set of the selected `node`.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
/// Get the **parent set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **parent set** of the selected node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
/// Get the **children set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **children set** of the selected node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -1,247 +0,0 @@
//! Continuous Time Bayesian Network
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
use crate::process;
use super::ctmp::CtmpProcess;
use super::{NetworkProcess, NetworkProcessState};
/// It represents both the structure and the parameters of a CTBN.
///
/// # Arguments
///
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
/// * `nodes` - A vector containing all the nodes and their parameters.
///
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
///
/// # Example
///
/// ```rust
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Get all the children of node X1
/// let cs = net.get_children_set(X1);
/// assert_eq!(&X2, cs.iter().next().unwrap());
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params>,
}
impl CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new(),
}
}
///Transform the **CTBN** into a **CTMP**
///
/// # Return
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess {
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
let state_space = variables_domain.product();
let variables_set = BTreeSet::from_iter(self.get_node_indices());
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space));
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: NetworkProcessState = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
for idx_node in 0..self.nodes.len() {
let p = match self.get_node(idx_node) {
Params::DiscreteStatesContinousTime(p) => p,
};
for next_node_state in 0..variables_domain[idx_node] {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
&variables_set,
);
amalgamated_cim[[0, idx_current_state, idx_next_state]] +=
p.get_cim().as_ref().unwrap()[[
self.get_param_index_network(idx_node, &current_state_statetype),
current_state[idx_node],
next_node_state,
]];
}
}
}
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new(
"ctmp".to_string(),
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param))
.unwrap();
return ctmp;
}
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> {
let mut state = state;
let mut array_state = Array1::zeros(variables_domain.shape()[0]);
for (idx, var) in variables_domain.indexed_iter() {
array_state[idx] = state % var;
state = state / var;
}
return array_state;
}
/// Get the Adjacency Matrix.
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> {
self.adj_matrix.as_ref()
}
}
impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() - 1)
}
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len()
}
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
}
fn get_param_index_from_custom_parent_set(
&self,
current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set
.iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
}
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
}

@ -1,114 +0,0 @@
use std::collections::BTreeSet;
use crate::{
params::{Params, StateType},
process,
};
use super::{NetworkProcess, NetworkProcessState};
pub struct CtmpProcess {
param: Option<Params>,
}
impl CtmpProcess {
pub fn new() -> CtmpProcess {
CtmpProcess { param: None }
}
}
impl NetworkProcess for CtmpProcess {
fn initialize_adj_matrix(&mut self) {
unimplemented!("CtmpProcess has only one node")
}
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> {
match self.param {
None => {
self.param = Some(n);
Ok(0)
}
Some(_) => Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(),
)),
}
}
fn add_edge(&mut self, _parent: usize, _child: usize) {
unimplemented!("CtmpProcess has only one node")
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
match self.param {
None => 0..0,
Some(_) => 0..1,
}
}
fn get_number_of_nodes(&self) -> usize {
match self.param {
None => 0,
Some(_) => 1,
}
}
fn get_node(&self, node_idx: usize) -> &crate::params::Params {
if node_idx == 0 {
self.param.as_ref().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params {
if node_idx == 0 {
self.param.as_mut().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
if node == 0 {
match current_state[0] {
StateType::Discrete(x) => x,
}
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_from_custom_parent_set(
&self,
_current_state: &NetworkProcessState,
_parent_set: &BTreeSet<usize>,
) -> usize {
unimplemented!("CtmpProcess has only one node")
}
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
}

@ -1,59 +0,0 @@
pub mod reward_evaluation;
pub mod reward_function;
use std::collections::HashMap;
use crate::process;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction: Sync {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
pub trait RewardEvaluation {
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64>;
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &process::NetworkProcessState,
) -> f64;
}

@ -1,205 +0,0 @@
use std::collections::HashMap;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use statrs::distribution::ContinuousCDF;
use crate::params::{self, ParamsTrait};
use crate::process;
use crate::{
process::NetworkProcessState,
reward::RewardEvaluation,
sampling::{ForwardSampler, Sampler},
};
pub enum RewardCriteria {
FiniteHorizon,
InfiniteHorizon { discount_factor: f64 },
}
pub struct MonteCarloReward {
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
}
impl MonteCarloReward {
pub fn new(
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
) -> MonteCarloReward {
MonteCarloReward {
max_iterations,
max_err_stop,
alpha_stop,
end_time,
reward_criteria,
seed,
}
}
}
impl RewardEvaluation for MonteCarloReward {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let variables_domain: Vec<Vec<params::StateType>> = network_process
.get_node_indices()
.map(|x| match network_process.get_node(x) {
params::Params::DiscreteStatesContinousTime(x) => (0..x
.get_reserved_space_as_parent())
.map(|s| params::StateType::Discrete(s))
.collect(),
})
.collect();
let n_states: usize = variables_domain.iter().map(|x| x.len()).product();
(0..n_states)
.into_par_iter()
.map(|s| {
let state: process::NetworkProcessState = variables_domain
.iter()
.fold((s, vec![]), |acc, x| {
let mut acc = acc;
let idx_s = acc.0 % x.len();
acc.1.push(x[idx_s].clone());
acc.0 = acc.0 / x.len();
acc
})
.1;
let r = self.evaluate_state(network_process, reward_function, &state);
(state, r)
})
.collect()
}
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &NetworkProcessState,
) -> f64 {
let mut sampler =
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone()));
let mut expected_value = 0.0;
let mut squared_expected_value = 0.0;
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap();
for i in 0..self.max_iterations {
sampler.reset();
let mut ret = 0.0;
let mut previous = sampler.next().unwrap();
while previous.t < self.end_time {
let current = sampler.next().unwrap();
if current.t > self.end_time {
let r = reward_function.call(&previous.state, None);
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => self.end_time - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * self.end_time)
}
};
ret += discount * r.instantaneous_reward;
} else {
let r = reward_function.call(&previous.state, Some(&current.state));
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => current.t - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * current.t)
}
};
ret += discount * r.instantaneous_reward;
ret += match self.reward_criteria {
RewardCriteria::FiniteHorizon => 1.0,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * current.t)
}
} * r.transition_reward;
}
previous = current;
}
let float_i = i as f64;
expected_value =
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0);
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0)
+ ret.powi(2) / (float_i + 1.0);
if i > 2 {
let var =
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2));
if self.alpha_stop
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt())
> 0.0
{
return expected_value;
}
}
}
expected_value
}
}
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> {
inner_reward: RE,
}
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> {
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> {
NeighborhoodRelativeReward { inner_reward }
}
}
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let absolute_reward = self
.inner_reward
.evaluate_state_space(network_process, reward_function);
//This approach optimize memory. Maybe optimizing execution time can be better.
absolute_reward
.iter()
.map(|(k1, v1)| {
let mut max_val: f64 = 1.0;
absolute_reward.iter().for_each(|(k2, v2)| {
let count_diff: usize = k1
.iter()
.zip(k2.iter())
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 })
.sum();
if count_diff < 2 {
max_val = max_val.max(v1 / v2);
}
});
(k1.clone(), max_val)
})
.collect()
}
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
_network_process: &N,
_reward_function: &R,
_state: &process::NetworkProcessState,
) -> f64 {
unimplemented!();
}
}

@ -1,106 +0,0 @@
//! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process,
reward::{Reward, RewardFunction},
};
use ndarray;
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>,
}
impl FactoredRewardFunction {
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> {
&self.transition_reward[node_idx]
}
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> {
&mut self.transition_reward[node_idx]
}
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> {
&self.instantaneous_reward[node_idx]
}
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(
&self,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state
.iter()
.zip(current_state.iter())
.enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p {
params::StateType::Discrete(p) => p,
};
let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else {
Reward {
transition_reward: 0.0,
instantaneous_reward,
}
}
}
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() {
//This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent();
instantaneous_reward.push(ndarray::Array1::zeros(size));
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
}
}

@ -1,40 +1,27 @@
//! Module containing methods for the sampling.
use crate::{
params::ParamsTrait,
process::{NetworkProcess, NetworkProcessState},
network::Network,
params::{self, ParamsTrait},
};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: NetworkProcessState,
}
pub trait Sampler: Iterator<Item = Sample> {
pub trait Sampler: Iterator {
fn reset(&mut self);
}
pub struct ForwardSampler<'a, T>
where
T: NetworkProcess,
T: Network,
{
net: &'a T,
rng: ChaCha8Rng,
current_time: f64,
current_state: NetworkProcessState,
current_state: Vec<params::StateType>,
next_transitions: Vec<Option<f64>>,
initial_state: Option<NetworkProcessState>,
}
impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
pub fn new(
net: &'a T,
seed: Option<u64>,
initial_state: Option<NetworkProcessState>,
) -> ForwardSampler<'a, T> {
impl<'a, T: Network> ForwardSampler<'a, T> {
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed),
@ -42,20 +29,19 @@ impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
None => SeedableRng::from_entropy(),
};
let mut fs = ForwardSampler {
net,
rng,
net: net,
rng: rng,
current_time: 0.0,
current_state: vec![],
next_transitions: vec![],
initial_state,
};
fs.reset();
return fs;
}
}
impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = Sample;
impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>);
fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone();
@ -108,26 +94,18 @@ impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None;
}
Some(Sample {
t: ret_time,
state: ret_state,
})
Some((ret_time, ret_state))
}
}
impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> {
impl<'a, T: Network> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) {
self.current_time = 0.0;
match &self.initial_state {
None => {
self.current_state = self
.net
.get_node_indices()
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng))
.collect()
}
Some(is) => self.current_state = is.clone(),
};
.collect();
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect();
}
}

@ -1,13 +1,11 @@
//! Learn the structure of the network.
pub mod constraint_based_algorithm;
pub mod hypothesis_test;
pub mod score_based_algorithm;
pub mod score_function;
use crate::{process, tools::Dataset};
use crate::{network, tools};
pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: process::NetworkProcess;
T: network::Network;
}

@ -1,348 +1,3 @@
//! Module containing constraint based algorithms like CTPC and Hiton.
use crate::params::Params;
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
use std::collections::{BTreeSet, HashMap};
use std::mem;
use std::usize;
use super::hypothesis_test::*;
use crate::parameter_learning::ParameterLearning;
use crate::process;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools::Dataset;
pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P,
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>,
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>,
parent_set_size_small: usize,
}
impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache {
parameter_learning,
cache_persistent_small: HashMap::new(),
cache_persistent_big: HashMap::new(),
parent_set_size_small: 0,
}
}
pub fn fit<T: process::NetworkProcess>(
&mut self,
net: &T,
dataset: &Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
let parent_set_len = parent_set.as_ref().unwrap().len();
if parent_set_len > self.parent_set_size_small + 1 {
//self.cache_persistent_small = self.cache_persistent_big;
mem::swap(
&mut self.cache_persistent_small,
&mut self.cache_persistent_big,
);
self.cache_persistent_big = HashMap::new();
self.parent_set_size_small += 1;
}
if parent_set_len > self.parent_set_size_small {
match self.cache_persistent_big.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_big.insert(parent_set, params.clone());
params
}
}
} else {
match self.cache_persistent_small.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_small
.insert(parent_set, params.clone());
params
}
}
}
}
}
/// Continuous-Time Peter Clark algorithm.
///
/// A method to learn the structure of the network.
///
/// # Arguments
///
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// use reCTBN::parameter_learning::BayesianApproach;
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
/// #
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("A"));
/// # domain.insert(String::from("B"));
/// # domain.insert(String::from("C"));
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
/// # //Create the node n1 using the parameters
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("D"));
/// # domain.insert(String::from("E"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("G"));
/// # domain.insert(String::from("H"));
/// # domain.insert(String::from("I"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # // Initialize a ctbn
/// # let mut net = CtbnNetwork::new();
/// #
/// # // Add the nodes and their edges
/// # let n1 = net.add_node(n1).unwrap();
/// # let n2 = net.add_node(n2).unwrap();
/// # let n3 = net.add_node(n3).unwrap();
/// # net.add_edge(n1, n2);
/// # net.add_edge(n1, n3);
/// # net.add_edge(n2, n3);
/// #
/// # match &mut net.get_node_mut(n1) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-3.0, 2.0, 1.0],
/// # [1.5, -2.0, 0.5],
/// # [0.4, 0.6, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n2) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.5],
/// # [3.0, -4.0, 1.0],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 4.0],
/// # [1.5, -2.0, 0.5],
/// # [3.0, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.0, 0.1, 0.9],
/// # [2.0, -2.5, 0.5],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n3) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.3, 0.2],
/// # [0.5, -4.0, 2.5, 1.0],
/// # [2.5, 0.5, -4.0, 1.0],
/// # [0.7, 0.2, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 3.0, 1.0],
/// # [1.5, -3.0, 0.5, 1.0],
/// # [2.0, 1.3, -5.0, 1.7],
/// # [2.5, 0.5, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.3, 0.1, 0.9],
/// # [1.4, -4.0, 0.5, 2.1],
/// # [1.0, 1.5, -3.0, 0.5],
/// # [0.4, 0.3, 0.1, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.7, 0.3],
/// # [1.3, -5.9, 2.7, 1.9],
/// # [2.0, 1.5, -4.0, 0.5],
/// # [0.2, 0.7, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 1.0, 2.0, 3.0],
/// # [0.5, -3.0, 1.0, 1.5],
/// # [1.4, 2.1, -4.3, 0.8],
/// # [0.5, 1.0, 2.5, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.9, 0.3, 0.1],
/// # [0.1, -1.3, 0.2, 1.0],
/// # [0.5, 1.0, -3.0, 1.5],
/// # [0.1, 0.4, 0.3, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.6, 0.4],
/// # [2.6, -7.1, 1.4, 3.1],
/// # [5.0, 1.0, -8.0, 2.0],
/// # [1.4, 0.4, 0.2, -2.0]
/// # ],
/// # [
/// # [-3.0, 1.0, 1.5, 0.5],
/// # [3.0, -6.0, 1.0, 2.0],
/// # [0.3, 0.5, -1.9, 1.1],
/// # [5.0, 1.0, 2.0, -8.0]
/// # ],
/// # [
/// # [-2.6, 0.6, 0.2, 1.8],
/// # [2.0, -6.0, 3.0, 1.0],
/// # [0.1, 0.5, -1.3, 0.7],
/// # [0.8, 0.6, 0.2, -1.6]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # // Generate the trajectory
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
///
/// // Initialize the hypothesis tests to pass to the CTPC with their
/// // respective significance level `alpha`
/// let f = F::new(1e-6);
/// let chi_sq = ChiSquare::new(1e-4);
/// // Use the bayesian approach to learn the parameters
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
///
/// //Initialize CTPC
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
///
/// // Learn the structure of the network from the generated trajectory
/// let net = ctpc.fit_transform(net, &data);
/// #
/// # // Compare the generated network with the original one
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
/// ```
pub struct CTPC<P: ParameterLearning> {
parameter_learning: P,
Ftest: F,
Chi2test: ChiSquare,
}
impl<P: ParameterLearning> CTPC<P> {
pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> {
CTPC {
parameter_learning,
Ftest,
Chi2test,
}
}
}
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where
T: process::NetworkProcess,
{
//Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
panic!("Dataset and Network must have the same number of variables.")
}
//Make the network mutable.
let mut net = net;
net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| {
let mut cache = Cache::new(&self.parameter_learning);
let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices()
.into_iter()
.filter(|x| x != &child_node)
.collect();
let mut separation_set_size = 0;
while separation_set_size < candidate_parent_set.len() {
let mut candidate_parent_set_TMP = candidate_parent_set.clone();
for parent_node in candidate_parent_set.iter() {
for separation_set in candidate_parent_set
.iter()
.filter(|x| x != &parent_node)
.map(|x| *x)
.combinations(separation_set_size)
{
let separation_set = separation_set.into_iter().collect();
if self.Ftest.call(
&net,
child_node,
*parent_node,
&separation_set,
dataset,
&mut cache,
) && self.Chi2test.call(
&net,
child_node,
*parent_node,
&separation_set,
dataset,
&mut cache,
) {
candidate_parent_set_TMP.remove(parent_node);
break;
}
}
}
candidate_parent_set = candidate_parent_set_TMP;
separation_set_size += 1;
}
(child_node, candidate_parent_set)
}));
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
}
net
}
}
//pub struct CTPC {
//
//}

@ -1,13 +1,10 @@
//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc...
use std::collections::BTreeSet;
use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use statrs::distribution::{ChiSquared, ContinuousCDF};
use crate::params::*;
use crate::structure_learning::constraint_based_algorithm::Cache;
use crate::{parameter_learning, process, tools::Dataset};
use crate::{network, parameter_learning};
pub trait HypothesisTest {
fn call<T, P>(
@ -16,167 +13,23 @@ pub trait HypothesisTest {
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut Cache<P>,
cache: &mut parameter_learning::Cache<P>,
) -> bool
where
T: process::NetworkProcess,
T: network::Network,
P: parameter_learning::ParameterLearning;
}
/// Does the chi-squared test (χ2 test).
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct ChiSquare {
alpha: f64,
}
/// Does the F-test.
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct F {
alpha: f64,
}
impl F {
pub fn new(alpha: f64) -> F {
F { alpha }
}
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices(
&self,
i: usize,
M1: &Array3<usize>,
cim_1: &Array3<f64>,
j: usize,
M2: &Array3<usize>,
cim_2: &Array3<f64>,
) -> bool {
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
let cim_1 = cim_1.index_axis(Axis(0), i);
let cim_2 = cim_2.index_axis(Axis(0), j);
let r1 = M1.sum_axis(Axis(1));
let r2 = M2.sum_axis(Axis(1));
let q1 = cim_1.diag();
let q2 = cim_2.diag();
for idx in 0..r1.shape()[0] {
let s = q2[idx] / q1[idx];
let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap();
let s = F.cdf(s);
let lim_sx = self.alpha / 2.0;
let lim_dx = 1.0 - (self.alpha / 2.0);
if s < lim_sx || s > lim_dx {
return false;
}
}
true
}
}
impl HypothesisTest for F {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,
P: parameter_learning::ParameterLearning,
{
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node,
};
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
Params::DiscreteStatesContinousTime(node) => node,
};
let partial_cardinality_product: usize = extended_separation_set
.iter()
.take_while(|x| **x != parent_node)
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ (idx_M_big
/ (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product;
if !self.compare_matrices(
idx_M_small,
P_small.get_transitions().as_ref().unwrap(),
P_small.get_cim().as_ref().unwrap(),
idx_M_big,
P_big.get_transitions().as_ref().unwrap(),
P_big.get_cim().as_ref().unwrap(),
) {
return false;
}
}
return true;
}
}
pub struct F {}
impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha }
}
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices(
&self,
i: usize,
@ -189,8 +42,26 @@ impl ChiSquare {
// continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
//
// M = M M = M
// 1 xx'|s 2 xx'|y,s
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
// __________________
// / ===
// / \ M
// / / xx'|s
// / ===
// / x'ϵVal /X \
// / \ i/ 1
//K = / ------------------ L = -
// / === K
// / \ M
// / / xx'|y,s
// / ===
// / x'ϵVal /X \
// \ / \ i/
// \/
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
let K = K.mapv(f64::sqrt);
// Reshape to column vector.
@ -199,12 +70,26 @@ impl ChiSquare {
K.into_shape((n, 1)).unwrap()
};
let L = 1.0 / &K;
// ===== 2
// \ (K . M - L . M)
// \ 2 1
// / ---------------
// / M + M
// ===== 2 1
// x'ϵVal /X \
// \ i/
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1);
println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1));
println!("K*M2: {:?}", (K * &M2));
println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
ret
println!("CHI^2: {:?}", n);
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x)));
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha))
}
}
@ -215,27 +100,26 @@ impl HypothesisTest for ChiSquare {
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut Cache<P>,
cache: &mut parameter_learning::Cache<P>,
) -> bool
where
T: process::NetworkProcess,
T: network::Network,
P: parameter_learning::ParameterLearning,
{
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM
// di dimensione nxn
// (CIM, M, T)
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node,
};
//
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node,
};
// Commentare qui
let partial_cardinality_product: usize = extended_separation_set
.iter()
.take_while(|x| **x != parent_node)

@ -1,13 +1,8 @@
//! Module containing score based algorithms like Hill Climbing and Tabu Search.
use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::{process, tools::Dataset};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
use crate::{network, tools};
pub struct HillClimbing<S: ScoreFunction> {
score_function: S,
@ -24,9 +19,9 @@ impl<S: ScoreFunction> HillClimbing<S> {
}
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T
where
T: process::NetworkProcess,
T: network::Network,
{
//Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
@ -39,9 +34,8 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
//Reset the adj matrix
net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
//Iterate over each node to learn their parent set.
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| {
for node in net.get_node_indices() {
//Initialize an empty parent set.
let mut parent_set: BTreeSet<usize> = BTreeSet::new();
//Compute the score for the empty parent set
@ -80,14 +74,10 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
}
}
}
(node, parent_set)
}));
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
//Apply the learned parent_set to the network struct.
parent_set.iter().for_each(|p| net.add_edge(*p, node));
}
return net;
}
}

@ -1,13 +1,11 @@
//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
use std::collections::BTreeSet;
use ndarray::prelude::*;
use statrs::function::gamma;
use crate::{parameter_learning, params, process, tools};
use crate::{network, parameter_learning, params, tools};
pub trait ScoreFunction: Sync {
pub trait ScoreFunction {
fn call<T>(
&self,
net: &T,
@ -16,7 +14,7 @@ pub trait ScoreFunction: Sync {
dataset: &tools::Dataset,
) -> f64
where
T: process::NetworkProcess;
T: network::Network;
}
pub struct LogLikelihood {
@ -41,7 +39,7 @@ impl LogLikelihood {
dataset: &tools::Dataset,
) -> (f64, Array3<usize>)
where
T: process::NetworkProcess,
T: network::Network,
{
//Identify the type of node used
match &net.get_node(node) {
@ -100,7 +98,7 @@ impl ScoreFunction for LogLikelihood {
dataset: &tools::Dataset,
) -> f64
where
T: process::NetworkProcess,
T: network::Network,
{
self.compute_score(net, node, parent_set, dataset).0
}
@ -127,7 +125,7 @@ impl ScoreFunction for BIC {
dataset: &tools::Dataset,
) -> f64
where
T: process::NetworkProcess,
T: network::Network,
{
//Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);

@ -1,15 +1,7 @@
//! Contains commonly used methods used across the crate.
use ndarray::prelude::*;
use std::ops::{DivAssign, MulAssign, Range};
use ndarray::{Array, Array1, Array2, Array3, Axis};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait;
use crate::process::NetworkProcess;
use crate::sampling::{ForwardSampler, Sampler};
use crate::{params, process};
use crate::{network, params};
#[derive(Clone)]
pub struct Trajectory {
@ -36,7 +28,6 @@ impl Trajectory {
}
}
#[derive(Clone)]
pub struct Dataset {
trajectories: Vec<Trajectory>,
}
@ -59,7 +50,7 @@ impl Dataset {
}
}
pub fn trajectory_generator<T: process::NetworkProcess>(
pub fn trajectory_generator<T: network::Network>(
net: &T,
n_trajectories: u64,
t_end: f64,
@ -69,25 +60,26 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
let mut trajectories: Vec<Trajectory> = Vec::new();
//Random Generator object
let mut sampler = ForwardSampler::new(net, seed, None);
let mut sampler = ForwardSampler::new(net, seed);
//Each iteration generate one trajectory
for _ in 0..n_trajectories {
//History of all the moments in which something changed
let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform
//distribution.
let mut events: Vec<process::NetworkProcessState> = Vec::new();
let mut events: Vec<Vec<params::StateType>> = Vec::new();
//Current Time and Current State
let mut sample = sampler.next().unwrap();
let (mut t, mut current_state) = sampler.next().unwrap();
//Generate new samples until ending time is reached.
while sample.t < t_end {
time.push(sample.t);
events.push(sample.state);
sample = sampler.next().unwrap();
while t < t_end {
time.push(t);
events.push(current_state);
(t, current_state) = sampler.next().unwrap();
}
let current_state = events.last().unwrap().clone();
current_state = events.last().unwrap().clone();
events.push(current_state);
//Add t_end as last time.
@ -113,243 +105,3 @@ pub fn trajectory_generator<T: process::NetworkProcess>(
//Return a dataset object with the sampled trajectories.
Dataset::new(trajectories)
}
pub trait RandomGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> Self;
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T);
}
/// Graph Generator using an uniform distribution.
///
/// A method to generate a random graph with edges uniformly distributed.
///
/// # Arguments
///
/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`.
/// * `rng` - is the random numbers generator.
///
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// use reCTBN::tools::UniformGraphGenerator;
/// use reCTBN::tools::RandomGraphGenerator;
/// # let mut net = CtbnNetwork::new();
/// # let nodes_cardinality = 8;
/// # let domain_cardinality = 4;
/// # for node in 0..nodes_cardinality {
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # for dvalue in 0..domain_cardinality {
/// # domain.insert(dvalue.to_string());
/// # }
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new(
/// # node.to_string(),
/// # domain
/// # );
/// # //Create the node using the parameters
/// # let node = DiscreteStatesContinousTime(param);
/// # // Add the node to the network
/// # net.add_node(node).unwrap();
/// # }
///
/// // Initialize the Graph Generator using the one with an
/// // uniform distribution
/// let density = 1.0/3.0;
/// let seed = Some(7641630759785120);
/// let mut structure_generator = UniformGraphGenerator::new(
/// density,
/// seed
/// );
///
/// // Generate the graph directly on the network
/// structure_generator.generate_graph(&mut net);
/// # // Count all the edges generated in the network
/// # let mut edges = 0;
/// # for node in net.get_node_indices(){
/// # edges += net.get_children_set(node).len()
/// # }
/// # // Number of all the nodes in the network
/// # let nodes = net.get_node_indices().len() as f64;
/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
/// # // ±10% of tolerance
/// # let tolerance = ((expected_edges as f64)*0.10) as usize;
/// # // As the way `generate_graph()` is implemented we can only reasonably
/// # // expect the number of edges to be somewhere around the expected value.
/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
/// ```
pub struct UniformGraphGenerator {
density: f64,
rng: ChaCha8Rng,
}
impl RandomGraphGenerator for UniformGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> UniformGraphGenerator {
if density < 0.0 || density > 1.0 {
panic!(
"Density value must be between 1.0 and 0.0, got {}.",
density
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformGraphGenerator { density, rng }
}
/// Generate an uniformly distributed graph.
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T) {
net.initialize_adj_matrix();
let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx {
for child in 0..last_node_idx {
if parent != child {
if self.rng.gen_bool(self.density) {
net.add_edge(parent, child);
}
}
}
}
}
}
pub trait RandomParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> Self;
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T);
}
/// Parameters Generator using an uniform distribution.
///
/// A method to generate random parameters uniformly distributed.
///
/// # Arguments
///
/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`.
/// * `rng` - is the random numbers generator.
///
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::params::ParamsTrait;
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::tools::RandomGraphGenerator;
/// # use reCTBN::tools::UniformGraphGenerator;
/// use reCTBN::tools::RandomParametersGenerator;
/// use reCTBN::tools::UniformParametersGenerator;
/// # let mut net = CtbnNetwork::new();
/// # let nodes_cardinality = 8;
/// # let domain_cardinality = 4;
/// # for node in 0..nodes_cardinality {
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # for dvalue in 0..domain_cardinality {
/// # domain.insert(dvalue.to_string());
/// # }
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new(
/// # node.to_string(),
/// # domain
/// # );
/// # //Create the node using the parameters
/// # let node = DiscreteStatesContinousTime(param);
/// # // Add the node to the network
/// # net.add_node(node).unwrap();
/// # }
/// #
/// # // Initialize the Graph Generator using the one with an
/// # // uniform distribution
/// # let mut structure_generator = UniformGraphGenerator::new(
/// # 1.0/3.0,
/// # Some(7641630759785120)
/// # );
/// #
/// # // Generate the graph directly on the network
/// # structure_generator.generate_graph(&mut net);
///
/// // Initialize the parameters generator with uniform distributin
/// let mut cim_generator = UniformParametersGenerator::new(
/// 0.0..7.0,
/// Some(7641630759785120)
/// );
///
/// // Generate CIMs with uniformly distributed parameters.
/// cim_generator.generate_parameters(&mut net);
/// #
/// # for node in net.get_node_indices() {
/// # assert_eq!(
/// # Ok(()),
/// # net.get_node(node).validate_params()
/// # );
/// }
/// ```
pub struct UniformParametersGenerator {
interval: Range<f64>,
rng: ChaCha8Rng,
}
impl RandomParametersGenerator for UniformParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> UniformParametersGenerator {
if interval.start < 0.0 || interval.end < 0.0 {
panic!(
"Interval must be entirely less or equal than 0, got {}..{}.",
interval.start, interval.end
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformParametersGenerator { interval, rng }
}
/// Generate CIMs with uniformly distributed parameters.
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) {
for node in net.get_node_indices() {
let parent_set_state_space_cardinality: usize = net
.get_parent_set(node)
.iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
let node_domain_cardinality = param.get_reserved_space_as_parent();
let mut cim = Array3::<f64>::from_shape_fn(
(
parent_set_state_space_cardinality,
node_domain_cardinality,
node_domain_cardinality,
),
|_| self.rng.gen(),
);
cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0);
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone())
});
x.mul_assign(&diag.clone().insert_axis(Axis(1)));
// Recomputing the diagonal in order to reduce the issues caused by the
// loss of precision when validating the parameters.
let diag_sum = -x.sum_axis(Axis(1));
x.diag_mut().assign(&diag_sum)
});
param.set_cim_unchecked(cim);
}
}
}
}
}

@ -1,12 +1,9 @@
mod utils;
use std::collections::BTreeSet;
use approx::AbsDiffEq;
use ndarray::arr3;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::params::{self, ParamsTrait};
use reCTBN::process::NetworkProcess;
use reCTBN::process::{ctbn::*};
use utils::generate_discrete_time_continous_node;
#[test]
@ -132,245 +129,3 @@ fn compute_index_from_custom_parent_set() {
);
assert_eq!(2, idx);
}
#[test]
fn simple_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.initialize_adj_matrix();
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0);
let p_ctbn = p_ctbn.get_cim().as_ref().unwrap();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
}
#[test]
fn chain_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}
#[test]
fn chainfork_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let n4 = net
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2))
.unwrap();
net.add_edge(n1, n3);
net.add_edge(n2, n3);
net.add_edge(n3, n4);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n4) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}

@ -1,127 +0,0 @@
mod utils;
use std::collections::BTreeSet;
use reCTBN::{
params,
params::ParamsTrait,
process::{ctmp::*, NetworkProcess},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn define_simple_ctmp() {
let _ = CtmpProcess::new();
assert!(true);
}
#[test]
fn add_node_to_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
}
#[test]
fn add_two_nodes_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
match n2 {
Ok(_) => assert!(false),
Err(_) => assert!(true),
};
}
#[test]
#[should_panic]
fn add_edge_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
net.add_edge(0, 1)
}
#[test]
fn childen_and_parents() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(0, net.get_parent_set(0).len());
assert_eq!(0, net.get_children_set(0).len());
}
#[test]
#[should_panic]
fn get_childen_panic() {
let net = CtmpProcess::new();
net.get_children_set(0);
}
#[test]
#[should_panic]
fn get_childen_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_children_set(1);
}
#[test]
#[should_panic]
fn get_parent_panic() {
let net = CtmpProcess::new();
net.get_parent_set(0);
}
#[test]
#[should_panic]
fn get_parent_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_parent_set(1);
}
#[test]
fn compute_index_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]);
assert_eq!(6, idx);
}
#[test]
#[should_panic]
fn compute_index_from_custom_parent_set_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let _idx = net.get_param_index_from_custom_parent_set(
&vec![params::StateType::Discrete(6)],
&BTreeSet::from([0])
);
}

@ -2,11 +2,10 @@
mod utils;
use ndarray::arr3;
use reCTBN::process::ctbn::*;
use reCTBN::process::NetworkProcess;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::parameter_learning::*;
use reCTBN::params;
use reCTBN::params::Params::DiscreteStatesContinousTime;
use reCTBN::tools::*;
use utils::*;
@ -67,78 +66,18 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
));
}
fn generate_nodes(
net: &mut CtbnNetwork,
nodes_cardinality: usize,
nodes_domain_cardinality: usize
) {
for node_label in 0..nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
}
fn learn_binary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 2);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test]
fn learn_binary_cim_MLE() {
let mle = MLE {};
learn_binary_cim(mle);
}
#[test]
fn learn_binary_cim_MLE_gen() {
let mle = MLE {};
learn_binary_cim_gen(mle);
}
#[test]
fn learn_binary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim(ba);
}
#[test]
fn learn_binary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim_gen(ba);
}
fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -216,63 +155,18 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
));
}
fn learn_ternary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
4.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test]
fn learn_ternary_cim_MLE() {
let mle = MLE {};
learn_ternary_cim(mle);
}
#[test]
fn learn_ternary_cim_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_gen(mle);
}
#[test]
fn learn_ternary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim(ba);
}
#[test]
fn learn_ternary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_gen(ba);
}
fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -340,63 +234,18 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
));
}
fn learn_ternary_cim_no_parents_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(0) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 0, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test]
fn learn_ternary_cim_no_parents_MLE() {
let mle = MLE {};
learn_ternary_cim_no_parents(mle);
}
#[test]
fn learn_ternary_cim_no_parents_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_no_parents_gen(mle);
}
#[test]
fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents(ba);
}
#[test]
fn learn_ternary_cim_no_parents_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents_gen(ba);
}
fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -583,66 +432,14 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
));
}
fn learn_mixed_discrete_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
net.add_edge(0, 2);
net.add_edge(1, 2);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..8.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(2) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 2, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.2
)
);
}
#[test]
fn learn_mixed_discrete_cim_MLE() {
let mle = MLE {};
learn_mixed_discrete_cim(mle);
}
#[test]
fn learn_mixed_discrete_cim_MLE_gen() {
let mle = MLE {};
learn_mixed_discrete_cim_gen(mle);
}
#[test]
fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim(ba);
}
#[test]
fn learn_mixed_discrete_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim_gen(ba);
}

@ -1,122 +0,0 @@
mod utils;
use approx::assert_abs_diff_eq;
use ndarray::*;
use reCTBN::{
params,
process::{ctbn::*, NetworkProcess, NetworkProcessState},
reward::{reward_evaluation::*, reward_function::*, *},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn simple_factored_reward_function_binary_node_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1)
.assign(&arr1(&[3.0, 3.0]));
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap();
}
}
net.initialize_adj_matrix();
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2);
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2);
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215));
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
}
#[test]
fn simple_factored_reward_function_chain_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap();
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n2)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n3)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
let s000: NetworkProcessState = vec![
params::StateType::Discrete(1),
params::StateType::Discrete(0),
params::StateType::Discrete(0),
];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1);
}

@ -1,117 +0,0 @@
mod utils;
use ndarray::*;
use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params};
#[test]
fn simple_factored_reward_function_binary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
#[test]
fn simple_factored_reward_function_ternary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
}
#[test]
fn factored_reward_function_two_nodes() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
}

@ -4,12 +4,10 @@ mod utils;
use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3};
use reCTBN::process::ctbn::*;
use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::params;
use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::constraint_based_algorithm::*;
use reCTBN::structure_learning::score_based_algorithm::*;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::StructureLearningAlgorithm;
@ -108,7 +106,7 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
}
}
let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let mut net = CtbnNetwork::new();
let _n1 = net
@ -117,50 +115,6 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
let _net = sl.fit_transform(net, &data);
}
fn generate_nodes(
net: &mut CtbnNetwork,
nodes_cardinality: usize,
nodes_domain_cardinality: usize
) {
for node_label in 0..nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
}
fn check_compatibility_between_dataset_and_network_gen<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
let mut net = CtbnNetwork::new();
let _n1 = net
.add_node(
generate_discrete_time_continous_node(String::from("0"),
3)
).unwrap();
let _net = sl.fit_transform(net, &data);
}
#[test]
#[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
@ -169,14 +123,6 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
check_compatibility_between_dataset_and_network(hl);
}
#[test]
#[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
check_compatibility_between_dataset_and_network_gen(hl);
}
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -234,25 +180,6 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
}
fn learn_ternary_net_2_nodes_gen<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0);
@ -260,13 +187,6 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
learn_ternary_net_2_nodes(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_ternary_net_2_nodes_gen(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
let bic = BIC::new(1, 1.0);
@ -274,13 +194,6 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
learn_ternary_net_2_nodes(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_ternary_net_2_nodes_gen(hl);
}
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -405,30 +318,6 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
return (net, data);
}
fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
net.add_edge(0, 2);
net.add_edge(1, 2);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
return (net, data);
}
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data);
@ -437,14 +326,6 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) {
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
}
fn learn_mixed_discrete_net_3_nodes_gen<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0);
@ -452,13 +333,6 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
learn_mixed_discrete_net_3_nodes(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_mixed_discrete_net_3_nodes_gen(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
let bic = BIC::new(1, 1.0);
@ -466,13 +340,6 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
learn_mixed_discrete_net_3_nodes(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_mixed_discrete_net_3_nodes_gen(hl);
}
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data);
@ -481,14 +348,6 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgo
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
}
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
let ll = LogLikelihood::new(1, 1.0);
@ -496,13 +355,6 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() {
let bic = BIC::new(1, 1.0);
@ -510,13 +362,6 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
}
#[test]
pub fn chi_square_compare_matrices() {
let i: usize = 1;
@ -545,7 +390,7 @@ pub fn chi_square_compare_matrices() {
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(1e-4);
let chi_sq = ChiSquare::new(0.1);
assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
}
@ -575,7 +420,7 @@ pub fn chi_square_compare_matrices_2() {
[ 400, 0, 600],
[ 700, 800, 0]]
]);
let chi_sq = ChiSquare::new(1e-4);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
}
@ -607,86 +452,6 @@ pub fn chi_square_compare_matrices_3() {
[ 700, 800, 0]
],
]);
let chi_sq = ChiSquare::new(1e-4);
let chi_sq = ChiSquare::new(0.1);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
}
#[test]
pub fn chi_square_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(&parameter_learning);
let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache));
}
#[test]
pub fn f_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(&parameter_learning);
let f = F::new(1e-6);
assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache));
}
#[test]
pub fn learn_ternary_net_2_nodes_ctpc() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_ternary_net_2_nodes(ctpc);
}
#[test]
pub fn learn_ternary_net_2_nodes_ctpc_gen() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_ternary_net_2_nodes_gen(ctpc);
}
#[test]
fn learn_mixed_discrete_net_3_nodes_ctpc() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_mixed_discrete_net_3_nodes(ctpc);
}
#[test]
fn learn_mixed_discrete_net_3_nodes_ctpc_gen() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_mixed_discrete_net_3_nodes_gen(ctpc);
}

@ -1,15 +1,9 @@
use std::ops::Range;
use ndarray::{arr1, arr2, arr3};
use reCTBN::params::ParamsTrait;
use reCTBN::process::ctbn::*;
use reCTBN::process::ctmp::*;
use reCTBN::process::NetworkProcess;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::params;
use reCTBN::tools::*;
use utils::*;
#[macro_use]
extern crate approx;
@ -88,164 +82,3 @@ fn dataset_wrong_shape() {
let t2 = Trajectory::new(time, events);
Dataset::new(vec![t1, t2]);
}
#[test]
#[should_panic]
fn uniform_graph_generator_wrong_density_1() {
let density = 2.1;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
#[test]
#[should_panic]
fn uniform_graph_generator_wrong_density_2() {
let density = -0.5;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
#[test]
fn uniform_graph_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
}
#[test]
fn uniform_graph_generator_generate_graph_ctbn() {
let mut net = CtbnNetwork::new();
let nodes_cardinality = 0..=100;
let nodes_domain_cardinality = 2;
for node_label in nodes_cardinality {
net.add_node(
utils::generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
Some(7641630759785120)
);
structure_generator.generate_graph(&mut net);
let mut edges = 0;
for node in net.get_node_indices(){
edges += net.get_children_set(node).len()
}
let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance
// As the way `generate_graph()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
}
#[test]
#[should_panic]
fn uniform_graph_generator_generate_graph_ctmp() {
let mut net = CtmpProcess::new();
let node_label = String::from("0");
let node_domain_cardinality = 4;
net.add_node(
generate_discrete_time_continous_node(
node_label,
node_domain_cardinality
)
).unwrap();
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
Some(7641630759785120)
);
structure_generator.generate_graph(&mut net);
}
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_1() {
let interval: Range<f64> = -2.0..-5.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
None
);
}
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_2() {
let interval: Range<f64> = -1.0..0.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
None
);
}
#[test]
fn uniform_parameters_generator_right_densities_ctbn() {
let mut net = CtbnNetwork::new();
let nodes_cardinality = 0..=3;
let nodes_domain_cardinality = 9;
for node_label in nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
let density = 1.0/3.0;
let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
seed
);
structure_generator.generate_graph(&mut net);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
seed
);
cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
assert_eq!(
Ok(()),
net.get_node(node).validate_params()
);
}
}
#[test]
fn uniform_parameters_generator_right_densities_ctmp() {
let mut net = CtmpProcess::new();
let node_label = String::from("0");
let node_domain_cardinality = 4;
net.add_node(
generate_discrete_time_continous_node(
node_label,
node_domain_cardinality
)
).unwrap();
let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
seed
);
cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
assert_eq!(
Ok(()),
net.get_node(node).validate_params()
);
}
}

@ -0,0 +1,69 @@
name: CI
on:
push:
branches:
- main
- master
pull_request:
jobs:
linux:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: messense/maturin-action@v1
with:
manylinux: auto
command: build
args: --release --sdist -o dist --find-interpreter
- name: Upload wheels
uses: actions/upload-artifact@v2
with:
name: wheels
path: dist
windows:
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- uses: messense/maturin-action@v1
with:
command: build
args: --release -o dist --find-interpreter
- name: Upload wheels
uses: actions/upload-artifact@v2
with:
name: wheels
path: dist
macos:
runs-on: macos-latest
steps:
- uses: actions/checkout@v3
- uses: messense/maturin-action@v1
with:
command: build
args: --release -o dist --universal2 --find-interpreter
- name: Upload wheels
uses: actions/upload-artifact@v2
with:
name: wheels
path: dist
release:
name: Release
runs-on: ubuntu-latest
if: "startsWith(github.ref, 'refs/tags/')"
needs: [ macos, windows, linux ]
steps:
- uses: actions/download-artifact@v2
with:
name: wheels
- name: Publish to PyPI
uses: messense/maturin-action@v1
env:
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
with:
command: upload
args: --skip-existing *

@ -0,0 +1,72 @@
/target
# Byte-compiled / optimized / DLL files
__pycache__/
.pytest_cache/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.Python
.venv/
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
include/
man/
venv/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
pip-selfcheck.json
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
.DS_Store
# Sphinx documentation
docs/_build/
# PyCharm
.idea/
# VSCode
.vscode/
# Pyenv
.python-version

@ -0,0 +1,15 @@
[package]
name = "reCTBNpy"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib"]
[workspace]
[dependencies]
pyo3 = { version = "0.17.1", features = ["extension-module"] }
numpy = "*"
reCTBN = { path="../reCTBN" }

@ -0,0 +1,14 @@
[build-system]
requires = ["maturin>=0.13,<0.14"]
build-backend = "maturin"
[project]
name = "reCTBNpy"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]

@ -0,0 +1,21 @@
use pyo3::prelude::*;
pub mod pyctbn;
pub mod pyparams;
pub mod pytools;
/// A Python module implemented in Rust.
#[pymodule]
fn reCTBNpy(py: Python, m: &PyModule) -> PyResult<()> {
let network_module = PyModule::new(py, "network")?;
network_module.add_class::<pyctbn::PyCtbnNetwork>()?;
m.add_submodule(network_module)?;
let params_module = PyModule::new(py, "params")?;
params_module.add_class::<pyparams::PyDiscreteStateContinousTime>()?;
params_module.add_class::<pyparams::PyStateType>()?;
params_module.add_class::<pyparams::PyParams>()?;
m.add_submodule(params_module)?;
Ok(())
}

@ -0,0 +1,68 @@
use std::collections::BTreeSet;
use crate::{pyparams, pytools};
use pyo3::prelude::*;
use reCTBN::{ctbn, network::Network, params, tools, params::Params};
#[pyclass]
pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork);
#[pymethods]
impl PyCtbnNetwork {
#[new]
pub fn new() -> Self {
PyCtbnNetwork(ctbn::CtbnNetwork::new())
}
pub fn add_node(&mut self, n: pyparams::PyParams) {
self.0.add_node(n.0);
}
pub fn get_number_of_nodes(&self) -> usize {
self.0.get_number_of_nodes()
}
pub fn add_edge(&mut self, parent: usize, child: usize) {
self.0.add_edge(parent, child);
}
pub fn get_node_indices(&self) -> BTreeSet<usize> {
self.0.get_node_indices().collect()
}
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.0.get_parent_set(node)
}
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.0.get_children_set(node)
}
pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) {
match &n.0 {
Params::DiscreteStatesContinousTime(new_p) => {
if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){
p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap();
}
else {
panic!("Node type mismatch")
}
}
}
}
pub fn trajectory_generator(
&self,
n_trajectories: u64,
t_end: f64,
seed: Option<u64>,
) -> pytools::PyDataset {
pytools::PyDataset(tools::trajectory_generator(
&self.0,
n_trajectories,
t_end,
seed,
))
}
}

@ -0,0 +1,102 @@
use numpy::{self, ToPyArray};
use pyo3::{exceptions::PyValueError, prelude::*};
use reCTBN::params::{self, ParamsTrait};
use std::collections::BTreeSet;
pub struct PyParamsError(params::ParamsError);
impl From<PyParamsError> for PyErr {
fn from(error: PyParamsError) -> Self {
PyValueError::new_err(error.0.to_string())
}
}
impl From<params::ParamsError> for PyParamsError {
fn from(other: params::ParamsError) -> Self {
Self(other)
}
}
#[pyclass]
pub struct PyStateType(pub params::StateType);
#[pyclass]
#[derive(Clone)]
pub struct PyParams(pub params::Params);
#[pymethods]
impl PyParams {
#[staticmethod]
pub fn new_discrete_state_continous_time(p: PyDiscreteStateContinousTime) -> Self{
PyParams(params::Params::DiscreteStatesContinousTime(p.0))
}
pub fn get_reserved_space_as_parent(&self) -> usize {
self.0.get_reserved_space_as_parent()
}
pub fn get_label(&self) -> String {
self.0.get_label().to_string()
}
}
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements:
/// - **domain**: an ordered and exhaustive set of possible states
/// - **cim**: Conditional Intensity Matrix
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
/// learning task and are composed by:
/// - **transitions**: number of transitions from one state to another given a specific
/// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
#[derive(Clone)]
#[pyclass]
pub struct PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams);
#[pymethods]
impl PyDiscreteStateContinousTime {
#[new]
pub fn new(label: String, domain: BTreeSet<String>) -> Self {
PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams::new(label, domain))
}
pub fn get_cim<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<f64>> {
match self.0.get_cim() {
Some(x) => Some(x.to_pyarray(py)),
None => None,
}
}
pub fn set_cim<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<f64>) -> Result<(), PyParamsError> {
self.0.set_cim(cim.as_array().to_owned())?;
Ok(())
}
pub fn get_transitions<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<usize>> {
match self.0.get_transitions() {
Some(x) => Some(x.to_pyarray(py)),
None => None,
}
}
pub fn set_transitions<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<usize>){
self.0.set_transitions(cim.as_array().to_owned());
}
pub fn get_residence_time<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray2<f64>> {
match self.0.get_residence_time() {
Some(x) => Some(x.to_pyarray(py)),
None => None,
}
}
pub fn set_residence_time<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray2<f64>) {
self.0.set_residence_time(cim.as_array().to_owned());
}
}

@ -0,0 +1,50 @@
use numpy::{self, ToPyArray};
use pyo3::{exceptions::PyValueError, prelude::*};
use reCTBN::{tools, network};
#[pyclass]
#[derive(Clone)]
pub struct PyTrajectory(pub tools::Trajectory);
#[pymethods]
impl PyTrajectory {
#[new]
pub fn new(
time: numpy::PyReadonlyArray1<f64>,
events: numpy::PyReadonlyArray2<usize>,
) -> PyTrajectory {
PyTrajectory(tools::Trajectory::new(
time.as_array().to_owned(),
events.as_array().to_owned(),
))
}
pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1<f64> {
self.0.get_time().to_pyarray(py)
}
pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2<usize> {
self.0.get_events().to_pyarray(py)
}
}
#[pyclass]
pub struct PyDataset(pub tools::Dataset);
#[pymethods]
impl PyDataset {
#[new]
pub fn new(trajectories: Vec<PyTrajectory>) -> PyDataset {
PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect()))
}
pub fn get_number_of_trajectories(&self) -> usize {
self.0.get_trajectories().len()
}
pub fn get_trajectory(&self, idx: usize) -> PyTrajectory {
PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone())
}
}
Loading…
Cancel
Save