Compare commits

...

91 Commits

Author SHA1 Message Date
Meliurwen f6873c52c2
Merge branch 'meta-license' into 'dev' 2 years ago
Meliurwen d45c96a38d
Added License section on the README 2 years ago
Meliurwen 57ae851470
Merge branch 'dev' into 'meta-license' 2 years ago
Meliurwen 4a9445385c
Double licensing this project with APACHE and MIT 2 years ago
Meliurwen 0eb427e5cf
Merge branch '90-feature-method-to-get-the-adjacency-matrix' into 'dev' 2 years ago
Meliurwen 7a3ac6c9ab
Using `as_ref()` instead of `clone()` in `get_adj_matrix()` 2 years ago
Meliurwen 4fd0ee0407
Function `get_adj_matrix()` is now specific to `CtbnNetwork` only 2 years ago
Meliurwen c4da4ceadd
Added `get_adj_matrix()` 2 years ago
AlessandroBregoli e638a627bb
Merge pull request #87 from AlessandroBregoli/76-feature-reward-evaluation 2 years ago
AlessandroBregoli 776b9aa030 clippy error solved 2 years ago
AlessandroBregoli adb0f99419 Added automatic stopping for MonteCarloReward 2 years ago
Meliurwen b7fc23ed8a
Merge branch '84-feature-synthetic-network-generator' into 'dev' 2 years ago
Meliurwen 0639a755d0
Refactored `generate_parameters` moving some code inside `match` statement 2 years ago
Meliurwen 4884010ea9
Added doctests for `UniformParametersGenerator` and `UniformGraphGenerator` 2 years ago
Meliurwen 430033afdb
Added tests for the learning of parameters using uniform graph and parameters generators as complementary to their handcrafted version 2 years ago
Meliurwen e08d12ac1f
Added tests for structure learning algorithms using uniform graph and parameters generators as complementary to their handcrafted version 2 years ago
Meliurwen a01a9ef201
Recomputing the diagonal when generating parameters to counter the precision loss and increase `f64::EPSILON` calculating its square root instead of multiplying it with the node's `domain_size` 2 years ago
Meliurwen 097dc25030
Added tests for `UniformParametersGenerator` and `UniformGraphGenerator` against `CTMP`, plus some small refactoring to the other tests 2 years ago
Meliurwen 0f61cbee4c
Refactored CIM validation for `UniformParametersGenerator` test 2 years ago
Meliurwen f4e3c98c79
Implemented `UniformParametersGenerator` and its test 2 years ago
Meliurwen d6f0fb9623
WIP: implementing `UniformParametersGenerator` 2 years ago
Meliurwen 434e671f0a
Renamed `UniformRandomGenerator` to `UniformGraphGenerator`, replaced `CtbnNetwork` requirement with `NetworkProcess` in `RandomGraphGenerator`, some related tweaks 2 years ago
Meliurwen 4b994d8a19
Renamed `StructureGen` with `UniformRandomGenerator` and defining the new trait `RandomGraphGenerator` 2 years ago
Meliurwen a077f738ee
Added `StructureGen` struct for generating the structure of a `CtbnNetwork` 2 years ago
Meliurwen 49c2c55f61
Merge branch '82-feature-create-doctest-for-ctpc' into 'dev' 2 years ago
Meliurwen c2df26c3e6
Added docstrings for the F-test and removed some comments 2 years ago
Meliurwen 7ec56914d9
Added doctest for CTPC 2 years ago
Alessandro Bregoli 5d676be180 parallelized re 2 years ago
Alessandro Bregoli ff235b4b77 Parallelize score based strucutre learning 2 years ago
Alessandro Bregoli a104d1fbf9 Merge branch 'dev' into 76-feature-reward-evaluation 2 years ago
Meliurwen 0bd325d349
Merge branch '8-feature-constraint-based-structure-learning-algorithm-for-ctbn' into 'dev' 2 years ago
Meliurwen b8938a934f
Merge branch 'move-cache' into '8-feature-constraint-based-structure-learning-algorithm-for-ctbn' 2 years ago
Meliurwen 5632833963
Moved `Cache` to `constraint_based_algorithm.rs` 2 years ago
Meliurwen 0e1cca0456
Merge branch 'ctpc-parallelization' into '8-feature-constraint-based-structure-learning-algorithm-for-ctbn' 2 years ago
Meliurwen 4d3f9518e4
CTPC parallelization at the nodes level with `rayon` 2 years ago
Meliurwen 867bf02934
Greatly improved memory consumption in cache, stale data no longer pile up 2 years ago
Meliurwen a0da3e2fe8
Fixed formatting issue 2 years ago
Meliurwen ea3e406bf1
Implemented basic cache 2 years ago
Meliurwen 19856195c3
Refactored cache laying grounds for its node-centered implementation changing also its signature, propagated this change and refactored CTPC tests 2 years ago
Meliurwen ea5df7cad6
Solved another issue with `candidate_parent_set` variable in CTPC 2 years ago
Meliurwen 468ebf09cc
WIP: Added `#[derive(Clone)]` to `Dataset` and `Trajectory` 2 years ago
Meliurwen 8d0f9db289
WIP: Added tests for CTPC 2 years ago
Meliurwen 6d42d8a805
Solved issue with `candidate_parent_set` variable in CTPC and added loop to fill the adjacency matrix 2 years ago
Meliurwen 6d952f8c07
Added `itertools` a WIP version of CTPC and some hacky and temporary modifications 2 years ago
AlessandroBregoli 9284ca5dd2 Implemanted NeighborhoodRelativeReward 2 years ago
AlessandroBregoli 414aa31867 Bugfix 2 years ago
AlessandroBregoli 687f19ff1f Added FiniteHorizon 2 years ago
AlessandroBregoli bb239aaa0c Implemented reward_evaluation for an entire process. 2 years ago
AlessandroBregoli cecf16a771 Added sigle state evaluation 2 years ago
AlessandroBregoli 4fc5c1d4b5 Refactored reward module 2 years ago
Meliurwen 6e90458418
Laying grounds for CTPC 2 years ago
Meliurwen cac19b1756
Aligned F-test to the new changes 2 years ago
Meliurwen 2e49df0266
Merge branch 'dev' into '8-feature-constraint-based-structure-learning-algorithm-for-ctbn' 2 years ago
Meliurwen 3f80f07e9f
Merge branch '74-feature-reward-function' into 'dev' 2 years ago
AlessandroBregoli bcb64a161a Mini refactor. Introduced the type alias NetworkProcessState. 2 years ago
AlessandroBregoli 68ef7ea7c3 Added comments 2 years ago
AlessandroBregoli f6015acce9 Added tests 2 years ago
AlessandroBregoli 055eb7088e Implemented FactoredRewardFunction 2 years ago
AlessandroBregoli 1878f687d6 Refactor of sampling 2 years ago
AlessandroBregoli fd3b1ecfea
Merge pull request #73 from AlessandroBregoli/13-feature-add-ctmp 2 years ago
AlessandroBregoli 38e744e034 Fix fmt 2 years ago
AlessandroBregoli 44eaf8713f Fix for clippy 2 years ago
AlessandroBregoli 28ed1a40b3 Fix for clippy 2 years ago
AlessandroBregoli 4a7a8c5fba Added more tests 2 years ago
Meliurwen 3a0151a9f6
Added test for F-test call function 2 years ago
AlessandroBregoli 7c3cba50d4 Implemented amalgamation 2 years ago
Meliurwen a2c5800891
Slight optimization of `F::compare_matrices` 2 years ago
Meliurwen 9fbdf25149
Fixed `chi_square_call` test, the test was passing, but only for pure chance 2 years ago
AlessandroBregoli ed5471c7cf Added ctmp 2 years ago
Meliurwen c08f4e1985
Added F call function 2 years ago
Meliurwen ec72a6a2f9
Defined the `compare_matrices` function for the F-test 2 years ago
Meliurwen 713b8a8013
Merge branch '69-meta-set-a-proper-strategy-for-dependency-requirements' into 'dev' 2 years ago
Meliurwen a92b605daa
In `Cargo.toml` set the exact version for `thiserror` crate and the tilde requirement for all the others 2 years ago
Meliurwen ce139afdb6
Merge branch '66-feature-add-a-first-wave-of-docstrings-to-core-parts-of-the-library' into 'dev' 2 years ago
Meliurwen 0ae2168a94
Commented out some prints in chi2 compare matrices 2 years ago
Meliurwen 832922922a
Fixed error in chi2 compare matrices docstring 2 years ago
Meliurwen 3522e1b6f6
Small formatting fix and rewording of a docstring 2 years ago
Meliurwen 245b3b5d45
Replaced the static docstrings at the crate level in `lib.rs` withthe external file `README.md` 2 years ago
Meliurwen 08623e28d4
Added various docstrings notes to all rust modules 2 years ago
Meliurwen f7165d0345
Added docstrings to `networks.rs` 2 years ago
Meliurwen 9ca8973550
Harmonized some docstrings in `params.rs` 2 years ago
Meliurwen 064b582833
Added docstrings for the chi-squared test 2 years ago
Meliurwen 2153f46758
Fixed small lint problem and added a paragraph to the README 2 years ago
Meliurwen ccced92149
Fixed some docstrings in `params.rs` 2 years ago
Meliurwen 616d5ec3d5
Fixed and prettified some imprecisions to old docstrings and added new ones in `ctbn.rs` 2 years ago
Meliurwen 672de56c31
Added a brief description of the project in form of docstrings at the crate level 2 years ago
Meliurwen 1cad41f7f4
Merge pull request #67 from AlessandroBregoli/65-meta-add-documentation-build-in-github-workflows 2 years ago
Meliurwen 174e85734e
Added docs generation in GitHub Workflows 2 years ago
Meliurwen c247da6bc0
Merge branch '8-feature-constraint-based-structure-learning-algorithm-for-ctbn' into 'dev' 2 years ago
Meliurwen 761e9da436
Merge branch 'dev' into '8-feature-constraint-based-structure-learning-algorithm-for-ctbn' 2 years ago
Meliurwen 8953471570 Added tests to chi square call function and added a constructor to cache 2 years ago
  1. 7
      .github/workflows/build.yml
  2. 176
      LICENSE-APACHE
  3. 23
      LICENSE-MIT
  4. 16
      README.md
  5. 18
      reCTBN/Cargo.toml
  6. 162
      reCTBN/src/ctbn.rs
  7. 5
      reCTBN/src/lib.rs
  8. 43
      reCTBN/src/network.rs
  9. 43
      reCTBN/src/parameter_learning.rs
  10. 59
      reCTBN/src/params.rs
  11. 120
      reCTBN/src/process.rs
  12. 247
      reCTBN/src/process/ctbn.rs
  13. 114
      reCTBN/src/process/ctmp.rs
  14. 59
      reCTBN/src/reward.rs
  15. 205
      reCTBN/src/reward/reward_evaluation.rs
  16. 106
      reCTBN/src/reward/reward_function.rs
  17. 58
      reCTBN/src/sampling.rs
  18. 8
      reCTBN/src/structure_learning.rs
  19. 351
      reCTBN/src/structure_learning/constraint_based_algorithm.rs
  20. 212
      reCTBN/src/structure_learning/hypothesis_test.rs
  21. 24
      reCTBN/src/structure_learning/score_based_algorithm.rs
  22. 14
      reCTBN/src/structure_learning/score_function.rs
  23. 273
      reCTBN/src/tools.rs
  24. 249
      reCTBN/tests/ctbn.rs
  25. 127
      reCTBN/tests/ctmp.rs
  26. 207
      reCTBN/tests/parameter_learning.rs
  27. 122
      reCTBN/tests/reward_evaluation.rs
  28. 117
      reCTBN/tests/reward_function.rs
  29. 247
      reCTBN/tests/structure_learning.rs
  30. 171
      reCTBN/tests/tools.rs

