Change from enum to Box<dyn Params>

pull/19/head
AlessandroBregoli 3 years ago
parent 30d493b240
commit 212f4aef4b
  1. 14
      src/ctbn.rs
  2. 53
      src/node.rs
  3. 10
      src/tools.rs

@ -32,13 +32,13 @@ use crate::network;
/// let params = params::DiscreteStatesContinousTimeParams::init(domain);
///
/// //Create the node using the parameters
/// let X1 = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),String::from("X1"));
/// let X1 = node::Node::init(Box::from(params),String::from("X1"));
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let params = params::DiscreteStatesContinousTimeParams::init(domain);
/// let X2 = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),String::from("X2"));
/// let X2 = node::Node::init(Box::from(params), String::from("X2"));
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::init();
@ -76,7 +76,7 @@ impl network::Network for CtbnNetwork {
}
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> {
n.reset_params();
n.params.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() -1)
@ -89,7 +89,7 @@ impl network::Network for CtbnNetwork {
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
self.nodes[child].params.reset_params();
}
}
@ -105,8 +105,8 @@ impl network::Network for CtbnNetwork {
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
acc.0 += self.nodes[x.0].params.state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent();
}
acc
}).0
@ -157,7 +157,7 @@ mod tests {
domain.insert(String::from("A"));
domain.insert(String::from("B"));
let params = params::DiscreteStatesContinousTimeParams::init(domain);
let n = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),name);
let n = node::Node::init(Box::from(params), name);
return n;
}

@ -1,68 +1,19 @@
use std::collections::BTreeSet;
use crate::params::*;
/// Enumerator representing the parameters supported in this library.
pub enum NodeType {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams)
}
pub struct Node {
pub params: NodeType,
pub params: Box<dyn Params>,
pub label: String
}
impl Node {
pub fn init(params: NodeType, label: String) -> Node {
pub fn init(params: Box<dyn Params>, label: String) -> Node {
Node{
params: params,
label:label
}
}
pub fn reset_params(&mut self) {
match &mut self.params {
NodeType::DiscreteStatesContinousTime(params) => {params.reset_params();}
}
}
pub fn get_params(&self) -> &NodeType {
&self.params
}
pub fn get_reserved_space_as_parent(&self) -> usize {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_reserved_space_as_parent()
}
}
pub fn state_to_index(&self,state: &StateType) -> usize{
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.state_to_index(state)
}
}
pub fn get_random_residence_time(&self, state: usize, u:usize) -> Result<f64, ParamsError> {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_residence_time(state, u)
}
}
pub fn get_random_state_uniform(&self) -> StateType {
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_state_uniform()
}
}
pub fn get_random_state(&self, state: usize, u:usize) -> Result<StateType, ParamsError>{
match &self.params {
NodeType::DiscreteStatesContinousTime(params) => params.get_random_state(state, u)
}
}
}
impl PartialEq for Node {

@ -25,7 +25,7 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
let mut time: Vec<f64> = Vec::new();
let mut events: Vec<Array1<u32>> = Vec::new();
let mut current_state: Vec<params::StateType> = node_idx.iter().map(|x| {
net.get_node(*x).get_random_state_uniform()
net.get_node(*x).params.get_random_state_uniform()
}).collect();
let mut next_transitions: Vec<Option<f64>> = (0..node_idx.len()).map(|_| Option::None).collect();
events.push(current_state.iter().map(|x| match x {
@ -35,8 +35,8 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
while t < t_end {
for (idx, val) in next_transitions.iter_mut().enumerate(){
if let None = val {
*val = Some(net.get_node(idx)
.get_random_residence_time(net.get_node(idx).state_to_index(&current_state[idx]),
*val = Some(net.get_node(idx).params
.get_random_residence_time(net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state)).unwrap() + t);
}
};
@ -55,9 +55,9 @@ pub fn trajectory_generator(net: &Box<dyn network::Network>, n_trajectories: u64
t = next_transitions[next_node_transition].unwrap().clone();
time.push(t.clone());
current_state[next_node_transition] = net.get_node(next_node_transition)
current_state[next_node_transition] = net.get_node(next_node_transition).params
.get_random_state(
net.get_node(next_node_transition).
net.get_node(next_node_transition).params.
state_to_index(
&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state))

Loading…
Cancel
Save