From 5cf1daccb7ed33e66e1ace9635550561d921ba5e Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Fri, 18 Feb 2022 16:18:27 +0100 Subject: [PATCH] Added an inelegant method for acces a parameter given it's parentset --- src/ctbn.rs | 14 ++++++++++++-- src/network.rs | 4 +++- src/node.rs | 21 +++++++++++++-------- src/params.rs | 14 ++++++++++++++ src/tools.rs | 9 +++++---- 5 files changed, 47 insertions(+), 15 deletions(-) diff --git a/src/ctbn.rs b/src/ctbn.rs index 05b2ac6..8d8dd4c 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, BTreeSet}; use petgraph::prelude::*; use crate::node; -use crate::params; +use crate::params::StateType; use crate::network; @@ -28,8 +28,18 @@ impl network::Network for CtbnNetwork { self.network.node_indices() } - fn get_node_weight(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{ + fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{ self.network.node_weight(node_idx.clone()).unwrap() } + fn get_param_index(&self, node: &petgraph::stable_graph::NodeIndex, u: Vec) -> usize{ + self.network.neighbors_directed(node.clone(), Direction::Incoming).zip(u).fold((0, 1), |mut acc, x| { + let n = self.get_node(node); + acc.0 += n.state_to_index(&x.1) * acc.1; + acc.1 *= n.get_reserved_space_as_parent(); + acc + + }).0 + } + } diff --git a/src/network.rs b/src/network.rs index 6ac2bcb..b7a03d8 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,5 +1,6 @@ use petgraph::prelude::*; use crate::node; use thiserror::Error; +use crate::params; #[derive(Error, Debug)] pub enum NetworkError { @@ -11,5 +12,6 @@ pub trait Network { fn add_node(&mut self, n: node::Node) -> Result; fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex); fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices; - fn get_node_weight(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node; + fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node; + fn get_param_index(&self, node: &petgraph::stable_graph::NodeIndex, u: Vec) -> usize; } diff --git a/src/node.rs b/src/node.rs index 4d76a72..2707d69 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,5 +1,4 @@ -use std::collections::BTreeSet; -use petgraph::prelude::*; +use std::collections::BTreeSet; use petgraph::prelude::*; use crate::params::*; pub enum NodeType { @@ -12,12 +11,6 @@ pub struct Node { } impl Node { - pub fn add_parent(&mut self, parent: &petgraph::stable_graph::NodeIndex) { - match &mut self.params { - NodeType::DiscreteStatesContinousTime(params) => {params.add_parent(parent);} - } - } - pub fn reset_params(&mut self) { match &mut self.params { NodeType::DiscreteStatesContinousTime(params) => {params.reset_params();} @@ -27,6 +20,18 @@ impl Node { pub fn get_params(&self) -> &NodeType { &self.params } + + pub fn get_reserved_space_as_parent(&self) -> usize { + match &self.params { + NodeType::DiscreteStatesContinousTime(params) => params.get_reserved_space_as_parent() + } + } + + pub fn state_to_index(&self,state: &StateType) -> usize{ + match &self.params { + NodeType::DiscreteStatesContinousTime(params) => params.state_to_index(state) + } + } } diff --git a/src/params.rs b/src/params.rs index 2d27a85..2cd24db 100644 --- a/src/params.rs +++ b/src/params.rs @@ -17,8 +17,11 @@ pub trait Params { fn get_random_state_uniform(&self) -> StateType; fn get_random_residence_time(&self, state: usize, u: usize) -> Result; fn get_random_state(&self, state: usize, u:usize) -> Result; + fn get_reserved_space_as_parent(&self) -> usize; + fn state_to_index(&self, state: &StateType) -> usize; } +#[derive(Clone)] pub enum StateType { Discrete(u32) } @@ -91,4 +94,15 @@ impl Params for DiscreteStatesContinousTimeParams { Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized"))) } } + + + fn get_reserved_space_as_parent(&self) -> usize { + self.domain.len() + } + + fn state_to_index(&self, state: &StateType) -> usize { + match state { + StateType::Discrete(val) => val.clone() as usize + } + } } diff --git a/src/tools.rs b/src/tools.rs index 97499a0..c56bf1c 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -2,6 +2,7 @@ use ndarray::prelude::*; use petgraph::prelude::*; use crate::network; use crate::node; +use crate::params; use crate::params::Params; pub struct Trajectory { @@ -14,7 +15,7 @@ pub struct Dataset { } -fn trajectory_generator(net: &Box, n_trajectories: u64, t_end: f64) -> Dataset { +pub fn trajectory_generator(net: &Box, n_trajectories: u64, t_end: f64) -> Dataset { let mut dataset = Dataset{ trajectories: Vec::new() }; @@ -23,9 +24,9 @@ fn trajectory_generator(net: &Box, n_trajectories: u64, t_ for _ in 0..n_trajectories { let t = 0.0; let mut time: Vec = Vec::new(); - let mut events: Vec> = Vec::new(); - let mut current_state: Vec = node_idx.iter().map(|x| { - match net.get_node_weight(&x).get_params() { + let mut events: Vec> = Vec::new(); + let mut current_state: Vec = node_idx.iter().map(|x| { + match net.get_node(&x).get_params() { node::NodeType::DiscreteStatesContinousTime(params) => params.get_random_state_uniform() }