Added an inelegant method for acces a parameter given it's parentset

main
AlessandroBregoli 3 years ago
parent 97a2d16e6d
commit f4fabe1923
  1. 14
      src/ctbn.rs
  2. 4
      src/network.rs
  3. 21
      src/node.rs
  4. 14
      src/params.rs
  5. 9
      src/tools.rs

@ -2,7 +2,7 @@ use std::collections::{HashMap, BTreeSet};
use petgraph::prelude::*;
use crate::node;
use crate::params;
use crate::params::StateType;
use crate::network;
@ -28,8 +28,18 @@ impl network::Network for CtbnNetwork {
self.network.node_indices()
}
fn get_node_weight(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{
fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{
self.network.node_weight(node_idx.clone()).unwrap()
}
fn get_param_index(&self, node: &petgraph::stable_graph::NodeIndex, u: Vec<StateType>) -> usize{
self.network.neighbors_directed(node.clone(), Direction::Incoming).zip(u).fold((0, 1), |acc, x| {
let n = self.get_node(node);
acc.0 += n.state_to_index(&x.1) * acc.1;
acc.1 *= n.get_reserved_space_as_parent();
acc
}).0
}
}

@ -1,5 +1,6 @@
use petgraph::prelude::*; use crate::node;
use thiserror::Error;
use crate::params;
#[derive(Error, Debug)]
pub enum NetworkError {
@ -11,5 +12,6 @@ pub trait Network {
fn add_node(&mut self, n: node::Node) -> Result<petgraph::graph::NodeIndex, NetworkError>;
fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex);
fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>;
fn get_node_weight(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node;
fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node;
fn get_param_index(&self, node: &petgraph::stable_graph::NodeIndex, u: Vec<params::StateType>) -> usize;
}

@ -1,5 +1,4 @@
use std::collections::BTreeSet;
use petgraph::prelude::*;
use std::collections::BTreeSet; use petgraph::prelude::*;
use crate::params::*;
pub enum NodeType {
@ -12,12 +11,6 @@ pub struct Node {
}
impl Node {
pub fn add_parent(&mut self, parent: &petgraph::stable_graph::NodeIndex) {
match &mut self.params {
NodeType::DiscreteStatesContinousTime(params) => {params.add_parent(parent);}
}
}
pub fn reset_params(&mut self) {
match &mut self.params {
NodeType::DiscreteStatesContinousTime(params) => {params.reset_params();}
@ -28,6 +21,18 @@ impl Node {
&self.params
}
pub fn get_reserved_space_as_parent(&self) -> usize {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_reserved_space_as_parent()
}
}
pub fn state_to_index(&self,state: &StateType) -> usize{
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.state_to_index(state)
}
}
}
impl PartialEq for Node {

@ -17,8 +17,11 @@ pub trait Params {
fn get_random_state_uniform(&self) -> StateType;
fn get_random_residence_time(&self, state: usize, u: usize) -> Result<f64, ParamsError>;
fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>;
fn get_reserved_space_as_parent(&self) -> usize;
fn state_to_index(&self, state: &StateType) -> usize;
}
#[derive(Clone)]
pub enum StateType {
Discrete(u32)
}
@ -91,4 +94,15 @@ impl Params for DiscreteStatesContinousTimeParams {
Option::None => Err(ParamsError::ParametersNotInitialized(String::from("CIM not initialized")))
}
}
fn get_reserved_space_as_parent(&self) -> usize {
self.domain.len()
}
fn state_to_index(&self, state: &StateType) -> usize {
match state {
StateType::Discrete(val) => val.clone() as usize
}
}
}

@ -2,6 +2,7 @@ use ndarray::prelude::*;
use petgraph::prelude::*;
use crate::network;
use crate::node;
use crate::params;
use crate::params::Params;
pub struct Trajectory {
@ -14,7 +15,7 @@ pub struct Dataset {
}
fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64, t_end: f64) -> Dataset {
let mut dataset = Dataset{
trajectories: Vec::new()
};
@ -23,9 +24,9 @@ fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64, t_
for _ in 0..n_trajectories {
let t = 0.0;
let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Vec<u32>> = Vec::new();
let mut current_state: Vec<u32> = node_idx.iter().map(|x| {
match net.get_node_weight(&x).get_params() {
let mut events: Vec<Vec<params::StateType>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
match net.get_node(&x).get_params() {
node::NodeType::DiscreteStatesContinousTime(params) =>
params.get_random_state_uniform()
}

Loading…
Cancel
Save