@ -22,7 +22,7 @@ jobs:
profile: minimal profile: minimal
toolchain: stable toolchain: stable
default: true default: true
components: clippy, rustfmt components: clippy, rustfmt, rust-docs
- name: Setup Rust nightly - name: Setup Rust nightly
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
@ -30,6 +30,11 @@ jobs:
toolchain: nightly toolchain: nightly
default: false default: false
components: rustfmt components: rustfmt
- name: Docs (doc)
uses: actions-rs/cargo@v1
with:
command: rustdoc
args: --package reCTBN -- --default-theme=ayu
- name: Linting (clippy) - name: Linting (clippy)
uses: actions-rs/clippy-check@v1 uses: actions-rs/clippy-check@v1
with: with:

@ -0,0 +1,176 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS

@ -0,0 +1,23 @@
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

@ -56,3 +56,19 @@ To check the **formatting**:
```sh ```sh
cargo fmt --all -- --check cargo fmt --all -- --check
``` ```
## Documentation
To generate the **documentation**:
```sh
cargo rustdoc --package reCTBN --open -- --default-theme=ayu
```
## License
This software is distributed under the terms of both the Apache License
(Version 2.0) and the MIT license.
See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for
details.

@ -6,13 +6,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
ndarray = {version="*", features=["approx-0_5"]} ndarray = {version="~0.15", features=["approx-0_5"]}
thiserror = "*" thiserror = "1.0.37"
rand = "*" rand = "~0.8"
bimap = "*" bimap = "~0.6"
enum_dispatch = "*" enum_dispatch = "~0.3"
statrs = "*" statrs = "~0.16"
rand_chacha = "*" rand_chacha = "~0.3"
itertools = "~0.10"
rayon = "~1.6"
[dev-dependencies] [dev-dependencies]
approx = { package = "approx", version = "0.5" } approx = { package = "approx", version = "~0.5" }

@ -1,162 +0,0 @@
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::network;
use crate::params::{Params, ParamsTrait, StateType};
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
///composed by the following elements:
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
///- **nodes**: a vector containing all the nodes and their parameters.
///The index of a node inside the vector is also used as index for the adj_matrix.
///
///# Examples
///
///```
///
/// use std::collections::BTreeSet;
/// use reCTBN::network::Network;
/// use reCTBN::params;
/// use reCTBN::ctbn::*;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Get all the children of node X1
/// let cs = net.get_children_set(X1);
/// assert_eq!(&X2, cs.iter().next().unwrap());
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params>,
}
impl CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new(),
}
}
}
impl network::Network for CtbnNetwork {
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() - 1)
}
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len()
}
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
}
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<StateType>,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set
.iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
}
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
}

@ -1,11 +1,12 @@
#![doc = include_str!("../../README.md")]
#![allow(non_snake_case)] #![allow(non_snake_case)]
#[cfg(test)] #[cfg(test)]
extern crate approx; extern crate approx;
pub mod ctbn;
pub mod network;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod process;
pub mod reward;
pub mod sampling; pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;

@ -1,43 +0,0 @@
use std::collections::BTreeSet;
use thiserror::Error;
use crate::params;
/// Error types for trait Network
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Error during node insertion")]
NodeInsertionError(String),
}
///Network
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network
fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_number_of_nodes(&self) -> usize;
fn get_node(&self, node_idx: usize) -> &params::Params;
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>)
-> usize;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network and a generic parent_set. Usually, the only values really used
///in *current_state* are the ones in the parent set of the *node*.
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -1,23 +1,25 @@
//! Module containing methods used to learn the parameters.
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::params::*; use crate::params::*;
use crate::{network, tools}; use crate::{process, tools::Dataset};
pub trait ParameterLearning { pub trait ParameterLearning: Sync {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params; ) -> Params;
} }
pub fn sufficient_statistics<T: network::Network>( pub fn sufficient_statistics<T: process::NetworkProcess>(
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: &BTreeSet<usize>, parent_set: &BTreeSet<usize>,
) -> (Array3<usize>, Array2<f64>) { ) -> (Array3<usize>, Array2<f64>) {
@ -71,10 +73,10 @@ pub fn sufficient_statistics<T: network::Network>(
pub struct MLE {} pub struct MLE {}
impl ParameterLearning for MLE { impl ParameterLearning for MLE {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
@ -118,10 +120,10 @@ pub struct BayesianApproach {
} }
impl ParameterLearning for BayesianApproach { impl ParameterLearning for BayesianApproach {
fn fit<T: network::Network>( fn fit<T: process::NetworkProcess>(
&self, &self,
net: &T, net: &T,
dataset: &tools::Dataset, dataset: &Dataset,
node: usize, node: usize,
parent_set: Option<BTreeSet<usize>>, parent_set: Option<BTreeSet<usize>>,
) -> Params { ) -> Params {
@ -142,6 +144,10 @@ impl ParameterLearning for BayesianApproach {
.zip(M.mapv(|x| x as f64).axis_iter(Axis(2))) .zip(M.mapv(|x| x as f64).axis_iter(Axis(2)))
.for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau)))); .for_each(|(mut C, m)| C.assign(&(&m.mapv(|y| y + alpha) / &T.mapv(|y| y + tau))));
CIM.outer_iter_mut().for_each(|mut C| {
C.diag_mut().fill(0.0);
});
//Set the diagonal of the inner matrices to the the row sum multiplied by -1 //Set the diagonal of the inner matrices to the the row sum multiplied by -1
let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0); let tmp_diag_sum: Array2<f64> = CIM.sum_axis(Axis(2)).mapv(|x| x * -1.0);
CIM.outer_iter_mut() CIM.outer_iter_mut()
@ -162,20 +168,3 @@ impl ParameterLearning for BayesianApproach {
return n; return n;
} }
} }
pub struct Cache<P: ParameterLearning> {
parameter_learning: P,
dataset: tools::Dataset,
}
impl<P: ParameterLearning> Cache<P> {
pub fn fit<T: network::Network>(
&mut self,
net: &T,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
self.parameter_learning
.fit(net, &self.dataset, node, parent_set)
}
}

@ -1,3 +1,5 @@
//! Module containing methods to define different types of nodes.
use std::collections::BTreeSet; use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch; use enum_dispatch::enum_dispatch;
@ -18,14 +20,13 @@ pub enum ParamsError {
} }
/// Allowed type of states /// Allowed type of states
#[derive(Clone)] #[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub enum StateType { pub enum StateType {
Discrete(usize), Discrete(usize),
} }
/// Parameters /// This is a core element for building different types of nodes; the goal is to define the set of
/// The Params trait is the core element for building different types of nodes. The goal is to /// methods required to describes a generic node.
/// define the set of method required to describes a generic node.
#[enum_dispatch(Params)] #[enum_dispatch(Params)]
pub trait ParamsTrait { pub trait ParamsTrait {
fn reset_params(&mut self); fn reset_params(&mut self);
@ -65,25 +66,27 @@ pub trait ParamsTrait {
fn get_label(&self) -> &String; fn get_label(&self) -> &String;
} }
/// The Params enum is the core element for building different types of nodes. The goal is to /// Is a core element for building different types of nodes; the goal is to define all the
/// define all the supported type of Parameters /// supported type of Parameters
#[derive(Clone)] #[derive(Clone)]
#[enum_dispatch] #[enum_dispatch]
pub enum Params { pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams), DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
} }
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the /// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements: /// following elements.
/// - **domain**: an ordered and exhaustive set of possible states ///
/// - **cim**: Conditional Intensity Matrix /// # Arguments
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter ///
/// learning task and are composed by: /// * `label` - node's variable name.
/// - **transitions**: number of transitions from one state to another given a specific /// * `domain` - an ordered and exhaustive set of possible states.
/// realization of the parent set /// * `cim` - Conditional Intensity Matrix.
/// - **residence_time**: permanence time in each possible states given a specific /// * `transitions` - number of transitions from one state to another given a specific realization
/// realization of the parent set /// of the parent set; is a sufficient statistics are mainly used during the parameter learning
/// task.
/// * `residence_time` - residence time in each possible state, given a specific realization of the
/// parent set; is a sufficient statistics are mainly used during the parameter learning task.
#[derive(Clone)] #[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams { pub struct DiscreteStatesContinousTimeParams {
label: String, label: String,
@ -104,15 +107,17 @@ impl DiscreteStatesContinousTimeParams {
} }
} }
///Getter function for CIM /// Getter function for CIM
pub fn get_cim(&self) -> &Option<Array3<f64>> { pub fn get_cim(&self) -> &Option<Array3<f64>> {
&self.cim &self.cim
} }
///Setter function for CIM.\\ /// Setter function for CIM.
///This function check if the cim is valid using the validate_params method. ///
///- **Valid cim inserted**: it substitute the CIM in self.cim and return Ok(()) /// This function checks if the CIM is valid using the [`validate_params`](self::ParamsTrait::validate_params) method:
///- **Invalid cim inserted**: it replace the self.cim value with None and it retu ParamsError /// * **Valid CIM inserted** - it substitutes the CIM in `self.cim` and returns `Ok(())`.
/// * **Invalid CIM inserted** - it replaces the `self.cim` value with `None` and it returns
/// `ParamsError`.
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> { pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
self.cim = Some(cim); self.cim = Some(cim);
match self.validate_params() { match self.validate_params() {
@ -124,27 +129,27 @@ impl DiscreteStatesContinousTimeParams {
} }
} }
///Unchecked version of the setter function for CIM. /// Unchecked version of the setter function for CIM.
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) { pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
self.cim = Some(cim); self.cim = Some(cim);
} }
///Getter function for transitions /// Getter function for transitions.
pub fn get_transitions(&self) -> &Option<Array3<usize>> { pub fn get_transitions(&self) -> &Option<Array3<usize>> {
&self.transitions &self.transitions
} }
///Setter function for transitions /// Setter function for transitions.
pub fn set_transitions(&mut self, transitions: Array3<usize>) { pub fn set_transitions(&mut self, transitions: Array3<usize>) {
self.transitions = Some(transitions); self.transitions = Some(transitions);
} }
///Getter function for residence_time /// Getter function for residence_time.
pub fn get_residence_time(&self) -> &Option<Array2<f64>> { pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
&self.residence_time &self.residence_time
} }
///Setter function for residence_time /// Setter function for residence_time.
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) { pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time); self.residence_time = Some(residence_time);
} }
@ -266,7 +271,7 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
if cim if cim
.sum_axis(Axis(2)) .sum_axis(Axis(2))
.iter() .iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0) .any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",

