Sampling almost done

main
Alessandro Bregoli 3 years ago
parent 9d98c7a6ee
commit d8dc28facb
  1. 1
      Cargo.toml
  2. 89
      src/ctbn.rs
  3. 17
      src/network.rs
  4. 24
      src/node.rs
  5. 1
      src/params.rs
  6. 53
      src/tools.rs

@ -7,7 +7,6 @@ edition = "2021"
[dependencies] [dependencies]
petgraph = "*"
ndarray = "*" ndarray = "*"
thiserror = "*" thiserror = "*"
rand = "*" rand = "*"

@ -1,6 +1,5 @@
use std::collections::{HashMap, BTreeSet}; use std::collections::{HashMap, BTreeSet};
use petgraph::prelude::*; use ndarray::prelude::*;
use crate::node; use crate::node;
use crate::params::StateType; use crate::params::StateType;
use crate::network; use crate::network;
@ -9,50 +8,80 @@ use crate::network;
pub struct CtbnNetwork { pub struct CtbnNetwork {
network: petgraph::stable_graph::StableGraph<node::Node, ()>, adj_matrix: Option<Array2<u16>>,
nodes: Vec<node::Node>
} }
impl network::Network for CtbnNetwork { impl network::Network for CtbnNetwork {
fn add_node(&mut self, mut n: node::Node) -> Result<petgraph::graph::NodeIndex, network::NetworkError> { fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f()));
}
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> {
n.reset_params(); n.reset_params();
Ok(self.network.add_node(n)) self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() -1)
} }
fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex) { fn add_edge(&mut self, parent: usize, child: usize) {
self.network.add_edge(parent.clone(), child.clone(), {}); if let None = self.adj_matrix {
let mut p = self.network.node_weight_mut(child.clone()).unwrap(); self.initialize_adj_matrix();
p.reset_params(); }
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
} }
fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>{ fn get_node_indices(&self) -> std::ops::Range<usize>{
self.network.node_indices() 0..self.nodes.len()
} }
fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{ fn get_node(&self, node_idx: usize) -> &node::Node{
self.network.node_weight(node_idx.clone()).unwrap() &self.nodes[node_idx]
} }
fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec<StateType>) -> usize{
self.network.neighbors_directed(node.clone(), Direction::Incoming).zip(u).fold((0, 1), |mut acc, x| {
let n = self.get_node(node);
acc.0 += n.state_to_index(&x.1) * acc.1;
acc.1 *= n.get_reserved_space_as_parent();
acc
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
}).0 }).0
} }
fn get_parent_set(&self, node: usize) -> Vec<usize> {
fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec<StateType>) -> usize{ self.adj_matrix.as_ref()
self.get_param_index_parents(node, &current_state.iter() .unwrap()
.zip(self.get_node_indices()) .column(node)
.filter_map(|x| { .iter()
match self.network.find_edge(x.1, node.clone()) { .enumerate()
Some(_) => Some(x.0.clone()), .filter_map(|(idx, x)| {
None => None if x > &0 {
} Some(idx)
}).collect() } else {
) None
}
}).collect()
}
fn get_children_set(&self, node: usize) -> Vec<usize>{
self.adj_matrix.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| {
if x > &0 {
Some(idx)
} else {
None
}
}).collect()
} }
} }

@ -1,6 +1,7 @@
use petgraph::prelude::*; use crate::node;
use thiserror::Error; use thiserror::Error;
use crate::params; use crate::params;
use ndarray::prelude::*;
use crate::node;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum NetworkError { pub enum NetworkError {
@ -9,10 +10,12 @@ pub enum NetworkError {
} }
pub trait Network { pub trait Network {
fn add_node(&mut self, n: node::Node) -> Result<petgraph::graph::NodeIndex, NetworkError>; fn initialize_adj_matrix(&mut self);
fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex); fn add_node(&mut self, n: node::Node) -> Result<usize, NetworkError>;
fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>; fn add_edge(&mut self, parent: usize, child: usize);
fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node; fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec<params::StateType>) -> usize; fn get_node(&self, node_idx: usize) -> &node::Node;
fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec<params::StateType>) -> usize; fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize;
fn get_parent_set(&self, node: usize) -> Vec<usize>;
fn get_children_set(&self, node: usize) -> Vec<usize>;
} }

@ -1,4 +1,4 @@
use std::collections::BTreeSet; use petgraph::prelude::*; use std::collections::BTreeSet;
use crate::params::*; use crate::params::*;
pub enum NodeType { pub enum NodeType {
@ -33,6 +33,28 @@ impl Node {
} }
} }
pub fn get_random_residence_time(&self, state: usize, u:usize) -> Result<f64, ParamsError> {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_residence_time(state, u)
}
}
pub fn get_random_state_uniform(&self) -> StateType {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_state_uniform()
}
}
pub fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>{
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_state(state, u)
}
}
} }
impl PartialEq for Node { impl PartialEq for Node {

@ -1,6 +1,5 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use std::collections::{HashMap, BTreeSet}; use std::collections::{HashMap, BTreeSet};
use petgraph::prelude::*;
use rand::Rng; use rand::Rng;
use thiserror::Error; use thiserror::Error;

@ -1,5 +1,4 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use petgraph::prelude::*;
use crate::network; use crate::network;
use crate::node; use crate::node;
use crate::params; use crate::params;
@ -22,28 +21,54 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
let node_idx: Vec<_> = net.get_node_indices().collect(); let node_idx: Vec<_> = net.get_node_indices().collect();
for _ in 0..n_trajectories { for _ in 0..n_trajectories {
let t = 0.0; let mut t = 0.0;
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Vec<params::StateType>> = Vec::new(); let mut events: Vec<Vec<params::StateType>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| { let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
match net.get_node(&x).get_params() { net.get_node(*x).get_random_state_uniform()
node::NodeType::DiscreteStatesContinousTime(params) =>
params.get_random_state_uniform()
}
}).collect(); }).collect();
let next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect(); let mut next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect();
events.push(current_state.clone()); events.push(current_state.clone());
time.push(t.clone()); time.push(t.clone());
while t < t_end { while t < t_end {
next_transitions.iter_mut().zip(net.get_node_indices()).map(|x| { next_transitions.iter_mut().enumerate().map(|(idx, val)| {
if let None = x.0 { if let None = val {
*(x.0) = Some(match net.get_node(&x.1).get_params(){ *val = Some(net.get_node(idx)
node::NodeType::DiscreteStatesContinousTime(params) => .get_random_residence_time(net.get_node(idx).state_to_index(&current_state[idx]),
params.get_random_residence_time(x.1, net.get_param_index_network(&x.1, &current_state)).unwrap() net.get_param_index_network(idx, &current_state)).unwrap() + t);
}); }
}}); });
let next_node_transition = next_transitions
.iter()
.enumerate()
.min_by(|x, y|
x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap().0;
if next_transitions[next_node_transition].unwrap() > t_end {
break
}
t = next_transitions[next_node_transition].unwrap().clone();
time.push(t.clone());
current_state[next_node_transition] = net.get_node(next_node_transition)
.get_random_state(
net.get_node(next_node_transition).
state_to_index(
&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state))
.unwrap();
events.push(current_state.clone());
next_transitions[next_node_transition] = None;
for child in net.get_children_set(next_node_transition){
next_transitions[child] = None
}
} }

Loading…
Cancel
Save