From 9d98c7a6ee91d0b76915f58c3305895f801fa252 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Fri, 18 Feb 2022 17:22:40 +0100 Subject: [PATCH] Partial commit --- src/ctbn.rs | 15 ++++++++++++++- src/network.rs | 3 ++- src/tools.rs | 8 ++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/ctbn.rs b/src/ctbn.rs index 8d8dd4c..6e5f0b4 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -32,7 +32,7 @@ impl network::Network for CtbnNetwork { self.network.node_weight(node_idx.clone()).unwrap() } - fn get_param_index(&self, node: &petgraph::stable_graph::NodeIndex, u: Vec) -> usize{ + fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec) -> usize{ self.network.neighbors_directed(node.clone(), Direction::Incoming).zip(u).fold((0, 1), |mut acc, x| { let n = self.get_node(node); acc.0 += n.state_to_index(&x.1) * acc.1; @@ -42,4 +42,17 @@ impl network::Network for CtbnNetwork { }).0 } + + fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec) -> usize{ + self.get_param_index_parents(node, ¤t_state.iter() + .zip(self.get_node_indices()) + .filter_map(|x| { + match self.network.find_edge(x.1, node.clone()) { + Some(_) => Some(x.0.clone()), + None => None + } + }).collect() + ) + } + } diff --git a/src/network.rs b/src/network.rs index b7a03d8..cd1f8a9 100644 --- a/src/network.rs +++ b/src/network.rs @@ -13,5 +13,6 @@ pub trait Network { fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex); fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices; fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node; - fn get_param_index(&self, node: &petgraph::stable_graph::NodeIndex, u: Vec) -> usize; + fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec) -> usize; + fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec) -> usize; } diff --git a/src/tools.rs b/src/tools.rs index c56bf1c..6258a19 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -35,6 +35,14 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 events.push(current_state.clone()); time.push(t.clone()); while t < t_end { + next_transitions.iter_mut().zip(net.get_node_indices()).map(|x| { + if let None = x.0 { + *(x.0) = Some(match net.get_node(&x.1).get_params(){ + node::NodeType::DiscreteStatesContinousTime(params) => + params.get_random_residence_time(x.1, net.get_param_index_network(&x.1, ¤t_state)).unwrap() + }); + }}); + }