@ -0,0 +1,120 @@
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
pub mod ctbn;
pub mod ctmp;
use std::collections::BTreeSet;
use thiserror::Error;
use crate::params;
/// Error types for trait Network
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Error during node insertion")]
NodeInsertionError(String),
}
/// This type is used to represent a specific realization of a generic NetworkProcess
pub type NetworkProcessState = Vec<params::StateType>;
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
/// as a CTBN).
pub trait NetworkProcess: Sync {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
/// Add an **directed edge** between a two nodes of the network.
///
/// # Arguments
///
/// * `parent` - parent node.
/// * `child` - child node.
fn add_edge(&mut self, parent: usize, child: usize);
/// Get all the indices of the nodes contained inside the network.
fn get_node_indices(&self) -> std::ops::Range<usize>;
/// Get the numbers of nodes contained in the network.
fn get_number_of_nodes(&self) -> usize;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node param**.
fn get_node(&self, node_idx: usize) -> &params::Params;
/// Get the **node param**.
///
/// # Arguments
///
/// * `node_idx` - node index value.
///
/// # Return
///
/// * The selected **node mutable param**.
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `node` - selected node.
/// * `current_state` - current configuration of the network.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize;
/// Compute the index that must be used to access the parameters of a `node`, given a specific
/// configuration of the network and a generic `parent_set`.
///
/// Usually, the only values really used in `current_state` are the ones in the parent set of
/// the `node`.
///
/// # Arguments
///
/// * `current_state` - current configuration of the network.
/// * `parent_set` - parent set of the selected `node`.
///
/// # Return
///
/// * Index of the `node` relative to the network.
fn get_param_index_from_custom_parent_set(
&self,
current_state: &Vec<params::StateType>,
parent_set: &BTreeSet<usize>,
) -> usize;
/// Get the **parent set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **parent set** of the selected node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
/// Get the **children set** of a given **node**.
///
/// # Arguments
///
/// * `node` - node index value.
///
/// # Return
///
/// * The **children set** of the selected node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -0,0 +1,247 @@
//! Continuous Time Bayesian Network
use std::collections::BTreeSet;
use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
use crate::process;
use super::ctmp::CtmpProcess;
use super::{NetworkProcess, NetworkProcessState};
/// It represents both the structure and the parameters of a CTBN.
///
/// # Arguments
///
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
/// * `nodes` - A vector containing all the nodes and their parameters.
///
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
///
/// # Example
///
/// ```rust
/// use std::collections::BTreeSet;
/// use reCTBN::process::NetworkProcess;
/// use reCTBN::params;
/// use reCTBN::process::ctbn::*;
///
/// //Create the domain for a discrete node
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
/// let X2 = net.add_node(X2).unwrap();
///
/// //Add an edge
/// net.add_edge(X1, X2);
///
/// //Get all the children of node X1
/// let cs = net.get_children_set(X1);
/// assert_eq!(&X2, cs.iter().next().unwrap());
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params>,
}
impl CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new(),
}
}
///Transform the **CTBN** into a **CTMP**
///
/// # Return
///
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
pub fn amalgamation(&self) -> CtmpProcess {
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
let state_space = variables_domain.product();
let variables_set = BTreeSet::from_iter(self.get_node_indices());
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space));
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: NetworkProcessState = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
for idx_node in 0..self.nodes.len() {
let p = match self.get_node(idx_node) {
Params::DiscreteStatesContinousTime(p) => p,
};
for next_node_state in 0..variables_domain[idx_node] {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
&variables_set,
);
amalgamated_cim[[0, idx_current_state, idx_next_state]] +=
p.get_cim().as_ref().unwrap()[[
self.get_param_index_network(idx_node, &current_state_statetype),
current_state[idx_node],
next_node_state,
]];
}
}
}
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new(
"ctmp".to_string(),
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param))
.unwrap();
return ctmp;
}
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> {
let mut state = state;
let mut array_state = Array1::zeros(variables_domain.shape()[0]);
for (idx, var) in variables_domain.indexed_iter() {
array_state[idx] = state % var;
state = state / var;
}
return array_state;
}
/// Get the Adjacency Matrix.
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> {
self.adj_matrix.as_ref()
}
}
impl process::NetworkProcess for CtbnNetwork {
/// Initialize an Adjacency matrix.
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
/// Add a new node.
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() - 1)
}
/// Connect two nodes with a new edge.
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len()
}
/// Get the number of nodes of the network.
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
}
fn get_param_index_from_custom_parent_set(
&self,
current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set
.iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
}
/// Get all the parents of the given node.
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
/// Get all the children of the given node.
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
}

@ -0,0 +1,114 @@
use std::collections::BTreeSet;
use crate::{
params::{Params, StateType},
process,
};
use super::{NetworkProcess, NetworkProcessState};
pub struct CtmpProcess {
param: Option<Params>,
}
impl CtmpProcess {
pub fn new() -> CtmpProcess {
CtmpProcess { param: None }
}
}
impl NetworkProcess for CtmpProcess {
fn initialize_adj_matrix(&mut self) {
unimplemented!("CtmpProcess has only one node")
}
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> {
match self.param {
None => {
self.param = Some(n);
Ok(0)
}
Some(_) => Err(process::NetworkError::NodeInsertionError(
"CtmpProcess has only one node".to_string(),
)),
}
}
fn add_edge(&mut self, _parent: usize, _child: usize) {
unimplemented!("CtmpProcess has only one node")
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
match self.param {
None => 0..0,
Some(_) => 0..1,
}
}
fn get_number_of_nodes(&self) -> usize {
match self.param {
None => 0,
Some(_) => 1,
}
}
fn get_node(&self, node_idx: usize) -> &crate::params::Params {
if node_idx == 0 {
self.param.as_ref().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params {
if node_idx == 0 {
self.param.as_mut().unwrap()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
if node == 0 {
match current_state[0] {
StateType::Discrete(x) => x,
}
} else {
unimplemented!("CtmpProcess has only one node")
}
}
fn get_param_index_from_custom_parent_set(
&self,
_current_state: &NetworkProcessState,
_parent_set: &BTreeSet<usize>,
) -> usize {
unimplemented!("CtmpProcess has only one node")
}
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> {
match self.param {
Some(_) => {
if node == 0 {
BTreeSet::new()
} else {
unimplemented!("CtmpProcess has only one node")
}
}
None => panic!("Uninitialized CtmpProcess"),
}
}
}

@ -0,0 +1,59 @@
pub mod reward_evaluation;
pub mod reward_function;
use std::collections::HashMap;
use crate::process;
/// Instantiation of reward function and instantaneous reward
///
///
/// # Arguments
///
/// * `transition_reward`: reward obtained transitioning from one state to another
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
#[derive(Debug, PartialEq)]
pub struct Reward {
pub transition_reward: f64,
pub instantaneous_reward: f64,
}
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
pub trait RewardFunction: Sync {
/// Given the current state and the previous state, it compute the reward.
///
/// # Arguments
///
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
/// * `previous_state`: an optional argument representing the previous state of the network
fn call(
&self,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward;
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
///
/// # Arguments
///
/// * `p`: any structure that implements the trait `process::NetworkProcess`
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self;
}
pub trait RewardEvaluation {
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64>;
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &process::NetworkProcessState,
) -> f64;
}

@ -0,0 +1,205 @@
use std::collections::HashMap;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use statrs::distribution::ContinuousCDF;
use crate::params::{self, ParamsTrait};
use crate::process;
use crate::{
process::NetworkProcessState,
reward::RewardEvaluation,
sampling::{ForwardSampler, Sampler},
};
pub enum RewardCriteria {
FiniteHorizon,
InfiniteHorizon { discount_factor: f64 },
}
pub struct MonteCarloReward {
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
}
impl MonteCarloReward {
pub fn new(
max_iterations: usize,
max_err_stop: f64,
alpha_stop: f64,
end_time: f64,
reward_criteria: RewardCriteria,
seed: Option<u64>,
) -> MonteCarloReward {
MonteCarloReward {
max_iterations,
max_err_stop,
alpha_stop,
end_time,
reward_criteria,
seed,
}
}
}
impl RewardEvaluation for MonteCarloReward {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let variables_domain: Vec<Vec<params::StateType>> = network_process
.get_node_indices()
.map(|x| match network_process.get_node(x) {
params::Params::DiscreteStatesContinousTime(x) => (0..x
.get_reserved_space_as_parent())
.map(|s| params::StateType::Discrete(s))
.collect(),
})
.collect();
let n_states: usize = variables_domain.iter().map(|x| x.len()).product();
(0..n_states)
.into_par_iter()
.map(|s| {
let state: process::NetworkProcessState = variables_domain
.iter()
.fold((s, vec![]), |acc, x| {
let mut acc = acc;
let idx_s = acc.0 % x.len();
acc.1.push(x[idx_s].clone());
acc.0 = acc.0 / x.len();
acc
})
.1;
let r = self.evaluate_state(network_process, reward_function, &state);
(state, r)
})
.collect()
}
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
state: &NetworkProcessState,
) -> f64 {
let mut sampler =
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone()));
let mut expected_value = 0.0;
let mut squared_expected_value = 0.0;
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap();
for i in 0..self.max_iterations {
sampler.reset();
let mut ret = 0.0;
let mut previous = sampler.next().unwrap();
while previous.t < self.end_time {
let current = sampler.next().unwrap();
if current.t > self.end_time {
let r = reward_function.call(&previous.state, None);
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => self.end_time - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * self.end_time)
}
};
ret += discount * r.instantaneous_reward;
} else {
let r = reward_function.call(&previous.state, Some(&current.state));
let discount = match self.reward_criteria {
RewardCriteria::FiniteHorizon => current.t - previous.t,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * previous.t)
- std::f64::consts::E.powf(-discount_factor * current.t)
}
};
ret += discount * r.instantaneous_reward;
ret += match self.reward_criteria {
RewardCriteria::FiniteHorizon => 1.0,
RewardCriteria::InfiniteHorizon { discount_factor } => {
std::f64::consts::E.powf(-discount_factor * current.t)
}
} * r.transition_reward;
}
previous = current;
}
let float_i = i as f64;
expected_value =
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0);
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0)
+ ret.powi(2) / (float_i + 1.0);
if i > 2 {
let var =
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2));
if self.alpha_stop
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt())
> 0.0
{
return expected_value;
}
}
}
expected_value
}
}
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> {
inner_reward: RE,
}
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> {
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> {
NeighborhoodRelativeReward { inner_reward }
}
}
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> {
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
network_process: &N,
reward_function: &R,
) -> HashMap<process::NetworkProcessState, f64> {
let absolute_reward = self
.inner_reward
.evaluate_state_space(network_process, reward_function);
//This approach optimize memory. Maybe optimizing execution time can be better.
absolute_reward
.iter()
.map(|(k1, v1)| {
let mut max_val: f64 = 1.0;
absolute_reward.iter().for_each(|(k2, v2)| {
let count_diff: usize = k1
.iter()
.zip(k2.iter())
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 })
.sum();
if count_diff < 2 {
max_val = max_val.max(v1 / v2);
}
});
(k1.clone(), max_val)
})
.collect()
}
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>(
&self,
_network_process: &N,
_reward_function: &R,
_state: &process::NetworkProcessState,
) -> f64 {
unimplemented!();
}
}

