Sampling almost done

main
Alessandro Bregoli 3 years ago
parent 9d98c7a6ee
commit d8dc28facb
  1. 1
      Cargo.toml
  2. 85
      src/ctbn.rs
  3. 17
      src/network.rs
  4. 24
      src/node.rs
  5. 1
      src/params.rs
  6. 51
      src/tools.rs

@ -7,7 +7,6 @@ edition = "2021"
[dependencies]
petgraph = "*"
ndarray = "*"
thiserror = "*"
rand = "*"

@ -1,6 +1,5 @@
use std::collections::{HashMap, BTreeSet};
use petgraph::prelude::*;
use ndarray::prelude::*;
use crate::node;
use crate::params::StateType;
use crate::network;
@ -9,50 +8,80 @@ use crate::network;
pub struct CtbnNetwork {
network: petgraph::stable_graph::StableGraph<node::Node, ()>,
adj_matrix: Option<Array2<u16>>,
nodes: Vec<node::Node>
}
impl network::Network for CtbnNetwork {
fn add_node(&mut self, mut n: node::Node) -> Result<petgraph::graph::NodeIndex, network::NetworkError> {
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f()));
}
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> {
n.reset_params();
Ok(self.network.add_node(n))
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() -1)
}
fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex) {
self.network.add_edge(parent.clone(), child.clone(), {});
let mut p = self.network.node_weight_mut(child.clone()).unwrap();
p.reset_params();
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>{
self.network.node_indices()
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
}
fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{
self.network.node_weight(node_idx.clone()).unwrap()
fn get_node_indices(&self) -> std::ops::Range<usize>{
0..self.nodes.len()
}
fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec<StateType>) -> usize{
self.network.neighbors_directed(node.clone(), Direction::Incoming).zip(u).fold((0, 1), |mut acc, x| {
let n = self.get_node(node);
acc.0 += n.state_to_index(&x.1) * acc.1;
acc.1 *= n.get_reserved_space_as_parent();
acc
fn get_node(&self, node_idx: usize) -> &node::Node{
&self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
}).0
}
fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec<StateType>) -> usize{
self.get_param_index_parents(node, &current_state.iter()
.zip(self.get_node_indices())
.filter_map(|x| {
match self.network.find_edge(x.1, node.clone()) {
Some(_) => Some(x.0.clone()),
None => None
fn get_parent_set(&self, node: usize) -> Vec<usize> {
self.adj_matrix.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| {
if x > &0 {
Some(idx)
} else {
None
}
}).collect()
}
fn get_children_set(&self, node: usize) -> Vec<usize>{
self.adj_matrix.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| {
if x > &0 {
Some(idx)
} else {
None
}
}).collect()
)
}
}

@ -1,6 +1,7 @@
use petgraph::prelude::*; use crate::node;
use thiserror::Error;
use crate::params;
use ndarray::prelude::*;
use crate::node;
#[derive(Error, Debug)]
pub enum NetworkError {
@ -9,10 +10,12 @@ pub enum NetworkError {
}
pub trait Network {
fn add_node(&mut self, n: node::Node) -> Result<petgraph::graph::NodeIndex, NetworkError>;
fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex);
fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>;
fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node;
fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec<params::StateType>) -> usize;
fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec<params::StateType>) -> usize;
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: node::Node) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize);
fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_node(&self, node_idx: usize) -> &node::Node;
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize;
fn get_parent_set(&self, node: usize) -> Vec<usize>;
fn get_children_set(&self, node: usize) -> Vec<usize>;
}

@ -1,4 +1,4 @@
use std::collections::BTreeSet; use petgraph::prelude::*;
use std::collections::BTreeSet;
use crate::params::*;
pub enum NodeType {
@ -33,6 +33,28 @@ impl Node {
}
}
pub fn get_random_residence_time(&self, state: usize, u:usize) -> Result<f64, ParamsError> {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_residence_time(state, u)
}
}
pub fn get_random_state_uniform(&self) -> StateType {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_state_uniform()
}
}
pub fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>{
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_state(state, u)
}
}
}
impl PartialEq for Node {

@ -1,6 +1,5 @@
use ndarray::prelude::*;
use std::collections::{HashMap, BTreeSet};
use petgraph::prelude::*;
use rand::Rng;
use thiserror::Error;

@ -1,5 +1,4 @@
use ndarray::prelude::*;
use petgraph::prelude::*;
use crate::network;
use crate::node;
use crate::params;
@ -22,26 +21,52 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
let node_idx: Vec<_> = net.get_node_indices().collect();
for _ in 0..n_trajectories {
let t = 0.0;
let mut t = 0.0;
let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Vec<params::StateType>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
match net.get_node(&x).get_params() {
node::NodeType::DiscreteStatesContinousTime(params) =>
params.get_random_state_uniform()
}
net.get_node(*x).get_random_state_uniform()
}).collect();
let next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect();
let mut next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect();
events.push(current_state.clone());
time.push(t.clone());
while t < t_end {
next_transitions.iter_mut().zip(net.get_node_indices()).map(|x| {
if let None = x.0 {
*(x.0) = Some(match net.get_node(&x.1).get_params(){
node::NodeType::DiscreteStatesContinousTime(params) =>
params.get_random_residence_time(x.1, net.get_param_index_network(&x.1, &current_state)).unwrap()
next_transitions.iter_mut().enumerate().map(|(idx, val)| {
if let None = val {
*val = Some(net.get_node(idx)
.get_random_residence_time(net.get_node(idx).state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state)).unwrap() + t);
}
});
}});
let next_node_transition = next_transitions
.iter()
.enumerate()
.min_by(|x, y|
x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
.unwrap().0;
if next_transitions[next_node_transition].unwrap() > t_end {
break
}
t = next_transitions[next_node_transition].unwrap().clone();
time.push(t.clone());
current_state[next_node_transition] = net.get_node(next_node_transition)
.get_random_state(
net.get_node(next_node_transition).
state_to_index(
&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state))
.unwrap();
events.push(current_state.clone());
next_transitions[next_node_transition] = None;
for child in net.get_children_set(next_node_transition){
next_transitions[child] = None
}
}

Loading…
Cancel
Save