Compare commits
91 Commits
62-feature
...
dev
Author | SHA1 | Date |
---|---|---|
Meliurwen | f6873c52c2 | 2 years ago |
Meliurwen | d45c96a38d | 2 years ago |
Meliurwen | 57ae851470 | 2 years ago |
Meliurwen | 4a9445385c | 2 years ago |
Meliurwen | 0eb427e5cf | 2 years ago |
Meliurwen | 7a3ac6c9ab | 2 years ago |
Meliurwen | 4fd0ee0407 | 2 years ago |
Meliurwen | c4da4ceadd | 2 years ago |
AlessandroBregoli | e638a627bb | 2 years ago |
AlessandroBregoli | 776b9aa030 | 2 years ago |
AlessandroBregoli | adb0f99419 | 2 years ago |
Meliurwen | b7fc23ed8a | 2 years ago |
Meliurwen | 0639a755d0 | 2 years ago |
Meliurwen | 4884010ea9 | 2 years ago |
Meliurwen | 430033afdb | 2 years ago |
Meliurwen | e08d12ac1f | 2 years ago |
Meliurwen | a01a9ef201 | 2 years ago |
Meliurwen | 097dc25030 | 2 years ago |
Meliurwen | 0f61cbee4c | 2 years ago |
Meliurwen | f4e3c98c79 | 2 years ago |
Meliurwen | d6f0fb9623 | 2 years ago |
Meliurwen | 434e671f0a | 2 years ago |
Meliurwen | 4b994d8a19 | 2 years ago |
Meliurwen | a077f738ee | 2 years ago |
Meliurwen | 49c2c55f61 | 2 years ago |
Meliurwen | c2df26c3e6 | 2 years ago |
Meliurwen | 7ec56914d9 | 2 years ago |
Alessandro Bregoli | 5d676be180 | 2 years ago |
Alessandro Bregoli | ff235b4b77 | 2 years ago |
Alessandro Bregoli | a104d1fbf9 | 2 years ago |
Meliurwen | 0bd325d349 | 2 years ago |
Meliurwen | b8938a934f | 2 years ago |
Meliurwen | 5632833963 | 2 years ago |
Meliurwen | 0e1cca0456 | 2 years ago |
Meliurwen | 4d3f9518e4 | 2 years ago |
Meliurwen | 867bf02934 | 2 years ago |
Meliurwen | a0da3e2fe8 | 2 years ago |
Meliurwen | ea3e406bf1 | 2 years ago |
Meliurwen | 19856195c3 | 2 years ago |
Meliurwen | ea5df7cad6 | 2 years ago |
Meliurwen | 468ebf09cc | 2 years ago |
Meliurwen | 8d0f9db289 | 2 years ago |
Meliurwen | 6d42d8a805 | 2 years ago |
Meliurwen | 6d952f8c07 | 2 years ago |
AlessandroBregoli | 9284ca5dd2 | 2 years ago |
AlessandroBregoli | 414aa31867 | 2 years ago |
AlessandroBregoli | 687f19ff1f | 2 years ago |
AlessandroBregoli | bb239aaa0c | 2 years ago |
AlessandroBregoli | cecf16a771 | 2 years ago |
AlessandroBregoli | 4fc5c1d4b5 | 2 years ago |
Meliurwen | 6e90458418 | 2 years ago |
Meliurwen | cac19b1756 | 2 years ago |
Meliurwen | 2e49df0266 | 2 years ago |
Meliurwen | 3f80f07e9f | 2 years ago |
AlessandroBregoli | bcb64a161a | 2 years ago |
AlessandroBregoli | 68ef7ea7c3 | 2 years ago |
AlessandroBregoli | f6015acce9 | 2 years ago |
AlessandroBregoli | 055eb7088e | 2 years ago |
AlessandroBregoli | 1878f687d6 | 2 years ago |
AlessandroBregoli | fd3b1ecfea | 2 years ago |
AlessandroBregoli | 38e744e034 | 2 years ago |
AlessandroBregoli | 44eaf8713f | 2 years ago |
AlessandroBregoli | 28ed1a40b3 | 2 years ago |
AlessandroBregoli | 4a7a8c5fba | 2 years ago |
Meliurwen | 3a0151a9f6 | 2 years ago |
AlessandroBregoli | 7c3cba50d4 | 2 years ago |
Meliurwen | a2c5800891 | 2 years ago |
Meliurwen | 9fbdf25149 | 2 years ago |
AlessandroBregoli | ed5471c7cf | 2 years ago |
Meliurwen | c08f4e1985 | 2 years ago |
Meliurwen | ec72a6a2f9 | 2 years ago |
Meliurwen | 713b8a8013 | 2 years ago |
Meliurwen | a92b605daa | 2 years ago |
Meliurwen | ce139afdb6 | 2 years ago |
Meliurwen | 0ae2168a94 | 2 years ago |
Meliurwen | 832922922a | 2 years ago |
Meliurwen | 3522e1b6f6 | 2 years ago |
Meliurwen | 245b3b5d45 | 2 years ago |
Meliurwen | 08623e28d4 | 2 years ago |
Meliurwen | f7165d0345 | 2 years ago |
Meliurwen | 9ca8973550 | 2 years ago |
Meliurwen | 064b582833 | 2 years ago |
Meliurwen | 2153f46758 | 2 years ago |
Meliurwen | ccced92149 | 2 years ago |
Meliurwen | 616d5ec3d5 | 2 years ago |
Meliurwen | 672de56c31 | 2 years ago |
Meliurwen | 1cad41f7f4 | 2 years ago |
Meliurwen | 174e85734e | 2 years ago |
Meliurwen | c247da6bc0 | 2 years ago |
Meliurwen | 761e9da436 | 2 years ago |
Meliurwen | 8953471570 | 2 years ago |
@ -0,0 +1,176 @@ |
|||||||
|
Apache License |
||||||
|
Version 2.0, January 2004 |
||||||
|
http://www.apache.org/licenses/ |
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION |
||||||
|
|
||||||
|
1. Definitions. |
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction, |
||||||
|
and distribution as defined by Sections 1 through 9 of this document. |
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by |
||||||
|
the copyright owner that is granting the License. |
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all |
||||||
|
other entities that control, are controlled by, or are under common |
||||||
|
control with that entity. For the purposes of this definition, |
||||||
|
"control" means (i) the power, direct or indirect, to cause the |
||||||
|
direction or management of such entity, whether by contract or |
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the |
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity. |
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity |
||||||
|
exercising permissions granted by this License. |
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications, |
||||||
|
including but not limited to software source code, documentation |
||||||
|
source, and configuration files. |
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical |
||||||
|
transformation or translation of a Source form, including but |
||||||
|
not limited to compiled object code, generated documentation, |
||||||
|
and conversions to other media types. |
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or |
||||||
|
Object form, made available under the License, as indicated by a |
||||||
|
copyright notice that is included in or attached to the work |
||||||
|
(an example is provided in the Appendix below). |
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object |
||||||
|
form, that is based on (or derived from) the Work and for which the |
||||||
|
editorial revisions, annotations, elaborations, or other modifications |
||||||
|
represent, as a whole, an original work of authorship. For the purposes |
||||||
|
of this License, Derivative Works shall not include works that remain |
||||||
|
separable from, or merely link (or bind by name) to the interfaces of, |
||||||
|
the Work and Derivative Works thereof. |
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including |
||||||
|
the original version of the Work and any modifications or additions |
||||||
|
to that Work or Derivative Works thereof, that is intentionally |
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner |
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of |
||||||
|
the copyright owner. For the purposes of this definition, "submitted" |
||||||
|
means any form of electronic, verbal, or written communication sent |
||||||
|
to the Licensor or its representatives, including but not limited to |
||||||
|
communication on electronic mailing lists, source code control systems, |
||||||
|
and issue tracking systems that are managed by, or on behalf of, the |
||||||
|
Licensor for the purpose of discussing and improving the Work, but |
||||||
|
excluding communication that is conspicuously marked or otherwise |
||||||
|
designated in writing by the copyright owner as "Not a Contribution." |
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity |
||||||
|
on behalf of whom a Contribution has been received by Licensor and |
||||||
|
subsequently incorporated within the Work. |
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of |
||||||
|
this License, each Contributor hereby grants to You a perpetual, |
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
||||||
|
copyright license to reproduce, prepare Derivative Works of, |
||||||
|
publicly display, publicly perform, sublicense, and distribute the |
||||||
|
Work and such Derivative Works in Source or Object form. |
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of |
||||||
|
this License, each Contributor hereby grants to You a perpetual, |
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
||||||
|
(except as stated in this section) patent license to make, have made, |
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work, |
||||||
|
where such license applies only to those patent claims licensable |
||||||
|
by such Contributor that are necessarily infringed by their |
||||||
|
Contribution(s) alone or by combination of their Contribution(s) |
||||||
|
with the Work to which such Contribution(s) was submitted. If You |
||||||
|
institute patent litigation against any entity (including a |
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work |
||||||
|
or a Contribution incorporated within the Work constitutes direct |
||||||
|
or contributory patent infringement, then any patent licenses |
||||||
|
granted to You under this License for that Work shall terminate |
||||||
|
as of the date such litigation is filed. |
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the |
||||||
|
Work or Derivative Works thereof in any medium, with or without |
||||||
|
modifications, and in Source or Object form, provided that You |
||||||
|
meet the following conditions: |
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or |
||||||
|
Derivative Works a copy of this License; and |
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices |
||||||
|
stating that You changed the files; and |
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works |
||||||
|
that You distribute, all copyright, patent, trademark, and |
||||||
|
attribution notices from the Source form of the Work, |
||||||
|
excluding those notices that do not pertain to any part of |
||||||
|
the Derivative Works; and |
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its |
||||||
|
distribution, then any Derivative Works that You distribute must |
||||||
|
include a readable copy of the attribution notices contained |
||||||
|
within such NOTICE file, excluding those notices that do not |
||||||
|
pertain to any part of the Derivative Works, in at least one |
||||||
|
of the following places: within a NOTICE text file distributed |
||||||
|
as part of the Derivative Works; within the Source form or |
||||||
|
documentation, if provided along with the Derivative Works; or, |
||||||
|
within a display generated by the Derivative Works, if and |
||||||
|
wherever such third-party notices normally appear. The contents |
||||||
|
of the NOTICE file are for informational purposes only and |
||||||
|
do not modify the License. You may add Your own attribution |
||||||
|
notices within Derivative Works that You distribute, alongside |
||||||
|
or as an addendum to the NOTICE text from the Work, provided |
||||||
|
that such additional attribution notices cannot be construed |
||||||
|
as modifying the License. |
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and |
||||||
|
may provide additional or different license terms and conditions |
||||||
|
for use, reproduction, or distribution of Your modifications, or |
||||||
|
for any such Derivative Works as a whole, provided Your use, |
||||||
|
reproduction, and distribution of the Work otherwise complies with |
||||||
|
the conditions stated in this License. |
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise, |
||||||
|
any Contribution intentionally submitted for inclusion in the Work |
||||||
|
by You to the Licensor shall be under the terms and conditions of |
||||||
|
this License, without any additional terms or conditions. |
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify |
||||||
|
the terms of any separate license agreement you may have executed |
||||||
|
with Licensor regarding such Contributions. |
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade |
||||||
|
names, trademarks, service marks, or product names of the Licensor, |
||||||
|
except as required for reasonable and customary use in describing the |
||||||
|
origin of the Work and reproducing the content of the NOTICE file. |
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or |
||||||
|
agreed to in writing, Licensor provides the Work (and each |
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS, |
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
||||||
|
implied, including, without limitation, any warranties or conditions |
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A |
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the |
||||||
|
appropriateness of using or redistributing the Work and assume any |
||||||
|
risks associated with Your exercise of permissions under this License. |
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory, |
||||||
|
whether in tort (including negligence), contract, or otherwise, |
||||||
|
unless required by applicable law (such as deliberate and grossly |
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be |
||||||
|
liable to You for damages, including any direct, indirect, special, |
||||||
|
incidental, or consequential damages of any character arising as a |
||||||
|
result of this License or out of the use or inability to use the |
||||||
|
Work (including but not limited to damages for loss of goodwill, |
||||||
|
work stoppage, computer failure or malfunction, or any and all |
||||||
|
other commercial damages or losses), even if such Contributor |
||||||
|
has been advised of the possibility of such damages. |
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing |
||||||
|
the Work or Derivative Works thereof, You may choose to offer, |
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity, |
||||||
|
or other liability obligations and/or rights consistent with this |
||||||
|
License. However, in accepting such obligations, You may act only |
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf |
||||||
|
of any other Contributor, and only if You agree to indemnify, |
||||||
|
defend, and hold each Contributor harmless for any liability |
||||||
|
incurred by, or claims asserted against, such Contributor by reason |
||||||
|
of your accepting any such warranty or additional liability. |
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS |
@ -0,0 +1,23 @@ |
|||||||
|
Permission is hereby granted, free of charge, to any |
||||||
|
person obtaining a copy of this software and associated |
||||||
|
documentation files (the "Software"), to deal in the |
||||||
|
Software without restriction, including without |
||||||
|
limitation the rights to use, copy, modify, merge, |
||||||
|
publish, distribute, sublicense, and/or sell copies of |
||||||
|
the Software, and to permit persons to whom the Software |
||||||
|
is furnished to do so, subject to the following |
||||||
|
conditions: |
||||||
|
|
||||||
|
The above copyright notice and this permission notice |
||||||
|
shall be included in all copies or substantial portions |
||||||
|
of the Software. |
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF |
||||||
|
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED |
||||||
|
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A |
||||||
|
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT |
||||||
|
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
||||||
|
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION |
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR |
||||||
|
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
||||||
|
DEALINGS IN THE SOFTWARE. |
@ -1,162 +0,0 @@ |
|||||||
use std::collections::BTreeSet; |
|
||||||
|
|
||||||
use ndarray::prelude::*; |
|
||||||
|
|
||||||
use crate::network; |
|
||||||
use crate::params::{Params, ParamsTrait, StateType}; |
|
||||||
|
|
||||||
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
|
|
||||||
///composed by the following elements:
|
|
||||||
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
|
|
||||||
///- **nodes**: a vector containing all the nodes and their parameters.
|
|
||||||
///The index of a node inside the vector is also used as index for the adj_matrix.
|
|
||||||
///
|
|
||||||
///# Examples
|
|
||||||
///
|
|
||||||
///```
|
|
||||||
///
|
|
||||||
/// use std::collections::BTreeSet;
|
|
||||||
/// use reCTBN::network::Network;
|
|
||||||
/// use reCTBN::params;
|
|
||||||
/// use reCTBN::ctbn::*;
|
|
||||||
///
|
|
||||||
/// //Create the domain for a discrete node
|
|
||||||
/// let mut domain = BTreeSet::new();
|
|
||||||
/// domain.insert(String::from("A"));
|
|
||||||
/// domain.insert(String::from("B"));
|
|
||||||
///
|
|
||||||
/// //Create the parameters for a discrete node using the domain
|
|
||||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
|
||||||
///
|
|
||||||
/// //Create the node using the parameters
|
|
||||||
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
|
||||||
///
|
|
||||||
/// let mut domain = BTreeSet::new();
|
|
||||||
/// domain.insert(String::from("A"));
|
|
||||||
/// domain.insert(String::from("B"));
|
|
||||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
|
||||||
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
|
||||||
///
|
|
||||||
/// //Initialize a ctbn
|
|
||||||
/// let mut net = CtbnNetwork::new();
|
|
||||||
///
|
|
||||||
/// //Add nodes
|
|
||||||
/// let X1 = net.add_node(X1).unwrap();
|
|
||||||
/// let X2 = net.add_node(X2).unwrap();
|
|
||||||
///
|
|
||||||
/// //Add an edge
|
|
||||||
/// net.add_edge(X1, X2);
|
|
||||||
///
|
|
||||||
/// //Get all the children of node X1
|
|
||||||
/// let cs = net.get_children_set(X1);
|
|
||||||
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
|
||||||
/// ```
|
|
||||||
pub struct CtbnNetwork { |
|
||||||
adj_matrix: Option<Array2<u16>>, |
|
||||||
nodes: Vec<Params>, |
|
||||||
} |
|
||||||
|
|
||||||
impl CtbnNetwork { |
|
||||||
pub fn new() -> CtbnNetwork { |
|
||||||
CtbnNetwork { |
|
||||||
adj_matrix: None, |
|
||||||
nodes: Vec::new(), |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
impl network::Network for CtbnNetwork { |
|
||||||
fn initialize_adj_matrix(&mut self) { |
|
||||||
self.adj_matrix = Some(Array2::<u16>::zeros( |
|
||||||
(self.nodes.len(), self.nodes.len()).f(), |
|
||||||
)); |
|
||||||
} |
|
||||||
|
|
||||||
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { |
|
||||||
n.reset_params(); |
|
||||||
self.adj_matrix = Option::None; |
|
||||||
self.nodes.push(n); |
|
||||||
Ok(self.nodes.len() - 1) |
|
||||||
} |
|
||||||
|
|
||||||
fn add_edge(&mut self, parent: usize, child: usize) { |
|
||||||
if let None = self.adj_matrix { |
|
||||||
self.initialize_adj_matrix(); |
|
||||||
} |
|
||||||
|
|
||||||
if let Some(network) = &mut self.adj_matrix { |
|
||||||
network[[parent, child]] = 1; |
|
||||||
self.nodes[child].reset_params(); |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
|
||||||
0..self.nodes.len() |
|
||||||
} |
|
||||||
|
|
||||||
fn get_number_of_nodes(&self) -> usize { |
|
||||||
self.nodes.len() |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node(&self, node_idx: usize) -> &Params { |
|
||||||
&self.nodes[node_idx] |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { |
|
||||||
&mut self.nodes[node_idx] |
|
||||||
} |
|
||||||
|
|
||||||
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize { |
|
||||||
self.adj_matrix |
|
||||||
.as_ref() |
|
||||||
.unwrap() |
|
||||||
.column(node) |
|
||||||
.iter() |
|
||||||
.enumerate() |
|
||||||
.fold((0, 1), |mut acc, x| { |
|
||||||
if x.1 > &0 { |
|
||||||
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
|
||||||
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
|
||||||
} |
|
||||||
acc |
|
||||||
}) |
|
||||||
.0 |
|
||||||
} |
|
||||||
|
|
||||||
fn get_param_index_from_custom_parent_set( |
|
||||||
&self, |
|
||||||
current_state: &Vec<StateType>, |
|
||||||
parent_set: &BTreeSet<usize>, |
|
||||||
) -> usize { |
|
||||||
parent_set |
|
||||||
.iter() |
|
||||||
.fold((0, 1), |mut acc, x| { |
|
||||||
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; |
|
||||||
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); |
|
||||||
acc |
|
||||||
}) |
|
||||||
.0 |
|
||||||
} |
|
||||||
|
|
||||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
|
||||||
self.adj_matrix |
|
||||||
.as_ref() |
|
||||||
.unwrap() |
|
||||||
.column(node) |
|
||||||
.iter() |
|
||||||
.enumerate() |
|
||||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
|
||||||
.collect() |
|
||||||
} |
|
||||||
|
|
||||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
|
||||||
self.adj_matrix |
|
||||||
.as_ref() |
|
||||||
.unwrap() |
|
||||||
.row(node) |
|
||||||
.iter() |
|
||||||
.enumerate() |
|
||||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
|
||||||
.collect() |
|
||||||
} |
|
||||||
} |
|
@ -1,11 +1,12 @@ |
|||||||
|
#![doc = include_str!("../../README.md")] |
||||||
#![allow(non_snake_case)] |
#![allow(non_snake_case)] |
||||||
#[cfg(test)] |
#[cfg(test)] |
||||||
extern crate approx; |
extern crate approx; |
||||||
|
|
||||||
pub mod ctbn; |
|
||||||
pub mod network; |
|
||||||
pub mod parameter_learning; |
pub mod parameter_learning; |
||||||
pub mod params; |
pub mod params; |
||||||
|
pub mod process; |
||||||
|
pub mod reward; |
||||||
pub mod sampling; |
pub mod sampling; |
||||||
pub mod structure_learning; |
pub mod structure_learning; |
||||||
pub mod tools; |
pub mod tools; |
||||||
|
@ -1,43 +0,0 @@ |
|||||||
use std::collections::BTreeSet; |
|
||||||
|
|
||||||
use thiserror::Error; |
|
||||||
|
|
||||||
use crate::params; |
|
||||||
|
|
||||||
/// Error types for trait Network
|
|
||||||
#[derive(Error, Debug)] |
|
||||||
pub enum NetworkError { |
|
||||||
#[error("Error during node insertion")] |
|
||||||
NodeInsertionError(String), |
|
||||||
} |
|
||||||
|
|
||||||
///Network
|
|
||||||
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
|
|
||||||
pub trait Network { |
|
||||||
fn initialize_adj_matrix(&mut self); |
|
||||||
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; |
|
||||||
fn add_edge(&mut self, parent: usize, child: usize); |
|
||||||
|
|
||||||
///Get all the indices of the nodes contained inside the network
|
|
||||||
fn get_node_indices(&self) -> std::ops::Range<usize>; |
|
||||||
fn get_number_of_nodes(&self) -> usize; |
|
||||||
fn get_node(&self, node_idx: usize) -> ¶ms::Params; |
|
||||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; |
|
||||||
|
|
||||||
///Compute the index that must be used to access the parameters of a node given a specific
|
|
||||||
///configuration of the network. Usually, the only values really used in *current_state* are
|
|
||||||
///the ones in the parent set of the *node*.
|
|
||||||
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) |
|
||||||
-> usize; |
|
||||||
|
|
||||||
///Compute the index that must be used to access the parameters of a node given a specific
|
|
||||||
///configuration of the network and a generic parent_set. Usually, the only values really used
|
|
||||||
///in *current_state* are the ones in the parent set of the *node*.
|
|
||||||
fn get_param_index_from_custom_parent_set( |
|
||||||
&self, |
|
||||||
current_state: &Vec<params::StateType>, |
|
||||||
parent_set: &BTreeSet<usize>, |
|
||||||
) -> usize; |
|
||||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
|
||||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
|
||||||
} |
|
@ -0,0 +1,120 @@ |
|||||||
|
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
|
||||||
|
|
||||||
|
pub mod ctbn; |
||||||
|
pub mod ctmp; |
||||||
|
|
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use thiserror::Error; |
||||||
|
|
||||||
|
use crate::params; |
||||||
|
|
||||||
|
/// Error types for trait Network
|
||||||
|
#[derive(Error, Debug)] |
||||||
|
pub enum NetworkError { |
||||||
|
#[error("Error during node insertion")] |
||||||
|
NodeInsertionError(String), |
||||||
|
} |
||||||
|
|
||||||
|
/// This type is used to represent a specific realization of a generic NetworkProcess
|
||||||
|
pub type NetworkProcessState = Vec<params::StateType>; |
||||||
|
|
||||||
|
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
|
||||||
|
/// as a CTBN).
|
||||||
|
pub trait NetworkProcess: Sync { |
||||||
|
fn initialize_adj_matrix(&mut self); |
||||||
|
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; |
||||||
|
/// Add an **directed edge** between a two nodes of the network.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `parent` - parent node.
|
||||||
|
/// * `child` - child node.
|
||||||
|
fn add_edge(&mut self, parent: usize, child: usize); |
||||||
|
|
||||||
|
/// Get all the indices of the nodes contained inside the network.
|
||||||
|
fn get_node_indices(&self) -> std::ops::Range<usize>; |
||||||
|
|
||||||
|
/// Get the numbers of nodes contained in the network.
|
||||||
|
fn get_number_of_nodes(&self) -> usize; |
||||||
|
|
||||||
|
/// Get the **node param**.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `node_idx` - node index value.
|
||||||
|
///
|
||||||
|
/// # Return
|
||||||
|
///
|
||||||
|
/// * The selected **node param**.
|
||||||
|
fn get_node(&self, node_idx: usize) -> ¶ms::Params; |
||||||
|
|
||||||
|
/// Get the **node param**.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `node_idx` - node index value.
|
||||||
|
///
|
||||||
|
/// # Return
|
||||||
|
///
|
||||||
|
/// * The selected **node mutable param**.
|
||||||
|
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; |
||||||
|
|
||||||
|
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
||||||
|
/// configuration of the network.
|
||||||
|
///
|
||||||
|
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
||||||
|
/// the `node`.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `node` - selected node.
|
||||||
|
/// * `current_state` - current configuration of the network.
|
||||||
|
///
|
||||||
|
/// # Return
|
||||||
|
///
|
||||||
|
/// * Index of the `node` relative to the network.
|
||||||
|
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; |
||||||
|
|
||||||
|
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
||||||
|
/// configuration of the network and a generic `parent_set`.
|
||||||
|
///
|
||||||
|
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
||||||
|
/// the `node`.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `current_state` - current configuration of the network.
|
||||||
|
/// * `parent_set` - parent set of the selected `node`.
|
||||||
|
///
|
||||||
|
/// # Return
|
||||||
|
///
|
||||||
|
/// * Index of the `node` relative to the network.
|
||||||
|
fn get_param_index_from_custom_parent_set( |
||||||
|
&self, |
||||||
|
current_state: &Vec<params::StateType>, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
) -> usize; |
||||||
|
|
||||||
|
/// Get the **parent set** of a given **node**.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `node` - node index value.
|
||||||
|
///
|
||||||
|
/// # Return
|
||||||
|
///
|
||||||
|
/// * The **parent set** of the selected node.
|
||||||
|
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
||||||
|
|
||||||
|
/// Get the **children set** of a given **node**.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `node` - node index value.
|
||||||
|
///
|
||||||
|
/// # Return
|
||||||
|
///
|
||||||
|
/// * The **children set** of the selected node.
|
||||||
|
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
||||||
|
} |
@ -0,0 +1,247 @@ |
|||||||
|
//! Continuous Time Bayesian Network
|
||||||
|
|
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use ndarray::prelude::*; |
||||||
|
|
||||||
|
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; |
||||||
|
use crate::process; |
||||||
|
|
||||||
|
use super::ctmp::CtmpProcess; |
||||||
|
use super::{NetworkProcess, NetworkProcessState}; |
||||||
|
|
||||||
|
/// It represents both the structure and the parameters of a CTBN.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
|
||||||
|
/// * `nodes` - A vector containing all the nodes and their parameters.
|
||||||
|
///
|
||||||
|
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use std::collections::BTreeSet;
|
||||||
|
/// use reCTBN::process::NetworkProcess;
|
||||||
|
/// use reCTBN::params;
|
||||||
|
/// use reCTBN::process::ctbn::*;
|
||||||
|
///
|
||||||
|
/// //Create the domain for a discrete node
|
||||||
|
/// let mut domain = BTreeSet::new();
|
||||||
|
/// domain.insert(String::from("A"));
|
||||||
|
/// domain.insert(String::from("B"));
|
||||||
|
///
|
||||||
|
/// //Create the parameters for a discrete node using the domain
|
||||||
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
||||||
|
///
|
||||||
|
/// //Create the node using the parameters
|
||||||
|
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
||||||
|
///
|
||||||
|
/// let mut domain = BTreeSet::new();
|
||||||
|
/// domain.insert(String::from("A"));
|
||||||
|
/// domain.insert(String::from("B"));
|
||||||
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
||||||
|
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
||||||
|
///
|
||||||
|
/// //Initialize a ctbn
|
||||||
|
/// let mut net = CtbnNetwork::new();
|
||||||
|
///
|
||||||
|
/// //Add nodes
|
||||||
|
/// let X1 = net.add_node(X1).unwrap();
|
||||||
|
/// let X2 = net.add_node(X2).unwrap();
|
||||||
|
///
|
||||||
|
/// //Add an edge
|
||||||
|
/// net.add_edge(X1, X2);
|
||||||
|
///
|
||||||
|
/// //Get all the children of node X1
|
||||||
|
/// let cs = net.get_children_set(X1);
|
||||||
|
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
||||||
|
/// ```
|
||||||
|
pub struct CtbnNetwork { |
||||||
|
adj_matrix: Option<Array2<u16>>, |
||||||
|
nodes: Vec<Params>, |
||||||
|
} |
||||||
|
|
||||||
|
impl CtbnNetwork { |
||||||
|
pub fn new() -> CtbnNetwork { |
||||||
|
CtbnNetwork { |
||||||
|
adj_matrix: None, |
||||||
|
nodes: Vec::new(), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
///Transform the **CTBN** into a **CTMP**
|
||||||
|
///
|
||||||
|
/// # Return
|
||||||
|
///
|
||||||
|
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
|
||||||
|
pub fn amalgamation(&self) -> CtmpProcess { |
||||||
|
let variables_domain = |
||||||
|
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); |
||||||
|
|
||||||
|
let state_space = variables_domain.product(); |
||||||
|
let variables_set = BTreeSet::from_iter(self.get_node_indices()); |
||||||
|
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space)); |
||||||
|
|
||||||
|
for idx_current_state in 0..state_space { |
||||||
|
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); |
||||||
|
let current_state_statetype: NetworkProcessState = current_state |
||||||
|
.iter() |
||||||
|
.map(|x| StateType::Discrete(*x)) |
||||||
|
.collect(); |
||||||
|
for idx_node in 0..self.nodes.len() { |
||||||
|
let p = match self.get_node(idx_node) { |
||||||
|
Params::DiscreteStatesContinousTime(p) => p, |
||||||
|
}; |
||||||
|
for next_node_state in 0..variables_domain[idx_node] { |
||||||
|
let mut next_state = current_state.clone(); |
||||||
|
next_state[idx_node] = next_node_state; |
||||||
|
|
||||||
|
let next_state_statetype: NetworkProcessState = |
||||||
|
next_state.iter().map(|x| StateType::Discrete(*x)).collect(); |
||||||
|
let idx_next_state = self.get_param_index_from_custom_parent_set( |
||||||
|
&next_state_statetype, |
||||||
|
&variables_set, |
||||||
|
); |
||||||
|
amalgamated_cim[[0, idx_current_state, idx_next_state]] += |
||||||
|
p.get_cim().as_ref().unwrap()[[ |
||||||
|
self.get_param_index_network(idx_node, ¤t_state_statetype), |
||||||
|
current_state[idx_node], |
||||||
|
next_node_state, |
||||||
|
]]; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( |
||||||
|
"ctmp".to_string(), |
||||||
|
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), |
||||||
|
); |
||||||
|
|
||||||
|
amalgamated_param.set_cim(amalgamated_cim).unwrap(); |
||||||
|
|
||||||
|
let mut ctmp = CtmpProcess::new(); |
||||||
|
|
||||||
|
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) |
||||||
|
.unwrap(); |
||||||
|
return ctmp; |
||||||
|
} |
||||||
|
|
||||||
|
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> { |
||||||
|
let mut state = state; |
||||||
|
let mut array_state = Array1::zeros(variables_domain.shape()[0]); |
||||||
|
for (idx, var) in variables_domain.indexed_iter() { |
||||||
|
array_state[idx] = state % var; |
||||||
|
state = state / var; |
||||||
|
} |
||||||
|
|
||||||
|
return array_state; |
||||||
|
} |
||||||
|
/// Get the Adjacency Matrix.
|
||||||
|
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> { |
||||||
|
self.adj_matrix.as_ref() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl process::NetworkProcess for CtbnNetwork { |
||||||
|
/// Initialize an Adjacency matrix.
|
||||||
|
fn initialize_adj_matrix(&mut self) { |
||||||
|
self.adj_matrix = Some(Array2::<u16>::zeros( |
||||||
|
(self.nodes.len(), self.nodes.len()).f(), |
||||||
|
)); |
||||||
|
} |
||||||
|
|
||||||
|
/// Add a new node.
|
||||||
|
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> { |
||||||
|
n.reset_params(); |
||||||
|
self.adj_matrix = Option::None; |
||||||
|
self.nodes.push(n); |
||||||
|
Ok(self.nodes.len() - 1) |
||||||
|
} |
||||||
|
|
||||||
|
/// Connect two nodes with a new edge.
|
||||||
|
fn add_edge(&mut self, parent: usize, child: usize) { |
||||||
|
if let None = self.adj_matrix { |
||||||
|
self.initialize_adj_matrix(); |
||||||
|
} |
||||||
|
|
||||||
|
if let Some(network) = &mut self.adj_matrix { |
||||||
|
network[[parent, child]] = 1; |
||||||
|
self.nodes[child].reset_params(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||||
|
0..self.nodes.len() |
||||||
|
} |
||||||
|
|
||||||
|
/// Get the number of nodes of the network.
|
||||||
|
fn get_number_of_nodes(&self) -> usize { |
||||||
|
self.nodes.len() |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node(&self, node_idx: usize) -> &Params { |
||||||
|
&self.nodes[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { |
||||||
|
&mut self.nodes[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
||||||
|
self.adj_matrix |
||||||
|
.as_ref() |
||||||
|
.unwrap() |
||||||
|
.column(node) |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.fold((0, 1), |mut acc, x| { |
||||||
|
if x.1 > &0 { |
||||||
|
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
||||||
|
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
||||||
|
} |
||||||
|
acc |
||||||
|
}) |
||||||
|
.0 |
||||||
|
} |
||||||
|
|
||||||
|
fn get_param_index_from_custom_parent_set( |
||||||
|
&self, |
||||||
|
current_state: &NetworkProcessState, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
) -> usize { |
||||||
|
parent_set |
||||||
|
.iter() |
||||||
|
.fold((0, 1), |mut acc, x| { |
||||||
|
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; |
||||||
|
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); |
||||||
|
acc |
||||||
|
}) |
||||||
|
.0 |
||||||
|
} |
||||||
|
|
||||||
|
/// Get all the parents of the given node.
|
||||||
|
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.adj_matrix |
||||||
|
.as_ref() |
||||||
|
.unwrap() |
||||||
|
.column(node) |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
|
||||||
|
/// Get all the children of the given node.
|
||||||
|
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.adj_matrix |
||||||
|
.as_ref() |
||||||
|
.unwrap() |
||||||
|
.row(node) |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,114 @@ |
|||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use crate::{ |
||||||
|
params::{Params, StateType}, |
||||||
|
process, |
||||||
|
}; |
||||||
|
|
||||||
|
use super::{NetworkProcess, NetworkProcessState}; |
||||||
|
|
||||||
|
pub struct CtmpProcess { |
||||||
|
param: Option<Params>, |
||||||
|
} |
||||||
|
|
||||||
|
impl CtmpProcess { |
||||||
|
pub fn new() -> CtmpProcess { |
||||||
|
CtmpProcess { param: None } |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl NetworkProcess for CtmpProcess { |
||||||
|
fn initialize_adj_matrix(&mut self) { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
|
||||||
|
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> { |
||||||
|
match self.param { |
||||||
|
None => { |
||||||
|
self.param = Some(n); |
||||||
|
Ok(0) |
||||||
|
} |
||||||
|
Some(_) => Err(process::NetworkError::NodeInsertionError( |
||||||
|
"CtmpProcess has only one node".to_string(), |
||||||
|
)), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn add_edge(&mut self, _parent: usize, _child: usize) { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||||
|
match self.param { |
||||||
|
None => 0..0, |
||||||
|
Some(_) => 0..1, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_number_of_nodes(&self) -> usize { |
||||||
|
match self.param { |
||||||
|
None => 0, |
||||||
|
Some(_) => 1, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node(&self, node_idx: usize) -> &crate::params::Params { |
||||||
|
if node_idx == 0 { |
||||||
|
self.param.as_ref().unwrap() |
||||||
|
} else { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { |
||||||
|
if node_idx == 0 { |
||||||
|
self.param.as_mut().unwrap() |
||||||
|
} else { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
||||||
|
if node == 0 { |
||||||
|
match current_state[0] { |
||||||
|
StateType::Discrete(x) => x, |
||||||
|
} |
||||||
|
} else { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_param_index_from_custom_parent_set( |
||||||
|
&self, |
||||||
|
_current_state: &NetworkProcessState, |
||||||
|
_parent_set: &BTreeSet<usize>, |
||||||
|
) -> usize { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
|
||||||
|
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||||
|
match self.param { |
||||||
|
Some(_) => { |
||||||
|
if node == 0 { |
||||||
|
BTreeSet::new() |
||||||
|
} else { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
} |
||||||
|
None => panic!("Uninitialized CtmpProcess"), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||||
|
match self.param { |
||||||
|
Some(_) => { |
||||||
|
if node == 0 { |
||||||
|
BTreeSet::new() |
||||||
|
} else { |
||||||
|
unimplemented!("CtmpProcess has only one node") |
||||||
|
} |
||||||
|
} |
||||||
|
None => panic!("Uninitialized CtmpProcess"), |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,59 @@ |
|||||||
|
pub mod reward_evaluation; |
||||||
|
pub mod reward_function; |
||||||
|
|
||||||
|
use std::collections::HashMap; |
||||||
|
|
||||||
|
use crate::process; |
||||||
|
|
||||||
|
/// Instantiation of reward function and instantaneous reward
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `transition_reward`: reward obtained transitioning from one state to another
|
||||||
|
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)] |
||||||
|
pub struct Reward { |
||||||
|
pub transition_reward: f64, |
||||||
|
pub instantaneous_reward: f64, |
||||||
|
} |
||||||
|
|
||||||
|
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
||||||
|
|
||||||
|
pub trait RewardFunction: Sync { |
||||||
|
/// Given the current state and the previous state, it compute the reward.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
||||||
|
/// * `previous_state`: an optional argument representing the previous state of the network
|
||||||
|
|
||||||
|
fn call( |
||||||
|
&self, |
||||||
|
current_state: &process::NetworkProcessState, |
||||||
|
previous_state: Option<&process::NetworkProcessState>, |
||||||
|
) -> Reward; |
||||||
|
|
||||||
|
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
||||||
|
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||||
|
} |
||||||
|
|
||||||
|
pub trait RewardEvaluation { |
||||||
|
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
) -> HashMap<process::NetworkProcessState, f64>; |
||||||
|
|
||||||
|
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
state: &process::NetworkProcessState, |
||||||
|
) -> f64; |
||||||
|
} |
@ -0,0 +1,205 @@ |
|||||||
|
use std::collections::HashMap; |
||||||
|
|
||||||
|
use rayon::prelude::{IntoParallelIterator, ParallelIterator}; |
||||||
|
use statrs::distribution::ContinuousCDF; |
||||||
|
|
||||||
|
use crate::params::{self, ParamsTrait}; |
||||||
|
use crate::process; |
||||||
|
|
||||||
|
use crate::{ |
||||||
|
process::NetworkProcessState, |
||||||
|
reward::RewardEvaluation, |
||||||
|
sampling::{ForwardSampler, Sampler}, |
||||||
|
}; |
||||||
|
|
||||||
|
pub enum RewardCriteria { |
||||||
|
FiniteHorizon, |
||||||
|
InfiniteHorizon { discount_factor: f64 }, |
||||||
|
} |
||||||
|
|
||||||
|
pub struct MonteCarloReward { |
||||||
|
max_iterations: usize, |
||||||
|
max_err_stop: f64, |
||||||
|
alpha_stop: f64, |
||||||
|
end_time: f64, |
||||||
|
reward_criteria: RewardCriteria, |
||||||
|
seed: Option<u64>, |
||||||
|
} |
||||||
|
|
||||||
|
impl MonteCarloReward { |
||||||
|
pub fn new( |
||||||
|
max_iterations: usize, |
||||||
|
max_err_stop: f64, |
||||||
|
alpha_stop: f64, |
||||||
|
end_time: f64, |
||||||
|
reward_criteria: RewardCriteria, |
||||||
|
seed: Option<u64>, |
||||||
|
) -> MonteCarloReward { |
||||||
|
MonteCarloReward { |
||||||
|
max_iterations, |
||||||
|
max_err_stop, |
||||||
|
alpha_stop, |
||||||
|
end_time, |
||||||
|
reward_criteria, |
||||||
|
seed, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl RewardEvaluation for MonteCarloReward { |
||||||
|
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
) -> HashMap<process::NetworkProcessState, f64> { |
||||||
|
let variables_domain: Vec<Vec<params::StateType>> = network_process |
||||||
|
.get_node_indices() |
||||||
|
.map(|x| match network_process.get_node(x) { |
||||||
|
params::Params::DiscreteStatesContinousTime(x) => (0..x |
||||||
|
.get_reserved_space_as_parent()) |
||||||
|
.map(|s| params::StateType::Discrete(s)) |
||||||
|
.collect(), |
||||||
|
}) |
||||||
|
.collect(); |
||||||
|
|
||||||
|
let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); |
||||||
|
|
||||||
|
(0..n_states) |
||||||
|
.into_par_iter() |
||||||
|
.map(|s| { |
||||||
|
let state: process::NetworkProcessState = variables_domain |
||||||
|
.iter() |
||||||
|
.fold((s, vec![]), |acc, x| { |
||||||
|
let mut acc = acc; |
||||||
|
let idx_s = acc.0 % x.len(); |
||||||
|
acc.1.push(x[idx_s].clone()); |
||||||
|
acc.0 = acc.0 / x.len(); |
||||||
|
acc |
||||||
|
}) |
||||||
|
.1; |
||||||
|
|
||||||
|
let r = self.evaluate_state(network_process, reward_function, &state); |
||||||
|
(state, r) |
||||||
|
}) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
|
||||||
|
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
state: &NetworkProcessState, |
||||||
|
) -> f64 { |
||||||
|
let mut sampler = |
||||||
|
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); |
||||||
|
let mut expected_value = 0.0; |
||||||
|
let mut squared_expected_value = 0.0; |
||||||
|
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); |
||||||
|
|
||||||
|
for i in 0..self.max_iterations { |
||||||
|
sampler.reset(); |
||||||
|
let mut ret = 0.0; |
||||||
|
let mut previous = sampler.next().unwrap(); |
||||||
|
while previous.t < self.end_time { |
||||||
|
let current = sampler.next().unwrap(); |
||||||
|
if current.t > self.end_time { |
||||||
|
let r = reward_function.call(&previous.state, None); |
||||||
|
let discount = match self.reward_criteria { |
||||||
|
RewardCriteria::FiniteHorizon => self.end_time - previous.t, |
||||||
|
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||||
|
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||||
|
- std::f64::consts::E.powf(-discount_factor * self.end_time) |
||||||
|
} |
||||||
|
}; |
||||||
|
ret += discount * r.instantaneous_reward; |
||||||
|
} else { |
||||||
|
let r = reward_function.call(&previous.state, Some(¤t.state)); |
||||||
|
let discount = match self.reward_criteria { |
||||||
|
RewardCriteria::FiniteHorizon => current.t - previous.t, |
||||||
|
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||||
|
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||||
|
- std::f64::consts::E.powf(-discount_factor * current.t) |
||||||
|
} |
||||||
|
}; |
||||||
|
ret += discount * r.instantaneous_reward; |
||||||
|
ret += match self.reward_criteria { |
||||||
|
RewardCriteria::FiniteHorizon => 1.0, |
||||||
|
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||||
|
std::f64::consts::E.powf(-discount_factor * current.t) |
||||||
|
} |
||||||
|
} * r.transition_reward; |
||||||
|
} |
||||||
|
previous = current; |
||||||
|
} |
||||||
|
|
||||||
|
let float_i = i as f64; |
||||||
|
expected_value = |
||||||
|
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); |
||||||
|
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) |
||||||
|
+ ret.powi(2) / (float_i + 1.0); |
||||||
|
|
||||||
|
if i > 2 { |
||||||
|
let var = |
||||||
|
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); |
||||||
|
if self.alpha_stop |
||||||
|
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) |
||||||
|
> 0.0 |
||||||
|
{ |
||||||
|
return expected_value; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
expected_value |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> { |
||||||
|
inner_reward: RE, |
||||||
|
} |
||||||
|
|
||||||
|
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> { |
||||||
|
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> { |
||||||
|
NeighborhoodRelativeReward { inner_reward } |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> { |
||||||
|
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
network_process: &N, |
||||||
|
reward_function: &R, |
||||||
|
) -> HashMap<process::NetworkProcessState, f64> { |
||||||
|
let absolute_reward = self |
||||||
|
.inner_reward |
||||||
|
.evaluate_state_space(network_process, reward_function); |
||||||
|
|
||||||
|
//This approach optimize memory. Maybe optimizing execution time can be better.
|
||||||
|
absolute_reward |
||||||
|
.iter() |
||||||
|
.map(|(k1, v1)| { |
||||||
|
let mut max_val: f64 = 1.0; |
||||||
|
absolute_reward.iter().for_each(|(k2, v2)| { |
||||||
|
let count_diff: usize = k1 |
||||||
|
.iter() |
||||||
|
.zip(k2.iter()) |
||||||
|
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) |
||||||
|
.sum(); |
||||||
|
if count_diff < 2 { |
||||||
|
max_val = max_val.max(v1 / v2); |
||||||
|
} |
||||||
|
}); |
||||||
|
(k1.clone(), max_val) |
||||||
|
}) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
|
||||||
|
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>( |
||||||
|
&self, |
||||||
|
_network_process: &N, |
||||||
|
_reward_function: &R, |
||||||
|
_state: &process::NetworkProcessState, |
||||||
|
) -> f64 { |
||||||
|
unimplemented!(); |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,106 @@ |
|||||||
|
//! Module for dealing with reward functions
|
||||||
|
|
||||||
|
use crate::{ |
||||||
|
params::{self, ParamsTrait}, |
||||||
|
process, |
||||||
|
reward::{Reward, RewardFunction}, |
||||||
|
}; |
||||||
|
|
||||||
|
use ndarray; |
||||||
|
|
||||||
|
/// Reward function over a factored state space
|
||||||
|
///
|
||||||
|
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
|
||||||
|
/// of the underling `NetworkProcess`
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
|
||||||
|
/// reward of a node
|
||||||
|
|
||||||
|
pub struct FactoredRewardFunction { |
||||||
|
transition_reward: Vec<ndarray::Array2<f64>>, |
||||||
|
instantaneous_reward: Vec<ndarray::Array1<f64>>, |
||||||
|
} |
||||||
|
|
||||||
|
impl FactoredRewardFunction { |
||||||
|
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
||||||
|
&self.transition_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
||||||
|
&mut self.transition_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
||||||
|
&self.instantaneous_reward[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
||||||
|
&mut self.instantaneous_reward[node_idx] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl RewardFunction for FactoredRewardFunction { |
||||||
|
fn call( |
||||||
|
&self, |
||||||
|
current_state: &process::NetworkProcessState, |
||||||
|
previous_state: Option<&process::NetworkProcessState>, |
||||||
|
) -> Reward { |
||||||
|
let instantaneous_reward: f64 = current_state |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.map(|(idx, x)| { |
||||||
|
let x = match x { |
||||||
|
params::StateType::Discrete(x) => x, |
||||||
|
}; |
||||||
|
self.instantaneous_reward[idx][*x] |
||||||
|
}) |
||||||
|
.sum(); |
||||||
|
if let Some(previous_state) = previous_state { |
||||||
|
let transition_reward = previous_state |
||||||
|
.iter() |
||||||
|
.zip(current_state.iter()) |
||||||
|
.enumerate() |
||||||
|
.find_map(|(idx, (p, c))| -> Option<f64> { |
||||||
|
let p = match p { |
||||||
|
params::StateType::Discrete(p) => p, |
||||||
|
}; |
||||||
|
let c = match c { |
||||||
|
params::StateType::Discrete(c) => c, |
||||||
|
}; |
||||||
|
if p != c { |
||||||
|
Some(self.transition_reward[idx][[*p, *c]]) |
||||||
|
} else { |
||||||
|
None |
||||||
|
} |
||||||
|
}) |
||||||
|
.unwrap_or(0.0); |
||||||
|
Reward { |
||||||
|
transition_reward, |
||||||
|
instantaneous_reward, |
||||||
|
} |
||||||
|
} else { |
||||||
|
Reward { |
||||||
|
transition_reward: 0.0, |
||||||
|
instantaneous_reward, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
||||||
|
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
||||||
|
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; |
||||||
|
for i in p.get_node_indices() { |
||||||
|
//This works only for discrete nodes!
|
||||||
|
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
||||||
|
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
||||||
|
transition_reward.push(ndarray::Array2::zeros((size, size))); |
||||||
|
} |
||||||
|
|
||||||
|
FactoredRewardFunction { |
||||||
|
transition_reward, |
||||||
|
instantaneous_reward, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -1,11 +1,13 @@ |
|||||||
|
//! Learn the structure of the network.
|
||||||
|
|
||||||
pub mod constraint_based_algorithm; |
pub mod constraint_based_algorithm; |
||||||
pub mod hypothesis_test; |
pub mod hypothesis_test; |
||||||
pub mod score_based_algorithm; |
pub mod score_based_algorithm; |
||||||
pub mod score_function; |
pub mod score_function; |
||||||
use crate::{network, tools}; |
use crate::{process, tools::Dataset}; |
||||||
|
|
||||||
pub trait StructureLearningAlgorithm { |
pub trait StructureLearningAlgorithm { |
||||||
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T |
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
||||||
where |
where |
||||||
T: network::Network; |
T: process::NetworkProcess; |
||||||
} |
} |
||||||
|
@ -1,3 +1,348 @@ |
|||||||
//pub struct CTPC {
|
//! Module containing constraint based algorithms like CTPC and Hiton.
|
||||||
//
|
|
||||||
//}
|
use crate::params::Params; |
||||||
|
use itertools::Itertools; |
||||||
|
use rayon::iter::{IntoParallelIterator, ParallelIterator}; |
||||||
|
use rayon::prelude::ParallelExtend; |
||||||
|
use std::collections::{BTreeSet, HashMap}; |
||||||
|
use std::mem; |
||||||
|
use std::usize; |
||||||
|
|
||||||
|
use super::hypothesis_test::*; |
||||||
|
use crate::parameter_learning::ParameterLearning; |
||||||
|
use crate::process; |
||||||
|
use crate::structure_learning::StructureLearningAlgorithm; |
||||||
|
use crate::tools::Dataset; |
||||||
|
|
||||||
|
pub struct Cache<'a, P: ParameterLearning> { |
||||||
|
parameter_learning: &'a P, |
||||||
|
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>, |
||||||
|
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>, |
||||||
|
parent_set_size_small: usize, |
||||||
|
} |
||||||
|
|
||||||
|
impl<'a, P: ParameterLearning> Cache<'a, P> { |
||||||
|
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
||||||
|
Cache { |
||||||
|
parameter_learning, |
||||||
|
cache_persistent_small: HashMap::new(), |
||||||
|
cache_persistent_big: HashMap::new(), |
||||||
|
parent_set_size_small: 0, |
||||||
|
} |
||||||
|
} |
||||||
|
pub fn fit<T: process::NetworkProcess>( |
||||||
|
&mut self, |
||||||
|
net: &T, |
||||||
|
dataset: &Dataset, |
||||||
|
node: usize, |
||||||
|
parent_set: Option<BTreeSet<usize>>, |
||||||
|
) -> Params { |
||||||
|
let parent_set_len = parent_set.as_ref().unwrap().len(); |
||||||
|
if parent_set_len > self.parent_set_size_small + 1 { |
||||||
|
//self.cache_persistent_small = self.cache_persistent_big;
|
||||||
|
mem::swap( |
||||||
|
&mut self.cache_persistent_small, |
||||||
|
&mut self.cache_persistent_big, |
||||||
|
); |
||||||
|
self.cache_persistent_big = HashMap::new(); |
||||||
|
self.parent_set_size_small += 1; |
||||||
|
} |
||||||
|
|
||||||
|
if parent_set_len > self.parent_set_size_small { |
||||||
|
match self.cache_persistent_big.get(&parent_set) { |
||||||
|
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
||||||
|
// not cloning requires a minor and reasoned refactoring across the library
|
||||||
|
Some(params) => params.clone(), |
||||||
|
None => { |
||||||
|
let params = |
||||||
|
self.parameter_learning |
||||||
|
.fit(net, dataset, node, parent_set.clone()); |
||||||
|
self.cache_persistent_big.insert(parent_set, params.clone()); |
||||||
|
params |
||||||
|
} |
||||||
|
} |
||||||
|
} else { |
||||||
|
match self.cache_persistent_small.get(&parent_set) { |
||||||
|
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
||||||
|
// not cloning requires a minor and reasoned refactoring across the library
|
||||||
|
Some(params) => params.clone(), |
||||||
|
None => { |
||||||
|
let params = |
||||||
|
self.parameter_learning |
||||||
|
.fit(net, dataset, node, parent_set.clone()); |
||||||
|
self.cache_persistent_small |
||||||
|
.insert(parent_set, params.clone()); |
||||||
|
params |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/// Continuous-Time Peter Clark algorithm.
|
||||||
|
///
|
||||||
|
/// A method to learn the structure of the network.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
|
||||||
|
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
|
||||||
|
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use std::collections::BTreeSet;
|
||||||
|
/// # use ndarray::{arr1, arr2, arr3};
|
||||||
|
/// # use reCTBN::params;
|
||||||
|
/// # use reCTBN::tools::trajectory_generator;
|
||||||
|
/// # use reCTBN::process::NetworkProcess;
|
||||||
|
/// # use reCTBN::process::ctbn::CtbnNetwork;
|
||||||
|
/// use reCTBN::parameter_learning::BayesianApproach;
|
||||||
|
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
|
||||||
|
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
|
||||||
|
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
|
||||||
|
/// #
|
||||||
|
/// # // Create the domain for a discrete node
|
||||||
|
/// # let mut domain = BTreeSet::new();
|
||||||
|
/// # domain.insert(String::from("A"));
|
||||||
|
/// # domain.insert(String::from("B"));
|
||||||
|
/// # domain.insert(String::from("C"));
|
||||||
|
/// # // Create the parameters for a discrete node using the domain
|
||||||
|
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
|
||||||
|
/// # //Create the node n1 using the parameters
|
||||||
|
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
|
||||||
|
/// #
|
||||||
|
/// # let mut domain = BTreeSet::new();
|
||||||
|
/// # domain.insert(String::from("D"));
|
||||||
|
/// # domain.insert(String::from("E"));
|
||||||
|
/// # domain.insert(String::from("F"));
|
||||||
|
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
|
||||||
|
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
|
||||||
|
/// #
|
||||||
|
/// # let mut domain = BTreeSet::new();
|
||||||
|
/// # domain.insert(String::from("G"));
|
||||||
|
/// # domain.insert(String::from("H"));
|
||||||
|
/// # domain.insert(String::from("I"));
|
||||||
|
/// # domain.insert(String::from("F"));
|
||||||
|
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
|
||||||
|
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
|
||||||
|
/// #
|
||||||
|
/// # // Initialize a ctbn
|
||||||
|
/// # let mut net = CtbnNetwork::new();
|
||||||
|
/// #
|
||||||
|
/// # // Add the nodes and their edges
|
||||||
|
/// # let n1 = net.add_node(n1).unwrap();
|
||||||
|
/// # let n2 = net.add_node(n2).unwrap();
|
||||||
|
/// # let n3 = net.add_node(n3).unwrap();
|
||||||
|
/// # net.add_edge(n1, n2);
|
||||||
|
/// # net.add_edge(n1, n3);
|
||||||
|
/// # net.add_edge(n2, n3);
|
||||||
|
/// #
|
||||||
|
/// # match &mut net.get_node_mut(n1) {
|
||||||
|
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||||
|
/// # assert_eq!(
|
||||||
|
/// # Ok(()),
|
||||||
|
/// # param.set_cim(arr3(&[
|
||||||
|
/// # [
|
||||||
|
/// # [-3.0, 2.0, 1.0],
|
||||||
|
/// # [1.5, -2.0, 0.5],
|
||||||
|
/// # [0.4, 0.6, -1.0]
|
||||||
|
/// # ],
|
||||||
|
/// # ]))
|
||||||
|
/// # );
|
||||||
|
/// # }
|
||||||
|
/// # }
|
||||||
|
/// #
|
||||||
|
/// # match &mut net.get_node_mut(n2) {
|
||||||
|
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||||
|
/// # assert_eq!(
|
||||||
|
/// # Ok(()),
|
||||||
|
/// # param.set_cim(arr3(&[
|
||||||
|
/// # [
|
||||||
|
/// # [-1.0, 0.5, 0.5],
|
||||||
|
/// # [3.0, -4.0, 1.0],
|
||||||
|
/// # [0.9, 0.1, -1.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-6.0, 2.0, 4.0],
|
||||||
|
/// # [1.5, -2.0, 0.5],
|
||||||
|
/// # [3.0, 1.0, -4.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-1.0, 0.1, 0.9],
|
||||||
|
/// # [2.0, -2.5, 0.5],
|
||||||
|
/// # [0.9, 0.1, -1.0]
|
||||||
|
/// # ],
|
||||||
|
/// # ]))
|
||||||
|
/// # );
|
||||||
|
/// # }
|
||||||
|
/// # }
|
||||||
|
/// #
|
||||||
|
/// # match &mut net.get_node_mut(n3) {
|
||||||
|
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||||
|
/// # assert_eq!(
|
||||||
|
/// # Ok(()),
|
||||||
|
/// # param.set_cim(arr3(&[
|
||||||
|
/// # [
|
||||||
|
/// # [-1.0, 0.5, 0.3, 0.2],
|
||||||
|
/// # [0.5, -4.0, 2.5, 1.0],
|
||||||
|
/// # [2.5, 0.5, -4.0, 1.0],
|
||||||
|
/// # [0.7, 0.2, 0.1, -1.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-6.0, 2.0, 3.0, 1.0],
|
||||||
|
/// # [1.5, -3.0, 0.5, 1.0],
|
||||||
|
/// # [2.0, 1.3, -5.0, 1.7],
|
||||||
|
/// # [2.5, 0.5, 1.0, -4.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-1.3, 0.3, 0.1, 0.9],
|
||||||
|
/// # [1.4, -4.0, 0.5, 2.1],
|
||||||
|
/// # [1.0, 1.5, -3.0, 0.5],
|
||||||
|
/// # [0.4, 0.3, 0.1, -0.8]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-2.0, 1.0, 0.7, 0.3],
|
||||||
|
/// # [1.3, -5.9, 2.7, 1.9],
|
||||||
|
/// # [2.0, 1.5, -4.0, 0.5],
|
||||||
|
/// # [0.2, 0.7, 0.1, -1.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-6.0, 1.0, 2.0, 3.0],
|
||||||
|
/// # [0.5, -3.0, 1.0, 1.5],
|
||||||
|
/// # [1.4, 2.1, -4.3, 0.8],
|
||||||
|
/// # [0.5, 1.0, 2.5, -4.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-1.3, 0.9, 0.3, 0.1],
|
||||||
|
/// # [0.1, -1.3, 0.2, 1.0],
|
||||||
|
/// # [0.5, 1.0, -3.0, 1.5],
|
||||||
|
/// # [0.1, 0.4, 0.3, -0.8]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-2.0, 1.0, 0.6, 0.4],
|
||||||
|
/// # [2.6, -7.1, 1.4, 3.1],
|
||||||
|
/// # [5.0, 1.0, -8.0, 2.0],
|
||||||
|
/// # [1.4, 0.4, 0.2, -2.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-3.0, 1.0, 1.5, 0.5],
|
||||||
|
/// # [3.0, -6.0, 1.0, 2.0],
|
||||||
|
/// # [0.3, 0.5, -1.9, 1.1],
|
||||||
|
/// # [5.0, 1.0, 2.0, -8.0]
|
||||||
|
/// # ],
|
||||||
|
/// # [
|
||||||
|
/// # [-2.6, 0.6, 0.2, 1.8],
|
||||||
|
/// # [2.0, -6.0, 3.0, 1.0],
|
||||||
|
/// # [0.1, 0.5, -1.3, 0.7],
|
||||||
|
/// # [0.8, 0.6, 0.2, -1.6]
|
||||||
|
/// # ],
|
||||||
|
/// # ]))
|
||||||
|
/// # );
|
||||||
|
/// # }
|
||||||
|
/// # }
|
||||||
|
/// #
|
||||||
|
/// # // Generate the trajectory
|
||||||
|
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
|
||||||
|
///
|
||||||
|
/// // Initialize the hypothesis tests to pass to the CTPC with their
|
||||||
|
/// // respective significance level `alpha`
|
||||||
|
/// let f = F::new(1e-6);
|
||||||
|
/// let chi_sq = ChiSquare::new(1e-4);
|
||||||
|
/// // Use the bayesian approach to learn the parameters
|
||||||
|
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
|
||||||
|
///
|
||||||
|
/// //Initialize CTPC
|
||||||
|
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
|
||||||
|
///
|
||||||
|
/// // Learn the structure of the network from the generated trajectory
|
||||||
|
/// let net = ctpc.fit_transform(net, &data);
|
||||||
|
/// #
|
||||||
|
/// # // Compare the generated network with the original one
|
||||||
|
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
|
||||||
|
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
|
||||||
|
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
|
||||||
|
/// ```
|
||||||
|
pub struct CTPC<P: ParameterLearning> { |
||||||
|
parameter_learning: P, |
||||||
|
Ftest: F, |
||||||
|
Chi2test: ChiSquare, |
||||||
|
} |
||||||
|
|
||||||
|
impl<P: ParameterLearning> CTPC<P> { |
||||||
|
pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> { |
||||||
|
CTPC { |
||||||
|
parameter_learning, |
||||||
|
Ftest, |
||||||
|
Chi2test, |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> { |
||||||
|
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
||||||
|
where |
||||||
|
T: process::NetworkProcess, |
||||||
|
{ |
||||||
|
//Check the coherence between dataset and network
|
||||||
|
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { |
||||||
|
panic!("Dataset and Network must have the same number of variables.") |
||||||
|
} |
||||||
|
|
||||||
|
//Make the network mutable.
|
||||||
|
let mut net = net; |
||||||
|
|
||||||
|
net.initialize_adj_matrix(); |
||||||
|
|
||||||
|
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![]; |
||||||
|
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { |
||||||
|
let mut cache = Cache::new(&self.parameter_learning); |
||||||
|
let mut candidate_parent_set: BTreeSet<usize> = net |
||||||
|
.get_node_indices() |
||||||
|
.into_iter() |
||||||
|
.filter(|x| x != &child_node) |
||||||
|
.collect(); |
||||||
|
let mut separation_set_size = 0; |
||||||
|
while separation_set_size < candidate_parent_set.len() { |
||||||
|
let mut candidate_parent_set_TMP = candidate_parent_set.clone(); |
||||||
|
for parent_node in candidate_parent_set.iter() { |
||||||
|
for separation_set in candidate_parent_set |
||||||
|
.iter() |
||||||
|
.filter(|x| x != &parent_node) |
||||||
|
.map(|x| *x) |
||||||
|
.combinations(separation_set_size) |
||||||
|
{ |
||||||
|
let separation_set = separation_set.into_iter().collect(); |
||||||
|
if self.Ftest.call( |
||||||
|
&net, |
||||||
|
child_node, |
||||||
|
*parent_node, |
||||||
|
&separation_set, |
||||||
|
dataset, |
||||||
|
&mut cache, |
||||||
|
) && self.Chi2test.call( |
||||||
|
&net, |
||||||
|
child_node, |
||||||
|
*parent_node, |
||||||
|
&separation_set, |
||||||
|
dataset, |
||||||
|
&mut cache, |
||||||
|
) { |
||||||
|
candidate_parent_set_TMP.remove(parent_node); |
||||||
|
break; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
candidate_parent_set = candidate_parent_set_TMP; |
||||||
|
separation_set_size += 1; |
||||||
|
} |
||||||
|
(child_node, candidate_parent_set) |
||||||
|
})); |
||||||
|
for (child_node, candidate_parent_set) in learned_parent_sets { |
||||||
|
for parent_node in candidate_parent_set.iter() { |
||||||
|
net.add_edge(*parent_node, child_node); |
||||||
|
} |
||||||
|
} |
||||||
|
net |
||||||
|
} |
||||||
|
} |
||||||
|
@ -0,0 +1,127 @@ |
|||||||
|
mod utils; |
||||||
|
|
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use reCTBN::{ |
||||||
|
params, |
||||||
|
params::ParamsTrait, |
||||||
|
process::{ctmp::*, NetworkProcess}, |
||||||
|
}; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn define_simple_ctmp() { |
||||||
|
let _ = CtmpProcess::new(); |
||||||
|
assert!(true); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn add_node_to_ctmp() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn add_two_nodes_to_ctmp() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let _n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||||
|
|
||||||
|
match n2 { |
||||||
|
Ok(_) => assert!(false), |
||||||
|
Err(_) => assert!(true), |
||||||
|
}; |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
#[should_panic] |
||||||
|
fn add_edge_to_ctmp() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let _n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||||
|
|
||||||
|
net.add_edge(0, 1) |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn childen_and_parents() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let _n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
assert_eq!(0, net.get_parent_set(0).len()); |
||||||
|
assert_eq!(0, net.get_children_set(0).len()); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
#[should_panic] |
||||||
|
fn get_childen_panic() { |
||||||
|
let net = CtmpProcess::new(); |
||||||
|
net.get_children_set(0); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
#[should_panic] |
||||||
|
fn get_childen_panic2() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let _n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
net.get_children_set(1); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
#[should_panic] |
||||||
|
fn get_parent_panic() { |
||||||
|
let net = CtmpProcess::new(); |
||||||
|
net.get_parent_set(0); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
#[should_panic] |
||||||
|
fn get_parent_panic2() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let _n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
net.get_parent_set(1); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn compute_index_ctmp() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node( |
||||||
|
String::from("n1"), |
||||||
|
10, |
||||||
|
)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); |
||||||
|
assert_eq!(6, idx); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
#[should_panic] |
||||||
|
fn compute_index_from_custom_parent_set_ctmp() { |
||||||
|
let mut net = CtmpProcess::new(); |
||||||
|
let _n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node( |
||||||
|
String::from("n1"), |
||||||
|
10, |
||||||
|
)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let _idx = net.get_param_index_from_custom_parent_set( |
||||||
|
&vec![params::StateType::Discrete(6)], |
||||||
|
&BTreeSet::from([0]) |
||||||
|
); |
||||||
|
} |
@ -0,0 +1,122 @@ |
|||||||
|
mod utils; |
||||||
|
|
||||||
|
use approx::assert_abs_diff_eq; |
||||||
|
use ndarray::*; |
||||||
|
use reCTBN::{ |
||||||
|
params, |
||||||
|
process::{ctbn::*, NetworkProcess, NetworkProcessState}, |
||||||
|
reward::{reward_evaluation::*, reward_function::*, *}, |
||||||
|
}; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_binary_node_mc() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1) |
||||||
|
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1) |
||||||
|
.assign(&arr1(&[3.0, 3.0])); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
net.initialize_adj_matrix(); |
||||||
|
|
||||||
|
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||||
|
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||||
|
|
||||||
|
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||||
|
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||||
|
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||||
|
|
||||||
|
let rst = mc.evaluate_state_space(&net, &rf); |
||||||
|
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); |
||||||
|
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); |
||||||
|
|
||||||
|
|
||||||
|
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); |
||||||
|
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||||
|
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||||
|
|
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_chain_mc() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let n3 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
net.add_edge(n1, n2); |
||||||
|
net.add_edge(n2, n3); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param |
||||||
|
.set_cim(arr3(&[ |
||||||
|
[[-0.01, 0.01], [5.0, -5.0]], |
||||||
|
[[-5.0, 5.0], [0.01, -0.01]], |
||||||
|
])) |
||||||
|
.unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
match &mut net.get_node_mut(n3) { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param |
||||||
|
.set_cim(arr3(&[ |
||||||
|
[[-0.01, 0.01], [5.0, -5.0]], |
||||||
|
[[-5.0, 5.0], [0.01, -0.01]], |
||||||
|
])) |
||||||
|
.unwrap(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n2) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n3) |
||||||
|
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||||
|
|
||||||
|
let s000: NetworkProcessState = vec![ |
||||||
|
params::StateType::Discrete(1), |
||||||
|
params::StateType::Discrete(0), |
||||||
|
params::StateType::Discrete(0), |
||||||
|
]; |
||||||
|
|
||||||
|
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||||
|
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); |
||||||
|
|
||||||
|
let rst = mc.evaluate_state_space(&net, &rf); |
||||||
|
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); |
||||||
|
|
||||||
|
} |
@ -0,0 +1,117 @@ |
|||||||
|
mod utils; |
||||||
|
|
||||||
|
use ndarray::*; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params}; |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_binary_node() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
||||||
|
|
||||||
|
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||||
|
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||||
|
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn simple_factored_reward_function_ternary_node() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||||
|
.unwrap(); |
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||||
|
|
||||||
|
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||||
|
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||||
|
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||||
|
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||||
|
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); |
||||||
|
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn factored_reward_function_two_nodes() { |
||||||
|
let mut net = CtbnNetwork::new(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
|
||||||
|
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||||
|
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||||
|
|
||||||
|
|
||||||
|
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||||
|
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); |
||||||
|
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; |
||||||
|
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; |
||||||
|
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; |
||||||
|
|
||||||
|
|
||||||
|
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; |
||||||
|
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; |
||||||
|
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; |
||||||
|
|
||||||
|
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||||
|
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); |
||||||
|
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); |
||||||
|
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); |
||||||
|
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); |
||||||
|
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||||
|
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); |
||||||
|
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||||
|
|
||||||
|
|
||||||
|
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); |
||||||
|
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); |
||||||
|
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); |
||||||
|
} |
Loading…
Reference in new issue