@ -0,0 +1,106 @@
//! Module for dealing with reward functions
use crate::{
params::{self, ParamsTrait},
process,
reward::{Reward, RewardFunction},
};
use ndarray;
/// Reward function over a factored state space
///
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
/// of the underling `NetworkProcess`
///
/// # Arguments
///
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
/// reward of a node
pub struct FactoredRewardFunction {
transition_reward: Vec<ndarray::Array2<f64>>,
instantaneous_reward: Vec<ndarray::Array1<f64>>,
}
impl FactoredRewardFunction {
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> {
&self.transition_reward[node_idx]
}
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> {
&mut self.transition_reward[node_idx]
}
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> {
&self.instantaneous_reward[node_idx]
}
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> {
&mut self.instantaneous_reward[node_idx]
}
}
impl RewardFunction for FactoredRewardFunction {
fn call(
&self,
current_state: &process::NetworkProcessState,
previous_state: Option<&process::NetworkProcessState>,
) -> Reward {
let instantaneous_reward: f64 = current_state
.iter()
.enumerate()
.map(|(idx, x)| {
let x = match x {
params::StateType::Discrete(x) => x,
};
self.instantaneous_reward[idx][*x]
})
.sum();
if let Some(previous_state) = previous_state {
let transition_reward = previous_state
.iter()
.zip(current_state.iter())
.enumerate()
.find_map(|(idx, (p, c))| -> Option<f64> {
let p = match p {
params::StateType::Discrete(p) => p,
};
let c = match c {
params::StateType::Discrete(c) => c,
};
if p != c {
Some(self.transition_reward[idx][[*p, *c]])
} else {
None
}
})
.unwrap_or(0.0);
Reward {
transition_reward,
instantaneous_reward,
}
} else {
Reward {
transition_reward: 0.0,
instantaneous_reward,
}
}
}
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self {
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![];
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![];
for i in p.get_node_indices() {
//This works only for discrete nodes!
let size: usize = p.get_node(i).get_reserved_space_as_parent();
instantaneous_reward.push(ndarray::Array1::zeros(size));
transition_reward.push(ndarray::Array2::zeros((size, size)));
}
FactoredRewardFunction {
transition_reward,
instantaneous_reward,
}
}
}

@ -1,27 +1,40 @@
//! Module containing methods for the sampling.
use crate::{ use crate::{
network::Network, params::ParamsTrait,
params::{self, ParamsTrait}, process::{NetworkProcess, NetworkProcessState},
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
pub trait Sampler: Iterator { #[derive(Clone)]
pub struct Sample {
pub t: f64,
pub state: NetworkProcessState,
}
pub trait Sampler: Iterator<Item = Sample> {
fn reset(&mut self); fn reset(&mut self);
} }
pub struct ForwardSampler<'a, T> pub struct ForwardSampler<'a, T>
where where
T: Network, T: NetworkProcess,
{ {
net: &'a T, net: &'a T,
rng: ChaCha8Rng, rng: ChaCha8Rng,
current_time: f64, current_time: f64,
current_state: Vec<params::StateType>, current_state: NetworkProcessState,
next_transitions: Vec<Option<f64>>, next_transitions: Vec<Option<f64>>,
initial_state: Option<NetworkProcessState>,
} }
impl<'a, T: Network> ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> ForwardSampler<'a, T> {
pub fn new(net: &'a T, seed: Option<u64>) -> ForwardSampler<'a, T> { pub fn new(
net: &'a T,
seed: Option<u64>,
initial_state: Option<NetworkProcessState>,
) -> ForwardSampler<'a, T> {
let rng: ChaCha8Rng = match seed { let rng: ChaCha8Rng = match seed {
//If a seed is present use it to initialize the random generator. //If a seed is present use it to initialize the random generator.
Some(seed) => SeedableRng::seed_from_u64(seed), Some(seed) => SeedableRng::seed_from_u64(seed),
@ -29,19 +42,20 @@ impl<'a, T: Network> ForwardSampler<'a, T> {
None => SeedableRng::from_entropy(), None => SeedableRng::from_entropy(),
}; };
let mut fs = ForwardSampler { let mut fs = ForwardSampler {
net: net, net,
rng: rng, rng,
current_time: 0.0, current_time: 0.0,
current_state: vec![], current_state: vec![],
next_transitions: vec![], next_transitions: vec![],
initial_state,
}; };
fs.reset(); fs.reset();
return fs; return fs;
} }
} }
impl<'a, T: Network> Iterator for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Iterator for ForwardSampler<'a, T> {
type Item = (f64, Vec<params::StateType>); type Item = Sample;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let ret_time = self.current_time.clone(); let ret_time = self.current_time.clone();
@ -94,18 +108,26 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None; self.next_transitions[child] = None;
} }
Some((ret_time, ret_state)) Some(Sample {
t: ret_time,
state: ret_state,
})
} }
} }
impl<'a, T: Network> Sampler for ForwardSampler<'a, T> { impl<'a, T: NetworkProcess> Sampler for ForwardSampler<'a, T> {
fn reset(&mut self) { fn reset(&mut self) {
self.current_time = 0.0; self.current_time = 0.0;
self.current_state = self match &self.initial_state {
.net None => {
.get_node_indices() self.current_state = self
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng)) .net
.collect(); .get_node_indices()
.map(|x| self.net.get_node(x).get_random_state_uniform(&mut self.rng))
.collect()
}
Some(is) => self.current_state = is.clone(),
};
self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect(); self.next_transitions = self.net.get_node_indices().map(|_| Option::None).collect();
} }
} }

@ -1,11 +1,13 @@
//! Learn the structure of the network.
pub mod constraint_based_algorithm; pub mod constraint_based_algorithm;
pub mod hypothesis_test; pub mod hypothesis_test;
pub mod score_based_algorithm; pub mod score_based_algorithm;
pub mod score_function; pub mod score_function;
use crate::{network, tools}; use crate::{process, tools::Dataset};
pub trait StructureLearningAlgorithm { pub trait StructureLearningAlgorithm {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where where
T: network::Network; T: process::NetworkProcess;
} }

@ -1,3 +1,348 @@
//pub struct CTPC { //! Module containing constraint based algorithms like CTPC and Hiton.
//
//} use crate::params::Params;
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
use std::collections::{BTreeSet, HashMap};
use std::mem;
use std::usize;
use super::hypothesis_test::*;
use crate::parameter_learning::ParameterLearning;
use crate::process;
use crate::structure_learning::StructureLearningAlgorithm;
use crate::tools::Dataset;
pub struct Cache<'a, P: ParameterLearning> {
parameter_learning: &'a P,
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>,
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>,
parent_set_size_small: usize,
}
impl<'a, P: ParameterLearning> Cache<'a, P> {
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> {
Cache {
parameter_learning,
cache_persistent_small: HashMap::new(),
cache_persistent_big: HashMap::new(),
parent_set_size_small: 0,
}
}
pub fn fit<T: process::NetworkProcess>(
&mut self,
net: &T,
dataset: &Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>,
) -> Params {
let parent_set_len = parent_set.as_ref().unwrap().len();
if parent_set_len > self.parent_set_size_small + 1 {
//self.cache_persistent_small = self.cache_persistent_big;
mem::swap(
&mut self.cache_persistent_small,
&mut self.cache_persistent_big,
);
self.cache_persistent_big = HashMap::new();
self.parent_set_size_small += 1;
}
if parent_set_len > self.parent_set_size_small {
match self.cache_persistent_big.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_big.insert(parent_set, params.clone());
params
}
}
} else {
match self.cache_persistent_small.get(&parent_set) {
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
// not cloning requires a minor and reasoned refactoring across the library
Some(params) => params.clone(),
None => {
let params =
self.parameter_learning
.fit(net, dataset, node, parent_set.clone());
self.cache_persistent_small
.insert(parent_set, params.clone());
params
}
}
}
}
}
/// Continuous-Time Peter Clark algorithm.
///
/// A method to learn the structure of the network.
///
/// # Arguments
///
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// use reCTBN::parameter_learning::BayesianApproach;
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
/// #
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("A"));
/// # domain.insert(String::from("B"));
/// # domain.insert(String::from("C"));
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
/// # //Create the node n1 using the parameters
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("D"));
/// # domain.insert(String::from("E"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # let mut domain = BTreeSet::new();
/// # domain.insert(String::from("G"));
/// # domain.insert(String::from("H"));
/// # domain.insert(String::from("I"));
/// # domain.insert(String::from("F"));
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
/// #
/// # // Initialize a ctbn
/// # let mut net = CtbnNetwork::new();
/// #
/// # // Add the nodes and their edges
/// # let n1 = net.add_node(n1).unwrap();
/// # let n2 = net.add_node(n2).unwrap();
/// # let n3 = net.add_node(n3).unwrap();
/// # net.add_edge(n1, n2);
/// # net.add_edge(n1, n3);
/// # net.add_edge(n2, n3);
/// #
/// # match &mut net.get_node_mut(n1) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-3.0, 2.0, 1.0],
/// # [1.5, -2.0, 0.5],
/// # [0.4, 0.6, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n2) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.5],
/// # [3.0, -4.0, 1.0],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 4.0],
/// # [1.5, -2.0, 0.5],
/// # [3.0, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.0, 0.1, 0.9],
/// # [2.0, -2.5, 0.5],
/// # [0.9, 0.1, -1.0]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # match &mut net.get_node_mut(n3) {
/// # params::Params::DiscreteStatesContinousTime(param) => {
/// # assert_eq!(
/// # Ok(()),
/// # param.set_cim(arr3(&[
/// # [
/// # [-1.0, 0.5, 0.3, 0.2],
/// # [0.5, -4.0, 2.5, 1.0],
/// # [2.5, 0.5, -4.0, 1.0],
/// # [0.7, 0.2, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 2.0, 3.0, 1.0],
/// # [1.5, -3.0, 0.5, 1.0],
/// # [2.0, 1.3, -5.0, 1.7],
/// # [2.5, 0.5, 1.0, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.3, 0.1, 0.9],
/// # [1.4, -4.0, 0.5, 2.1],
/// # [1.0, 1.5, -3.0, 0.5],
/// # [0.4, 0.3, 0.1, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.7, 0.3],
/// # [1.3, -5.9, 2.7, 1.9],
/// # [2.0, 1.5, -4.0, 0.5],
/// # [0.2, 0.7, 0.1, -1.0]
/// # ],
/// # [
/// # [-6.0, 1.0, 2.0, 3.0],
/// # [0.5, -3.0, 1.0, 1.5],
/// # [1.4, 2.1, -4.3, 0.8],
/// # [0.5, 1.0, 2.5, -4.0]
/// # ],
/// # [
/// # [-1.3, 0.9, 0.3, 0.1],
/// # [0.1, -1.3, 0.2, 1.0],
/// # [0.5, 1.0, -3.0, 1.5],
/// # [0.1, 0.4, 0.3, -0.8]
/// # ],
/// # [
/// # [-2.0, 1.0, 0.6, 0.4],
/// # [2.6, -7.1, 1.4, 3.1],
/// # [5.0, 1.0, -8.0, 2.0],
/// # [1.4, 0.4, 0.2, -2.0]
/// # ],
/// # [
/// # [-3.0, 1.0, 1.5, 0.5],
/// # [3.0, -6.0, 1.0, 2.0],
/// # [0.3, 0.5, -1.9, 1.1],
/// # [5.0, 1.0, 2.0, -8.0]
/// # ],
/// # [
/// # [-2.6, 0.6, 0.2, 1.8],
/// # [2.0, -6.0, 3.0, 1.0],
/// # [0.1, 0.5, -1.3, 0.7],
/// # [0.8, 0.6, 0.2, -1.6]
/// # ],
/// # ]))
/// # );
/// # }
/// # }
/// #
/// # // Generate the trajectory
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
///
/// // Initialize the hypothesis tests to pass to the CTPC with their
/// // respective significance level `alpha`
/// let f = F::new(1e-6);
/// let chi_sq = ChiSquare::new(1e-4);
/// // Use the bayesian approach to learn the parameters
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
///
/// //Initialize CTPC
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
///
/// // Learn the structure of the network from the generated trajectory
/// let net = ctpc.fit_transform(net, &data);
/// #
/// # // Compare the generated network with the original one
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
/// ```
pub struct CTPC<P: ParameterLearning> {
parameter_learning: P,
Ftest: F,
Chi2test: ChiSquare,
}
impl<P: ParameterLearning> CTPC<P> {
pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> {
CTPC {
parameter_learning,
Ftest,
Chi2test,
}
}
}
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> {
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where
T: process::NetworkProcess,
{
//Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
panic!("Dataset and Network must have the same number of variables.")
}
//Make the network mutable.
let mut net = net;
net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| {
let mut cache = Cache::new(&self.parameter_learning);
let mut candidate_parent_set: BTreeSet<usize> = net
.get_node_indices()
.into_iter()
.filter(|x| x != &child_node)
.collect();
let mut separation_set_size = 0;
while separation_set_size < candidate_parent_set.len() {
let mut candidate_parent_set_TMP = candidate_parent_set.clone();
for parent_node in candidate_parent_set.iter() {
for separation_set in candidate_parent_set
.iter()
.filter(|x| x != &parent_node)
.map(|x| *x)
.combinations(separation_set_size)
{
let separation_set = separation_set.into_iter().collect();
if self.Ftest.call(
&net,
child_node,
*parent_node,
&separation_set,
dataset,
&mut cache,
) && self.Chi2test.call(
&net,
child_node,
*parent_node,
&separation_set,
dataset,
&mut cache,
) {
candidate_parent_set_TMP.remove(parent_node);
break;
}
}
}
candidate_parent_set = candidate_parent_set_TMP;
separation_set_size += 1;
}
(child_node, candidate_parent_set)
}));
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
}
net
}
}

