Added one test for tool

pull/19/head
AlessandroBregoli 3 years ago
parent b1978b70dd
commit 1aaa252653
  1. 5
      src/ctbn.rs
  2. 1
      src/network.rs
  3. 8
      src/params.rs
  4. 54
      src/tools.rs

@ -102,6 +102,11 @@ impl network::Network for CtbnNetwork {
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
if x.1 > &0 {

@ -20,6 +20,7 @@ pub trait Network {
///Get all the indices of the nodes contained inside the network
fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_node(&self, node_idx: usize) -> &node::Node;
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node;
///Compute the index that must be used to access the parameter of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are

@ -65,10 +65,10 @@ pub enum Params {
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
pub struct DiscreteStatesContinousTimeParams {
domain: BTreeSet<String>,
cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>,
residence_time: Option<Array2<f64>>,
pub domain: BTreeSet<String>,
pub cim: Option<Array3<f64>>,
pub transitions: Option<Array3<u64>>,
pub residence_time: Option<Array2<f64>>,
}
impl DiscreteStatesContinousTimeParams {

@ -14,7 +14,7 @@ pub struct Dataset {
}
pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
pub fn trajectory_generator(net: Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
let mut dataset = Dataset{
trajectories: Vec::new()
};
@ -78,7 +78,7 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
events.push(current_state.iter().map(|x| match x {
params::StateType::Discrete(state) => state.clone()
}).collect());
time.push(t.clone());
time.push(t_end.clone());
dataset.trajectories.push(Trajectory {
@ -91,3 +91,53 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
dataset
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::Network;
use crate::ctbn::*;
use crate::node;
use crate::params;
use std::collections::BTreeSet;
use ndarray::arr3;
fn define_binary_node(name: String) -> node::Node {
let mut domain = BTreeSet::new();
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let param = params::DiscreteStatesContinousTimeParams::init(domain) ;
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name);
return n;
}
#[test]
fn run_sampling() {
let mut net = CtbnNetwork::init();
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap();
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[
[[-1.0,1.0],[4.0,-4.0]],
[[-6.0,6.0],[2.0,-2.0]]]));
}
}
let data = trajectory_generator(Box::from(net), 4, 1.0);
assert_eq!(4, data.trajectories.len());
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]);
}
}

Loading…
Cancel
Save