Added random generation of DiscreteStateContinousTime node

main
AlessandroBregoli 3 years ago
parent 67f77c3361
commit 97a2d16e6d
  1. 20
      src/ctbn.rs
  2. 3
      src/network.rs
  3. 27
      src/node.rs
  4. 71
      src/params.rs
  5. 23
      src/tools.rs

@ -13,25 +13,15 @@ pub struct CtbnNetwork {
} }
impl network::Network for CtbnNetwork { impl network::Network for CtbnNetwork {
fn add_node(&mut self, n: node::Node) -> Result<petgraph::graph::NodeIndex, network::NetworkError> { fn add_node(&mut self, mut n: node::Node) -> Result<petgraph::graph::NodeIndex, network::NetworkError> {
match &n.params { n.reset_params();
node::ParamsType::DiscreteStatesContinousTime(_) => { Ok(self.network.add_node(n))
if self.network.node_weights().any(|x| x.label == n.label) {
//TODO: Insert a better error description
return Err(network::NetworkError::NodeInsertionError(String::from("Label already used")));
}
let idx = self.network.add_node(n);
Ok(idx)
},
//TODO: Insert a better error description
_ => Err(network::NetworkError::NodeInsertionError(String::from("unsupported node")))
}
} }
fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex) { fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex) {
self.network.add_edge(parent.clone(), child.clone(), {}); self.network.add_edge(parent.clone(), child.clone(), {});
let mut p = self.network.node_weight(child.clone()); let mut p = self.network.node_weight_mut(child.clone()).unwrap();
match p. p.reset_params();
} }
fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>{ fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>{

@ -1,5 +1,4 @@
use petgraph::prelude::*; use petgraph::prelude::*; use crate::node;
use crate::node;
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug)] #[derive(Error, Debug)]

@ -1,12 +1,35 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use crate::params; use petgraph::prelude::*;
use crate::params::*;
pub enum NodeType {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams)
}
pub struct Node { pub struct Node {
pub params: Box<dyn params::Params>, pub params: NodeType,
pub label: String pub label: String
} }
impl Node {
pub fn add_parent(&mut self, parent: &petgraph::stable_graph::NodeIndex) {
match &mut self.params {
NodeType::DiscreteStatesContinousTime(params) => {params.add_parent(parent);}
}
}
pub fn reset_params(&mut self) {
match &mut self.params {
NodeType::DiscreteStatesContinousTime(params) => {params.reset_params();}
}
}
pub fn get_params(&self) -> &NodeType {
&self.params
}
}
impl PartialEq for Node { impl PartialEq for Node {
fn eq(&self, other: &Node) -> bool{ fn eq(&self, other: &Node) -> bool{
self.label == other.label self.label == other.label

@ -1,39 +1,94 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use std::collections::{HashMap, BTreeSet}; use std::collections::{HashMap, BTreeSet};
use petgraph::prelude::*; use petgraph::prelude::*;
use rand::Rng;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ParamsError {
#[error("Unsupported method")]
UnsupportedMethod(String),
#[error("Paramiters not initialized")]
ParametersNotInitialized(String)
}
pub trait Params { pub trait Params {
fn reset_params(&mut self);
fn get_random_state_uniform(&self) -> StateType;
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>;
fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>;
}
fn add_parent(&mut self, p: &petgraph::stable_graph::NodeIndex); pub enum StateType {
Discrete(u32)
} }
pub struct DiscreteStatesContinousTimeParams { pub struct DiscreteStatesContinousTimeParams {
domain: BTreeSet<String>, domain: BTreeSet<String>,
parents: BTreeSet<petgraph::stable_graph::NodeIndex>,
cim: Option<Array3<f64>>, cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>, transitions: Option<Array3<u64>>,
residence_time: Option<Array2<f64>> residence_time: Option<Array2<f64>>
} }
impl DiscreteStatesContinousTimeParams { impl DiscreteStatesContinousTimeParams {
fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams { pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams {
domain: domain, domain: domain,
parents: BTreeSet::new(),
cim: Option::None, cim: Option::None,
transitions: Option::None, transitions: Option::None,
residence_time: Option::None residence_time: Option::None
} }
} }
} }
impl Params for DiscreteStatesContinousTimeParams { impl Params for DiscreteStatesContinousTimeParams {
fn add_parent(&mut self, p: &petgraph::stable_graph::NodeIndex) {
self.parents.insert(p.clone()); fn reset_params(&mut self) {
self.cim = Option::None; self.cim = Option::None;
self.transitions = Option::None; self.transitions = Option::None;
self.residence_time = Option::None; self.residence_time = Option::None;
} }
fn get_random_state_uniform(&self) -> StateType {
let mut rng = rand::thread_rng();
StateType::Discrete(rng.gen_range(0..(self.domain.len() as u32)))
}
fn get_random_residence_time(&self, state: usize, u:usize) -> Result<f64, ParamsError> {
match &self.cim {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let x:f64 = rng.gen_range(0.0..1.0);
Ok(-x.ln()/lambda)
},
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized")))
}
}
fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>{
match &self.cim {
Option::Some(cim) => {
let mut rng = rand::thread_rng();
let lambda = cim[[u, state, state]] * -1.0;
let x = rng.gen_range(0.0..1.0);
let state = (cim.slice(s![u,state,..])).iter().scan((0, 0.0), |acc, &x| {
if x > 0.0 && acc.1 < x {
acc.0 += 1;
acc.1 += x;
return Some(*acc);
} else if acc.1 < x {
acc.0 += 1;
return Some(*acc);
}
None
}).last();
Ok(StateType::Discrete(state.unwrap().0))
},
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized")))
}
}
} }

@ -1,7 +1,8 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::network;
use petgraph::prelude::*; use petgraph::prelude::*;
use rand::Rng; use crate::network;
use crate::node;
use crate::params::Params;
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,
@ -17,12 +18,26 @@ fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64, t_
let mut dataset = Dataset{ let mut dataset = Dataset{
trajectories: Vec::new() trajectories: Vec::new()
}; };
let node_idx: Vec<_> = net.get_node_indices().collect();
for _ in 0..n_trajectories { for _ in 0..n_trajectories {
let mut rng = rand::thread_rng();
let t = 0.0; let t = 0.0;
let mut time: Vec<f64> = Vec::new(); let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Vec<u32>> = Vec::new(); let mut events: Vec<Vec<u32>> = Vec::new();
let current_state: Vec<u32> = net.get_node_indices().map(|x| rng.gen_range(0..2)).collect(); let mut current_state: Vec<u32> = node_idx.iter().map(|x| {
match net.get_node_weight(&x).get_params() {
node::NodeType::DiscreteStatesContinousTime(params) =>
params.get_random_state_uniform()
}
}).collect();
let next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect();
events.push(current_state.clone());
time.push(t.clone());
while t < t_end {
}
} }

Loading…
Cancel
Save