@ -1,10 +1,13 @@
//! Module for constraint based algorithms containing hypothesis test algorithms like chi-squared test, F test, etc...
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::{Array3, Axis}; use ndarray::{Array3, Axis};
use statrs::distribution::{ChiSquared, ContinuousCDF}; use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use crate::params::*; use crate::params::*;
use crate::{network, parameter_learning}; use crate::structure_learning::constraint_based_algorithm::Cache;
use crate::{parameter_learning, process, tools::Dataset};
pub trait HypothesisTest { pub trait HypothesisTest {
fn call<T, P>( fn call<T, P>(
@ -13,23 +16,167 @@ pub trait HypothesisTest {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P>, dataset: &Dataset,
cache: &mut Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning; P: parameter_learning::ParameterLearning;
} }
/// Does the chi-squared test (χ2 test).
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct ChiSquare { pub struct ChiSquare {
alpha: f64, alpha: f64,
} }
pub struct F {} /// Does the F-test.
///
/// Used to determine if a difference between two sets of data is due to chance, or if it is due to
/// a relationship (dependence) between the variables.
///
/// # Arguments
///
/// * `alpha` - is the significance level, the probability to reject a true null hypothesis;
/// in other words is the risk of concluding that an association between the variables exists
/// when there is no actual association.
pub struct F {
alpha: f64,
}
impl F {
pub fn new(alpha: f64) -> F {
F { alpha }
}
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices(
&self,
i: usize,
M1: &Array3<usize>,
cim_1: &Array3<f64>,
j: usize,
M2: &Array3<usize>,
cim_2: &Array3<f64>,
) -> bool {
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
let cim_1 = cim_1.index_axis(Axis(0), i);
let cim_2 = cim_2.index_axis(Axis(0), j);
let r1 = M1.sum_axis(Axis(1));
let r2 = M2.sum_axis(Axis(1));
let q1 = cim_1.diag();
let q2 = cim_2.diag();
for idx in 0..r1.shape()[0] {
let s = q2[idx] / q1[idx];
let F = FisherSnedecor::new(r1[idx], r2[idx]).unwrap();
let s = F.cdf(s);
let lim_sx = self.alpha / 2.0;
let lim_dx = 1.0 - (self.alpha / 2.0);
if s < lim_sx || s > lim_dx {
return false;
}
}
true
}
}
impl HypothesisTest for F {
fn call<T, P>(
&self,
net: &T,
child_node: usize,
parent_node: usize,
separation_set: &BTreeSet<usize>,
dataset: &Dataset,
cache: &mut Cache<P>,
) -> bool
where
T: process::NetworkProcess,
P: parameter_learning::ParameterLearning,
{
let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node,
};
let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node);
let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
Params::DiscreteStatesContinousTime(node) => node,
};
let partial_cardinality_product: usize = extended_separation_set
.iter()
.take_while(|x| **x != parent_node)
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
for idx_M_big in 0..P_big.get_transitions().as_ref().unwrap().shape()[0] {
let idx_M_small: usize = idx_M_big % partial_cardinality_product
+ (idx_M_big
/ (partial_cardinality_product
* net.get_node(parent_node).get_reserved_space_as_parent()))
* partial_cardinality_product;
if !self.compare_matrices(
idx_M_small,
P_small.get_transitions().as_ref().unwrap(),
P_small.get_cim().as_ref().unwrap(),
idx_M_big,
P_big.get_transitions().as_ref().unwrap(),
P_big.get_cim().as_ref().unwrap(),
) {
return false;
}
}
return true;
}
}
impl ChiSquare { impl ChiSquare {
pub fn new(alpha: f64) -> ChiSquare { pub fn new(alpha: f64) -> ChiSquare {
ChiSquare { alpha } ChiSquare { alpha }
} }
/// Compare two matrices extracted from two 3rd-orer tensors.
///
/// # Arguments
///
/// * `i` - Position of the matrix of `M1` to compare with `M2`.
/// * `M1` - 3rd-order tensor 1.
/// * `j` - Position of the matrix of `M2` to compare with `M1`.
/// * `M2` - 3rd-order tensor 2.
///
/// # Returns
///
/// * `true` - when the matrices `M1` and `M2` are very similar, then **independendent**.
/// * `false` - when the matrices `M1` and `M2` are too different, then **dependent**.
pub fn compare_matrices( pub fn compare_matrices(
&self, &self,
i: usize, i: usize,
@ -42,26 +189,8 @@ impl ChiSquare {
// continuous-time Bayesian networks. // continuous-time Bayesian networks.
// International Journal of Approximate Reasoning, 138, pp.105-122. // International Journal of Approximate Reasoning, 138, pp.105-122.
// Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm // Also: https://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm
//
// M = M M = M
// 1 xx'|s 2 xx'|y,s
let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64); let M1 = M1.index_axis(Axis(0), i).mapv(|x| x as f64);
let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64); let M2 = M2.index_axis(Axis(0), j).mapv(|x| x as f64);
// __________________
// / ===
// / \ M
// / / xx'|s
// / ===
// / x'ϵVal /X \
// / \ i/ 1
//K = / ------------------ L = -
// / === K
// / \ M
// / / xx'|y,s
// / ===
// / x'ϵVal /X \
// \ / \ i/
// \/
let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1)); let K = M1.sum_axis(Axis(1)) / M2.sum_axis(Axis(1));
let K = K.mapv(f64::sqrt); let K = K.mapv(f64::sqrt);
// Reshape to column vector. // Reshape to column vector.
@ -70,26 +199,12 @@ impl ChiSquare {
K.into_shape((n, 1)).unwrap() K.into_shape((n, 1)).unwrap()
}; };
let L = 1.0 / &K; let L = 1.0 / &K;
// ===== 2
// \ (K . M - L . M)
// \ 2 1
// / ---------------
// / M + M
// ===== 2 1
// x'ϵVal /X \
// \ i/
let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1); let mut X_2 = (&K * &M2 - &L * &M1).mapv(|a| a.powi(2)) / (&M2 + &M1);
println!("M1: {:?}", M1);
println!("M2: {:?}", M2);
println!("L*M1: {:?}", (L * &M1));
println!("K*M2: {:?}", (K * &M2));
println!("X_2: {:?}", X_2);
X_2.diag_mut().fill(0.0); X_2.diag_mut().fill(0.0);
let X_2 = X_2.sum_axis(Axis(1)); let X_2 = X_2.sum_axis(Axis(1));
let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap(); let n = ChiSquared::new((X_2.dim() - 1) as f64).unwrap();
println!("CHI^2: {:?}", n); let ret = X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha));
println!("CHI^2 CDF: {:?}", X_2.mapv(|x| n.cdf(x))); ret
X_2.into_iter().all(|x| n.cdf(x) < (1.0 - self.alpha))
} }
} }
@ -100,26 +215,27 @@ impl HypothesisTest for ChiSquare {
child_node: usize, child_node: usize,
parent_node: usize, parent_node: usize,
separation_set: &BTreeSet<usize>, separation_set: &BTreeSet<usize>,
cache: &mut parameter_learning::Cache<P>, dataset: &Dataset,
cache: &mut Cache<P>,
) -> bool ) -> bool
where where
T: network::Network, T: process::NetworkProcess,
P: parameter_learning::ParameterLearning, P: parameter_learning::ParameterLearning,
{ {
// Prendo dalla cache l'apprendimento dei parametri, che sarebbe una CIM let P_small = match cache.fit(net, &dataset, child_node, Some(separation_set.clone())) {
// di dimensione nxn
// (CIM, M, T)
let P_small = match cache.fit(net, child_node, Some(separation_set.clone())) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
//
let mut extended_separation_set = separation_set.clone(); let mut extended_separation_set = separation_set.clone();
extended_separation_set.insert(parent_node); extended_separation_set.insert(parent_node);
let P_big = match cache.fit(net, child_node, Some(extended_separation_set.clone())) { let P_big = match cache.fit(
net,
&dataset,
child_node,
Some(extended_separation_set.clone()),
) {
Params::DiscreteStatesContinousTime(node) => node, Params::DiscreteStatesContinousTime(node) => node,
}; };
// Commentare qui
let partial_cardinality_product: usize = extended_separation_set let partial_cardinality_product: usize = extended_separation_set
.iter() .iter()
.take_while(|x| **x != parent_node) .take_while(|x| **x != parent_node)

@ -1,8 +1,13 @@
//! Module containing score based algorithms like Hill Climbing and Tabu Search.
use std::collections::BTreeSet; use std::collections::BTreeSet;
use crate::structure_learning::score_function::ScoreFunction; use crate::structure_learning::score_function::ScoreFunction;
use crate::structure_learning::StructureLearningAlgorithm; use crate::structure_learning::StructureLearningAlgorithm;
use crate::{network, tools}; use crate::{process, tools::Dataset};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::ParallelExtend;
pub struct HillClimbing<S: ScoreFunction> { pub struct HillClimbing<S: ScoreFunction> {
score_function: S, score_function: S,
@ -19,9 +24,9 @@ impl<S: ScoreFunction> HillClimbing<S> {
} }
impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> { impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Check the coherence between dataset and network //Check the coherence between dataset and network
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] {
@ -34,8 +39,9 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes()); let max_parent_set = self.max_parent_set.unwrap_or(net.get_number_of_nodes());
//Reset the adj matrix //Reset the adj matrix
net.initialize_adj_matrix(); net.initialize_adj_matrix();
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![];
//Iterate over each node to learn their parent set. //Iterate over each node to learn their parent set.
for node in net.get_node_indices() { learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|node| {
//Initialize an empty parent set. //Initialize an empty parent set.
let mut parent_set: BTreeSet<usize> = BTreeSet::new(); let mut parent_set: BTreeSet<usize> = BTreeSet::new();
//Compute the score for the empty parent set //Compute the score for the empty parent set
@ -74,10 +80,14 @@ impl<S: ScoreFunction> StructureLearningAlgorithm for HillClimbing<S> {
} }
} }
} }
//Apply the learned parent_set to the network struct. (node, parent_set)
parent_set.iter().for_each(|p| net.add_edge(*p, node)); }));
}
for (child_node, candidate_parent_set) in learned_parent_sets {
for parent_node in candidate_parent_set.iter() {
net.add_edge(*parent_node, child_node);
}
}
return net; return net;
} }
} }

