Compare commits
3 Commits
dev
...
62-feature
Author | SHA1 | Date |
---|---|---|
AlessandroBregoli | e1af14b620 | 2 years ago |
AlessandroBregoli | aa630bf9e9 | 2 years ago |
AlessandroBregoli | 3dae67a80c | 2 years ago |
@ -1,3 +1,4 @@ |
|||||||
/target |
/target |
||||||
Cargo.lock |
Cargo.lock |
||||||
.vscode |
.vscode |
||||||
|
poetry.lock |
||||||
|
@ -1,176 +0,0 @@ |
|||||||
Apache License |
|
||||||
Version 2.0, January 2004 |
|
||||||
http://www.apache.org/licenses/ |
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION |
|
||||||
|
|
||||||
1. Definitions. |
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction, |
|
||||||
and distribution as defined by Sections 1 through 9 of this document. |
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by |
|
||||||
the copyright owner that is granting the License. |
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all |
|
||||||
other entities that control, are controlled by, or are under common |
|
||||||
control with that entity. For the purposes of this definition, |
|
||||||
"control" means (i) the power, direct or indirect, to cause the |
|
||||||
direction or management of such entity, whether by contract or |
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the |
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity. |
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity |
|
||||||
exercising permissions granted by this License. |
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications, |
|
||||||
including but not limited to software source code, documentation |
|
||||||
source, and configuration files. |
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical |
|
||||||
transformation or translation of a Source form, including but |
|
||||||
not limited to compiled object code, generated documentation, |
|
||||||
and conversions to other media types. |
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or |
|
||||||
Object form, made available under the License, as indicated by a |
|
||||||
copyright notice that is included in or attached to the work |
|
||||||
(an example is provided in the Appendix below). |
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object |
|
||||||
form, that is based on (or derived from) the Work and for which the |
|
||||||
editorial revisions, annotations, elaborations, or other modifications |
|
||||||
represent, as a whole, an original work of authorship. For the purposes |
|
||||||
of this License, Derivative Works shall not include works that remain |
|
||||||
separable from, or merely link (or bind by name) to the interfaces of, |
|
||||||
the Work and Derivative Works thereof. |
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including |
|
||||||
the original version of the Work and any modifications or additions |
|
||||||
to that Work or Derivative Works thereof, that is intentionally |
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner |
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of |
|
||||||
the copyright owner. For the purposes of this definition, "submitted" |
|
||||||
means any form of electronic, verbal, or written communication sent |
|
||||||
to the Licensor or its representatives, including but not limited to |
|
||||||
communication on electronic mailing lists, source code control systems, |
|
||||||
and issue tracking systems that are managed by, or on behalf of, the |
|
||||||
Licensor for the purpose of discussing and improving the Work, but |
|
||||||
excluding communication that is conspicuously marked or otherwise |
|
||||||
designated in writing by the copyright owner as "Not a Contribution." |
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity |
|
||||||
on behalf of whom a Contribution has been received by Licensor and |
|
||||||
subsequently incorporated within the Work. |
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of |
|
||||||
this License, each Contributor hereby grants to You a perpetual, |
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
|
||||||
copyright license to reproduce, prepare Derivative Works of, |
|
||||||
publicly display, publicly perform, sublicense, and distribute the |
|
||||||
Work and such Derivative Works in Source or Object form. |
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of |
|
||||||
this License, each Contributor hereby grants to You a perpetual, |
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
|
||||||
(except as stated in this section) patent license to make, have made, |
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work, |
|
||||||
where such license applies only to those patent claims licensable |
|
||||||
by such Contributor that are necessarily infringed by their |
|
||||||
Contribution(s) alone or by combination of their Contribution(s) |
|
||||||
with the Work to which such Contribution(s) was submitted. If You |
|
||||||
institute patent litigation against any entity (including a |
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work |
|
||||||
or a Contribution incorporated within the Work constitutes direct |
|
||||||
or contributory patent infringement, then any patent licenses |
|
||||||
granted to You under this License for that Work shall terminate |
|
||||||
as of the date such litigation is filed. |
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the |
|
||||||
Work or Derivative Works thereof in any medium, with or without |
|
||||||
modifications, and in Source or Object form, provided that You |
|
||||||
meet the following conditions: |
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or |
|
||||||
Derivative Works a copy of this License; and |
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices |
|
||||||
stating that You changed the files; and |
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works |
|
||||||
that You distribute, all copyright, patent, trademark, and |
|
||||||
attribution notices from the Source form of the Work, |
|
||||||
excluding those notices that do not pertain to any part of |
|
||||||
the Derivative Works; and |
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its |
|
||||||
distribution, then any Derivative Works that You distribute must |
|
||||||
include a readable copy of the attribution notices contained |
|
||||||
within such NOTICE file, excluding those notices that do not |
|
||||||
pertain to any part of the Derivative Works, in at least one |
|
||||||
of the following places: within a NOTICE text file distributed |
|
||||||
as part of the Derivative Works; within the Source form or |
|
||||||
documentation, if provided along with the Derivative Works; or, |
|
||||||
within a display generated by the Derivative Works, if and |
|
||||||
wherever such third-party notices normally appear. The contents |
|
||||||
of the NOTICE file are for informational purposes only and |
|
||||||
do not modify the License. You may add Your own attribution |
|
||||||
notices within Derivative Works that You distribute, alongside |
|
||||||
or as an addendum to the NOTICE text from the Work, provided |
|
||||||
that such additional attribution notices cannot be construed |
|
||||||
as modifying the License. |
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and |
|
||||||
may provide additional or different license terms and conditions |
|
||||||
for use, reproduction, or distribution of Your modifications, or |
|
||||||
for any such Derivative Works as a whole, provided Your use, |
|
||||||
reproduction, and distribution of the Work otherwise complies with |
|
||||||
the conditions stated in this License. |
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise, |
|
||||||
any Contribution intentionally submitted for inclusion in the Work |
|
||||||
by You to the Licensor shall be under the terms and conditions of |
|
||||||
this License, without any additional terms or conditions. |
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify |
|
||||||
the terms of any separate license agreement you may have executed |
|
||||||
with Licensor regarding such Contributions. |
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade |
|
||||||
names, trademarks, service marks, or product names of the Licensor, |
|
||||||
except as required for reasonable and customary use in describing the |
|
||||||
origin of the Work and reproducing the content of the NOTICE file. |
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or |
|
||||||
agreed to in writing, Licensor provides the Work (and each |
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS, |
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
||||||
implied, including, without limitation, any warranties or conditions |
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A |
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the |
|
||||||
appropriateness of using or redistributing the Work and assume any |
|
||||||
risks associated with Your exercise of permissions under this License. |
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory, |
|
||||||
whether in tort (including negligence), contract, or otherwise, |
|
||||||
unless required by applicable law (such as deliberate and grossly |
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be |
|
||||||
liable to You for damages, including any direct, indirect, special, |
|
||||||
incidental, or consequential damages of any character arising as a |
|
||||||
result of this License or out of the use or inability to use the |
|
||||||
Work (including but not limited to damages for loss of goodwill, |
|
||||||
work stoppage, computer failure or malfunction, or any and all |
|
||||||
other commercial damages or losses), even if such Contributor |
|
||||||
has been advised of the possibility of such damages. |
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing |
|
||||||
the Work or Derivative Works thereof, You may choose to offer, |
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity, |
|
||||||
or other liability obligations and/or rights consistent with this |
|
||||||
License. However, in accepting such obligations, You may act only |
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf |
|
||||||
of any other Contributor, and only if You agree to indemnify, |
|
||||||
defend, and hold each Contributor harmless for any liability |
|
||||||
incurred by, or claims asserted against, such Contributor by reason |
|
||||||
of your accepting any such warranty or additional liability. |
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS |
|
@ -1,23 +0,0 @@ |
|||||||
Permission is hereby granted, free of charge, to any |
|
||||||
person obtaining a copy of this software and associated |
|
||||||
documentation files (the "Software"), to deal in the |
|
||||||
Software without restriction, including without |
|
||||||
limitation the rights to use, copy, modify, merge, |
|
||||||
publish, distribute, sublicense, and/or sell copies of |
|
||||||
the Software, and to permit persons to whom the Software |
|
||||||
is furnished to do so, subject to the following |
|
||||||
conditions: |
|
||||||
|
|
||||||
The above copyright notice and this permission notice |
|
||||||
shall be included in all copies or substantial portions |
|
||||||
of the Software. |
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF |
|
||||||
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED |
|
||||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A |
|
||||||
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT |
|
||||||
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
|
||||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION |
|
||||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR |
|
||||||
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
||||||
DEALINGS IN THE SOFTWARE. |
|
@ -0,0 +1,16 @@ |
|||||||
|
[tool.poetry] |
||||||
|
name = "rectbnpy" |
||||||
|
version = "0.1.0" |
||||||
|
description = "" |
||||||
|
authors = ["AlessandroBregoli <alessandroxciv@gmail.com>"] |
||||||
|
readme = "README.md" |
||||||
|
|
||||||
|
[tool.poetry.dependencies] |
||||||
|
python = "^3.10" |
||||||
|
maturin = "^0.13.3" |
||||||
|
numpy = "^1.23.3" |
||||||
|
|
||||||
|
|
||||||
|
[build-system] |
||||||
|
requires = ["poetry-core"] |
||||||
|
build-backend = "poetry.core.masonry.api" |
@ -0,0 +1,162 @@ |
|||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use ndarray::prelude::*; |
||||||
|
|
||||||
|
use crate::network; |
||||||
|
use crate::params::{Params, ParamsTrait, StateType}; |
||||||
|
|
||||||
|
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
|
||||||
|
///composed by the following elements:
|
||||||
|
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
|
||||||
|
///- **nodes**: a vector containing all the nodes and their parameters.
|
||||||
|
///The index of a node inside the vector is also used as index for the adj_matrix.
|
||||||
|
///
|
||||||
|
///# Examples
|
||||||
|
///
|
||||||
|
///```
|
||||||
|
///
|
||||||
|
/// use std::collections::BTreeSet;
|
||||||
|
/// use reCTBN::network::Network;
|
||||||
|
/// use reCTBN::params;
|
||||||
|
/// use reCTBN::ctbn::*;
|
||||||
|
///
|
||||||
|
/// //Create the domain for a discrete node
|
||||||
|
/// let mut domain = BTreeSet::new();
|
||||||
|
/// domain.insert(String::from("A"));
|
||||||
|
/// domain.insert(String::from("B"));
|
||||||
|
///
|
||||||
|
/// //Create the parameters for a discrete node using the domain
|
||||||
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
||||||
|
///
|
||||||
|
/// //Create the node using the parameters
|
||||||
|
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
||||||
|
///
|
||||||
|
/// let mut domain = BTreeSet::new();
|
||||||
|
/// domain.insert(String::from("A"));
|
||||||
|
/// domain.insert(String::from("B"));
|
||||||
|
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
||||||
|
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
||||||
|
///
|
||||||
|
/// //Initialize a ctbn
|
||||||
|
/// let mut net = CtbnNetwork::new();
|
||||||
|
///
|
||||||
|
/// //Add nodes
|
||||||
|
/// let X1 = net.add_node(X1).unwrap();
|
||||||
|
/// let X2 = net.add_node(X2).unwrap();
|
||||||
|
///
|
||||||
|
/// //Add an edge
|
||||||
|
/// net.add_edge(X1, X2);
|
||||||
|
///
|
||||||
|
/// //Get all the children of node X1
|
||||||
|
/// let cs = net.get_children_set(X1);
|
||||||
|
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
||||||
|
/// ```
|
||||||
|
pub struct CtbnNetwork { |
||||||
|
adj_matrix: Option<Array2<u16>>, |
||||||
|
nodes: Vec<Params>, |
||||||
|
} |
||||||
|
|
||||||
|
impl CtbnNetwork { |
||||||
|
pub fn new() -> CtbnNetwork { |
||||||
|
CtbnNetwork { |
||||||
|
adj_matrix: None, |
||||||
|
nodes: Vec::new(), |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl network::Network for CtbnNetwork { |
||||||
|
fn initialize_adj_matrix(&mut self) { |
||||||
|
self.adj_matrix = Some(Array2::<u16>::zeros( |
||||||
|
(self.nodes.len(), self.nodes.len()).f(), |
||||||
|
)); |
||||||
|
} |
||||||
|
|
||||||
|
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { |
||||||
|
n.reset_params(); |
||||||
|
self.adj_matrix = Option::None; |
||||||
|
self.nodes.push(n); |
||||||
|
Ok(self.nodes.len() - 1) |
||||||
|
} |
||||||
|
|
||||||
|
fn add_edge(&mut self, parent: usize, child: usize) { |
||||||
|
if let None = self.adj_matrix { |
||||||
|
self.initialize_adj_matrix(); |
||||||
|
} |
||||||
|
|
||||||
|
if let Some(network) = &mut self.adj_matrix { |
||||||
|
network[[parent, child]] = 1; |
||||||
|
self.nodes[child].reset_params(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node_indices(&self) -> std::ops::Range<usize> { |
||||||
|
0..self.nodes.len() |
||||||
|
} |
||||||
|
|
||||||
|
fn get_number_of_nodes(&self) -> usize { |
||||||
|
self.nodes.len() |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node(&self, node_idx: usize) -> &Params { |
||||||
|
&self.nodes[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { |
||||||
|
&mut self.nodes[node_idx] |
||||||
|
} |
||||||
|
|
||||||
|
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize { |
||||||
|
self.adj_matrix |
||||||
|
.as_ref() |
||||||
|
.unwrap() |
||||||
|
.column(node) |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.fold((0, 1), |mut acc, x| { |
||||||
|
if x.1 > &0 { |
||||||
|
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
||||||
|
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
||||||
|
} |
||||||
|
acc |
||||||
|
}) |
||||||
|
.0 |
||||||
|
} |
||||||
|
|
||||||
|
fn get_param_index_from_custom_parent_set( |
||||||
|
&self, |
||||||
|
current_state: &Vec<StateType>, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
) -> usize { |
||||||
|
parent_set |
||||||
|
.iter() |
||||||
|
.fold((0, 1), |mut acc, x| { |
||||||
|
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; |
||||||
|
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); |
||||||
|
acc |
||||||
|
}) |
||||||
|
.0 |
||||||
|
} |
||||||
|
|
||||||
|
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.adj_matrix |
||||||
|
.as_ref() |
||||||
|
.unwrap() |
||||||
|
.column(node) |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
|
||||||
|
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.adj_matrix |
||||||
|
.as_ref() |
||||||
|
.unwrap() |
||||||
|
.row(node) |
||||||
|
.iter() |
||||||
|
.enumerate() |
||||||
|
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
||||||
|
.collect() |
||||||
|
} |
||||||
|
} |
@ -1,12 +1,11 @@ |
|||||||
#![doc = include_str!("../../README.md")] |
|
||||||
#![allow(non_snake_case)] |
#![allow(non_snake_case)] |
||||||
#[cfg(test)] |
#[cfg(test)] |
||||||
extern crate approx; |
extern crate approx; |
||||||
|
|
||||||
|
pub mod ctbn; |
||||||
|
pub mod network; |
||||||
pub mod parameter_learning; |
pub mod parameter_learning; |
||||||
pub mod params; |
pub mod params; |
||||||
pub mod process; |
|
||||||
pub mod reward; |
|
||||||
pub mod sampling; |
pub mod sampling; |
||||||
pub mod structure_learning; |
pub mod structure_learning; |
||||||
pub mod tools; |
pub mod tools; |
||||||
|
@ -0,0 +1,43 @@ |
|||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use thiserror::Error; |
||||||
|
|
||||||
|
use crate::params; |
||||||
|
|
||||||
|
/// Error types for trait Network
|
||||||
|
#[derive(Error, Debug)] |
||||||
|
pub enum NetworkError { |
||||||
|
#[error("Error during node insertion")] |
||||||
|
NodeInsertionError(String), |
||||||
|
} |
||||||
|
|
||||||
|
///Network
|
||||||
|
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
|
||||||
|
pub trait Network { |
||||||
|
fn initialize_adj_matrix(&mut self); |
||||||
|
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; |
||||||
|
fn add_edge(&mut self, parent: usize, child: usize); |
||||||
|
|
||||||
|
///Get all the indices of the nodes contained inside the network
|
||||||
|
fn get_node_indices(&self) -> std::ops::Range<usize>; |
||||||
|
fn get_number_of_nodes(&self) -> usize; |
||||||
|
fn get_node(&self, node_idx: usize) -> ¶ms::Params; |
||||||
|
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; |
||||||
|
|
||||||
|
///Compute the index that must be used to access the parameters of a node given a specific
|
||||||
|
///configuration of the network. Usually, the only values really used in *current_state* are
|
||||||
|
///the ones in the parent set of the *node*.
|
||||||
|
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) |
||||||
|
-> usize; |
||||||
|
|
||||||
|
///Compute the index that must be used to access the parameters of a node given a specific
|
||||||
|
///configuration of the network and a generic parent_set. Usually, the only values really used
|
||||||
|
///in *current_state* are the ones in the parent set of the *node*.
|
||||||
|
fn get_param_index_from_custom_parent_set( |
||||||
|
&self, |
||||||
|
current_state: &Vec<params::StateType>, |
||||||
|
parent_set: &BTreeSet<usize>, |
||||||
|
) -> usize; |
||||||
|
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
||||||
|
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
||||||
|
} |
@ -1,120 +0,0 @@ |
|||||||
//! Defines methods for dealing with Probabilistic Graphical Models like the CTBNs
|
|
||||||
|
|
||||||
pub mod ctbn; |
|
||||||
pub mod ctmp; |
|
||||||
|
|
||||||
use std::collections::BTreeSet; |
|
||||||
|
|
||||||
use thiserror::Error; |
|
||||||
|
|
||||||
use crate::params; |
|
||||||
|
|
||||||
/// Error types for trait Network
|
|
||||||
#[derive(Error, Debug)] |
|
||||||
pub enum NetworkError { |
|
||||||
#[error("Error during node insertion")] |
|
||||||
NodeInsertionError(String), |
|
||||||
} |
|
||||||
|
|
||||||
/// This type is used to represent a specific realization of a generic NetworkProcess
|
|
||||||
pub type NetworkProcessState = Vec<params::StateType>; |
|
||||||
|
|
||||||
/// It defines the required methods for a structure used as a Probabilistic Graphical Models (such
|
|
||||||
/// as a CTBN).
|
|
||||||
pub trait NetworkProcess: Sync { |
|
||||||
fn initialize_adj_matrix(&mut self); |
|
||||||
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>; |
|
||||||
/// Add an **directed edge** between a two nodes of the network.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `parent` - parent node.
|
|
||||||
/// * `child` - child node.
|
|
||||||
fn add_edge(&mut self, parent: usize, child: usize); |
|
||||||
|
|
||||||
/// Get all the indices of the nodes contained inside the network.
|
|
||||||
fn get_node_indices(&self) -> std::ops::Range<usize>; |
|
||||||
|
|
||||||
/// Get the numbers of nodes contained in the network.
|
|
||||||
fn get_number_of_nodes(&self) -> usize; |
|
||||||
|
|
||||||
/// Get the **node param**.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `node_idx` - node index value.
|
|
||||||
///
|
|
||||||
/// # Return
|
|
||||||
///
|
|
||||||
/// * The selected **node param**.
|
|
||||||
fn get_node(&self, node_idx: usize) -> ¶ms::Params; |
|
||||||
|
|
||||||
/// Get the **node param**.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `node_idx` - node index value.
|
|
||||||
///
|
|
||||||
/// # Return
|
|
||||||
///
|
|
||||||
/// * The selected **node mutable param**.
|
|
||||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; |
|
||||||
|
|
||||||
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
|
||||||
/// configuration of the network.
|
|
||||||
///
|
|
||||||
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
|
||||||
/// the `node`.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `node` - selected node.
|
|
||||||
/// * `current_state` - current configuration of the network.
|
|
||||||
///
|
|
||||||
/// # Return
|
|
||||||
///
|
|
||||||
/// * Index of the `node` relative to the network.
|
|
||||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize; |
|
||||||
|
|
||||||
/// Compute the index that must be used to access the parameters of a `node`, given a specific
|
|
||||||
/// configuration of the network and a generic `parent_set`.
|
|
||||||
///
|
|
||||||
/// Usually, the only values really used in `current_state` are the ones in the parent set of
|
|
||||||
/// the `node`.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `current_state` - current configuration of the network.
|
|
||||||
/// * `parent_set` - parent set of the selected `node`.
|
|
||||||
///
|
|
||||||
/// # Return
|
|
||||||
///
|
|
||||||
/// * Index of the `node` relative to the network.
|
|
||||||
fn get_param_index_from_custom_parent_set( |
|
||||||
&self, |
|
||||||
current_state: &Vec<params::StateType>, |
|
||||||
parent_set: &BTreeSet<usize>, |
|
||||||
) -> usize; |
|
||||||
|
|
||||||
/// Get the **parent set** of a given **node**.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `node` - node index value.
|
|
||||||
///
|
|
||||||
/// # Return
|
|
||||||
///
|
|
||||||
/// * The **parent set** of the selected node.
|
|
||||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>; |
|
||||||
|
|
||||||
/// Get the **children set** of a given **node**.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `node` - node index value.
|
|
||||||
///
|
|
||||||
/// # Return
|
|
||||||
///
|
|
||||||
/// * The **children set** of the selected node.
|
|
||||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize>; |
|
||||||
} |
|
@ -1,247 +0,0 @@ |
|||||||
//! Continuous Time Bayesian Network
|
|
||||||
|
|
||||||
use std::collections::BTreeSet; |
|
||||||
|
|
||||||
use ndarray::prelude::*; |
|
||||||
|
|
||||||
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType}; |
|
||||||
use crate::process; |
|
||||||
|
|
||||||
use super::ctmp::CtmpProcess; |
|
||||||
use super::{NetworkProcess, NetworkProcessState}; |
|
||||||
|
|
||||||
/// It represents both the structure and the parameters of a CTBN.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
|
|
||||||
/// * `nodes` - A vector containing all the nodes and their parameters.
|
|
||||||
///
|
|
||||||
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use std::collections::BTreeSet;
|
|
||||||
/// use reCTBN::process::NetworkProcess;
|
|
||||||
/// use reCTBN::params;
|
|
||||||
/// use reCTBN::process::ctbn::*;
|
|
||||||
///
|
|
||||||
/// //Create the domain for a discrete node
|
|
||||||
/// let mut domain = BTreeSet::new();
|
|
||||||
/// domain.insert(String::from("A"));
|
|
||||||
/// domain.insert(String::from("B"));
|
|
||||||
///
|
|
||||||
/// //Create the parameters for a discrete node using the domain
|
|
||||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
|
|
||||||
///
|
|
||||||
/// //Create the node using the parameters
|
|
||||||
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
|
|
||||||
///
|
|
||||||
/// let mut domain = BTreeSet::new();
|
|
||||||
/// domain.insert(String::from("A"));
|
|
||||||
/// domain.insert(String::from("B"));
|
|
||||||
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
|
|
||||||
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
|
|
||||||
///
|
|
||||||
/// //Initialize a ctbn
|
|
||||||
/// let mut net = CtbnNetwork::new();
|
|
||||||
///
|
|
||||||
/// //Add nodes
|
|
||||||
/// let X1 = net.add_node(X1).unwrap();
|
|
||||||
/// let X2 = net.add_node(X2).unwrap();
|
|
||||||
///
|
|
||||||
/// //Add an edge
|
|
||||||
/// net.add_edge(X1, X2);
|
|
||||||
///
|
|
||||||
/// //Get all the children of node X1
|
|
||||||
/// let cs = net.get_children_set(X1);
|
|
||||||
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
|
||||||
/// ```
|
|
||||||
pub struct CtbnNetwork { |
|
||||||
adj_matrix: Option<Array2<u16>>, |
|
||||||
nodes: Vec<Params>, |
|
||||||
} |
|
||||||
|
|
||||||
impl CtbnNetwork { |
|
||||||
pub fn new() -> CtbnNetwork { |
|
||||||
CtbnNetwork { |
|
||||||
adj_matrix: None, |
|
||||||
nodes: Vec::new(), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
///Transform the **CTBN** into a **CTMP**
|
|
||||||
///
|
|
||||||
/// # Return
|
|
||||||
///
|
|
||||||
/// * The equivalent *CtmpProcess* computed from the current CtbnNetwork
|
|
||||||
pub fn amalgamation(&self) -> CtmpProcess { |
|
||||||
let variables_domain = |
|
||||||
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent())); |
|
||||||
|
|
||||||
let state_space = variables_domain.product(); |
|
||||||
let variables_set = BTreeSet::from_iter(self.get_node_indices()); |
|
||||||
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space)); |
|
||||||
|
|
||||||
for idx_current_state in 0..state_space { |
|
||||||
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state); |
|
||||||
let current_state_statetype: NetworkProcessState = current_state |
|
||||||
.iter() |
|
||||||
.map(|x| StateType::Discrete(*x)) |
|
||||||
.collect(); |
|
||||||
for idx_node in 0..self.nodes.len() { |
|
||||||
let p = match self.get_node(idx_node) { |
|
||||||
Params::DiscreteStatesContinousTime(p) => p, |
|
||||||
}; |
|
||||||
for next_node_state in 0..variables_domain[idx_node] { |
|
||||||
let mut next_state = current_state.clone(); |
|
||||||
next_state[idx_node] = next_node_state; |
|
||||||
|
|
||||||
let next_state_statetype: NetworkProcessState = |
|
||||||
next_state.iter().map(|x| StateType::Discrete(*x)).collect(); |
|
||||||
let idx_next_state = self.get_param_index_from_custom_parent_set( |
|
||||||
&next_state_statetype, |
|
||||||
&variables_set, |
|
||||||
); |
|
||||||
amalgamated_cim[[0, idx_current_state, idx_next_state]] += |
|
||||||
p.get_cim().as_ref().unwrap()[[ |
|
||||||
self.get_param_index_network(idx_node, ¤t_state_statetype), |
|
||||||
current_state[idx_node], |
|
||||||
next_node_state, |
|
||||||
]]; |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new( |
|
||||||
"ctmp".to_string(), |
|
||||||
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())), |
|
||||||
); |
|
||||||
|
|
||||||
amalgamated_param.set_cim(amalgamated_cim).unwrap(); |
|
||||||
|
|
||||||
let mut ctmp = CtmpProcess::new(); |
|
||||||
|
|
||||||
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param)) |
|
||||||
.unwrap(); |
|
||||||
return ctmp; |
|
||||||
} |
|
||||||
|
|
||||||
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> { |
|
||||||
let mut state = state; |
|
||||||
let mut array_state = Array1::zeros(variables_domain.shape()[0]); |
|
||||||
for (idx, var) in variables_domain.indexed_iter() { |
|
||||||
array_state[idx] = state % var; |
|
||||||
state = state / var; |
|
||||||
} |
|
||||||
|
|
||||||
return array_state; |
|
||||||
} |
|
||||||
/// Get the Adjacency Matrix.
|
|
||||||
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> { |
|
||||||
self.adj_matrix.as_ref() |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
impl process::NetworkProcess for CtbnNetwork { |
|
||||||
/// Initialize an Adjacency matrix.
|
|
||||||
fn initialize_adj_matrix(&mut self) { |
|
||||||
self.adj_matrix = Some(Array2::<u16>::zeros( |
|
||||||
(self.nodes.len(), self.nodes.len()).f(), |
|
||||||
)); |
|
||||||
} |
|
||||||
|
|
||||||
/// Add a new node.
|
|
||||||
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> { |
|
||||||
n.reset_params(); |
|
||||||
self.adj_matrix = Option::None; |
|
||||||
self.nodes.push(n); |
|
||||||
Ok(self.nodes.len() - 1) |
|
||||||
} |
|
||||||
|
|
||||||
/// Connect two nodes with a new edge.
|
|
||||||
fn add_edge(&mut self, parent: usize, child: usize) { |
|
||||||
if let None = self.adj_matrix { |
|
||||||
self.initialize_adj_matrix(); |
|
||||||
} |
|
||||||
|
|
||||||
if let Some(network) = &mut self.adj_matrix { |
|
||||||
network[[parent, child]] = 1; |
|
||||||
self.nodes[child].reset_params(); |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
|
||||||
0..self.nodes.len() |
|
||||||
} |
|
||||||
|
|
||||||
/// Get the number of nodes of the network.
|
|
||||||
fn get_number_of_nodes(&self) -> usize { |
|
||||||
self.nodes.len() |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node(&self, node_idx: usize) -> &Params { |
|
||||||
&self.nodes[node_idx] |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params { |
|
||||||
&mut self.nodes[node_idx] |
|
||||||
} |
|
||||||
|
|
||||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
|
||||||
self.adj_matrix |
|
||||||
.as_ref() |
|
||||||
.unwrap() |
|
||||||
.column(node) |
|
||||||
.iter() |
|
||||||
.enumerate() |
|
||||||
.fold((0, 1), |mut acc, x| { |
|
||||||
if x.1 > &0 { |
|
||||||
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
|
||||||
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
|
||||||
} |
|
||||||
acc |
|
||||||
}) |
|
||||||
.0 |
|
||||||
} |
|
||||||
|
|
||||||
fn get_param_index_from_custom_parent_set( |
|
||||||
&self, |
|
||||||
current_state: &NetworkProcessState, |
|
||||||
parent_set: &BTreeSet<usize>, |
|
||||||
) -> usize { |
|
||||||
parent_set |
|
||||||
.iter() |
|
||||||
.fold((0, 1), |mut acc, x| { |
|
||||||
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; |
|
||||||
acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); |
|
||||||
acc |
|
||||||
}) |
|
||||||
.0 |
|
||||||
} |
|
||||||
|
|
||||||
/// Get all the parents of the given node.
|
|
||||||
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
|
||||||
self.adj_matrix |
|
||||||
.as_ref() |
|
||||||
.unwrap() |
|
||||||
.column(node) |
|
||||||
.iter() |
|
||||||
.enumerate() |
|
||||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
|
||||||
.collect() |
|
||||||
} |
|
||||||
|
|
||||||
/// Get all the children of the given node.
|
|
||||||
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
|
||||||
self.adj_matrix |
|
||||||
.as_ref() |
|
||||||
.unwrap() |
|
||||||
.row(node) |
|
||||||
.iter() |
|
||||||
.enumerate() |
|
||||||
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None }) |
|
||||||
.collect() |
|
||||||
} |
|
||||||
} |
|
@ -1,114 +0,0 @@ |
|||||||
use std::collections::BTreeSet; |
|
||||||
|
|
||||||
use crate::{ |
|
||||||
params::{Params, StateType}, |
|
||||||
process, |
|
||||||
}; |
|
||||||
|
|
||||||
use super::{NetworkProcess, NetworkProcessState}; |
|
||||||
|
|
||||||
pub struct CtmpProcess { |
|
||||||
param: Option<Params>, |
|
||||||
} |
|
||||||
|
|
||||||
impl CtmpProcess { |
|
||||||
pub fn new() -> CtmpProcess { |
|
||||||
CtmpProcess { param: None } |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
impl NetworkProcess for CtmpProcess { |
|
||||||
fn initialize_adj_matrix(&mut self) { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
|
|
||||||
fn add_node(&mut self, n: crate::params::Params) -> Result<usize, process::NetworkError> { |
|
||||||
match self.param { |
|
||||||
None => { |
|
||||||
self.param = Some(n); |
|
||||||
Ok(0) |
|
||||||
} |
|
||||||
Some(_) => Err(process::NetworkError::NodeInsertionError( |
|
||||||
"CtmpProcess has only one node".to_string(), |
|
||||||
)), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn add_edge(&mut self, _parent: usize, _child: usize) { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node_indices(&self) -> std::ops::Range<usize> { |
|
||||||
match self.param { |
|
||||||
None => 0..0, |
|
||||||
Some(_) => 0..1, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_number_of_nodes(&self) -> usize { |
|
||||||
match self.param { |
|
||||||
None => 0, |
|
||||||
Some(_) => 1, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node(&self, node_idx: usize) -> &crate::params::Params { |
|
||||||
if node_idx == 0 { |
|
||||||
self.param.as_ref().unwrap() |
|
||||||
} else { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_node_mut(&mut self, node_idx: usize) -> &mut crate::params::Params { |
|
||||||
if node_idx == 0 { |
|
||||||
self.param.as_mut().unwrap() |
|
||||||
} else { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize { |
|
||||||
if node == 0 { |
|
||||||
match current_state[0] { |
|
||||||
StateType::Discrete(x) => x, |
|
||||||
} |
|
||||||
} else { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_param_index_from_custom_parent_set( |
|
||||||
&self, |
|
||||||
_current_state: &NetworkProcessState, |
|
||||||
_parent_set: &BTreeSet<usize>, |
|
||||||
) -> usize { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
|
|
||||||
fn get_parent_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
|
||||||
match self.param { |
|
||||||
Some(_) => { |
|
||||||
if node == 0 { |
|
||||||
BTreeSet::new() |
|
||||||
} else { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
} |
|
||||||
None => panic!("Uninitialized CtmpProcess"), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn get_children_set(&self, node: usize) -> std::collections::BTreeSet<usize> { |
|
||||||
match self.param { |
|
||||||
Some(_) => { |
|
||||||
if node == 0 { |
|
||||||
BTreeSet::new() |
|
||||||
} else { |
|
||||||
unimplemented!("CtmpProcess has only one node") |
|
||||||
} |
|
||||||
} |
|
||||||
None => panic!("Uninitialized CtmpProcess"), |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
@ -1,59 +0,0 @@ |
|||||||
pub mod reward_evaluation; |
|
||||||
pub mod reward_function; |
|
||||||
|
|
||||||
use std::collections::HashMap; |
|
||||||
|
|
||||||
use crate::process; |
|
||||||
|
|
||||||
/// Instantiation of reward function and instantaneous reward
|
|
||||||
///
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `transition_reward`: reward obtained transitioning from one state to another
|
|
||||||
/// * `instantaneous_reward`: reward per unit of time obtained staying in a specific state
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)] |
|
||||||
pub struct Reward { |
|
||||||
pub transition_reward: f64, |
|
||||||
pub instantaneous_reward: f64, |
|
||||||
} |
|
||||||
|
|
||||||
/// The trait RewardFunction describe the methods that all the reward functions must satisfy
|
|
||||||
|
|
||||||
pub trait RewardFunction: Sync { |
|
||||||
/// Given the current state and the previous state, it compute the reward.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `current_state`: the current state of the network represented as a `process::NetworkProcessState`
|
|
||||||
/// * `previous_state`: an optional argument representing the previous state of the network
|
|
||||||
|
|
||||||
fn call( |
|
||||||
&self, |
|
||||||
current_state: &process::NetworkProcessState, |
|
||||||
previous_state: Option<&process::NetworkProcessState>, |
|
||||||
) -> Reward; |
|
||||||
|
|
||||||
/// Initialize the RewardFunction internal accordingly to the structure of a NetworkProcess
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `p`: any structure that implements the trait `process::NetworkProcess`
|
|
||||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self; |
|
||||||
} |
|
||||||
|
|
||||||
pub trait RewardEvaluation { |
|
||||||
fn evaluate_state_space<N: process::NetworkProcess, R: RewardFunction>( |
|
||||||
&self, |
|
||||||
network_process: &N, |
|
||||||
reward_function: &R, |
|
||||||
) -> HashMap<process::NetworkProcessState, f64>; |
|
||||||
|
|
||||||
fn evaluate_state<N: process::NetworkProcess, R: RewardFunction>( |
|
||||||
&self, |
|
||||||
network_process: &N, |
|
||||||
reward_function: &R, |
|
||||||
state: &process::NetworkProcessState, |
|
||||||
) -> f64; |
|
||||||
} |
|
@ -1,205 +0,0 @@ |
|||||||
use std::collections::HashMap; |
|
||||||
|
|
||||||
use rayon::prelude::{IntoParallelIterator, ParallelIterator}; |
|
||||||
use statrs::distribution::ContinuousCDF; |
|
||||||
|
|
||||||
use crate::params::{self, ParamsTrait}; |
|
||||||
use crate::process; |
|
||||||
|
|
||||||
use crate::{ |
|
||||||
process::NetworkProcessState, |
|
||||||
reward::RewardEvaluation, |
|
||||||
sampling::{ForwardSampler, Sampler}, |
|
||||||
}; |
|
||||||
|
|
||||||
pub enum RewardCriteria { |
|
||||||
FiniteHorizon, |
|
||||||
InfiniteHorizon { discount_factor: f64 }, |
|
||||||
} |
|
||||||
|
|
||||||
pub struct MonteCarloReward { |
|
||||||
max_iterations: usize, |
|
||||||
max_err_stop: f64, |
|
||||||
alpha_stop: f64, |
|
||||||
end_time: f64, |
|
||||||
reward_criteria: RewardCriteria, |
|
||||||
seed: Option<u64>, |
|
||||||
} |
|
||||||
|
|
||||||
impl MonteCarloReward { |
|
||||||
pub fn new( |
|
||||||
max_iterations: usize, |
|
||||||
max_err_stop: f64, |
|
||||||
alpha_stop: f64, |
|
||||||
end_time: f64, |
|
||||||
reward_criteria: RewardCriteria, |
|
||||||
seed: Option<u64>, |
|
||||||
) -> MonteCarloReward { |
|
||||||
MonteCarloReward { |
|
||||||
max_iterations, |
|
||||||
max_err_stop, |
|
||||||
alpha_stop, |
|
||||||
end_time, |
|
||||||
reward_criteria, |
|
||||||
seed, |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
impl RewardEvaluation for MonteCarloReward { |
|
||||||
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
|
||||||
&self, |
|
||||||
network_process: &N, |
|
||||||
reward_function: &R, |
|
||||||
) -> HashMap<process::NetworkProcessState, f64> { |
|
||||||
let variables_domain: Vec<Vec<params::StateType>> = network_process |
|
||||||
.get_node_indices() |
|
||||||
.map(|x| match network_process.get_node(x) { |
|
||||||
params::Params::DiscreteStatesContinousTime(x) => (0..x |
|
||||||
.get_reserved_space_as_parent()) |
|
||||||
.map(|s| params::StateType::Discrete(s)) |
|
||||||
.collect(), |
|
||||||
}) |
|
||||||
.collect(); |
|
||||||
|
|
||||||
let n_states: usize = variables_domain.iter().map(|x| x.len()).product(); |
|
||||||
|
|
||||||
(0..n_states) |
|
||||||
.into_par_iter() |
|
||||||
.map(|s| { |
|
||||||
let state: process::NetworkProcessState = variables_domain |
|
||||||
.iter() |
|
||||||
.fold((s, vec![]), |acc, x| { |
|
||||||
let mut acc = acc; |
|
||||||
let idx_s = acc.0 % x.len(); |
|
||||||
acc.1.push(x[idx_s].clone()); |
|
||||||
acc.0 = acc.0 / x.len(); |
|
||||||
acc |
|
||||||
}) |
|
||||||
.1; |
|
||||||
|
|
||||||
let r = self.evaluate_state(network_process, reward_function, &state); |
|
||||||
(state, r) |
|
||||||
}) |
|
||||||
.collect() |
|
||||||
} |
|
||||||
|
|
||||||
fn evaluate_state<N: crate::process::NetworkProcess, R: super::RewardFunction>( |
|
||||||
&self, |
|
||||||
network_process: &N, |
|
||||||
reward_function: &R, |
|
||||||
state: &NetworkProcessState, |
|
||||||
) -> f64 { |
|
||||||
let mut sampler = |
|
||||||
ForwardSampler::new(network_process, self.seed.clone(), Some(state.clone())); |
|
||||||
let mut expected_value = 0.0; |
|
||||||
let mut squared_expected_value = 0.0; |
|
||||||
let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); |
|
||||||
|
|
||||||
for i in 0..self.max_iterations { |
|
||||||
sampler.reset(); |
|
||||||
let mut ret = 0.0; |
|
||||||
let mut previous = sampler.next().unwrap(); |
|
||||||
while previous.t < self.end_time { |
|
||||||
let current = sampler.next().unwrap(); |
|
||||||
if current.t > self.end_time { |
|
||||||
let r = reward_function.call(&previous.state, None); |
|
||||||
let discount = match self.reward_criteria { |
|
||||||
RewardCriteria::FiniteHorizon => self.end_time - previous.t, |
|
||||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
|
||||||
std::f64::consts::E.powf(-discount_factor * previous.t) |
|
||||||
- std::f64::consts::E.powf(-discount_factor * self.end_time) |
|
||||||
} |
|
||||||
}; |
|
||||||
ret += discount * r.instantaneous_reward; |
|
||||||
} else { |
|
||||||
let r = reward_function.call(&previous.state, Some(¤t.state)); |
|
||||||
let discount = match self.reward_criteria { |
|
||||||
RewardCriteria::FiniteHorizon => current.t - previous.t, |
|
||||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
|
||||||
std::f64::consts::E.powf(-discount_factor * previous.t) |
|
||||||
- std::f64::consts::E.powf(-discount_factor * current.t) |
|
||||||
} |
|
||||||
}; |
|
||||||
ret += discount * r.instantaneous_reward; |
|
||||||
ret += match self.reward_criteria { |
|
||||||
RewardCriteria::FiniteHorizon => 1.0, |
|
||||||
RewardCriteria::InfiniteHorizon { discount_factor } => { |
|
||||||
std::f64::consts::E.powf(-discount_factor * current.t) |
|
||||||
} |
|
||||||
} * r.transition_reward; |
|
||||||
} |
|
||||||
previous = current; |
|
||||||
} |
|
||||||
|
|
||||||
let float_i = i as f64; |
|
||||||
expected_value = |
|
||||||
expected_value * float_i as f64 / (float_i + 1.0) + ret / (float_i + 1.0); |
|
||||||
squared_expected_value = squared_expected_value * float_i as f64 / (float_i + 1.0) |
|
||||||
+ ret.powi(2) / (float_i + 1.0); |
|
||||||
|
|
||||||
if i > 2 { |
|
||||||
let var = |
|
||||||
(float_i + 1.0) / float_i * (squared_expected_value - expected_value.powi(2)); |
|
||||||
if self.alpha_stop |
|
||||||
- 2.0 * normal.cdf(-(float_i + 1.0).sqrt() * self.max_err_stop / var.sqrt()) |
|
||||||
> 0.0 |
|
||||||
{ |
|
||||||
return expected_value; |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
expected_value |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
pub struct NeighborhoodRelativeReward<RE: RewardEvaluation> { |
|
||||||
inner_reward: RE, |
|
||||||
} |
|
||||||
|
|
||||||
impl<RE: RewardEvaluation> NeighborhoodRelativeReward<RE> { |
|
||||||
pub fn new(inner_reward: RE) -> NeighborhoodRelativeReward<RE> { |
|
||||||
NeighborhoodRelativeReward { inner_reward } |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
impl<RE: RewardEvaluation> RewardEvaluation for NeighborhoodRelativeReward<RE> { |
|
||||||
fn evaluate_state_space<N: process::NetworkProcess, R: super::RewardFunction>( |
|
||||||
&self, |
|
||||||
network_process: &N, |
|
||||||
reward_function: &R, |
|
||||||
) -> HashMap<process::NetworkProcessState, f64> { |
|
||||||
let absolute_reward = self |
|
||||||
.inner_reward |
|
||||||
.evaluate_state_space(network_process, reward_function); |
|
||||||
|
|
||||||
//This approach optimize memory. Maybe optimizing execution time can be better.
|
|
||||||
absolute_reward |
|
||||||
.iter() |
|
||||||
.map(|(k1, v1)| { |
|
||||||
let mut max_val: f64 = 1.0; |
|
||||||
absolute_reward.iter().for_each(|(k2, v2)| { |
|
||||||
let count_diff: usize = k1 |
|
||||||
.iter() |
|
||||||
.zip(k2.iter()) |
|
||||||
.map(|(s1, s2)| if s1 == s2 { 0 } else { 1 }) |
|
||||||
.sum(); |
|
||||||
if count_diff < 2 { |
|
||||||
max_val = max_val.max(v1 / v2); |
|
||||||
} |
|
||||||
}); |
|
||||||
(k1.clone(), max_val) |
|
||||||
}) |
|
||||||
.collect() |
|
||||||
} |
|
||||||
|
|
||||||
fn evaluate_state<N: process::NetworkProcess, R: super::RewardFunction>( |
|
||||||
&self, |
|
||||||
_network_process: &N, |
|
||||||
_reward_function: &R, |
|
||||||
_state: &process::NetworkProcessState, |
|
||||||
) -> f64 { |
|
||||||
unimplemented!(); |
|
||||||
} |
|
||||||
} |
|
@ -1,106 +0,0 @@ |
|||||||
//! Module for dealing with reward functions
|
|
||||||
|
|
||||||
use crate::{ |
|
||||||
params::{self, ParamsTrait}, |
|
||||||
process, |
|
||||||
reward::{Reward, RewardFunction}, |
|
||||||
}; |
|
||||||
|
|
||||||
use ndarray; |
|
||||||
|
|
||||||
/// Reward function over a factored state space
|
|
||||||
///
|
|
||||||
/// The `FactoredRewardFunction` assume the reward function is the sum of the reward of each node
|
|
||||||
/// of the underling `NetworkProcess`
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `transition_reward`: a vector of two-dimensional arrays. Each array contains the transition
|
|
||||||
/// reward of a node
|
|
||||||
|
|
||||||
pub struct FactoredRewardFunction { |
|
||||||
transition_reward: Vec<ndarray::Array2<f64>>, |
|
||||||
instantaneous_reward: Vec<ndarray::Array1<f64>>, |
|
||||||
} |
|
||||||
|
|
||||||
impl FactoredRewardFunction { |
|
||||||
pub fn get_transition_reward(&self, node_idx: usize) -> &ndarray::Array2<f64> { |
|
||||||
&self.transition_reward[node_idx] |
|
||||||
} |
|
||||||
|
|
||||||
pub fn get_transition_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array2<f64> { |
|
||||||
&mut self.transition_reward[node_idx] |
|
||||||
} |
|
||||||
|
|
||||||
pub fn get_instantaneous_reward(&self, node_idx: usize) -> &ndarray::Array1<f64> { |
|
||||||
&self.instantaneous_reward[node_idx] |
|
||||||
} |
|
||||||
|
|
||||||
pub fn get_instantaneous_reward_mut(&mut self, node_idx: usize) -> &mut ndarray::Array1<f64> { |
|
||||||
&mut self.instantaneous_reward[node_idx] |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
impl RewardFunction for FactoredRewardFunction { |
|
||||||
fn call( |
|
||||||
&self, |
|
||||||
current_state: &process::NetworkProcessState, |
|
||||||
previous_state: Option<&process::NetworkProcessState>, |
|
||||||
) -> Reward { |
|
||||||
let instantaneous_reward: f64 = current_state |
|
||||||
.iter() |
|
||||||
.enumerate() |
|
||||||
.map(|(idx, x)| { |
|
||||||
let x = match x { |
|
||||||
params::StateType::Discrete(x) => x, |
|
||||||
}; |
|
||||||
self.instantaneous_reward[idx][*x] |
|
||||||
}) |
|
||||||
.sum(); |
|
||||||
if let Some(previous_state) = previous_state { |
|
||||||
let transition_reward = previous_state |
|
||||||
.iter() |
|
||||||
.zip(current_state.iter()) |
|
||||||
.enumerate() |
|
||||||
.find_map(|(idx, (p, c))| -> Option<f64> { |
|
||||||
let p = match p { |
|
||||||
params::StateType::Discrete(p) => p, |
|
||||||
}; |
|
||||||
let c = match c { |
|
||||||
params::StateType::Discrete(c) => c, |
|
||||||
}; |
|
||||||
if p != c { |
|
||||||
Some(self.transition_reward[idx][[*p, *c]]) |
|
||||||
} else { |
|
||||||
None |
|
||||||
} |
|
||||||
}) |
|
||||||
.unwrap_or(0.0); |
|
||||||
Reward { |
|
||||||
transition_reward, |
|
||||||
instantaneous_reward, |
|
||||||
} |
|
||||||
} else { |
|
||||||
Reward { |
|
||||||
transition_reward: 0.0, |
|
||||||
instantaneous_reward, |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
fn initialize_from_network_process<T: process::NetworkProcess>(p: &T) -> Self { |
|
||||||
let mut transition_reward: Vec<ndarray::Array2<f64>> = vec![]; |
|
||||||
let mut instantaneous_reward: Vec<ndarray::Array1<f64>> = vec![]; |
|
||||||
for i in p.get_node_indices() { |
|
||||||
//This works only for discrete nodes!
|
|
||||||
let size: usize = p.get_node(i).get_reserved_space_as_parent(); |
|
||||||
instantaneous_reward.push(ndarray::Array1::zeros(size)); |
|
||||||
transition_reward.push(ndarray::Array2::zeros((size, size))); |
|
||||||
} |
|
||||||
|
|
||||||
FactoredRewardFunction { |
|
||||||
transition_reward, |
|
||||||
instantaneous_reward, |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
@ -1,13 +1,11 @@ |
|||||||
//! Learn the structure of the network.
|
|
||||||
|
|
||||||
pub mod constraint_based_algorithm; |
pub mod constraint_based_algorithm; |
||||||
pub mod hypothesis_test; |
pub mod hypothesis_test; |
||||||
pub mod score_based_algorithm; |
pub mod score_based_algorithm; |
||||||
pub mod score_function; |
pub mod score_function; |
||||||
use crate::{process, tools::Dataset}; |
use crate::{network, tools}; |
||||||
|
|
||||||
pub trait StructureLearningAlgorithm { |
pub trait StructureLearningAlgorithm { |
||||||
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
fn fit_transform<T>(&self, net: T, dataset: &tools::Dataset) -> T |
||||||
where |
where |
||||||
T: process::NetworkProcess; |
T: network::Network; |
||||||
} |
} |
||||||
|
@ -1,348 +1,3 @@ |
|||||||
//! Module containing constraint based algorithms like CTPC and Hiton.
|
//pub struct CTPC {
|
||||||
|
//
|
||||||
use crate::params::Params; |
//}
|
||||||
use itertools::Itertools; |
|
||||||
use rayon::iter::{IntoParallelIterator, ParallelIterator}; |
|
||||||
use rayon::prelude::ParallelExtend; |
|
||||||
use std::collections::{BTreeSet, HashMap}; |
|
||||||
use std::mem; |
|
||||||
use std::usize; |
|
||||||
|
|
||||||
use super::hypothesis_test::*; |
|
||||||
use crate::parameter_learning::ParameterLearning; |
|
||||||
use crate::process; |
|
||||||
use crate::structure_learning::StructureLearningAlgorithm; |
|
||||||
use crate::tools::Dataset; |
|
||||||
|
|
||||||
pub struct Cache<'a, P: ParameterLearning> { |
|
||||||
parameter_learning: &'a P, |
|
||||||
cache_persistent_small: HashMap<Option<BTreeSet<usize>>, Params>, |
|
||||||
cache_persistent_big: HashMap<Option<BTreeSet<usize>>, Params>, |
|
||||||
parent_set_size_small: usize, |
|
||||||
} |
|
||||||
|
|
||||||
impl<'a, P: ParameterLearning> Cache<'a, P> { |
|
||||||
pub fn new(parameter_learning: &'a P) -> Cache<'a, P> { |
|
||||||
Cache { |
|
||||||
parameter_learning, |
|
||||||
cache_persistent_small: HashMap::new(), |
|
||||||
cache_persistent_big: HashMap::new(), |
|
||||||
parent_set_size_small: 0, |
|
||||||
} |
|
||||||
} |
|
||||||
pub fn fit<T: process::NetworkProcess>( |
|
||||||
&mut self, |
|
||||||
net: &T, |
|
||||||
dataset: &Dataset, |
|
||||||
node: usize, |
|
||||||
parent_set: Option<BTreeSet<usize>>, |
|
||||||
) -> Params { |
|
||||||
let parent_set_len = parent_set.as_ref().unwrap().len(); |
|
||||||
if parent_set_len > self.parent_set_size_small + 1 { |
|
||||||
//self.cache_persistent_small = self.cache_persistent_big;
|
|
||||||
mem::swap( |
|
||||||
&mut self.cache_persistent_small, |
|
||||||
&mut self.cache_persistent_big, |
|
||||||
); |
|
||||||
self.cache_persistent_big = HashMap::new(); |
|
||||||
self.parent_set_size_small += 1; |
|
||||||
} |
|
||||||
|
|
||||||
if parent_set_len > self.parent_set_size_small { |
|
||||||
match self.cache_persistent_big.get(&parent_set) { |
|
||||||
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
|
||||||
// not cloning requires a minor and reasoned refactoring across the library
|
|
||||||
Some(params) => params.clone(), |
|
||||||
None => { |
|
||||||
let params = |
|
||||||
self.parameter_learning |
|
||||||
.fit(net, dataset, node, parent_set.clone()); |
|
||||||
self.cache_persistent_big.insert(parent_set, params.clone()); |
|
||||||
params |
|
||||||
} |
|
||||||
} |
|
||||||
} else { |
|
||||||
match self.cache_persistent_small.get(&parent_set) { |
|
||||||
// TODO: Better not clone `params`, useless clock cycles, RAM use and I/O
|
|
||||||
// not cloning requires a minor and reasoned refactoring across the library
|
|
||||||
Some(params) => params.clone(), |
|
||||||
None => { |
|
||||||
let params = |
|
||||||
self.parameter_learning |
|
||||||
.fit(net, dataset, node, parent_set.clone()); |
|
||||||
self.cache_persistent_small |
|
||||||
.insert(parent_set, params.clone()); |
|
||||||
params |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
/// Continuous-Time Peter Clark algorithm.
|
|
||||||
///
|
|
||||||
/// A method to learn the structure of the network.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * [`parameter_learning`](crate::parameter_learning) - is the method used to learn the parameters.
|
|
||||||
/// * [`Ftest`](crate::structure_learning::hypothesis_test::F) - is the F-test hyppothesis test.
|
|
||||||
/// * [`Chi2test`](crate::structure_learning::hypothesis_test::ChiSquare) - is the chi-squared test (χ2 test) hypothesis test.
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use std::collections::BTreeSet;
|
|
||||||
/// # use ndarray::{arr1, arr2, arr3};
|
|
||||||
/// # use reCTBN::params;
|
|
||||||
/// # use reCTBN::tools::trajectory_generator;
|
|
||||||
/// # use reCTBN::process::NetworkProcess;
|
|
||||||
/// # use reCTBN::process::ctbn::CtbnNetwork;
|
|
||||||
/// use reCTBN::parameter_learning::BayesianApproach;
|
|
||||||
/// use reCTBN::structure_learning::StructureLearningAlgorithm;
|
|
||||||
/// use reCTBN::structure_learning::hypothesis_test::{F, ChiSquare};
|
|
||||||
/// use reCTBN::structure_learning::constraint_based_algorithm::CTPC;
|
|
||||||
/// #
|
|
||||||
/// # // Create the domain for a discrete node
|
|
||||||
/// # let mut domain = BTreeSet::new();
|
|
||||||
/// # domain.insert(String::from("A"));
|
|
||||||
/// # domain.insert(String::from("B"));
|
|
||||||
/// # domain.insert(String::from("C"));
|
|
||||||
/// # // Create the parameters for a discrete node using the domain
|
|
||||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n1".to_string(), domain);
|
|
||||||
/// # //Create the node n1 using the parameters
|
|
||||||
/// # let n1 = params::Params::DiscreteStatesContinousTime(param);
|
|
||||||
/// #
|
|
||||||
/// # let mut domain = BTreeSet::new();
|
|
||||||
/// # domain.insert(String::from("D"));
|
|
||||||
/// # domain.insert(String::from("E"));
|
|
||||||
/// # domain.insert(String::from("F"));
|
|
||||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n2".to_string(), domain);
|
|
||||||
/// # let n2 = params::Params::DiscreteStatesContinousTime(param);
|
|
||||||
/// #
|
|
||||||
/// # let mut domain = BTreeSet::new();
|
|
||||||
/// # domain.insert(String::from("G"));
|
|
||||||
/// # domain.insert(String::from("H"));
|
|
||||||
/// # domain.insert(String::from("I"));
|
|
||||||
/// # domain.insert(String::from("F"));
|
|
||||||
/// # let param = params::DiscreteStatesContinousTimeParams::new("n3".to_string(), domain);
|
|
||||||
/// # let n3 = params::Params::DiscreteStatesContinousTime(param);
|
|
||||||
/// #
|
|
||||||
/// # // Initialize a ctbn
|
|
||||||
/// # let mut net = CtbnNetwork::new();
|
|
||||||
/// #
|
|
||||||
/// # // Add the nodes and their edges
|
|
||||||
/// # let n1 = net.add_node(n1).unwrap();
|
|
||||||
/// # let n2 = net.add_node(n2).unwrap();
|
|
||||||
/// # let n3 = net.add_node(n3).unwrap();
|
|
||||||
/// # net.add_edge(n1, n2);
|
|
||||||
/// # net.add_edge(n1, n3);
|
|
||||||
/// # net.add_edge(n2, n3);
|
|
||||||
/// #
|
|
||||||
/// # match &mut net.get_node_mut(n1) {
|
|
||||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
|
||||||
/// # assert_eq!(
|
|
||||||
/// # Ok(()),
|
|
||||||
/// # param.set_cim(arr3(&[
|
|
||||||
/// # [
|
|
||||||
/// # [-3.0, 2.0, 1.0],
|
|
||||||
/// # [1.5, -2.0, 0.5],
|
|
||||||
/// # [0.4, 0.6, -1.0]
|
|
||||||
/// # ],
|
|
||||||
/// # ]))
|
|
||||||
/// # );
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # match &mut net.get_node_mut(n2) {
|
|
||||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
|
||||||
/// # assert_eq!(
|
|
||||||
/// # Ok(()),
|
|
||||||
/// # param.set_cim(arr3(&[
|
|
||||||
/// # [
|
|
||||||
/// # [-1.0, 0.5, 0.5],
|
|
||||||
/// # [3.0, -4.0, 1.0],
|
|
||||||
/// # [0.9, 0.1, -1.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-6.0, 2.0, 4.0],
|
|
||||||
/// # [1.5, -2.0, 0.5],
|
|
||||||
/// # [3.0, 1.0, -4.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-1.0, 0.1, 0.9],
|
|
||||||
/// # [2.0, -2.5, 0.5],
|
|
||||||
/// # [0.9, 0.1, -1.0]
|
|
||||||
/// # ],
|
|
||||||
/// # ]))
|
|
||||||
/// # );
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # match &mut net.get_node_mut(n3) {
|
|
||||||
/// # params::Params::DiscreteStatesContinousTime(param) => {
|
|
||||||
/// # assert_eq!(
|
|
||||||
/// # Ok(()),
|
|
||||||
/// # param.set_cim(arr3(&[
|
|
||||||
/// # [
|
|
||||||
/// # [-1.0, 0.5, 0.3, 0.2],
|
|
||||||
/// # [0.5, -4.0, 2.5, 1.0],
|
|
||||||
/// # [2.5, 0.5, -4.0, 1.0],
|
|
||||||
/// # [0.7, 0.2, 0.1, -1.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-6.0, 2.0, 3.0, 1.0],
|
|
||||||
/// # [1.5, -3.0, 0.5, 1.0],
|
|
||||||
/// # [2.0, 1.3, -5.0, 1.7],
|
|
||||||
/// # [2.5, 0.5, 1.0, -4.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-1.3, 0.3, 0.1, 0.9],
|
|
||||||
/// # [1.4, -4.0, 0.5, 2.1],
|
|
||||||
/// # [1.0, 1.5, -3.0, 0.5],
|
|
||||||
/// # [0.4, 0.3, 0.1, -0.8]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-2.0, 1.0, 0.7, 0.3],
|
|
||||||
/// # [1.3, -5.9, 2.7, 1.9],
|
|
||||||
/// # [2.0, 1.5, -4.0, 0.5],
|
|
||||||
/// # [0.2, 0.7, 0.1, -1.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-6.0, 1.0, 2.0, 3.0],
|
|
||||||
/// # [0.5, -3.0, 1.0, 1.5],
|
|
||||||
/// # [1.4, 2.1, -4.3, 0.8],
|
|
||||||
/// # [0.5, 1.0, 2.5, -4.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-1.3, 0.9, 0.3, 0.1],
|
|
||||||
/// # [0.1, -1.3, 0.2, 1.0],
|
|
||||||
/// # [0.5, 1.0, -3.0, 1.5],
|
|
||||||
/// # [0.1, 0.4, 0.3, -0.8]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-2.0, 1.0, 0.6, 0.4],
|
|
||||||
/// # [2.6, -7.1, 1.4, 3.1],
|
|
||||||
/// # [5.0, 1.0, -8.0, 2.0],
|
|
||||||
/// # [1.4, 0.4, 0.2, -2.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-3.0, 1.0, 1.5, 0.5],
|
|
||||||
/// # [3.0, -6.0, 1.0, 2.0],
|
|
||||||
/// # [0.3, 0.5, -1.9, 1.1],
|
|
||||||
/// # [5.0, 1.0, 2.0, -8.0]
|
|
||||||
/// # ],
|
|
||||||
/// # [
|
|
||||||
/// # [-2.6, 0.6, 0.2, 1.8],
|
|
||||||
/// # [2.0, -6.0, 3.0, 1.0],
|
|
||||||
/// # [0.1, 0.5, -1.3, 0.7],
|
|
||||||
/// # [0.8, 0.6, 0.2, -1.6]
|
|
||||||
/// # ],
|
|
||||||
/// # ]))
|
|
||||||
/// # );
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # // Generate the trajectory
|
|
||||||
/// # let data = trajectory_generator(&net, 300, 30.0, Some(4164901764658873));
|
|
||||||
///
|
|
||||||
/// // Initialize the hypothesis tests to pass to the CTPC with their
|
|
||||||
/// // respective significance level `alpha`
|
|
||||||
/// let f = F::new(1e-6);
|
|
||||||
/// let chi_sq = ChiSquare::new(1e-4);
|
|
||||||
/// // Use the bayesian approach to learn the parameters
|
|
||||||
/// let parameter_learning = BayesianApproach { alpha: 1, tau:1.0 };
|
|
||||||
///
|
|
||||||
/// //Initialize CTPC
|
|
||||||
/// let ctpc = CTPC::new(parameter_learning, f, chi_sq);
|
|
||||||
///
|
|
||||||
/// // Learn the structure of the network from the generated trajectory
|
|
||||||
/// let net = ctpc.fit_transform(net, &data);
|
|
||||||
/// #
|
|
||||||
/// # // Compare the generated network with the original one
|
|
||||||
/// # assert_eq!(BTreeSet::new(), net.get_parent_set(0));
|
|
||||||
/// # assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(1));
|
|
||||||
/// # assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
|
|
||||||
/// ```
|
|
||||||
pub struct CTPC<P: ParameterLearning> { |
|
||||||
parameter_learning: P, |
|
||||||
Ftest: F, |
|
||||||
Chi2test: ChiSquare, |
|
||||||
} |
|
||||||
|
|
||||||
impl<P: ParameterLearning> CTPC<P> { |
|
||||||
pub fn new(parameter_learning: P, Ftest: F, Chi2test: ChiSquare) -> CTPC<P> { |
|
||||||
CTPC { |
|
||||||
parameter_learning, |
|
||||||
Ftest, |
|
||||||
Chi2test, |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
impl<P: ParameterLearning> StructureLearningAlgorithm for CTPC<P> { |
|
||||||
fn fit_transform<T>(&self, net: T, dataset: &Dataset) -> T |
|
||||||
where |
|
||||||
T: process::NetworkProcess, |
|
||||||
{ |
|
||||||
//Check the coherence between dataset and network
|
|
||||||
if net.get_number_of_nodes() != dataset.get_trajectories()[0].get_events().shape()[1] { |
|
||||||
panic!("Dataset and Network must have the same number of variables.") |
|
||||||
} |
|
||||||
|
|
||||||
//Make the network mutable.
|
|
||||||
let mut net = net; |
|
||||||
|
|
||||||
net.initialize_adj_matrix(); |
|
||||||
|
|
||||||
let mut learned_parent_sets: Vec<(usize, BTreeSet<usize>)> = vec![]; |
|
||||||
learned_parent_sets.par_extend(net.get_node_indices().into_par_iter().map(|child_node| { |
|
||||||
let mut cache = Cache::new(&self.parameter_learning); |
|
||||||
let mut candidate_parent_set: BTreeSet<usize> = net |
|
||||||
.get_node_indices() |
|
||||||
.into_iter() |
|
||||||
.filter(|x| x != &child_node) |
|
||||||
.collect(); |
|
||||||
let mut separation_set_size = 0; |
|
||||||
while separation_set_size < candidate_parent_set.len() { |
|
||||||
let mut candidate_parent_set_TMP = candidate_parent_set.clone(); |
|
||||||
for parent_node in candidate_parent_set.iter() { |
|
||||||
for separation_set in candidate_parent_set |
|
||||||
.iter() |
|
||||||
.filter(|x| x != &parent_node) |
|
||||||
.map(|x| *x) |
|
||||||
.combinations(separation_set_size) |
|
||||||
{ |
|
||||||
let separation_set = separation_set.into_iter().collect(); |
|
||||||
if self.Ftest.call( |
|
||||||
&net, |
|
||||||
child_node, |
|
||||||
*parent_node, |
|
||||||
&separation_set, |
|
||||||
dataset, |
|
||||||
&mut cache, |
|
||||||
) && self.Chi2test.call( |
|
||||||
&net, |
|
||||||
child_node, |
|
||||||
*parent_node, |
|
||||||
&separation_set, |
|
||||||
dataset, |
|
||||||
&mut cache, |
|
||||||
) { |
|
||||||
candidate_parent_set_TMP.remove(parent_node); |
|
||||||
break; |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
candidate_parent_set = candidate_parent_set_TMP; |
|
||||||
separation_set_size += 1; |
|
||||||
} |
|
||||||
(child_node, candidate_parent_set) |
|
||||||
})); |
|
||||||
for (child_node, candidate_parent_set) in learned_parent_sets { |
|
||||||
for parent_node in candidate_parent_set.iter() { |
|
||||||
net.add_edge(*parent_node, child_node); |
|
||||||
} |
|
||||||
} |
|
||||||
net |
|
||||||
} |
|
||||||
} |
|
||||||
|
@ -1,127 +0,0 @@ |
|||||||
mod utils; |
|
||||||
|
|
||||||
use std::collections::BTreeSet; |
|
||||||
|
|
||||||
use reCTBN::{ |
|
||||||
params, |
|
||||||
params::ParamsTrait, |
|
||||||
process::{ctmp::*, NetworkProcess}, |
|
||||||
}; |
|
||||||
use utils::generate_discrete_time_continous_node; |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn define_simple_ctmp() { |
|
||||||
let _ = CtmpProcess::new(); |
|
||||||
assert!(true); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn add_node_to_ctmp() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn add_two_nodes_to_ctmp() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let _n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
|
||||||
|
|
||||||
match n2 { |
|
||||||
Ok(_) => assert!(false), |
|
||||||
Err(_) => assert!(true), |
|
||||||
}; |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
#[should_panic] |
|
||||||
fn add_edge_to_ctmp() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let _n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)); |
|
||||||
|
|
||||||
net.add_edge(0, 1) |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn childen_and_parents() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let _n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
assert_eq!(0, net.get_parent_set(0).len()); |
|
||||||
assert_eq!(0, net.get_children_set(0).len()); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
#[should_panic] |
|
||||||
fn get_childen_panic() { |
|
||||||
let net = CtmpProcess::new(); |
|
||||||
net.get_children_set(0); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
#[should_panic] |
|
||||||
fn get_childen_panic2() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let _n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
net.get_children_set(1); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
#[should_panic] |
|
||||||
fn get_parent_panic() { |
|
||||||
let net = CtmpProcess::new(); |
|
||||||
net.get_parent_set(0); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
#[should_panic] |
|
||||||
fn get_parent_panic2() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let _n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
net.get_parent_set(1); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn compute_index_ctmp() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node( |
|
||||||
String::from("n1"), |
|
||||||
10, |
|
||||||
)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
let idx = net.get_param_index_network(n1, &vec![params::StateType::Discrete(6)]); |
|
||||||
assert_eq!(6, idx); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
#[should_panic] |
|
||||||
fn compute_index_from_custom_parent_set_ctmp() { |
|
||||||
let mut net = CtmpProcess::new(); |
|
||||||
let _n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node( |
|
||||||
String::from("n1"), |
|
||||||
10, |
|
||||||
)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
let _idx = net.get_param_index_from_custom_parent_set( |
|
||||||
&vec![params::StateType::Discrete(6)], |
|
||||||
&BTreeSet::from([0]) |
|
||||||
); |
|
||||||
} |
|
@ -1,122 +0,0 @@ |
|||||||
mod utils; |
|
||||||
|
|
||||||
use approx::assert_abs_diff_eq; |
|
||||||
use ndarray::*; |
|
||||||
use reCTBN::{ |
|
||||||
params, |
|
||||||
process::{ctbn::*, NetworkProcess, NetworkProcessState}, |
|
||||||
reward::{reward_evaluation::*, reward_function::*, *}, |
|
||||||
}; |
|
||||||
use utils::generate_discrete_time_continous_node; |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn simple_factored_reward_function_binary_node_mc() { |
|
||||||
let mut net = CtbnNetwork::new(); |
|
||||||
let n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
|
||||||
rf.get_transition_reward_mut(n1) |
|
||||||
.assign(&arr2(&[[0.0, 0.0], [0.0, 0.0]])); |
|
||||||
rf.get_instantaneous_reward_mut(n1) |
|
||||||
.assign(&arr1(&[3.0, 3.0])); |
|
||||||
|
|
||||||
match &mut net.get_node_mut(n1) { |
|
||||||
params::Params::DiscreteStatesContinousTime(param) => { |
|
||||||
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])).unwrap(); |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
net.initialize_adj_matrix(); |
|
||||||
|
|
||||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
|
||||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
|
||||||
|
|
||||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
|
||||||
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
|
||||||
assert_abs_diff_eq!(3.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
|
||||||
|
|
||||||
let rst = mc.evaluate_state_space(&net, &rf); |
|
||||||
assert_abs_diff_eq!(3.0, rst[&s0], epsilon = 1e-2); |
|
||||||
assert_abs_diff_eq!(3.0, rst[&s1], epsilon = 1e-2); |
|
||||||
|
|
||||||
|
|
||||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::FiniteHorizon, Some(215)); |
|
||||||
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s0), epsilon = 1e-2); |
|
||||||
assert_abs_diff_eq!(30.0, mc.evaluate_state(&net, &rf, &s1), epsilon = 1e-2); |
|
||||||
|
|
||||||
|
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn simple_factored_reward_function_chain_mc() { |
|
||||||
let mut net = CtbnNetwork::new(); |
|
||||||
let n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
let n2 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
let n3 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
net.add_edge(n1, n2); |
|
||||||
net.add_edge(n2, n3); |
|
||||||
|
|
||||||
match &mut net.get_node_mut(n1) { |
|
||||||
params::Params::DiscreteStatesContinousTime(param) => { |
|
||||||
param.set_cim(arr3(&[[[-0.1, 0.1], [1.0, -1.0]]])).unwrap(); |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
match &mut net.get_node_mut(n2) { |
|
||||||
params::Params::DiscreteStatesContinousTime(param) => { |
|
||||||
param |
|
||||||
.set_cim(arr3(&[ |
|
||||||
[[-0.01, 0.01], [5.0, -5.0]], |
|
||||||
[[-5.0, 5.0], [0.01, -0.01]], |
|
||||||
])) |
|
||||||
.unwrap(); |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
match &mut net.get_node_mut(n3) { |
|
||||||
params::Params::DiscreteStatesContinousTime(param) => { |
|
||||||
param |
|
||||||
.set_cim(arr3(&[ |
|
||||||
[[-0.01, 0.01], [5.0, -5.0]], |
|
||||||
[[-5.0, 5.0], [0.01, -0.01]], |
|
||||||
])) |
|
||||||
.unwrap(); |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
|
||||||
rf.get_transition_reward_mut(n1) |
|
||||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
|
||||||
|
|
||||||
rf.get_transition_reward_mut(n2) |
|
||||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
|
||||||
|
|
||||||
rf.get_transition_reward_mut(n3) |
|
||||||
.assign(&arr2(&[[0.0, 1.0], [1.0, 0.0]])); |
|
||||||
|
|
||||||
let s000: NetworkProcessState = vec![ |
|
||||||
params::StateType::Discrete(1), |
|
||||||
params::StateType::Discrete(0), |
|
||||||
params::StateType::Discrete(0), |
|
||||||
]; |
|
||||||
|
|
||||||
let mc = MonteCarloReward::new(10000, 1e-1, 1e-1, 10.0, RewardCriteria::InfiniteHorizon { discount_factor: 1.0 }, Some(215)); |
|
||||||
assert_abs_diff_eq!(2.447, mc.evaluate_state(&net, &rf, &s000), epsilon = 1e-1); |
|
||||||
|
|
||||||
let rst = mc.evaluate_state_space(&net, &rf); |
|
||||||
assert_abs_diff_eq!(2.447, rst[&s000], epsilon = 1e-1); |
|
||||||
|
|
||||||
} |
|
@ -1,117 +0,0 @@ |
|||||||
mod utils; |
|
||||||
|
|
||||||
use ndarray::*; |
|
||||||
use utils::generate_discrete_time_continous_node; |
|
||||||
use reCTBN::{process::{NetworkProcess, ctbn::*, NetworkProcessState}, reward::{*, reward_function::*}, params}; |
|
||||||
|
|
||||||
|
|
||||||
#[test] |
|
||||||
fn simple_factored_reward_function_binary_node() { |
|
||||||
let mut net = CtbnNetwork::new(); |
|
||||||
let n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
|
||||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
|
||||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0])); |
|
||||||
|
|
||||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
|
||||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
|
||||||
assert_eq!(rf.call(&s0, None), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
|
||||||
assert_eq!(rf.call(&s1, None), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
|
||||||
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
|
||||||
|
|
||||||
assert_eq!(rf.call(&s0, Some(&s0)), Reward{transition_reward: 0.0, instantaneous_reward: 3.0}); |
|
||||||
assert_eq!(rf.call(&s1, Some(&s1)), Reward{transition_reward: 0.0, instantaneous_reward: 5.0}); |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
#[test] |
|
||||||
fn simple_factored_reward_function_ternary_node() { |
|
||||||
let mut net = CtbnNetwork::new(); |
|
||||||
let n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
|
||||||
.unwrap(); |
|
||||||
|
|
||||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
|
||||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
|
||||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
|
||||||
|
|
||||||
let s0: NetworkProcessState = vec![params::StateType::Discrete(0)]; |
|
||||||
let s1: NetworkProcessState = vec![params::StateType::Discrete(1)]; |
|
||||||
let s2: NetworkProcessState = vec![params::StateType::Discrete(2)]; |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s0, Some(&s1)), Reward{transition_reward: 2.0, instantaneous_reward: 3.0}); |
|
||||||
assert_eq!(rf.call(&s0, Some(&s2)), Reward{transition_reward: 5.0, instantaneous_reward: 3.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s1, Some(&s0)), Reward{transition_reward: 1.0, instantaneous_reward: 5.0}); |
|
||||||
assert_eq!(rf.call(&s1, Some(&s2)), Reward{transition_reward: 6.0, instantaneous_reward: 5.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s2, Some(&s0)), Reward{transition_reward: 3.0, instantaneous_reward: 9.0}); |
|
||||||
assert_eq!(rf.call(&s2, Some(&s1)), Reward{transition_reward: 4.0, instantaneous_reward: 9.0}); |
|
||||||
} |
|
||||||
|
|
||||||
#[test] |
|
||||||
fn factored_reward_function_two_nodes() { |
|
||||||
let mut net = CtbnNetwork::new(); |
|
||||||
let n1 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) |
|
||||||
.unwrap(); |
|
||||||
let n2 = net |
|
||||||
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) |
|
||||||
.unwrap(); |
|
||||||
net.add_edge(n1, n2); |
|
||||||
|
|
||||||
|
|
||||||
let mut rf = FactoredRewardFunction::initialize_from_network_process(&net); |
|
||||||
rf.get_transition_reward_mut(n1).assign(&arr2(&[[0.0, 1.0, 3.0],[2.0,0.0, 4.0], [5.0, 6.0, 0.0]])); |
|
||||||
rf.get_instantaneous_reward_mut(n1).assign(&arr1(&[3.0,5.0, 9.0])); |
|
||||||
|
|
||||||
|
|
||||||
rf.get_transition_reward_mut(n2).assign(&arr2(&[[12.0, 1.0],[2.0,12.0]])); |
|
||||||
rf.get_instantaneous_reward_mut(n2).assign(&arr1(&[3.0,5.0])); |
|
||||||
let s00: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(0)]; |
|
||||||
let s01: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(0)]; |
|
||||||
let s02: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(0)]; |
|
||||||
|
|
||||||
|
|
||||||
let s10: NetworkProcessState = vec![params::StateType::Discrete(0), params::StateType::Discrete(1)]; |
|
||||||
let s11: NetworkProcessState = vec![params::StateType::Discrete(1), params::StateType::Discrete(1)]; |
|
||||||
let s12: NetworkProcessState = vec![params::StateType::Discrete(2), params::StateType::Discrete(1)]; |
|
||||||
|
|
||||||
assert_eq!(rf.call(&s00, Some(&s01)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
|
||||||
assert_eq!(rf.call(&s00, Some(&s02)), Reward{transition_reward: 5.0, instantaneous_reward: 6.0}); |
|
||||||
assert_eq!(rf.call(&s00, Some(&s10)), Reward{transition_reward: 2.0, instantaneous_reward: 6.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s01, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
|
||||||
assert_eq!(rf.call(&s01, Some(&s02)), Reward{transition_reward: 6.0, instantaneous_reward: 8.0}); |
|
||||||
assert_eq!(rf.call(&s01, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s02, Some(&s00)), Reward{transition_reward: 3.0, instantaneous_reward: 12.0}); |
|
||||||
assert_eq!(rf.call(&s02, Some(&s01)), Reward{transition_reward: 4.0, instantaneous_reward: 12.0}); |
|
||||||
assert_eq!(rf.call(&s02, Some(&s12)), Reward{transition_reward: 2.0, instantaneous_reward: 12.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s10, Some(&s11)), Reward{transition_reward: 2.0, instantaneous_reward: 8.0}); |
|
||||||
assert_eq!(rf.call(&s10, Some(&s12)), Reward{transition_reward: 5.0, instantaneous_reward: 8.0}); |
|
||||||
assert_eq!(rf.call(&s10, Some(&s00)), Reward{transition_reward: 1.0, instantaneous_reward: 8.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s11, Some(&s10)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
|
||||||
assert_eq!(rf.call(&s11, Some(&s12)), Reward{transition_reward: 6.0, instantaneous_reward: 10.0}); |
|
||||||
assert_eq!(rf.call(&s11, Some(&s01)), Reward{transition_reward: 1.0, instantaneous_reward: 10.0}); |
|
||||||
|
|
||||||
|
|
||||||
assert_eq!(rf.call(&s12, Some(&s10)), Reward{transition_reward: 3.0, instantaneous_reward: 14.0}); |
|
||||||
assert_eq!(rf.call(&s12, Some(&s11)), Reward{transition_reward: 4.0, instantaneous_reward: 14.0}); |
|
||||||
assert_eq!(rf.call(&s12, Some(&s02)), Reward{transition_reward: 1.0, instantaneous_reward: 14.0}); |
|
||||||
} |
|
@ -0,0 +1,69 @@ |
|||||||
|
name: CI |
||||||
|
|
||||||
|
on: |
||||||
|
push: |
||||||
|
branches: |
||||||
|
- main |
||||||
|
- master |
||||||
|
pull_request: |
||||||
|
|
||||||
|
jobs: |
||||||
|
linux: |
||||||
|
runs-on: ubuntu-latest |
||||||
|
steps: |
||||||
|
- uses: actions/checkout@v3 |
||||||
|
- uses: messense/maturin-action@v1 |
||||||
|
with: |
||||||
|
manylinux: auto |
||||||
|
command: build |
||||||
|
args: --release --sdist -o dist --find-interpreter |
||||||
|
- name: Upload wheels |
||||||
|
uses: actions/upload-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
path: dist |
||||||
|
|
||||||
|
windows: |
||||||
|
runs-on: windows-latest |
||||||
|
steps: |
||||||
|
- uses: actions/checkout@v3 |
||||||
|
- uses: messense/maturin-action@v1 |
||||||
|
with: |
||||||
|
command: build |
||||||
|
args: --release -o dist --find-interpreter |
||||||
|
- name: Upload wheels |
||||||
|
uses: actions/upload-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
path: dist |
||||||
|
|
||||||
|
macos: |
||||||
|
runs-on: macos-latest |
||||||
|
steps: |
||||||
|
- uses: actions/checkout@v3 |
||||||
|
- uses: messense/maturin-action@v1 |
||||||
|
with: |
||||||
|
command: build |
||||||
|
args: --release -o dist --universal2 --find-interpreter |
||||||
|
- name: Upload wheels |
||||||
|
uses: actions/upload-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
path: dist |
||||||
|
|
||||||
|
release: |
||||||
|
name: Release |
||||||
|
runs-on: ubuntu-latest |
||||||
|
if: "startsWith(github.ref, 'refs/tags/')" |
||||||
|
needs: [ macos, windows, linux ] |
||||||
|
steps: |
||||||
|
- uses: actions/download-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
- name: Publish to PyPI |
||||||
|
uses: messense/maturin-action@v1 |
||||||
|
env: |
||||||
|
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} |
||||||
|
with: |
||||||
|
command: upload |
||||||
|
args: --skip-existing * |
@ -0,0 +1,72 @@ |
|||||||
|
/target |
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files |
||||||
|
__pycache__/ |
||||||
|
.pytest_cache/ |
||||||
|
*.py[cod] |
||||||
|
|
||||||
|
# C extensions |
||||||
|
*.so |
||||||
|
|
||||||
|
# Distribution / packaging |
||||||
|
.Python |
||||||
|
.venv/ |
||||||
|
env/ |
||||||
|
bin/ |
||||||
|
build/ |
||||||
|
develop-eggs/ |
||||||
|
dist/ |
||||||
|
eggs/ |
||||||
|
lib/ |
||||||
|
lib64/ |
||||||
|
parts/ |
||||||
|
sdist/ |
||||||
|
var/ |
||||||
|
include/ |
||||||
|
man/ |
||||||
|
venv/ |
||||||
|
*.egg-info/ |
||||||
|
.installed.cfg |
||||||
|
*.egg |
||||||
|
|
||||||
|
# Installer logs |
||||||
|
pip-log.txt |
||||||
|
pip-delete-this-directory.txt |
||||||
|
pip-selfcheck.json |
||||||
|
|
||||||
|
# Unit test / coverage reports |
||||||
|
htmlcov/ |
||||||
|
.tox/ |
||||||
|
.coverage |
||||||
|
.cache |
||||||
|
nosetests.xml |
||||||
|
coverage.xml |
||||||
|
|
||||||
|
# Translations |
||||||
|
*.mo |
||||||
|
|
||||||
|
# Mr Developer |
||||||
|
.mr.developer.cfg |
||||||
|
.project |
||||||
|
.pydevproject |
||||||
|
|
||||||
|
# Rope |
||||||
|
.ropeproject |
||||||
|
|
||||||
|
# Django stuff: |
||||||
|
*.log |
||||||
|
*.pot |
||||||
|
|
||||||
|
.DS_Store |
||||||
|
|
||||||
|
# Sphinx documentation |
||||||
|
docs/_build/ |
||||||
|
|
||||||
|
# PyCharm |
||||||
|
.idea/ |
||||||
|
|
||||||
|
# VSCode |
||||||
|
.vscode/ |
||||||
|
|
||||||
|
# Pyenv |
||||||
|
.python-version |
@ -0,0 +1,15 @@ |
|||||||
|
[package] |
||||||
|
name = "reCTBNpy" |
||||||
|
version = "0.1.0" |
||||||
|
edition = "2021" |
||||||
|
|
||||||
|
|
||||||
|
[lib] |
||||||
|
crate-type = ["cdylib"] |
||||||
|
|
||||||
|
[workspace] |
||||||
|
|
||||||
|
[dependencies] |
||||||
|
pyo3 = { version = "0.17.1", features = ["extension-module"] } |
||||||
|
numpy = "*" |
||||||
|
reCTBN = { path="../reCTBN" } |
@ -0,0 +1,14 @@ |
|||||||
|
[build-system] |
||||||
|
requires = ["maturin>=0.13,<0.14"] |
||||||
|
build-backend = "maturin" |
||||||
|
|
||||||
|
[project] |
||||||
|
name = "reCTBNpy" |
||||||
|
requires-python = ">=3.7" |
||||||
|
classifiers = [ |
||||||
|
"Programming Language :: Rust", |
||||||
|
"Programming Language :: Python :: Implementation :: CPython", |
||||||
|
"Programming Language :: Python :: Implementation :: PyPy", |
||||||
|
] |
||||||
|
|
||||||
|
|
@ -0,0 +1,21 @@ |
|||||||
|
use pyo3::prelude::*; |
||||||
|
pub mod pyctbn; |
||||||
|
pub mod pyparams; |
||||||
|
pub mod pytools; |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/// A Python module implemented in Rust.
|
||||||
|
#[pymodule] |
||||||
|
fn reCTBNpy(py: Python, m: &PyModule) -> PyResult<()> { |
||||||
|
let network_module = PyModule::new(py, "network")?; |
||||||
|
network_module.add_class::<pyctbn::PyCtbnNetwork>()?; |
||||||
|
m.add_submodule(network_module)?; |
||||||
|
|
||||||
|
let params_module = PyModule::new(py, "params")?; |
||||||
|
params_module.add_class::<pyparams::PyDiscreteStateContinousTime>()?; |
||||||
|
params_module.add_class::<pyparams::PyStateType>()?; |
||||||
|
params_module.add_class::<pyparams::PyParams>()?; |
||||||
|
m.add_submodule(params_module)?; |
||||||
|
Ok(()) |
||||||
|
} |
@ -0,0 +1,68 @@ |
|||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use crate::{pyparams, pytools}; |
||||||
|
use pyo3::prelude::*; |
||||||
|
use reCTBN::{ctbn, network::Network, params, tools, params::Params}; |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyCtbnNetwork { |
||||||
|
#[new] |
||||||
|
pub fn new() -> Self { |
||||||
|
PyCtbnNetwork(ctbn::CtbnNetwork::new()) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn add_node(&mut self, n: pyparams::PyParams) { |
||||||
|
self.0.add_node(n.0); |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_number_of_nodes(&self) -> usize { |
||||||
|
self.0.get_number_of_nodes() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn add_edge(&mut self, parent: usize, child: usize) { |
||||||
|
self.0.add_edge(parent, child); |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_node_indices(&self) -> BTreeSet<usize> { |
||||||
|
self.0.get_node_indices().collect() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.0.get_parent_set(node) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.0.get_children_set(node) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) { |
||||||
|
match &n.0 { |
||||||
|
Params::DiscreteStatesContinousTime(new_p) => { |
||||||
|
if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){ |
||||||
|
p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap(); |
||||||
|
|
||||||
|
} |
||||||
|
else { |
||||||
|
panic!("Node type mismatch") |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn trajectory_generator( |
||||||
|
&self, |
||||||
|
n_trajectories: u64, |
||||||
|
t_end: f64, |
||||||
|
seed: Option<u64>, |
||||||
|
) -> pytools::PyDataset { |
||||||
|
pytools::PyDataset(tools::trajectory_generator( |
||||||
|
&self.0, |
||||||
|
n_trajectories, |
||||||
|
t_end, |
||||||
|
seed, |
||||||
|
)) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,102 @@ |
|||||||
|
use numpy::{self, ToPyArray}; |
||||||
|
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||||
|
use reCTBN::params::{self, ParamsTrait}; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
pub struct PyParamsError(params::ParamsError); |
||||||
|
|
||||||
|
impl From<PyParamsError> for PyErr { |
||||||
|
fn from(error: PyParamsError) -> Self { |
||||||
|
PyValueError::new_err(error.0.to_string()) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl From<params::ParamsError> for PyParamsError { |
||||||
|
fn from(other: params::ParamsError) -> Self { |
||||||
|
Self(other) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
pub struct PyStateType(pub params::StateType); |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
#[derive(Clone)] |
||||||
|
pub struct PyParams(pub params::Params); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyParams { |
||||||
|
#[staticmethod] |
||||||
|
pub fn new_discrete_state_continous_time(p: PyDiscreteStateContinousTime) -> Self{ |
||||||
|
PyParams(params::Params::DiscreteStatesContinousTime(p.0)) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_reserved_space_as_parent(&self) -> usize { |
||||||
|
self.0.get_reserved_space_as_parent() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_label(&self) -> String { |
||||||
|
self.0.get_label().to_string() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/// DiscreteStatesContinousTime.
|
||||||
|
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
|
||||||
|
/// following elements:
|
||||||
|
/// - **domain**: an ordered and exhaustive set of possible states
|
||||||
|
/// - **cim**: Conditional Intensity Matrix
|
||||||
|
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
|
||||||
|
/// learning task and are composed by:
|
||||||
|
/// - **transitions**: number of transitions from one state to another given a specific
|
||||||
|
/// realization of the parent set
|
||||||
|
/// - **residence_time**: permanence time in each possible states given a specific
|
||||||
|
/// realization of the parent set
|
||||||
|
#[derive(Clone)] |
||||||
|
#[pyclass] |
||||||
|
pub struct PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams); |
||||||
|
|
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyDiscreteStateContinousTime { |
||||||
|
#[new] |
||||||
|
pub fn new(label: String, domain: BTreeSet<String>) -> Self { |
||||||
|
PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams::new(label, domain)) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_cim<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<f64>> { |
||||||
|
match self.0.get_cim() { |
||||||
|
Some(x) => Some(x.to_pyarray(py)), |
||||||
|
None => None, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_cim<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<f64>) -> Result<(), PyParamsError> { |
||||||
|
self.0.set_cim(cim.as_array().to_owned())?; |
||||||
|
Ok(()) |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
pub fn get_transitions<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<usize>> { |
||||||
|
match self.0.get_transitions() { |
||||||
|
Some(x) => Some(x.to_pyarray(py)), |
||||||
|
None => None, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_transitions<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<usize>){ |
||||||
|
self.0.set_transitions(cim.as_array().to_owned()); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
pub fn get_residence_time<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray2<f64>> { |
||||||
|
match self.0.get_residence_time() { |
||||||
|
Some(x) => Some(x.to_pyarray(py)), |
||||||
|
None => None, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_residence_time<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray2<f64>) { |
||||||
|
self.0.set_residence_time(cim.as_array().to_owned()); |
||||||
|
} |
||||||
|
|
||||||
|
} |
@ -0,0 +1,50 @@ |
|||||||
|
use numpy::{self, ToPyArray}; |
||||||
|
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||||
|
use reCTBN::{tools, network}; |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
#[derive(Clone)] |
||||||
|
pub struct PyTrajectory(pub tools::Trajectory); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyTrajectory { |
||||||
|
#[new] |
||||||
|
pub fn new( |
||||||
|
time: numpy::PyReadonlyArray1<f64>, |
||||||
|
events: numpy::PyReadonlyArray2<usize>, |
||||||
|
) -> PyTrajectory { |
||||||
|
PyTrajectory(tools::Trajectory::new( |
||||||
|
time.as_array().to_owned(), |
||||||
|
events.as_array().to_owned(), |
||||||
|
)) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1<f64> { |
||||||
|
self.0.get_time().to_pyarray(py) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2<usize> { |
||||||
|
self.0.get_events().to_pyarray(py) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
pub struct PyDataset(pub tools::Dataset); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyDataset { |
||||||
|
#[new] |
||||||
|
pub fn new(trajectories: Vec<PyTrajectory>) -> PyDataset { |
||||||
|
PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect())) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_number_of_trajectories(&self) -> usize { |
||||||
|
self.0.get_trajectories().len() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_trajectory(&self, idx: usize) -> PyTrajectory { |
||||||
|
PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone()) |
||||||
|
} |
||||||
|
|
||||||
|
} |
||||||
|
|
Loading…
Reference in new issue