From 1aaa252653d4394f0021eb72c61629fd9755cc37 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Wed, 2 Mar 2022 15:25:00 +0100 Subject: [PATCH] Added one test for tool --- src/ctbn.rs | 5 +++++ src/network.rs | 1 + src/params.rs | 8 ++++---- src/tools.rs | 54 ++++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/src/ctbn.rs b/src/ctbn.rs index b57180d..9153cb5 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -102,6 +102,11 @@ impl network::Network for CtbnNetwork { } + fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{ + &mut self.nodes[node_idx] + } + + fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { if x.1 > &0 { diff --git a/src/network.rs b/src/network.rs index 17cc62b..4fb7e63 100644 --- a/src/network.rs +++ b/src/network.rs @@ -20,6 +20,7 @@ pub trait Network { ///Get all the indices of the nodes contained inside the network fn get_node_indices(&self) -> std::ops::Range; fn get_node(&self, node_idx: usize) -> &node::Node; + fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node; ///Compute the index that must be used to access the parameter of a node given a specific ///configuration of the network. Usually, the only values really used in *current_state* are diff --git a/src/params.rs b/src/params.rs index 964351d..1c7c7af 100644 --- a/src/params.rs +++ b/src/params.rs @@ -65,10 +65,10 @@ pub enum Params { /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set pub struct DiscreteStatesContinousTimeParams { - domain: BTreeSet, - cim: Option>, - transitions: Option>, - residence_time: Option>, + pub domain: BTreeSet, + pub cim: Option>, + pub transitions: Option>, + pub residence_time: Option>, } impl DiscreteStatesContinousTimeParams { diff --git a/src/tools.rs b/src/tools.rs index 349f88e..eaf2ef6 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -14,7 +14,7 @@ pub struct Dataset { } -pub fn trajectory_generator(net: &Box, n_trajectories: u64, t_end: f64) -> Dataset { +pub fn trajectory_generator(net: Box, n_trajectories: u64, t_end: f64) -> Dataset { let mut dataset = Dataset{ trajectories: Vec::new() }; @@ -78,7 +78,7 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 events.push(current_state.iter().map(|x| match x { params::StateType::Discrete(state) => state.clone() }).collect()); - time.push(t.clone()); + time.push(t_end.clone()); dataset.trajectories.push(Trajectory { @@ -91,3 +91,53 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 dataset } + + +#[cfg(test)] +mod tests { + use super::*; + use crate::network::Network; + use crate::ctbn::*; + use crate::node; + use crate::params; + use std::collections::BTreeSet; + use ndarray::arr3; + + fn define_binary_node(name: String) -> node::Node { + let mut domain = BTreeSet::new(); + domain.insert(String::from("A")); + domain.insert(String::from("B")); + let param = params::DiscreteStatesContinousTimeParams::init(domain) ; + let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name); + return n; + } + + + #[test] + fn run_sampling() { + let mut net = CtbnNetwork::init(); + let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap(); + let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); + net.add_edge(n1, n2); + + match &mut net.get_node_mut(n1).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); + } + } + + + match &mut net.get_node_mut(n2).params { + params::Params::DiscreteStatesContinousTime(param) => { + param.cim = Some (arr3(&[ + [[-1.0,1.0],[4.0,-4.0]], + [[-6.0,6.0],[2.0,-2.0]]])); + } + } + + let data = trajectory_generator(Box::from(net), 4, 1.0); + + assert_eq!(4, data.trajectories.len()); + assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); + } +}