@ -1,11 +1,13 @@
//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::prelude::*; use ndarray::prelude::*;
use statrs::function::gamma; use statrs::function::gamma;
use crate::{network, parameter_learning, params, tools}; use crate::{parameter_learning, params, process, tools};
pub trait ScoreFunction { pub trait ScoreFunction: Sync {
fn call<T>( fn call<T>(
&self, &self,
net: &T, net: &T,
@ -14,7 +16,7 @@ pub trait ScoreFunction {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network; T: process::NetworkProcess;
} }
pub struct LogLikelihood { pub struct LogLikelihood {
@ -39,7 +41,7 @@ impl LogLikelihood {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> (f64, Array3<usize>) ) -> (f64, Array3<usize>)
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Identify the type of node used //Identify the type of node used
match &net.get_node(node) { match &net.get_node(node) {
@ -98,7 +100,7 @@ impl ScoreFunction for LogLikelihood {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network, T: process::NetworkProcess,
{ {
self.compute_score(net, node, parent_set, dataset).0 self.compute_score(net, node, parent_set, dataset).0
} }
@ -125,7 +127,7 @@ impl ScoreFunction for BIC {
dataset: &tools::Dataset, dataset: &tools::Dataset,
) -> f64 ) -> f64
where where
T: network::Network, T: process::NetworkProcess,
{ {
//Compute the log-likelihood //Compute the log-likelihood
let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset); let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);

@ -1,8 +1,17 @@
use ndarray::prelude::*; //! Contains commonly used methods used across the crate.
use std::ops::{DivAssign, MulAssign, Range};
use ndarray::{Array, Array1, Array2, Array3, Axis};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::params::ParamsTrait;
use crate::process::NetworkProcess;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{network, params}; use crate::{params, process};
#[derive(Clone)]
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,
events: Array2<usize>, events: Array2<usize>,
@ -27,6 +36,7 @@ impl Trajectory {
} }
} }
#[derive(Clone)]
pub struct Dataset { pub struct Dataset {
trajectories: Vec<Trajectory>, trajectories: Vec<Trajectory>,
} }
@ -49,7 +59,7 @@ impl Dataset {
} }
} }
pub fn trajectory_generator<T: network::Network>( pub fn trajectory_generator<T: process::NetworkProcess>(
net: &T, net: &T,
n_trajectories: u64, n_trajectories: u64,
t_end: f64, t_end: f64,
@ -59,26 +69,25 @@ pub fn trajectory_generator<T: network::Network>(
let mut trajectories: Vec<Trajectory> = Vec::new(); let mut trajectories: Vec<Trajectory> = Vec::new();
//Random Generator object //Random Generator object
let mut sampler = ForwardSampler::new(net, seed, None);
let mut sampler = ForwardSampler::new(net, seed);
//Each iteration generate one trajectory //Each iteration generate one trajectory
for _ in 0..n_trajectories { for _ in 0..n_trajectories {
//History of all the moments in which something changed //History of all the moments in which something changed
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut events: Vec<Vec<params::StateType>> = Vec::new(); let mut events: Vec<process::NetworkProcessState> = Vec::new();
//Current Time and Current State //Current Time and Current State
let (mut t, mut current_state) = sampler.next().unwrap(); let mut sample = sampler.next().unwrap();
//Generate new samples until ending time is reached. //Generate new samples until ending time is reached.
while t < t_end { while sample.t < t_end {
time.push(t); time.push(sample.t);
events.push(current_state); events.push(sample.state);
(t, current_state) = sampler.next().unwrap(); sample = sampler.next().unwrap();
} }
current_state = events.last().unwrap().clone(); let current_state = events.last().unwrap().clone();
events.push(current_state); events.push(current_state);
//Add t_end as last time. //Add t_end as last time.
@ -104,3 +113,243 @@ pub fn trajectory_generator<T: network::Network>(
//Return a dataset object with the sampled trajectories. //Return a dataset object with the sampled trajectories.
Dataset::new(trajectories) Dataset::new(trajectories)
} }
pub trait RandomGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> Self;
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T);
}
/// Graph Generator using an uniform distribution.
///
/// A method to generate a random graph with edges uniformly distributed.
///
/// # Arguments
///
/// * `density` - is the density of the graph in terms of edges; domain: `0.0 ≤ density ≤ 1.0`.
/// * `rng` - is the random numbers generator.
///
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// use reCTBN::tools::UniformGraphGenerator;
/// use reCTBN::tools::RandomGraphGenerator;
/// # let mut net = CtbnNetwork::new();
/// # let nodes_cardinality = 8;
/// # let domain_cardinality = 4;
/// # for node in 0..nodes_cardinality {
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # for dvalue in 0..domain_cardinality {
/// # domain.insert(dvalue.to_string());
/// # }
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new(
/// # node.to_string(),
/// # domain
/// # );
/// # //Create the node using the parameters
/// # let node = DiscreteStatesContinousTime(param);
/// # // Add the node to the network
/// # net.add_node(node).unwrap();
/// # }
///
/// // Initialize the Graph Generator using the one with an
/// // uniform distribution
/// let density = 1.0/3.0;
/// let seed = Some(7641630759785120);
/// let mut structure_generator = UniformGraphGenerator::new(
/// density,
/// seed
/// );
///
/// // Generate the graph directly on the network
/// structure_generator.generate_graph(&mut net);
/// # // Count all the edges generated in the network
/// # let mut edges = 0;
/// # for node in net.get_node_indices(){
/// # edges += net.get_children_set(node).len()
/// # }
/// # // Number of all the nodes in the network
/// # let nodes = net.get_node_indices().len() as f64;
/// # let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
/// # // ±10% of tolerance
/// # let tolerance = ((expected_edges as f64)*0.10) as usize;
/// # // As the way `generate_graph()` is implemented we can only reasonably
/// # // expect the number of edges to be somewhere around the expected value.
/// # assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
/// ```
pub struct UniformGraphGenerator {
density: f64,
rng: ChaCha8Rng,
}
impl RandomGraphGenerator for UniformGraphGenerator {
fn new(density: f64, seed: Option<u64>) -> UniformGraphGenerator {
if density < 0.0 || density > 1.0 {
panic!(
"Density value must be between 1.0 and 0.0, got {}.",
density
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformGraphGenerator { density, rng }
}
/// Generate an uniformly distributed graph.
fn generate_graph<T: NetworkProcess>(&mut self, net: &mut T) {
net.initialize_adj_matrix();
let last_node_idx = net.get_node_indices().len();
for parent in 0..last_node_idx {
for child in 0..last_node_idx {
if parent != child {
if self.rng.gen_bool(self.density) {
net.add_edge(parent, child);
}
}
}
}
}
}
pub trait RandomParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> Self;
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T);
}
/// Parameters Generator using an uniform distribution.
///
/// A method to generate random parameters uniformly distributed.
///
/// # Arguments
///
/// * `interval` - is the interval of the random values oh the CIM's diagonal; domain: `≥ 0.0`.
/// * `rng` - is the random numbers generator.
///
/// # Example
///
/// ```rust
/// # use std::collections::BTreeSet;
/// # use ndarray::{arr1, arr2, arr3};
/// # use reCTBN::params;
/// # use reCTBN::params::ParamsTrait;
/// # use reCTBN::params::Params::DiscreteStatesContinousTime;
/// # use reCTBN::process::NetworkProcess;
/// # use reCTBN::process::ctbn::CtbnNetwork;
/// # use reCTBN::tools::trajectory_generator;
/// # use reCTBN::tools::RandomGraphGenerator;
/// # use reCTBN::tools::UniformGraphGenerator;
/// use reCTBN::tools::RandomParametersGenerator;
/// use reCTBN::tools::UniformParametersGenerator;
/// # let mut net = CtbnNetwork::new();
/// # let nodes_cardinality = 8;
/// # let domain_cardinality = 4;
/// # for node in 0..nodes_cardinality {
/// # // Create the domain for a discrete node
/// # let mut domain = BTreeSet::new();
/// # for dvalue in 0..domain_cardinality {
/// # domain.insert(dvalue.to_string());
/// # }
/// # // Create the parameters for a discrete node using the domain
/// # let param = params::DiscreteStatesContinousTimeParams::new(
/// # node.to_string(),
/// # domain
/// # );
/// # //Create the node using the parameters
/// # let node = DiscreteStatesContinousTime(param);
/// # // Add the node to the network
/// # net.add_node(node).unwrap();
/// # }
/// #
/// # // Initialize the Graph Generator using the one with an
/// # // uniform distribution
/// # let mut structure_generator = UniformGraphGenerator::new(
/// # 1.0/3.0,
/// # Some(7641630759785120)
/// # );
/// #
/// # // Generate the graph directly on the network
/// # structure_generator.generate_graph(&mut net);
///
/// // Initialize the parameters generator with uniform distributin
/// let mut cim_generator = UniformParametersGenerator::new(
/// 0.0..7.0,
/// Some(7641630759785120)
/// );
///
/// // Generate CIMs with uniformly distributed parameters.
/// cim_generator.generate_parameters(&mut net);
/// #
/// # for node in net.get_node_indices() {
/// # assert_eq!(
/// # Ok(()),
/// # net.get_node(node).validate_params()
/// # );
/// }
/// ```
pub struct UniformParametersGenerator {
interval: Range<f64>,
rng: ChaCha8Rng,
}
impl RandomParametersGenerator for UniformParametersGenerator {
fn new(interval: Range<f64>, seed: Option<u64>) -> UniformParametersGenerator {
if interval.start < 0.0 || interval.end < 0.0 {
panic!(
"Interval must be entirely less or equal than 0, got {}..{}.",
interval.start, interval.end
);
}
let rng: ChaCha8Rng = match seed {
Some(seed) => SeedableRng::seed_from_u64(seed),
None => SeedableRng::from_entropy(),
};
UniformParametersGenerator { interval, rng }
}
/// Generate CIMs with uniformly distributed parameters.
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) {
for node in net.get_node_indices() {
let parent_set_state_space_cardinality: usize = net
.get_parent_set(node)
.iter()
.map(|x| net.get_node(*x).get_reserved_space_as_parent())
.product();
match &mut net.get_node_mut(node) {
params::Params::DiscreteStatesContinousTime(param) => {
let node_domain_cardinality = param.get_reserved_space_as_parent();
let mut cim = Array3::<f64>::from_shape_fn(
(
parent_set_state_space_cardinality,
node_domain_cardinality,
node_domain_cardinality,
),
|_| self.rng.gen(),
);
cim.axis_iter_mut(Axis(0)).for_each(|mut x| {
x.diag_mut().fill(0.0);
x.div_assign(&x.sum_axis(Axis(1)).insert_axis(Axis(1)));
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| {
self.rng.gen_range(self.interval.clone())
});
x.mul_assign(&diag.clone().insert_axis(Axis(1)));
// Recomputing the diagonal in order to reduce the issues caused by the
// loss of precision when validating the parameters.
let diag_sum = -x.sum_axis(Axis(1));
x.diag_mut().assign(&diag_sum)
});
param.set_cim_unchecked(cim);
}
}
}
}
}

@ -1,9 +1,12 @@
mod utils; mod utils;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use reCTBN::ctbn::*;
use reCTBN::network::Network; use approx::AbsDiffEq;
use ndarray::arr3;
use reCTBN::params::{self, ParamsTrait}; use reCTBN::params::{self, ParamsTrait};
use reCTBN::process::NetworkProcess;
use reCTBN::process::{ctbn::*};
use utils::generate_discrete_time_continous_node; use utils::generate_discrete_time_continous_node;
#[test] #[test]
@ -129,3 +132,245 @@ fn compute_index_from_custom_parent_set() {
); );
assert_eq!(2, idx); assert_eq!(2, idx);
} }
#[test]
fn simple_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.initialize_adj_matrix();
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctbn) = &net.get_node(0);
let p_ctbn = p_ctbn.get_cim().as_ref().unwrap();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
assert!(p_ctmp.abs_diff_eq(p_ctbn, std::f64::EPSILON));
}
#[test]
fn chain_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-1.20e-01, 1.00e-01, 1.00e-02, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -6.01e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, -1.01e+01, 1.00e-01, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e-02, 1.00e+00, -6.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.11e+00, 1.00e-01, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.10e+01, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 5.00e+00, 0.00e+00, -5.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e-02, 1.00e+00, -1.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}
#[test]
fn chainfork_amalgamation() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let n4 = net
.add_node(generate_discrete_time_continous_node(String::from("n4"), 2))
.unwrap();
net.add_edge(n1, n3);
net.add_edge(n2, n3);
net.add_edge(n3, n4);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])));
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
match &mut net.get_node_mut(n4) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]]
]))
);
}
}
let ctmp = net.amalgamation();
let params::Params::DiscreteStatesContinousTime(p_ctmp) = &ctmp.get_node(0);
let p_ctmp = p_ctmp.get_cim().as_ref().unwrap();
let p_ctmp_handmade = arr3(&[[
[
-2.20e-01, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, -1.12e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
1.00e+00, 0.00e+00, -1.12e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -1.02e+01, 1.00e-01, 1.00e-01, 0.00e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -1.11e+01, 0.00e+00, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -1.11e+01, 1.00e-01,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -7.01e+00,
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
-5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00,
0.00e+00, 1.00e+00, 1.00e+00, -1.20e+01, 0.00e+00, 0.00e+00, 0.00e+00, 5.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00, 0.00e+00,
5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -5.21e+00, 1.00e-01, 1.00e-01, 0.00e+00,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 0.00e+00,
0.00e+00, 5.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, -6.11e+00, 0.00e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00,
0.00e+00, 0.00e+00, 5.00e+00, 0.00e+00, 1.00e+00, 0.00e+00, -6.11e+00, 1.00e-01,
],
[
0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02,
0.00e+00, 0.00e+00, 0.00e+00, 1.00e-02, 0.00e+00, 1.00e+00, 1.00e+00, -2.02e+00,
],
]]);
assert!(p_ctmp.abs_diff_eq(&p_ctmp_handmade, 1e-8));
}

