diff --git a/Cargo.toml b/Cargo.toml index f8d3b03..b5d70a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" [dependencies] -petgraph = "*" ndarray = "*" thiserror = "*" rand = "*" diff --git a/src/ctbn.rs b/src/ctbn.rs index 6e5f0b4..0bdd575 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -1,6 +1,5 @@ use std::collections::{HashMap, BTreeSet}; -use petgraph::prelude::*; - +use ndarray::prelude::*; use crate::node; use crate::params::StateType; use crate::network; @@ -9,50 +8,80 @@ use crate::network; pub struct CtbnNetwork { - network: petgraph::stable_graph::StableGraph, + adj_matrix: Option>, + nodes: Vec } impl network::Network for CtbnNetwork { - fn add_node(&mut self, mut n: node::Node) -> Result { + fn initialize_adj_matrix(&mut self) { + self.adj_matrix = Some(Array2::::zeros((self.nodes.len(), self.nodes.len()).f())); + + } + + fn add_node(&mut self, mut n: node::Node) -> Result { n.reset_params(); - Ok(self.network.add_node(n)) + self.adj_matrix = Option::None; + self.nodes.push(n); + Ok(self.nodes.len() -1) } - fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex) { - self.network.add_edge(parent.clone(), child.clone(), {}); - let mut p = self.network.node_weight_mut(child.clone()).unwrap(); - p.reset_params(); + fn add_edge(&mut self, parent: usize, child: usize) { + if let None = self.adj_matrix { + self.initialize_adj_matrix(); + } + + if let Some(network) = &mut self.adj_matrix { + network[[parent, child]] = 1; + self.nodes[child].reset_params(); + } } - fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices{ - self.network.node_indices() + fn get_node_indices(&self) -> std::ops::Range{ + 0..self.nodes.len() } - fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{ - self.network.node_weight(node_idx.clone()).unwrap() + fn get_node(&self, node_idx: usize) -> &node::Node{ + &self.nodes[node_idx] } - fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec) -> usize{ - self.network.neighbors_directed(node.clone(), Direction::Incoming).zip(u).fold((0, 1), |mut acc, x| { - let n = self.get_node(node); - acc.0 += n.state_to_index(&x.1) * acc.1; - acc.1 *= n.get_reserved_space_as_parent(); - acc + fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ + self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { + if x.1 > &0 { + acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; + acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); + } + acc }).0 } - - fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec) -> usize{ - self.get_param_index_parents(node, ¤t_state.iter() - .zip(self.get_node_indices()) - .filter_map(|x| { - match self.network.find_edge(x.1, node.clone()) { - Some(_) => Some(x.0.clone()), - None => None - } - }).collect() - ) + fn get_parent_set(&self, node: usize) -> Vec { + self.adj_matrix.as_ref() + .unwrap() + .column(node) + .iter() + .enumerate() + .filter_map(|(idx, x)| { + if x > &0 { + Some(idx) + } else { + None + } + }).collect() + } + fn get_children_set(&self, node: usize) -> Vec{ + self.adj_matrix.as_ref() + .unwrap() + .row(node) + .iter() + .enumerate() + .filter_map(|(idx, x)| { + if x > &0 { + Some(idx) + } else { + None + } + }).collect() } } diff --git a/src/network.rs b/src/network.rs index cd1f8a9..03d132e 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,6 +1,7 @@ -use petgraph::prelude::*; use crate::node; use thiserror::Error; use crate::params; +use ndarray::prelude::*; +use crate::node; #[derive(Error, Debug)] pub enum NetworkError { @@ -9,10 +10,12 @@ pub enum NetworkError { } pub trait Network { - fn add_node(&mut self, n: node::Node) -> Result; - fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex); - fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices; - fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node; - fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec) -> usize; - fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec) -> usize; + fn initialize_adj_matrix(&mut self); + fn add_node(&mut self, n: node::Node) -> Result; + fn add_edge(&mut self, parent: usize, child: usize); + fn get_node_indices(&self) -> std::ops::Range; + fn get_node(&self, node_idx: usize) -> &node::Node; + fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; + fn get_parent_set(&self, node: usize) -> Vec; + fn get_children_set(&self, node: usize) -> Vec; } diff --git a/src/node.rs b/src/node.rs index 2707d69..8c922c8 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeSet; use petgraph::prelude::*; +use std::collections::BTreeSet; use crate::params::*; pub enum NodeType { @@ -33,6 +33,28 @@ impl Node { } } + + pub fn get_random_residence_time(&self, state: usize, u:usize) -> Result { + match &self.params { + NodeType::DiscreteStatesContinousTime(params) => params.get_random_residence_time(state, u) + } + } + + + pub fn get_random_state_uniform(&self) -> StateType { + match &self.params { + NodeType::DiscreteStatesContinousTime(params) => params.get_random_state_uniform() + } + } + + + pub fn get_random_state(&self, state: usize, u:usize) -> Result{ + match &self.params { + NodeType::DiscreteStatesContinousTime(params) => params.get_random_state(state, u) + } + } + + } impl PartialEq for Node { diff --git a/src/params.rs b/src/params.rs index 2cd24db..f3793a3 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,6 +1,5 @@ use ndarray::prelude::*; use std::collections::{HashMap, BTreeSet}; -use petgraph::prelude::*; use rand::Rng; use thiserror::Error; diff --git a/src/tools.rs b/src/tools.rs index 6258a19..048a77b 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,5 +1,4 @@ use ndarray::prelude::*; -use petgraph::prelude::*; use crate::network; use crate::node; use crate::params; @@ -22,28 +21,54 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 let node_idx: Vec<_> = net.get_node_indices().collect(); for _ in 0..n_trajectories { - let t = 0.0; + let mut t = 0.0; let mut time: Vec = Vec::new(); let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx.iter().map(|x| { - match net.get_node(&x).get_params() { - node::NodeType::DiscreteStatesContinousTime(params) => - params.get_random_state_uniform() -} + net.get_node(*x).get_random_state_uniform() }).collect(); - let next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); + let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); events.push(current_state.clone()); time.push(t.clone()); while t < t_end { - next_transitions.iter_mut().zip(net.get_node_indices()).map(|x| { - if let None = x.0 { - *(x.0) = Some(match net.get_node(&x.1).get_params(){ - node::NodeType::DiscreteStatesContinousTime(params) => - params.get_random_residence_time(x.1, net.get_param_index_network(&x.1, ¤t_state)).unwrap() - }); - }}); + next_transitions.iter_mut().enumerate().map(|(idx, val)| { + if let None = val { + *val = Some(net.get_node(idx) + .get_random_residence_time(net.get_node(idx).state_to_index(¤t_state[idx]), + net.get_param_index_network(idx, ¤t_state)).unwrap() + t); + } + }); + let next_node_transition = next_transitions + .iter() + .enumerate() + .min_by(|x, y| + x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) + .unwrap().0; + + if next_transitions[next_node_transition].unwrap() > t_end { + break + } + + t = next_transitions[next_node_transition].unwrap().clone(); + time.push(t.clone()); + + current_state[next_node_transition] = net.get_node(next_node_transition) + .get_random_state( + net.get_node(next_node_transition). + state_to_index( + ¤t_state[next_node_transition]), + net.get_param_index_network(next_node_transition, ¤t_state)) + .unwrap(); + events.push(current_state.clone()); + next_transitions[next_node_transition] = None; + for child in net.get_children_set(next_node_transition){ + next_transitions[child] = None + } + + + }