|
|
|
@ -5,16 +5,18 @@ use ndarray::prelude::*; |
|
|
|
|
use crate::network; |
|
|
|
|
use crate::params::{Params, ParamsTrait, StateType}; |
|
|
|
|
|
|
|
|
|
///CTBN network. It represents both the structure and the parameters of a CTBN. CtbnNetwork is
|
|
|
|
|
///composed by the following elements:
|
|
|
|
|
///- **adj_metrix**: a 2d ndarray representing the adjacency matrix
|
|
|
|
|
///- **nodes**: a vector containing all the nodes and their parameters.
|
|
|
|
|
///The index of a node inside the vector is also used as index for the adj_matrix.
|
|
|
|
|
/// It represents both the structure and the parameters of a CTBN.
|
|
|
|
|
///
|
|
|
|
|
///# Examples
|
|
|
|
|
/// # Arguments
|
|
|
|
|
///
|
|
|
|
|
///```
|
|
|
|
|
/// * `adj_matrix` - A 2D ndarray representing the adjacency matrix
|
|
|
|
|
/// * `nodes` - A vector containing all the nodes and their parameters.
|
|
|
|
|
///
|
|
|
|
|
/// The index of a node inside the vector is also used as index for the `adj_matrix`.
|
|
|
|
|
///
|
|
|
|
|
/// # Example
|
|
|
|
|
///
|
|
|
|
|
/// ```rust
|
|
|
|
|
/// use std::collections::BTreeSet;
|
|
|
|
|
/// use reCTBN::network::Network;
|
|
|
|
|
/// use reCTBN::params;
|
|
|
|
@ -66,12 +68,14 @@ impl CtbnNetwork { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl network::Network for CtbnNetwork { |
|
|
|
|
/// Initialize an Adjacency matrix.
|
|
|
|
|
fn initialize_adj_matrix(&mut self) { |
|
|
|
|
self.adj_matrix = Some(Array2::<u16>::zeros( |
|
|
|
|
(self.nodes.len(), self.nodes.len()).f(), |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Add a new node.
|
|
|
|
|
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> { |
|
|
|
|
n.reset_params(); |
|
|
|
|
self.adj_matrix = Option::None; |
|
|
|
@ -79,6 +83,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
|
Ok(self.nodes.len() - 1) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Connect two nodes with a new edge.
|
|
|
|
|
fn add_edge(&mut self, parent: usize, child: usize) { |
|
|
|
|
if let None = self.adj_matrix { |
|
|
|
|
self.initialize_adj_matrix(); |
|
|
|
@ -94,6 +99,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
|
0..self.nodes.len() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Get the number of nodes of the network.
|
|
|
|
|
fn get_number_of_nodes(&self) -> usize { |
|
|
|
|
self.nodes.len() |
|
|
|
|
} |
|
|
|
@ -138,6 +144,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
|
.0 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Get all the parents of the given node.
|
|
|
|
|
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
|
|
|
|
self.adj_matrix |
|
|
|
|
.as_ref() |
|
|
|
@ -149,6 +156,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
|
.collect() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Get all the children of the given node.
|
|
|
|
|
fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
|
|
|
|
self.adj_matrix |
|
|
|
|
.as_ref() |
|
|
|
|