|
|
|
@ -14,7 +14,7 @@ pub struct Dataset { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset { |
|
|
|
|
pub fn trajectory_generator(net: Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset { |
|
|
|
|
let mut dataset = Dataset{ |
|
|
|
|
trajectories: Vec::new() |
|
|
|
|
}; |
|
|
|
@ -78,7 +78,7 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64 |
|
|
|
|
events.push(current_state.iter().map(|x| match x { |
|
|
|
|
params::StateType::Discrete(state) => state.clone() |
|
|
|
|
}).collect()); |
|
|
|
|
time.push(t.clone()); |
|
|
|
|
time.push(t_end.clone()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset.trajectories.push(Trajectory { |
|
|
|
@ -91,3 +91,53 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64 |
|
|
|
|
|
|
|
|
|
dataset |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[cfg(test)] |
|
|
|
|
mod tests { |
|
|
|
|
use super::*; |
|
|
|
|
use crate::network::Network; |
|
|
|
|
use crate::ctbn::*; |
|
|
|
|
use crate::node; |
|
|
|
|
use crate::params; |
|
|
|
|
use std::collections::BTreeSet; |
|
|
|
|
use ndarray::arr3; |
|
|
|
|
|
|
|
|
|
fn define_binary_node(name: String) -> node::Node { |
|
|
|
|
let mut domain = BTreeSet::new(); |
|
|
|
|
domain.insert(String::from("A")); |
|
|
|
|
domain.insert(String::from("B")); |
|
|
|
|
let param = params::DiscreteStatesContinousTimeParams::init(domain) ; |
|
|
|
|
let n = node::Node::init(params::Params::DiscreteStatesContinousTime(param), name); |
|
|
|
|
return n; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn run_sampling() { |
|
|
|
|
let mut net = CtbnNetwork::init(); |
|
|
|
|
let n1 = net.add_node(define_binary_node(String::from("n1"))).unwrap(); |
|
|
|
|
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); |
|
|
|
|
net.add_edge(n1, n2); |
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n1).params { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
match &mut net.get_node_mut(n2).params { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
param.cim = Some (arr3(&[ |
|
|
|
|
[[-1.0,1.0],[4.0,-4.0]], |
|
|
|
|
[[-6.0,6.0],[2.0,-2.0]]])); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
let data = trajectory_generator(Box::from(net), 4, 1.0); |
|
|
|
|
|
|
|
|
|
assert_eq!(4, data.trajectories.len()); |
|
|
|
|
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|