@ -0,0 +1,127 @@
mod utils;
use std::collections::BTreeSet;
use reCTBN::{
params,
params::ParamsTrait,
process::{ctmp::*, NetworkProcess},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn define_simple_ctmp() {
let _ = CtmpProcess::new();
assert!(true);
}
#[test]
fn add_node_to_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
}
#[test]
fn add_two_nodes_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
match n2 {
Ok(_) => assert!(false),
Err(_) => assert!(true),
};
}
#[test]
#[should_panic]
fn add_edge_to_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2));
net.add_edge(0, 1)
}
#[test]
fn childen_and_parents() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(0, net.get_parent_set(0).len());
assert_eq!(0, net.get_children_set(0).len());
}
#[test]
#[should_panic]
fn get_childen_panic() {
let net = CtmpProcess::new();
net.get_children_set(0);
}
#[test]
#[should_panic]
fn get_childen_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_children_set(1);
}
#[test]
#[should_panic]
fn get_parent_panic() {
let net = CtmpProcess::new();
net.get_parent_set(0);
}
#[test]
#[should_panic]
fn get_parent_panic2() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
net.get_parent_set(1);
}
#[test]
fn compute_index_ctmp() {
let mut net = CtmpProcess::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]);
assert_eq!(6, idx);
}
#[test]
#[should_panic]
fn compute_index_from_custom_parent_set_ctmp() {
let mut net = CtmpProcess::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(
String::from("n1"),
10,
))
.unwrap();
let _idx = net.get_param_index_from_custom_parent_set(
&vec![params::StateType::Discrete(6)],
&BTreeSet::from([0])
);
}

@ -2,10 +2,11 @@
mod utils; mod utils;
use ndarray::arr3; use ndarray::arr3;
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::*; use reCTBN::parameter_learning::*;
use reCTBN::params; use reCTBN::params;
use reCTBN::params::Params::DiscreteStatesContinousTime;
use reCTBN::tools::*; use reCTBN::tools::*;
use utils::*; use utils::*;
@ -66,18 +67,78 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) {
)); ));
} }
fn generate_nodes(
net: &mut CtbnNetwork,
nodes_cardinality: usize,
nodes_domain_cardinality: usize
) {
for node_label in 0..nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
}
fn learn_binary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 2);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test] #[test]
fn learn_binary_cim_MLE() { fn learn_binary_cim_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_binary_cim(mle); learn_binary_cim(mle);
} }
#[test]
fn learn_binary_cim_MLE_gen() {
let mle = MLE {};
learn_binary_cim_gen(mle);
}
#[test] #[test]
fn learn_binary_cim_BA() { fn learn_binary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim(ba); learn_binary_cim(ba);
} }
#[test]
fn learn_binary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim_gen(ba);
}
fn learn_ternary_cim<T: ParameterLearning>(pl: T) { fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -155,18 +216,63 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) {
)); ));
} }
fn learn_ternary_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
4.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(1) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 1, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test] #[test]
fn learn_ternary_cim_MLE() { fn learn_ternary_cim_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_ternary_cim(mle); learn_ternary_cim(mle);
} }
#[test]
fn learn_ternary_cim_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_gen(mle);
}
#[test] #[test]
fn learn_ternary_cim_BA() { fn learn_ternary_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim(ba); learn_ternary_cim(ba);
} }
#[test]
fn learn_ternary_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_gen(ba);
}
fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) { fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -234,18 +340,63 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) {
)); ));
} }
fn learn_ternary_cim_no_parents_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..6.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(0) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 0, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.1
)
);
}
#[test] #[test]
fn learn_ternary_cim_no_parents_MLE() { fn learn_ternary_cim_no_parents_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_ternary_cim_no_parents(mle); learn_ternary_cim_no_parents(mle);
} }
#[test]
fn learn_ternary_cim_no_parents_MLE_gen() {
let mle = MLE {};
learn_ternary_cim_no_parents_gen(mle);
}
#[test] #[test]
fn learn_ternary_cim_no_parents_BA() { fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents(ba); learn_ternary_cim_no_parents(ba);
} }
#[test]
fn learn_ternary_cim_no_parents_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents_gen(ba);
}
fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) { fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -432,14 +583,66 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
)); ));
} }
fn learn_mixed_discrete_cim_gen<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
net.add_edge(0, 2);
net.add_edge(1, 2);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
1.0..8.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let p_gen = match net.get_node(2) {
DiscreteStatesContinousTime(p_gen) => p_gen,
};
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let p_tj = match pl.fit(&net, &data, 2, None) {
DiscreteStatesContinousTime(p_tj) => p_tj,
};
assert_eq!(
p_tj.get_cim().as_ref().unwrap().shape(),
p_gen.get_cim().as_ref().unwrap().shape()
);
assert!(
p_tj.get_cim().as_ref().unwrap().abs_diff_eq(
&p_gen.get_cim().as_ref().unwrap(),
0.2
)
);
}
#[test] #[test]
fn learn_mixed_discrete_cim_MLE() { fn learn_mixed_discrete_cim_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_mixed_discrete_cim(mle); learn_mixed_discrete_cim(mle);
} }
#[test]
fn learn_mixed_discrete_cim_MLE_gen() {
let mle = MLE {};
learn_mixed_discrete_cim_gen(mle);
}
#[test] #[test]
fn learn_mixed_discrete_cim_BA() { fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 }; let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim(ba); learn_mixed_discrete_cim(ba);
} }
#[test]
fn learn_mixed_discrete_cim_BA_gen() {
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim_gen(ba);
}

@ -0,0 +1,122 @@
mod utils;
use approx::assert_abs_diff_eq;
use ndarray::*;
use reCTBN::{
params,
process::{ctbn::*, NetworkProcess, NetworkProcessState},
reward::{reward_evaluation::*, reward_function::*, *},
};
use utils::generate_discrete_time_continous_node;
#[test]
fn simple_factored_reward_function_binary_node_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1)
.assign(&arr1(&[3.0, 3.0]));
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap();
}
}
net.initialize_adj_matrix();
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2);
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2);
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215));
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2);
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2);
}
#[test]
fn simple_factored_reward_function_chain_mc() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap();
}
}
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
param
.set_cim(arr3(&[
[[-0.01, 0.01], [5.0, -5.0]],
[[-5.0, 5.0], [0.01, -0.01]],
]))
.unwrap();
}
}
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n2)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
rf.get_transition_reward_mut(n3)
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]]));
let s000: NetworkProcessState = vec![
params::StateType::Discrete(1),
params::StateType::Discrete(0),
params::StateType::Discrete(0),
];
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215));
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1);
let rst = mc.evaluate_state_space(&net, &rf);
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1);
}

