Compare commits
91 Commits
62-feature
...
dev
Author | SHA1 | Date |
---|---|---|
Meliurwen | f6873c52c2 | 2 years ago |
Meliurwen | d45c96a38d | 2 years ago |
Meliurwen | 57ae851470 | 2 years ago |
Meliurwen | 4a9445385c | 2 years ago |
Meliurwen | 0eb427e5cf | 2 years ago |
Meliurwen | 7a3ac6c9ab | 2 years ago |
Meliurwen | 4fd0ee0407 | 2 years ago |
Meliurwen | c4da4ceadd | 2 years ago |
AlessandroBregoli | e638a627bb | 2 years ago |
AlessandroBregoli | 776b9aa030 | 2 years ago |
AlessandroBregoli | adb0f99419 | 2 years ago |
Meliurwen | b7fc23ed8a | 2 years ago |
Meliurwen | 0639a755d0 | 2 years ago |
Meliurwen | 4884010ea9 | 2 years ago |
Meliurwen | 430033afdb | 2 years ago |
Meliurwen | e08d12ac1f | 2 years ago |
Meliurwen | a01a9ef201 | 2 years ago |
Meliurwen | 097dc25030 | 2 years ago |
Meliurwen | 0f61cbee4c | 2 years ago |
Meliurwen | f4e3c98c79 | 2 years ago |
Meliurwen | d6f0fb9623 | 2 years ago |
Meliurwen | 434e671f0a | 2 years ago |
Meliurwen | 4b994d8a19 | 2 years ago |
Meliurwen | a077f738ee | 2 years ago |
Meliurwen | 49c2c55f61 | 2 years ago |
Meliurwen | c2df26c3e6 | 2 years ago |
Meliurwen | 7ec56914d9 | 2 years ago |
Alessandro Bregoli | 5d676be180 | 2 years ago |
Alessandro Bregoli | ff235b4b77 | 2 years ago |
Alessandro Bregoli | a104d1fbf9 | 2 years ago |
Meliurwen | 0bd325d349 | 2 years ago |
Meliurwen | b8938a934f | 2 years ago |
Meliurwen | 5632833963 | 2 years ago |
Meliurwen | 0e1cca0456 | 2 years ago |
Meliurwen | 4d3f9518e4 | 2 years ago |
Meliurwen | 867bf02934 | 2 years ago |
Meliurwen | a0da3e2fe8 | 2 years ago |
Meliurwen | ea3e406bf1 | 2 years ago |
Meliurwen | 19856195c3 | 2 years ago |
Meliurwen | ea5df7cad6 | 2 years ago |
Meliurwen | 468ebf09cc | 2 years ago |
Meliurwen | 8d0f9db289 | 2 years ago |
Meliurwen | 6d42d8a805 | 2 years ago |
Meliurwen | 6d952f8c07 | 2 years ago |
AlessandroBregoli | 9284ca5dd2 | 2 years ago |
AlessandroBregoli | 414aa31867 | 2 years ago |
AlessandroBregoli | 687f19ff1f | 2 years ago |
AlessandroBregoli | bb239aaa0c | 2 years ago |
AlessandroBregoli | cecf16a771 | 2 years ago |
AlessandroBregoli | 4fc5c1d4b5 | 2 years ago |
Meliurwen | 6e90458418 | 2 years ago |
Meliurwen | cac19b1756 | 2 years ago |
Meliurwen | 2e49df0266 | 2 years ago |
Meliurwen | 3f80f07e9f | 2 years ago |
AlessandroBregoli | bcb64a161a | 2 years ago |
AlessandroBregoli | 68ef7ea7c3 | 2 years ago |
AlessandroBregoli | f6015acce9 | 2 years ago |
AlessandroBregoli | 055eb7088e | 2 years ago |
AlessandroBregoli | 1878f687d6 | 2 years ago |
AlessandroBregoli | fd3b1ecfea | 2 years ago |
AlessandroBregoli | 38e744e034 | 2 years ago |
AlessandroBregoli | 44eaf8713f | 2 years ago |
AlessandroBregoli | 28ed1a40b3 | 2 years ago |
AlessandroBregoli | 4a7a8c5fba | 2 years ago |
Meliurwen | 3a0151a9f6 | 2 years ago |
AlessandroBregoli | 7c3cba50d4 | 2 years ago |
Meliurwen | a2c5800891 | 2 years ago |
Meliurwen | 9fbdf25149 | 2 years ago |
AlessandroBregoli | ed5471c7cf | 2 years ago |
Meliurwen | c08f4e1985 | 2 years ago |
Meliurwen | ec72a6a2f9 | 2 years ago |
Meliurwen | 713b8a8013 | 2 years ago |
Meliurwen | a92b605daa | 2 years ago |
Meliurwen | ce139afdb6 | 2 years ago |
Meliurwen | 0ae2168a94 | 2 years ago |
Meliurwen | 832922922a | 2 years ago |
Meliurwen | 3522e1b6f6 | 2 years ago |
Meliurwen | 245b3b5d45 | 2 years ago |
Meliurwen | 08623e28d4 | 2 years ago |
Meliurwen | f7165d0345 | 2 years ago |
Meliurwen | 9ca8973550 | 2 years ago |
Meliurwen | 064b582833 | 2 years ago |
Meliurwen | 2153f46758 | 2 years ago |
Meliurwen | ccced92149 | 2 years ago |
Meliurwen | 616d5ec3d5 | 2 years ago |
Meliurwen | 672de56c31 | 2 years ago |
Meliurwen | 1cad41f7f4 | 2 years ago |
Meliurwen | 174e85734e | 2 years ago |
Meliurwen | c247da6bc0 | 2 years ago |
Meliurwen | 761e9da436 | 2 years ago |
Meliurwen | 8953471570 | 2 years ago |
@ -0,0 +1,176 @@ |
||||
Apache License |
||||
Version 2.0, January 2004 |
||||
http://www.apache.org/licenses/ |
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION |
||||
|
||||
1. Definitions. |
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, |
||||
and distribution as defined by Sections 1 through 9 of this document. |
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by |
||||
the copyright owner that is granting the License. |
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all |
||||
other entities that control, are controlled by, or are under common |
||||
control with that entity. For the purposes of this definition, |
||||
"control" means (i) the power, direct or indirect, to cause the |
||||
direction or management of such entity, whether by contract or |
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the |
||||
outstanding shares, or (iii) beneficial ownership of such entity. |
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity |
||||
exercising permissions granted by this License. |
||||
|
||||
"Source" form shall mean the preferred form for making modifications, |
||||
including but not limited to software source code, documentation |
||||
source, and configuration files. |
||||
|
||||
"Object" form shall mean any form resulting from mechanical |
||||
transformation or translation of a Source form, including but |
||||
not limited to compiled object code, generated documentation, |
||||
and conversions to other media types. |
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or |
||||
Object form, made available under the License, as indicated by a |
||||
copyright notice that is included in or attached to the work |
||||
(an example is provided in the Appendix below). |
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object |
||||
form, that is based on (or derived from) the Work and for which the |
||||
editorial revisions, annotations, elaborations, or other modifications |
||||
represent, as a whole, an original work of authorship. For the purposes |
||||
of this License, Derivative Works shall not include works that remain |
||||
separable from, or merely link (or bind by name) to the interfaces of, |
||||
the Work and Derivative Works thereof. |
||||
|
||||
"Contribution" shall mean any work of authorship, including |
||||
the original version of the Work and any modifications or additions |
||||
to that Work or Derivative Works thereof, that is intentionally |
||||
submitted to Licensor for inclusion in the Work by the copyright owner |
||||
or by an individual or Legal Entity authorized to submit on behalf of |
||||
the copyright owner. For the purposes of this definition, "submitted" |
||||
means any form of electronic, verbal, or written communication sent |
||||
to the Licensor or its representatives, including but not limited to |
||||
communication on electronic mailing lists, source code control systems, |
||||
and issue tracking systems that are managed by, or on behalf of, the |
||||
Licensor for the purpose of discussing and improving the Work, but |
||||
excluding communication that is conspicuously marked or otherwise |
||||
designated in writing by the copyright owner as "Not a Contribution." |
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity |
||||
on behalf of whom a Contribution has been received by Licensor and |
||||
subsequently incorporated within the Work. |
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of |
||||
this License, each Contributor hereby grants to You a perpetual, |
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
||||
copyright license to reproduce, prepare Derivative Works of, |
||||
publicly display, publicly perform, sublicense, and distribute the |
||||
Work and such Derivative Works in Source or Object form. |
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of |
||||
this License, each Contributor hereby grants to You a perpetual, |
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
||||
(except as stated in this section) patent license to make, have made, |
||||
use, offer to sell, sell, import, and otherwise transfer the Work, |
||||
where such license applies only to those patent claims licensable |
||||
by such Contributor that are necessarily infringed by their |
||||
Contribution(s) alone or by combination of their Contribution(s) |
||||
with the Work to which such Contribution(s) was submitted. If You |
||||
institute patent litigation against any entity (including a |
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work |
||||
or a Contribution incorporated within the Work constitutes direct |
||||
or contributory patent infringement, then any patent licenses |
||||
granted to You under this License for that Work shall terminate |
||||
as of the date such litigation is filed. |
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the |
||||
Work or Derivative Works thereof in any medium, with or without |
||||
modifications, and in Source or Object form, provided that You |
||||
meet the following conditions: |
||||
|
||||
(a) You must give any other recipients of the Work or |
||||
Derivative Works a copy of this License; and |
||||
|
||||
(b) You must cause any modified files to carry prominent notices |
||||
stating that You changed the files; and |
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works |
||||
that You distribute, all copyright, patent, trademark, and |
||||
attribution notices from the Source form of the Work, |
||||
excluding those notices that do not pertain to any part of |
||||
the Derivative Works; and |
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its |
||||
distribution, then any Derivative Works that You distribute must |
||||
include a readable copy of the attribution notices contained |
||||
within such NOTICE file, excluding those notices that do not |
||||
pertain to any part of the Derivative Works, in at least one |
||||
of the following places: within a NOTICE text file distributed |
||||
as part of the Derivative Works; within the Source form or |
||||
documentation, if provided along with the Derivative Works; or, |
||||
within a display generated by the Derivative Works, if and |
||||
wherever such third-party notices normally appear. The contents |
||||
of the NOTICE file are for informational purposes only and |
||||
do not modify the License. You may add Your own attribution |
||||
notices within Derivative Works that You distribute, alongside |
||||
or as an addendum to the NOTICE text from the Work, provided |
||||
that such additional attribution notices cannot be construed |
||||
as modifying the License. |
||||
|
||||
You may add Your own copyright statement to Your modifications and |
||||
may provide additional or different license terms and conditions |
||||
for use, reproduction, or distribution of Your modifications, or |
||||
for any such Derivative Works as a whole, provided Your use, |
||||
reproduction, and distribution of the Work otherwise complies with |
||||
the conditions stated in this License. |
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, |
||||
any Contribution intentionally submitted for inclusion in the Work |
||||
by You to the Licensor shall be under the terms and conditions of |
||||
this License, without any additional terms or conditions. |
||||
Notwithstanding the above, nothing herein shall supersede or modify |
||||
the terms of any separate license agreement you may have executed |
||||
with Licensor regarding such Contributions. |
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade |
||||
names, trademarks, service marks, or product names of the Licensor, |
||||
except as required for reasonable and customary use in describing the |
||||
origin of the Work and reproducing the content of the NOTICE file. |
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or |
||||
agreed to in writing, Licensor provides the Work (and each |
||||
Contributor provides its Contributions) on an "AS IS" BASIS, |
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
||||
implied, including, without limitation, any warranties or conditions |
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A |
||||
PARTICULAR PURPOSE. You are solely responsible for determining the |
||||
appropriateness of using or redistributing the Work and assume any |
||||
risks associated with Your exercise of permissions under this License. |
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, |
||||
whether in tort (including negligence), contract, or otherwise, |
||||
unless required by applicable law (such as deliberate and grossly |
||||
negligent acts) or agreed to in writing, shall any Contributor be |
||||
liable to You for damages, including any direct, indirect, special, |
||||
incidental, or consequential damages of any character arising as a |
||||
result of this License or out of the use or inability to use the |
||||
Work (including but not limited to damages for loss of goodwill, |
||||
work stoppage, computer failure or malfunction, or any and all |
||||
other commercial damages or losses), even if such Contributor |
||||
has been advised of the possibility of such damages. |
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing |
||||
the Work or Derivative Works thereof, You may choose to offer, |
||||
and charge a fee for, acceptance of support, warranty, indemnity, |
||||
or other liability obligations and/or rights consistent with this |
||||
License. However, in accepting such obligations, You may act only |
||||
on Your own behalf and on Your sole responsibility, not on behalf |
||||
of any other Contributor, and only if You agree to indemnify, |
||||
defend, and hold each Contributor harmless for any liability |
||||
incurred by, or claims asserted against, such Contributor by reason |
||||
of your accepting any such warranty or additional liability. |
||||
|
||||
END OF TERMS AND CONDITIONS |
@ -0,0 +1,23 @@ |
||||
Permission is hereby granted, free of charge, to any |
||||
person obtaining a copy of this software and associated |
||||
documentation files (the "Software"), to deal in the |
||||
Software without restriction, including without |
||||
limitation the rights to use, copy, modify, merge, |
||||
publish, distribute, sublicense, and/or sell copies of |
||||
the Software, and to permit persons to whom the Software |
||||
is furnished to do so, subject to the following |
||||
conditions: |
||||
|
||||
The above copyright notice and this permission notice |
||||
shall be included in all copies or substantial portions |
||||
of the Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF |
||||
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED |
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A |
||||
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT |
||||
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION |
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR |
||||
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
||||
DEALINGS IN THE SOFTWARE. |
@ -1,162 +0,0 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use ndarray::prelude::*; |
||||
|
||||
use crate::network; |
||||
use crate::params::{Params, ParamsTrait, StateType}; |
||||
|
||||
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
|
||||
///composed by the following elements:
|
||||
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
|
||||
///- **nodes**: a vector containing all the nodes and their parameters.
|
||||
///The index of a node inside the vector is also used as index for the adj_matrix.
|
||||
///
|
||||
///# Examples
|
||||
///
|
||||
///```
|
||||
///
|
||||
/// use std::collections::BTreeSet;
|
||||
/// use reCTBN::network::Network;
|
||||
/// use reCTBN::params;
|
||||
/// use reCTBN::ctbn::*;
|
||||
///
|
||||
/// //Create the domain for a discrete node
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
///
|
||||
/// //Create the parameters for a discrete node using the domain
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
||||
///
|
||||
/// //Create the node using the parameters
|
||||
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
||||
///
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
||||
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
||||
///
|
||||
/// //Initialize a ctbn
|
||||
/// let mut net = CtbnNetwork::new();
|
||||
///
|
||||
/// //Add nodes
|
||||
/// let X1 = net.add_node(X1).unwrap();
|
||||
/// let X2 = net.add_node(X2).unwrap();
|
||||
///
|
||||
/// //Add an edge
|
||||
/// net.add_edge(X1, X2);
|
||||
///
|
||||
/// //Get all the children of node X1
|
||||
/// let cs = net.get_children_set(X1);
|
||||
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
||||
/// ```
|
||||
pub struct CtbnNetwork { |
||||
adj_matrix: Option<Array2<u16>>, |
||||
nodes: Vec<Params>, |
||||
} |
||||
|
||||
impl CtbnNetwork { |
||||
pub fn new() -> CtbnNetwork { |
||||
CtbnNetwork { |
||||
adj_matrix: None, |
||||
nodes: Vec::new(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl network::Network for CtbnNetwork { |
||||
fn initialize_adj_matrix(&mut self) { |
||||
self.adj_matrix = Some(Array2::<u16>::zeros( |
||||
(self.nodes.len(), self.nodes.len()).f(), |
||||
)); |
||||
} |
||||
|
||||
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { |
||||
n.reset_params(); |
||||
self.adj_matrix = Option::None; |
||||
self.nodes.push(n); |
||||
Ok(self.nodes.len() - 1) |
||||
} |
||||
|
||||
fn add_edge(&mut self, parent: usize, child: usize) { |
||||
if let None = self.adj_matrix { |
||||
self.initialize_adj_matrix(); |
||||
} |
||||
|
||||
if let Some(network) = &mut self.adj_matrix { |
||||
network[[parent, child]] = 1; |
||||
self.nodes[child].reset_params(); |
||||
} |
||||
} |
||||
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||
0..self.nodes.len() |
||||
} |
||||
|
||||
fn get_number_of_nodes(&self) -> usize { |
||||
self.nodes.len() |
||||
} |
||||
|
||||
fn get_node(&self, node_idx: usize) -> &Params { |
||||
&self.nodes[node_idx] |
||||
} |
||||
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { |
||||
&mut self.nodes[node_idx] |
||||
} |
||||
|
||||
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.column(node) |
||||
.iter() |
||||
.enumerate() |
||||
.fold((0, 1), |mut acc, x| { |
||||
if x.1 > &0 { |
||||
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
||||
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
||||
} |
||||
acc |
||||
}) |
||||
.0 |
||||
} |
||||
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
current_state: &Vec<StateType>, |
||||
parent_set: &BTreeSet<usize>, |
||||
) -> usize { |
||||
parent_set |
||||
.iter() |
||||
.fold((0, 1), |mut acc, x| { |
||||
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; |
||||
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); |
||||
acc |
||||
}) |
||||
.0 |
||||
} |
||||
|
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.column(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||
.collect() |
||||
} |
||||
|
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.row(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||
.collect() |
||||
} |
||||
} |
@ -1,11 +1,12 @@ |
||||
#![doc = include_str!("../../README.md")] |
||||
#![allow(non_snake_case)] |
||||
#[cfg(test)] |
||||
extern crate approx; |
||||
|
||||
pub mod ctbn; |
||||
pub mod network; |
||||
pub mod parameter_learning; |
||||
pub mod params; |
||||
pub mod process; |
||||
pub mod reward; |
||||
pub mod sampling; |
||||
pub mod structure_learning; |
||||
pub mod tools; |
||||
|
@ -1,43 +0,0 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use thiserror::Error; |
||||
|
||||
use crate::params; |
||||
|
||||
/// Error types for trait Network
|
||||
#[derive(Error, Debug)] |
||||
pub enum NetworkError { |
||||
#[error("Error during node insertion")] |
||||
NodeInsertionError(String), |
||||
} |
||||
|
||||
///Network
|
||||
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
|
||||
pub trait Network { |
||||
fn initialize_adj_matrix(&mut self); |
||||
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; |
||||
fn add_edge(&mut self, parent: usize, child: usize); |
||||
|
||||
///Get all the indices of the nodes contained inside the network
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize>; |
||||
fn get_number_of_nodes(&self) -> usize; |
||||
fn get_node(&self, node_idx: usize) -> ¶ms::Params; |
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; |
||||
|
||||
///Compute the index that must be used to access the parameters of a node given a specific
|
||||
///configuration of the network. Usually, the only values really used in *current_state* are
|
||||
///the ones in the parent set of the *node*.
|
||||
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) |
||||
-> usize; |
||||
|
||||
///Compute the index that must be used to access the parameters of a node given a specific
|
||||
///configuration of the network and a generic parent_set. Usually, the only values really used
|
||||
///in *current_state* are the ones in the parent set of the *node*.
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
current_state: &Vec<params::StateType>, |
||||
parent_set: &BTreeSet<usize>, |
||||
) -> usize; |
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
||||
} |
@ -0,0 +1,120 @@ |
||||
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
|
||||
|
||||
pub mod ctbn; |
||||
pub mod ctmp; |
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use thiserror::Error; |
||||
|
||||
use crate::params; |
||||
|
||||
/// Error types for trait Network
|
||||
#[derive(Error, Debug)] |
||||
pub enum NetworkError { |
||||
#[error("Error during node insertion")] |
||||
NodeInsertionError(String), |
||||
} |
||||
|
||||
/// This type is used to represent a specific realization of a generic NetworkProcess
|
||||
pub type NetworkProcessState = Vec<params::StateType>; |
||||
|
||||
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
|
||||
/// as a CTBN).
|
||||
pub trait NetworkProcess: Sync { |
||||
fn initialize_adj_matrix(&mut self); |
||||
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; |
||||
/// Add an **directed edge** between a two nodes of the network.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `parent` - parent node.
|
||||
/// * `child` - child node.
|
||||
fn add_edge(&mut self, parent: usize, child: usize); |
||||
|
||||
/// Get all the indices of the nodes contained inside the network.
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize>; |
||||
|
||||
/// Get the numbers of nodes contained in the network.
|
||||
fn get_number_of_nodes(&self) -> usize; |
||||
|
||||
/// Get the **node param**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node_idx` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The selected **node param**.
|
||||
fn get_node(&self, node_idx: usize) -> ¶ms::Params; |
||||
|
||||
/// Get the **node param**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node_idx` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The selected **node mutable param**.
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; |
||||
|
||||
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
||||
/// configuration of the network.
|
||||
///
|
||||
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
||||
/// the `node`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node` - selected node.
|
||||
/// * `current_state` - current configuration of the network.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * Index of the `node` relative to the network.
|
||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; |
||||
|
||||
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
||||
/// configuration of the network and a generic `parent_set`.
|
||||
///
|
||||
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
||||
/// the `node`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `current_state` - current configuration of the network.
|
||||
/// * `parent_set` - parent set of the selected `node`.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * Index of the `node` relative to the network.
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
current_state: &Vec<params::StateType>, |
||||
parent_set: &BTreeSet<usize>, |
||||
) -> usize; |
||||
|
||||
/// Get the **parent set** of a given **node**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The **parent set** of the selected node.
|
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
||||
|
||||
/// Get the **children set** of a given **node**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node` - node index value.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The **children set** of the selected node.
|
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
||||
} |
@ -0,0 +1,247 @@ |
||||
//! Continuous Time Bayesian Network
|
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use ndarray::prelude::*; |
||||
|
||||
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; |
||||
use crate::process; |
||||
|
||||
use super::ctmp::CtmpProcess; |
||||
use super::{NetworkProcess, NetworkProcessState}; |
||||
|
||||
/// It represents both the structure and the parameters of a CTBN.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
|
||||
/// * `nodes` - A vector containing all the nodes and their parameters.
|
||||
///
|
||||
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::collections::BTreeSet;
|
||||
/// use reCTBN::process::NetworkProcess;
|
||||
/// use reCTBN::params;
|
||||
/// use reCTBN::process::ctbn::*;
|
||||
///
|
||||
/// //Create the domain for a discrete node
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
///
|
||||
/// //Create the parameters for a discrete node using the domain
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
||||
///
|
||||
/// //Create the node using the parameters
|
||||
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
||||
///
|
||||
/// let mut domain = BTreeSet::new();
|
||||
/// domain.insert(String::from("A"));
|
||||
/// domain.insert(String::from("B"));
|
||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
||||
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
||||
///
|
||||
/// //Initialize a ctbn
|
||||
/// let mut net = CtbnNetwork::new();
|
||||
///
|
||||
/// //Add nodes
|
||||
/// let X1 = net.add_node(X1).unwrap();
|
||||
/// let X2 = net.add_node(X2).unwrap();
|
||||
///
|
||||
/// //Add an edge
|
||||
/// net.add_edge(X1, X2);
|
||||
///
|
||||
/// //Get all the children of node X1
|
||||
/// let cs = net.get_children_set(X1);
|
||||
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
||||
/// ```
|
||||
pub struct CtbnNetwork { |
||||
adj_matrix: Option<Array2<u16>>, |
||||
nodes: Vec<Params>, |
||||
} |
||||
|
||||
impl CtbnNetwork { |
||||
pub fn new() -> CtbnNetwork { |
||||
CtbnNetwork { |
||||
adj_matrix: None, |
||||
nodes: Vec::new(), |
||||
} |
||||
} |
||||
|
||||
///Transform the **CTBN** into a **CTMP**
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
|
||||
pub fn amalgamation(&self) -> CtmpProcess { |
||||
let variables_domain = |
||||
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); |
||||
|
||||
let state_space = variables_domain.product(); |
||||
let variables_set = BTreeSet::from_iter(self.get_node_indices()); |
||||
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space)); |
||||
|
||||
for idx_current_state in 0..state_space { |
||||
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); |
||||
let current_state_statetype: NetworkProcessState = current_state |
||||
.iter() |
||||
.map(|x| StateType::Discrete(*x)) |
||||
.collect(); |
||||
for idx_node in 0..self.nodes.len() { |
||||
let p = match self.get_node(idx_node) { |
||||
Params::DiscreteStatesContinousTime(p) => p, |
||||
}; |
||||
for next_node_state in 0..variables_domain[idx_node] { |
||||
let mut next_state = current_state.clone(); |
||||
next_state[idx_node] = next_node_state; |
||||
|
||||
let next_state_statetype: NetworkProcessState = |
||||
next_state.iter().map(|x| StateType::Discrete(*x)).collect(); |
||||
let idx_next_state = self.get_param_index_from_custom_parent_set( |
||||
&next_state_statetype, |
||||
&variables_set, |
||||
); |
||||
amalgamated_cim[[0, idx_current_state, idx_next_state]] += |
||||
p.get_cim().as_ref().unwrap()[[ |
||||
self.get_param_index_network(idx_node, ¤t_state_statetype), |
||||
current_state[idx_node], |
||||
next_node_state, |
||||
]]; |
||||
} |
||||
} |
||||
} |
||||
|
||||
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( |
||||
"ctmp".to_string(), |
||||
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), |
||||
); |
||||
|
||||
amalgamated_param.set_cim(amalgamated_cim).unwrap(); |
||||
|
||||
let mut ctmp = CtmpProcess::new(); |
||||
|
||||
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) |
||||
.unwrap(); |
||||
return ctmp; |
||||
} |
||||
|
||||
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> { |
||||
let mut state = state; |
||||
let mut array_state = Array1::zeros(variables_domain.shape()[0]); |
||||
for (idx, var) in variables_domain.indexed_iter() { |
||||
array_state[idx] = state % var; |
||||
state = state / var; |
||||
} |
||||
|
||||
return array_state; |
||||
} |
||||
/// Get the Adjacency Matrix.
|
||||
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> { |
||||
self.adj_matrix.as_ref() |
||||
} |
||||
} |
||||
|
||||
impl process::NetworkProcess for CtbnNetwork { |
||||
/// Initialize an Adjacency matrix.
|
||||
fn initialize_adj_matrix(&mut self) { |
||||
self.adj_matrix = Some(Array2::<u16>::zeros( |
||||
(self.nodes.len(), self.nodes.len()).f(), |
||||
)); |
||||
} |
||||
|
||||
/// Add a new node.
|
||||
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> { |
||||
n.reset_params(); |
||||
self.adj_matrix = Option::None; |
||||
self.nodes.push(n); |
||||
Ok(self.nodes.len() - 1) |
||||
} |
||||
|
||||
/// Connect two nodes with a new edge.
|
||||
fn add_edge(&mut self, parent: usize, child: usize) { |
||||
if let None = self.adj_matrix { |
||||
self.initialize_adj_matrix(); |
||||
} |
||||
|
||||
if let Some(network) = &mut self.adj_matrix { |
||||
network[[parent, child]] = 1; |
||||
self.nodes[child].reset_params(); |
||||
} |
||||
} |
||||
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||
0..self.nodes.len() |
||||
} |
||||
|
||||
/// Get the number of nodes of the network.
|
||||
fn get_number_of_nodes(&self) -> usize { |
||||
self.nodes.len() |
||||
} |
||||
|
||||
fn get_node(&self, node_idx: usize) -> &Params { |
||||
&self.nodes[node_idx] |
||||
} |
||||
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { |
||||
&mut self.nodes[node_idx] |
||||
} |
||||
|
||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.column(node) |
||||
.iter() |
||||
.enumerate() |
||||
.fold((0, 1), |mut acc, x| { |
||||
if x.1 > &0 { |
||||
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
||||
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
||||
} |
||||
acc |
||||
}) |
||||
.0 |
||||
} |
||||
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
current_state: &NetworkProcessState, |
||||
parent_set: &BTreeSet<usize>, |
||||
) -> usize { |
||||
parent_set |
||||
.iter() |
||||
.fold((0, 1), |mut acc, x| { |
||||
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; |
||||
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); |
||||
acc |
||||
}) |
||||
.0 |
||||
} |
||||
|
||||
/// Get all the parents of the given node.
|
||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.column(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||
.collect() |
||||
} |
||||
|
||||
/// Get all the children of the given node.
|
||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.adj_matrix |
||||
.as_ref() |
||||
.unwrap() |
||||
.row(node) |
||||
.iter() |
||||
.enumerate() |
||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||
.collect() |
||||
} |
||||
} |
@ -0,0 +1,114 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use crate::{ |
||||
params::{Params, StateType}, |
||||
process, |
||||
}; |
||||
|
||||
use super::{NetworkProcess, NetworkProcessState}; |
||||
|
||||
pub struct CtmpProcess { |
||||
param: Option<Params>, |
||||
} |
||||
|
||||
impl CtmpProcess { |
||||
pub fn new() -> CtmpProcess { |
||||
CtmpProcess { param: None } |
||||
} |
||||
} |
||||
|
||||
impl NetworkProcess for CtmpProcess { |
||||
fn initialize_adj_matrix(&mut self) { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> { |
||||
match self.param { |
||||
None => { |
||||
self.param = Some(n); |
||||
Ok(0) |
||||
} |
||||
Some(_) => Err(process::NetworkError::NodeInsertionError( |
||||
"CtmpProcess has only one node".to_string(), |
||||
)), |
||||
} |
||||
} |
||||
|
||||
fn add_edge(&mut self, _parent: usize, _child: usize) { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||
match self.param { |
||||
None => 0..0, |
||||
Some(_) => 0..1, |
||||
} |
||||
} |
||||
|
||||
fn get_number_of_nodes(&self) -> usize { |
||||
match self.param { |
||||
None => 0, |
||||
Some(_) => 1, |
||||
} |
||||
} |
||||
|
||||
fn get_node(&self, node_idx: usize) -> &crate::params::Params { |
||||
if node_idx == 0 { |
||||
self.param.as_ref().unwrap() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { |
||||
if node_idx == 0 { |
||||
self.param.as_mut().unwrap() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
||||
if node == 0 { |
||||
match current_state[0] { |
||||
StateType::Discrete(x) => x, |
||||
} |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
|
||||
fn get_param_index_from_custom_parent_set( |
||||
&self, |
||||
_current_state: &NetworkProcessState, |
||||
_parent_set: &BTreeSet<usize>, |
||||
) -> usize { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
|
||||
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||
match self.param { |
||||
Some(_) => { |
||||
if node == 0 { |
||||
BTreeSet::new() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
None => panic!("Uninitialized CtmpProcess"), |
||||
} |
||||
} |
||||
|
||||
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
||||
match self.param { |
||||
Some(_) => { |
||||
if node == 0 { |
||||
BTreeSet::new() |
||||
} else { |
||||
unimplemented!("CtmpProcess has only one node") |
||||
} |
||||
} |
||||
None => panic!("Uninitialized CtmpProcess"), |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,59 @@ |
||||
pub mod reward_evaluation; |
||||
pub mod reward_function; |
||||
|
||||
use std::collections::HashMap; |
||||
|
||||
use crate::process; |
||||
|
||||
/// Instantiation of reward function and instantaneous reward
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: reward obtained transitioning from one state to another
|
||||
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
||||
|
||||
#[derive(Debug, PartialEq)] |
||||
pub struct Reward { |
||||
pub transition_reward: f64, |
||||
pub instantaneous_reward: f64, |
||||
} |
||||
|
||||
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
||||
|
||||
pub trait RewardFunction: Sync { |
||||
/// Given the current state and the previous state, it compute the reward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
||||
/// * `previous_state`: an optional argument representing the previous state of the network
|
||||
|
||||
fn call( |
||||
&self, |
||||
current_state: &process::NetworkProcessState, |
||||
previous_state: Option<&process::NetworkProcessState>, |
||||
) -> Reward; |
||||
|
||||
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
||||
} |
||||
|
||||
pub trait RewardEvaluation { |
||||
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
) -> HashMap<process::NetworkProcessState, f64>; |
||||
|
||||
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
state: &process::NetworkProcessState, |
||||
) -> f64; |
||||
} |
@ -0,0 +1,205 @@ |
||||
use std::collections::HashMap; |
||||
|
||||
use rayon::prelude::{IntoParallelIterator, ParallelIterator}; |
||||
use statrs::distribution::ContinuousCDF; |
||||
|
||||
use crate::params::{self, ParamsTrait}; |
||||
use crate::process; |
||||
|
||||
use crate::{ |
||||
process::NetworkProcessState, |
||||
reward::RewardEvaluation, |
||||
sampling::{ForwardSampler, Sampler}, |
||||
}; |
||||
|
||||
pub enum RewardCriteria { |
||||
FiniteHorizon, |
||||
InfiniteHorizon { discount_factor: f64 }, |
||||
} |
||||
|
||||
pub struct MonteCarloReward { |
||||
max_iterations: usize, |
||||
max_err_stop: f64, |
||||
alpha_stop: f64, |
||||
end_time: f64, |
||||
reward_criteria: RewardCriteria, |
||||
seed: Option<u64>, |
||||
} |
||||
|
||||
impl MonteCarloReward { |
||||
pub fn new( |
||||
max_iterations: usize, |
||||
max_err_stop: f64, |
||||
alpha_stop: f64, |
||||
end_time: f64, |
||||
reward_criteria: RewardCriteria, |
||||
seed: Option<u64>, |
||||
) -> MonteCarloReward { |
||||
MonteCarloReward { |
||||
max_iterations, |
||||
max_err_stop, |
||||
alpha_stop, |
||||
end_time, |
||||
reward_criteria, |
||||
seed, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl RewardEvaluation for MonteCarloReward { |
||||
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
) -> HashMap<process::NetworkProcessState, f64> { |
||||
let variables_domain: Vec<Vec<params::StateType>> = network_process |
||||
.get_node_indices() |
||||
.map(|x| match network_process.get_node(x) { |
||||
params::Params::DiscreteStatesContinousTime(x) => (0..x |
||||
.get_reserved_space_as_parent()) |
||||
.map(|s| params::StateType::Discrete(s)) |
||||
.collect(), |
||||
}) |
||||
.collect(); |
||||
|
||||
let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); |
||||
|
||||
(0..n_states) |
||||
.into_par_iter() |
||||
.map(|s| { |
||||
let state: process::NetworkProcessState = variables_domain |
||||
.iter() |
||||
.fold((s, vec![]), |acc, x| { |
||||
let mut acc = acc; |
||||
let idx_s = acc.0 % x.len(); |
||||
acc.1.push(x[idx_s].clone()); |
||||
acc.0 = acc.0 / x.len(); |
||||
acc |
||||
}) |
||||
.1; |
||||
|
||||
let r = self.evaluate_state(network_process, reward_function, &state); |
||||
(state, r) |
||||
}) |
||||
.collect() |
||||
} |
||||
|
||||
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
state: &NetworkProcessState, |
||||
) -> f64 { |
||||
let mut sampler = |
||||
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); |
||||
let mut expected_value = 0.0; |
||||
let mut squared_expected_value = 0.0; |
||||
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); |
||||
|
||||
for i in 0..self.max_iterations { |
||||
sampler.reset(); |
||||
let mut ret = 0.0; |
||||
let mut previous = sampler.next().unwrap(); |
||||
while previous.t < self.end_time { |
||||
let current = sampler.next().unwrap(); |
||||
if current.t > self.end_time { |
||||
let r = reward_function.call(&previous.state, None); |
||||
let discount = match self.reward_criteria { |
||||
RewardCriteria::FiniteHorizon => self.end_time - previous.t, |
||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||
- std::f64::consts::E.powf(-discount_factor * self.end_time) |
||||
} |
||||
}; |
||||
ret += discount * r.instantaneous_reward; |
||||
} else { |
||||
let r = reward_function.call(&previous.state, Some(¤t.state)); |
||||
let discount = match self.reward_criteria { |
||||
RewardCriteria::FiniteHorizon => current.t - previous.t, |
||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||
std::f64::consts::E.powf(-discount_factor * previous.t) |
||||
- std::f64::consts::E.powf(-discount_factor * current.t) |
||||
} |
||||
}; |
||||
ret += discount * r.instantaneous_reward; |
||||
ret += match self.reward_criteria { |
||||
RewardCriteria::FiniteHorizon => 1.0, |
||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
||||
std::f64::consts::E.powf(-discount_factor * current.t) |
||||
} |
||||
} * r.transition_reward; |
||||
} |
||||
previous = current; |
||||
} |
||||
|
||||
let float_i = i as f64; |
||||
expected_value = |
||||
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); |
||||
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) |
||||
+ ret.powi(2) / (float_i + 1.0); |
||||
|
||||
if i > 2 { |
||||
let var = |
||||
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); |
||||
if self.alpha_stop |
||||
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) |
||||
> 0.0 |
||||
{ |
||||
return expected_value; |
||||
} |
||||
} |
||||
} |
||||
|
||||
expected_value |
||||
} |
||||
} |
||||
|
||||
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> { |
||||
inner_reward: RE, |
||||
} |
||||
|
||||
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> { |
||||
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> { |
||||
NeighborhoodRelativeReward { inner_reward } |
||||
} |
||||
} |
||||
|
||||
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> { |
||||
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
network_process: &N, |
||||
reward_function: &R, |
||||
) -> HashMap<process::NetworkProcessState, f64> { |
||||
let absolute_reward = self |
||||
.inner_reward |
||||
.evaluate_state_space(network_process, reward_function); |
||||
|
||||
//This approach optimize memory. Maybe optimizing execution time can be better.
|
||||
absolute_reward |
||||
.iter() |
||||
.map(|(k1, v1)| { |
||||
let mut max_val: f64 = 1.0; |
||||
absolute_reward.iter().for_each(|(k2, v2)| { |
||||
let count_diff: usize = k1 |
||||
.iter() |
||||
.zip(k2.iter()) |
||||
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) |
||||
.sum(); |
||||
if count_diff < 2 { |
||||
max_val = max_val.max(v1 / v2); |
||||
} |
||||
}); |
||||
(k1.clone(), max_val) |
||||
}) |
||||
.collect() |
||||
} |
||||
|
||||
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>( |
||||
&self, |
||||
_network_process: &N, |
||||
_reward_function: &R, |
||||
_state: &process::NetworkProcessState, |
||||
) -> f64 { |
||||
unimplemented!(); |
||||
} |
||||
} |
@ -0,0 +1,106 @@ |
||||
//! Module for dealing with reward functions
|
||||
|
||||
use crate::{ |
||||
params::{self, ParamsTrait}, |
||||
process, |
||||
reward::{Reward, RewardFunction}, |
||||
}; |
||||
|
||||
use ndarray; |
||||
|
||||
/// Reward function over a factored state space
|
||||
///
|
||||
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
|
||||
/// of the underling `NetworkProcess`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
|
||||
/// reward of a node
|
||||
|
||||
pub struct FactoredRewardFunction { |
||||
transition_reward: Vec<ndarray::Array2<f64>>, |
||||
instantaneous_reward: Vec<ndarray::Array1<f64>>, |
||||
} |
||||
|
||||
impl FactoredRewardFunction { |
||||
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
||||
&self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
||||
&mut self.transition_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
||||
&self.instantaneous_reward[node_idx] |
||||
} |
||||
|
||||
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
||||
&mut self.instantaneous_reward[node_idx] |
||||
} |
||||
} |
||||
|
||||
impl RewardFunction for FactoredRewardFunction { |
||||
fn call( |
||||
&self, |
||||
current_state: &process::NetworkProcessState, |
||||
previous_state: Option<&process::NetworkProcessState>, |
||||
) -> Reward { |
||||
let instantaneous_reward: f64 = current_state |
||||
.iter() |
||||
.enumerate() |
||||
.map(|(idx, x)| { |
||||
let x = match x { |
||||
params::StateType::Discrete(x) => x, |
||||
}; |
||||
self.instantaneous_reward[idx][*x] |
||||
}) |
||||
.sum(); |
||||
if let Some(previous_state) = previous_state { |
||||
let transition_reward = previous_state |
||||
.iter() |
||||
.zip(current_state.iter()) |
||||
.enumerate() |
||||
.find_map(|(idx, (p, c))| -> Option<f64> { |
||||
let p = match p { |
||||
params::StateType::Discrete(p) => p, |
||||
}; |
||||
let c = match c { |
||||
params::StateType::Discrete(c) => c, |
||||
}; |
||||
if p != c { |
||||
Some(self.transition_reward[idx][[*p, *c]]) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
.unwrap_or(0.0); |
||||
Reward { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} else { |
||||
Reward { |
||||
transition_reward: 0.0, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
||||
|
||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
||||
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
||||
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; |
||||
for i in p.get_node_indices() { |
||||
//This works only for discrete nodes!
|
||||
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
||||
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
||||
transition_reward.push(ndarray::Array2::zeros((size, size))); |
||||
} |
||||
|
||||
FactoredRewardFunction { |
||||
transition_reward, |
||||
instantaneous_reward, |
||||
} |
||||
} |
||||
} |
@ -1,11 +1,13 @@ |
||||
//! Learn the structure of the network.
|
||||
|
||||
pub mod constraint_based_algorithm; |
||||
pub mod hypothesis_test; |
||||
pub mod score_based_algorithm; |
||||
pub mod score_function; |
||||
use crate::{network, tools}; |
||||
use crate::{process, tools::Dataset}; |
||||
|
||||
pub trait StructureLearningAlgorithm { |
||||
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T |
||||
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
||||
where |
||||
T: network::Network; |
||||
T: process::NetworkProcess; |
||||
} |
||||
|
@ -1,3 +1,348 @@ |
||||
//pub struct CTPC {
|
||||
//
|
||||
//}
|
||||
//! Module containing constraint based algorithms like CTPC and Hiton.
|
||||
|
||||
use crate::params::Params; |
||||
use itertools::Itertools; |
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator}; |
||||
use rayon::prelude::ParallelExtend; |
||||
use std::collections::{BTreeSet, HashMap}; |
||||
use std::mem; |
||||
use std::usize; |
||||
|
||||
use super::hypothesis_test::*; |
||||
use crate::parameter_learning::ParameterLearning; |
||||
use crate::process; |
||||
use crate::structure_learning::StructureLearningAlgorithm; |
||||
use crate::tools::Dataset; |
||||
|
||||
pub struct Cache<'a, P: ParameterLearning> { |
||||
parameter_learning: &'a P, |
||||
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>, |
||||
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>, |
||||
parent_set_size_small: usize, |
||||
} |
||||
|
||||
impl<'a, P: ParameterLearning> Cache<'a, P> { |
||||
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
||||
Cache { |
||||
parameter_learning, |
||||
cache_persistent_small: HashMap::new(), |
||||
cache_persistent_big: HashMap::new(), |
||||
parent_set_size_small: 0, |
||||
} |
||||
} |
||||
pub fn fit<T: process::NetworkProcess>( |
||||
&mut self, |
||||
net: &T, |
||||
dataset: &Dataset, |
||||
node: usize, |
||||
parent_set: Option<BTreeSet<usize>>, |
||||
) -> Params { |
||||
let parent_set_len = parent_set.as_ref().unwrap().len(); |
||||
if parent_set_len > self.parent_set_size_small + 1 { |
||||
//self.cache_persistent_small = self.cache_persistent_big;
|
||||
mem::swap( |
||||
&mut self.cache_persistent_small, |
||||
&mut self.cache_persistent_big, |
||||
); |
||||
self.cache_persistent_big = HashMap::new(); |
||||
self.parent_set_size_small += 1; |
||||
} |
||||
|
||||
if parent_set_len > self.parent_set_size_small { |
||||
match self.cache_persistent_big.get(&parent_set) { |
||||
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
||||
// not cloning requires a minor and reasoned refactoring across the library
|
||||
Some(params) => params.clone(), |
||||
None => { |
||||
let params = |
||||
self.parameter_learning |
||||
.fit(net, dataset, node, parent_set.clone()); |
||||
self.cache_persistent_big.insert(parent_set, params.clone()); |
||||
params |
||||
} |
||||
} |
||||
} else { |
||||
match self.cache_persistent_small.get(&parent_set) { |
||||
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
||||
// not cloning requires a minor and reasoned refactoring across the library
|
||||
Some(params) => params.clone(), |
||||
None => { |
||||
let params = |
||||
self.parameter_learning |
||||
.fit(net, dataset, node, parent_set.clone()); |
||||
self.cache_persistent_small |
||||
.insert(parent_set, params.clone()); |
||||
params |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Continuous-Time Peter Clark algorithm.
|
||||
///
|
||||
/// A method to learn the structure of the network.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
|
||||
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
|
||||
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use std::collections::BTreeSet;
|
||||
/// # use ndarray::{arr1, arr2, arr3};
|
||||
/// # use reCTBN::params;
|
||||
/// # use reCTBN::tools::trajectory_generator;
|
||||
/// # use reCTBN::process::NetworkProcess;
|
||||
/// # use reCTBN::process::ctbn::CtbnNetwork;
|
||||
/// use reCTBN::parameter_learning::BayesianApproach;
|
||||
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
|
||||
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
|
||||
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
|
||||
/// #
|
||||
/// # // Create the domain for a discrete node
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # domain.insert(String::from("A"));
|
||||
/// # domain.insert(String::from("B"));
|
||||
/// # domain.insert(String::from("C"));
|
||||
/// # // Create the parameters for a discrete node using the domain
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
|
||||
/// # //Create the node n1 using the parameters
|
||||
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
|
||||
/// #
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # domain.insert(String::from("D"));
|
||||
/// # domain.insert(String::from("E"));
|
||||
/// # domain.insert(String::from("F"));
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
|
||||
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
|
||||
/// #
|
||||
/// # let mut domain = BTreeSet::new();
|
||||
/// # domain.insert(String::from("G"));
|
||||
/// # domain.insert(String::from("H"));
|
||||
/// # domain.insert(String::from("I"));
|
||||
/// # domain.insert(String::from("F"));
|
||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
|
||||
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
|
||||
/// #
|
||||
/// # // Initialize a ctbn
|
||||
/// # let mut net = CtbnNetwork::new();
|
||||
/// #
|
||||
/// # // Add the nodes and their edges
|
||||
/// # let n1 = net.add_node(n1).unwrap();
|
||||
/// # let n2 = net.add_node(n2).unwrap();
|
||||
/// # let n3 = net.add_node(n3).unwrap();
|
||||
/// # net.add_edge(n1, n2);
|
||||
/// # net.add_edge(n1, n3);
|
||||
/// # net.add_edge(n2, n3);
|
||||
/// #
|
||||
/// # match &mut net.get_node_mut(n1) {
|
||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||
/// # assert_eq!(
|
||||
/// # Ok(()),
|
||||
/// # param.set_cim(arr3(&[
|
||||
/// # [
|
||||
/// # [-3.0, 2.0, 1.0],
|
||||
/// # [1.5, -2.0, 0.5],
|
||||
/// # [0.4, 0.6, -1.0]
|
||||
/// # ],
|
||||
/// # ]))
|
||||
/// # );
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # match &mut net.get_node_mut(n2) {
|
||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||
/// # assert_eq!(
|
||||
/// # Ok(()),
|
||||
/// # param.set_cim(arr3(&[
|
||||
/// # [
|
||||
/// # [-1.0, 0.5, 0.5],
|
||||
/// # [3.0, -4.0, 1.0],
|
||||
/// # [0.9, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-6.0, 2.0, 4.0],
|
||||
/// # [1.5, -2.0, 0.5],
|
||||
/// # [3.0, 1.0, -4.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-1.0, 0.1, 0.9],
|
||||
/// # [2.0, -2.5, 0.5],
|
||||
/// # [0.9, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # ]))
|
||||
/// # );
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # match &mut net.get_node_mut(n3) {
|
||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
||||
/// # assert_eq!(
|
||||
/// # Ok(()),
|
||||
/// # param.set_cim(arr3(&[
|
||||
/// # [
|
||||
/// # [-1.0, 0.5, 0.3, 0.2],
|
||||
/// # [0.5, -4.0, 2.5, 1.0],
|
||||
/// # [2.5, 0.5, -4.0, 1.0],
|
||||
/// # [0.7, 0.2, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-6.0, 2.0, 3.0, 1.0],
|
||||
/// # [1.5, -3.0, 0.5, 1.0],
|
||||
/// # [2.0, 1.3, -5.0, 1.7],
|
||||
/// # [2.5, 0.5, 1.0, -4.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-1.3, 0.3, 0.1, 0.9],
|
||||
/// # [1.4, -4.0, 0.5, 2.1],
|
||||
/// # [1.0, 1.5, -3.0, 0.5],
|
||||
/// # [0.4, 0.3, 0.1, -0.8]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-2.0, 1.0, 0.7, 0.3],
|
||||
/// # [1.3, -5.9, 2.7, 1.9],
|
||||
/// # [2.0, 1.5, -4.0, 0.5],
|
||||
/// # [0.2, 0.7, 0.1, -1.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-6.0, 1.0, 2.0, 3.0],
|
||||
/// # [0.5, -3.0, 1.0, 1.5],
|
||||
/// # [1.4, 2.1, -4.3, 0.8],
|
||||
/// # [0.5, 1.0, 2.5, -4.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-1.3, 0.9, 0.3, 0.1],
|
||||
/// # [0.1, -1.3, 0.2, 1.0],
|
||||
/// # [0.5, 1.0, -3.0, 1.5],
|
||||
/// # [0.1, 0.4, 0.3, -0.8]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-2.0, 1.0, 0.6, 0.4],
|
||||
/// # [2.6, -7.1, 1.4, 3.1],
|
||||
/// # [5.0, 1.0, -8.0, 2.0],
|
||||
/// # [1.4, 0.4, 0.2, -2.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-3.0, 1.0, 1.5, 0.5],
|
||||
/// # [3.0, -6.0, 1.0, 2.0],
|
||||
/// # [0.3, 0.5, -1.9, 1.1],
|
||||
/// # [5.0, 1.0, 2.0, -8.0]
|
||||
/// # ],
|
||||
/// # [
|
||||
/// # [-2.6, 0.6, 0.2, 1.8],
|
||||
/// # [2.0, -6.0, 3.0, 1.0],
|
||||
/// # [0.1, 0.5, -1.3, 0.7],
|
||||
/// # [0.8, 0.6, 0.2, -1.6]
|
||||
/// # ],
|
||||
/// # ]))
|
||||
/// # );
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # // Generate the trajectory
|
||||
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
|
||||
///
|
||||
/// // Initialize the hypothesis tests to pass to the CTPC with their
|
||||
/// // respective significance level `alpha`
|
||||
/// let f = F::new(1e-6);
|
||||
/// let chi_sq = ChiSquare::new(1e-4);
|
||||
/// // Use the bayesian approach to learn the parameters
|
||||
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
|
||||
///
|
||||
/// //Initialize CTPC
|
||||
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
|
||||
///
|
||||
/// // Learn the structure of the network from the generated trajectory
|
||||
/// let net = ctpc.fit_transform(net, &data);
|
||||
/// #
|
||||
/// # // Compare the generated network with the original one
|
||||
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
|
||||
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
|
||||
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
|
||||
/// ```
|
||||
pub struct CTPC<P: ParameterLearning> { |
||||
parameter_learning: P, |
||||
Ftest: F, |
||||
Chi2test: ChiSquare, |
||||
} |
||||
|
||||
impl<P: ParameterLearning> CTPC<P> { |
||||
pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> { |
||||
CTPC { |
||||
parameter_learning, |
||||
Ftest, |
||||
Chi2test, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> { |
||||
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
||||
where |
||||
T: process::NetworkProcess, |
||||
{ |
||||
//Check the coherence between dataset and network
|
||||
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { |
||||
panic!("Dataset and Network must have the same number of variables.") |
||||
} |
||||
|
||||
//Make the network mutable.
|
||||
let mut net = net; |
||||
|
||||
net.initialize_adj_matrix(); |
||||
|
||||
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![]; |
||||
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { |
||||
let mut cache = Cache::new(&self.parameter_learning); |
||||
let mut candidate_parent_set: BTreeSet<usize> = net |
||||
.get_node_indices() |
||||
.into_iter() |
||||
.filter(|x| x != &child_node) |
||||
.collect(); |
||||
let mut separation_set_size = 0; |
||||
while separation_set_size < candidate_parent_set.len() { |
||||
let mut candidate_parent_set_TMP = candidate_parent_set.clone(); |
||||
for parent_node in candidate_parent_set.iter() { |
||||
for separation_set in candidate_parent_set |
||||
.iter() |
||||
.filter(|x| x != &parent_node) |
||||
.map(|x| *x) |
||||
.combinations(separation_set_size) |
||||
{ |
||||
let separation_set = separation_set.into_iter().collect(); |
||||
if self.Ftest.call( |
||||
&net, |
||||
child_node, |
||||
*parent_node, |
||||
&separation_set, |
||||
dataset, |
||||
&mut cache, |
||||
) && self.Chi2test.call( |
||||
&net, |
||||
child_node, |
||||
*parent_node, |
||||
&separation_set, |
||||
dataset, |
||||
&mut cache, |
||||
) { |
||||
candidate_parent_set_TMP.remove(parent_node); |
||||
break; |
||||
} |
||||
} |
||||
} |
||||
candidate_parent_set = candidate_parent_set_TMP; |
||||
separation_set_size += 1; |
||||
} |
||||
(child_node, candidate_parent_set) |
||||
})); |
||||
for (child_node, candidate_parent_set) in learned_parent_sets { |
||||
for parent_node in candidate_parent_set.iter() { |
||||
net.add_edge(*parent_node, child_node); |
||||
} |
||||
} |
||||
net |
||||
} |
||||
} |
||||
|
@ -0,0 +1,127 @@ |
||||
mod utils; |
||||
|
||||
use std::collections::BTreeSet; |
||||
|
||||
use reCTBN::{ |
||||
params, |
||||
params::ParamsTrait, |
||||
process::{ctmp::*, NetworkProcess}, |
||||
}; |
||||
use utils::generate_discrete_time_continous_node; |
||||
|
||||
#[test] |
||||
fn define_simple_ctmp() { |
||||
let _ = CtmpProcess::new(); |
||||
assert!(true); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_node_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_two_nodes_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||
|
||||
match n2 { |
||||
Ok(_) => assert!(false), |
||||
Err(_) => assert!(true), |
||||
}; |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn add_edge_to_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
||||
|
||||
net.add_edge(0, 1) |
||||
} |
||||
|
||||
#[test] |
||||
fn childen_and_parents() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
assert_eq!(0, net.get_parent_set(0).len()); |
||||
assert_eq!(0, net.get_children_set(0).len()); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_childen_panic() { |
||||
let net = CtmpProcess::new(); |
||||
net.get_children_set(0); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_childen_panic2() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
net.get_children_set(1); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_parent_panic() { |
||||
let net = CtmpProcess::new(); |
||||
net.get_parent_set(0); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn get_parent_panic2() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
net.get_parent_set(1); |
||||
} |
||||
|
||||
#[test] |
||||
fn compute_index_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node( |
||||
String::from("n1"), |
||||
10, |
||||
)) |
||||
.unwrap(); |
||||
|
||||
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); |
||||
assert_eq!(6, idx); |
||||
} |
||||
|
||||
#[test] |
||||
#[should_panic] |
||||
fn compute_index_from_custom_parent_set_ctmp() { |
||||
let mut net = CtmpProcess::new(); |
||||
let _n1 = net |
||||
.add_node(generate_discrete_time_continous_node( |
||||
String::from("n1"), |
||||
10, |
||||
)) |
||||
.unwrap(); |
||||
|
||||
let _idx = net.get_param_index_from_custom_parent_set( |
||||
&vec![params::StateType::Discrete(6)], |
||||
&BTreeSet::from([0]) |
||||
); |
||||
} |
@ -0,0 +1,122 @@ |
||||
mod utils; |
||||
|
||||
use approx::assert_abs_diff_eq; |
||||
use ndarray::*; |
||||
use reCTBN::{ |
||||
params, |
||||
process::{ctbn::*, NetworkProcess, NetworkProcessState}, |
||||
reward::{reward_evaluation::*, reward_function::*, *}, |
||||
}; |
||||
use utils::generate_discrete_time_continous_node; |
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_binary_node_mc() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1) |
||||
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1) |
||||
.assign(&arr1(&[3.0, 3.0])); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); |
||||
} |
||||
} |
||||
|
||||
net.initialize_adj_matrix(); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
|
||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||
|
||||
let rst = mc.evaluate_state_space(&net, &rf); |
||||
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); |
||||
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); |
||||
|
||||
|
||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); |
||||
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
||||
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
||||
|
||||
|
||||
} |
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_chain_mc() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
|
||||
let n3 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
||||
.unwrap(); |
||||
|
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n2, n3); |
||||
|
||||
match &mut net.get_node_mut(n1) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param |
||||
.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]], |
||||
])) |
||||
.unwrap(); |
||||
} |
||||
} |
||||
|
||||
|
||||
match &mut net.get_node_mut(n3) { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param |
||||
.set_cim(arr3(&[ |
||||
[[-0.01, 0.01], [5.0, -5.0]], |
||||
[[-5.0, 5.0], [0.01, -0.01]], |
||||
])) |
||||
.unwrap(); |
||||
} |
||||
} |
||||
|
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1) |
||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||
|
||||
rf.get_transition_reward_mut(n2) |
||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||
|
||||
rf.get_transition_reward_mut(n3) |
||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
||||
|
||||
let s000: NetworkProcessState = vec![ |
||||
params::StateType::Discrete(1), |
||||
params::StateType::Discrete(0), |
||||
params::StateType::Discrete(0), |
||||
]; |
||||
|
||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
||||
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); |
||||
|
||||
let rst = mc.evaluate_state_space(&net, &rf); |
||||
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); |
||||
|
||||
} |
@ -0,0 +1,117 @@ |
||||
mod utils; |
||||
|
||||
use ndarray::*; |
||||
use utils::generate_discrete_time_continous_node; |
||||
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params}; |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_binary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
|
||||
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn simple_factored_reward_function_ternary_node() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
||||
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; |
||||
|
||||
|
||||
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
||||
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
||||
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); |
||||
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); |
||||
} |
||||
|
||||
#[test] |
||||
fn factored_reward_function_two_nodes() { |
||||
let mut net = CtbnNetwork::new(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
|
||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
||||
|
||||
|
||||
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
||||
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); |
||||
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; |
||||
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; |
||||
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; |
||||
|
||||
|
||||
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; |
||||
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; |
||||
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; |
||||
|
||||
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); |
||||
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); |
||||
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); |
||||
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); |
||||
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
||||
|
||||
|
||||
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); |
||||
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); |
||||
} |
Loading…
Reference in new issue