@ -0,0 +1,117 @@
mod utils;
use ndarray::*;
use utils::generate_discrete_time_continous_node;
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params};
#[test]
fn simple_factored_reward_function_binary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0});
}
#[test]
fn simple_factored_reward_function_ternary_node() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)];
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)];
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)];
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0});
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0});
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0});
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0});
}
#[test]
fn factored_reward_function_two_nodes() {
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net);
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]]));
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0]));
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]]));
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0]));
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)];
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)];
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)];
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)];
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)];
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)];
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0});
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0});
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0});
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0});
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0});
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0});
}

@ -4,10 +4,12 @@ mod utils;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::process::ctbn::*;
use reCTBN::network::Network; use reCTBN::process::NetworkProcess;
use reCTBN::parameter_learning::BayesianApproach;
use reCTBN::params; use reCTBN::params;
use reCTBN::structure_learning::hypothesis_test::*; use reCTBN::structure_learning::hypothesis_test::*;
use reCTBN::structure_learning::constraint_based_algorithm::*;
use reCTBN::structure_learning::score_based_algorithm::*; use reCTBN::structure_learning::score_based_algorithm::*;
use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::StructureLearningAlgorithm; use reCTBN::structure_learning::StructureLearningAlgorithm;
@ -106,7 +108,7 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
} }
} }
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let _n1 = net let _n1 = net
@ -115,6 +117,50 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
let _net = sl.fit_transform(net, &data); let _net = sl.fit_transform(net, &data);
} }
fn generate_nodes(
net: &mut CtbnNetwork,
nodes_cardinality: usize,
nodes_domain_cardinality: usize
) {
for node_label in 0..nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
}
fn check_compatibility_between_dataset_and_network_gen<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 100, 30.0, Some(6347747169756259));
let mut net = CtbnNetwork::new();
let _n1 = net
.add_node(
generate_discrete_time_continous_node(String::from("0"),
3)
).unwrap();
let _net = sl.fit_transform(net, &data);
}
#[test] #[test]
#[should_panic] #[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing() { pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
@ -123,6 +169,14 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
check_compatibility_between_dataset_and_network(hl); check_compatibility_between_dataset_and_network(hl);
} }
#[test]
#[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
check_compatibility_between_dataset_and_network_gen(hl);
}
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) { fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -180,6 +234,25 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm>(sl: T) {
assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
} }
fn learn_ternary_net_2_nodes_gen<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_edge(0, 1);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
}
#[test] #[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0); let ll = LogLikelihood::new(1, 1.0);
@ -187,6 +260,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
learn_ternary_net_2_nodes(hl); learn_ternary_net_2_nodes(hl);
} }
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_ternary_net_2_nodes_gen(hl);
}
#[test] #[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
let bic = BIC::new(1, 1.0); let bic = BIC::new(1, 1.0);
@ -194,6 +274,13 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
learn_ternary_net_2_nodes(hl); learn_ternary_net_2_nodes(hl);
} }
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_ternary_net_2_nodes_gen(hl);
}
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -318,6 +405,30 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
return (net, data); return (net, data);
} }
fn get_mixed_discrete_net_3_nodes_with_data_gen() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new();
generate_nodes(&mut net, 2, 3);
net.add_node(
generate_discrete_time_continous_node(
String::from("3"),
4
)
).unwrap();
net.add_edge(0, 1);
net.add_edge(0, 2);
net.add_edge(1, 2);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
0.0..7.0,
Some(6813071588535822)
);
cim_generator.generate_parameters(&mut net);
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
return (net, data);
}
fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) { fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
@ -326,6 +437,14 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm>(sl: T) {
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
} }
fn learn_mixed_discrete_net_3_nodes_gen<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
}
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0); let ll = LogLikelihood::new(1, 1.0);
@ -333,6 +452,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
learn_mixed_discrete_net_3_nodes(hl); learn_mixed_discrete_net_3_nodes(hl);
} }
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_mixed_discrete_net_3_nodes_gen(hl);
}
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
let bic = BIC::new(1, 1.0); let bic = BIC::new(1, 1.0);
@ -340,6 +466,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
learn_mixed_discrete_net_3_nodes(hl); learn_mixed_discrete_net_3_nodes(hl);
} }
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_mixed_discrete_net_3_nodes_gen(hl);
}
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) { fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
@ -348,6 +481,14 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgo
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
} }
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data_gen();
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::new(), net.get_parent_set(0));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
}
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
let ll = LogLikelihood::new(1, 1.0); let ll = LogLikelihood::new(1, 1.0);
@ -355,6 +496,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
} }
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint_gen() {
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
}
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() {
let bic = BIC::new(1, 1.0); let bic = BIC::new(1, 1.0);
@ -362,6 +510,13 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint()
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl); learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
} }
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint_gen() {
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint_gen(hl);
}
#[test] #[test]
pub fn chi_square_compare_matrices() { pub fn chi_square_compare_matrices() {
let i: usize = 1; let i: usize = 1;
@ -390,7 +545,7 @@ pub fn chi_square_compare_matrices() {
[ 700, 800, 0] [ 700, 800, 0]
], ],
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(1e-4);
assert!(!chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(!chi_sq.compare_matrices(i, &M1, j, &M2));
} }
@ -420,7 +575,7 @@ pub fn chi_square_compare_matrices_2() {
[ 400, 0, 600], [ 400, 0, 600],
[ 700, 800, 0]] [ 700, 800, 0]]
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
@ -452,6 +607,86 @@ pub fn chi_square_compare_matrices_3() {
[ 700, 800, 0] [ 700, 800, 0]
], ],
]); ]);
let chi_sq = ChiSquare::new(0.1); let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.compare_matrices(i, &M1, j, &M2)); assert!(chi_sq.compare_matrices(i, &M1, j, &M2));
} }
#[test]
pub fn chi_square_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(&parameter_learning);
let chi_sq = ChiSquare::new(1e-4);
assert!(chi_sq.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!chi_sq.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!chi_sq.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(chi_sq.call(&net, N2, N3, &separation_set, &data, &mut cache));
}
#[test]
pub fn f_call() {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let N3: usize = 2;
let N2: usize = 1;
let N1: usize = 0;
let mut separation_set = BTreeSet::new();
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let mut cache = Cache::new(&parameter_learning);
let f = F::new(1e-6);
assert!(f.call(&net, N1, N3, &separation_set, &data, &mut cache));
let mut cache = Cache::new(&parameter_learning);
assert!(!f.call(&net, N3, N1, &separation_set, &data, &mut cache));
assert!(!f.call(&net, N3, N2, &separation_set, &data, &mut cache));
separation_set.insert(N1);
let mut cache = Cache::new(&parameter_learning);
assert!(f.call(&net, N2, N3, &separation_set, &data, &mut cache));
}
#[test]
pub fn learn_ternary_net_2_nodes_ctpc() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_ternary_net_2_nodes(ctpc);
}
#[test]
pub fn learn_ternary_net_2_nodes_ctpc_gen() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_ternary_net_2_nodes_gen(ctpc);
}
#[test]
fn learn_mixed_discrete_net_3_nodes_ctpc() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_mixed_discrete_net_3_nodes(ctpc);
}
#[test]
fn learn_mixed_discrete_net_3_nodes_ctpc_gen() {
let f = F::new(1e-6);
let chi_sq = ChiSquare::new(1e-4);
let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
learn_mixed_discrete_net_3_nodes_gen(ctpc);
}

@ -1,9 +1,15 @@
use std::ops::Range;
use ndarray::{arr1, arr2, arr3}; use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::params::ParamsTrait;
use reCTBN::network::Network; use reCTBN::process::ctbn::*;
use reCTBN::process::ctmp::*;
use reCTBN::process::NetworkProcess;
use reCTBN::params; use reCTBN::params;
use reCTBN::tools::*; use reCTBN::tools::*;
use utils::*;
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
@ -82,3 +88,164 @@ fn dataset_wrong_shape() {
let t2 = Trajectory::new(time, events); let t2 = Trajectory::new(time, events);
Dataset::new(vec![t1, t2]); Dataset::new(vec![t1, t2]);
} }
#[test]
#[should_panic]
fn uniform_graph_generator_wrong_density_1() {
let density = 2.1;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
#[test]
#[should_panic]
fn uniform_graph_generator_wrong_density_2() {
let density = -0.5;
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
#[test]
fn uniform_graph_generator_right_densities() {
for density in [1.0, 0.75, 0.5, 0.25, 0.0] {
let _structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
None
);
}
}
#[test]
fn uniform_graph_generator_generate_graph_ctbn() {
let mut net = CtbnNetwork::new();
let nodes_cardinality = 0..=100;
let nodes_domain_cardinality = 2;
for node_label in nodes_cardinality {
net.add_node(
utils::generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
Some(7641630759785120)
);
structure_generator.generate_graph(&mut net);
let mut edges = 0;
for node in net.get_node_indices(){
edges += net.get_children_set(node).len()
}
let nodes = net.get_node_indices().len() as f64;
let expected_edges = (density * nodes * (nodes - 1.0)).round() as usize;
let tolerance = ((expected_edges as f64)*0.05) as usize; // ±5% of tolerance
// As the way `generate_graph()` is implemented we can only reasonably
// expect the number of edges to be somewhere around the expected value.
assert!((expected_edges - tolerance) <= edges && edges <= (expected_edges + tolerance));
}
#[test]
#[should_panic]
fn uniform_graph_generator_generate_graph_ctmp() {
let mut net = CtmpProcess::new();
let node_label = String::from("0");
let node_domain_cardinality = 4;
net.add_node(
generate_discrete_time_continous_node(
node_label,
node_domain_cardinality
)
).unwrap();
let density = 1.0/3.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
Some(7641630759785120)
);
structure_generator.generate_graph(&mut net);
}
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_1() {
let interval: Range<f64> = -2.0..-5.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
None
);
}
#[test]
#[should_panic]
fn uniform_parameters_generator_wrong_density_2() {
let interval: Range<f64> = -1.0..0.0;
let _cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
None
);
}
#[test]
fn uniform_parameters_generator_right_densities_ctbn() {
let mut net = CtbnNetwork::new();
let nodes_cardinality = 0..=3;
let nodes_domain_cardinality = 9;
for node_label in nodes_cardinality {
net.add_node(
generate_discrete_time_continous_node(
node_label.to_string(),
nodes_domain_cardinality,
)
).unwrap();
}
let density = 1.0/3.0;
let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut structure_generator: UniformGraphGenerator = RandomGraphGenerator::new(
density,
seed
);
structure_generator.generate_graph(&mut net);
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
seed
);
cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
assert_eq!(
Ok(()),
net.get_node(node).validate_params()
);
}
}
#[test]
fn uniform_parameters_generator_right_densities_ctmp() {
let mut net = CtmpProcess::new();
let node_label = String::from("0");
let node_domain_cardinality = 4;
net.add_node(
generate_discrete_time_continous_node(
node_label,
node_domain_cardinality
)
).unwrap();
let seed = Some(7641630759785120);
let interval = 0.0..7.0;
let mut cim_generator: UniformParametersGenerator = RandomParametersGenerator::new(
interval,
seed
);
cim_generator.generate_parameters(&mut net);
for node in net.get_node_indices() {
assert_eq!(
Ok(()),
net.get_node(node).validate_params()
);
}
}

Loading…
Cancel
Save