diff --git a/src/ctbn.rs b/src/ctbn.rs index da71f29..8997b29 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -32,13 +32,13 @@ use crate::network; /// let params = params::DiscreteStatesContinousTimeParams::init(domain); /// /// //Create the node using the parameters -/// let X1 = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),String::from("X1")); +/// let X1 = node::Node::init(Box::from(params),String::from("X1")); /// /// let mut domain = BTreeSet::new(); /// domain.insert(String::from("A")); /// domain.insert(String::from("B")); /// let params = params::DiscreteStatesContinousTimeParams::init(domain); -/// let X2 = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),String::from("X2")); +/// let X2 = node::Node::init(Box::from(params), String::from("X2")); /// /// //Initialize a ctbn /// let mut net = CtbnNetwork::init(); @@ -76,7 +76,7 @@ impl network::Network for CtbnNetwork { } fn add_node(&mut self, mut n: node::Node) -> Result { - n.reset_params(); + n.params.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); Ok(self.nodes.len() -1) @@ -89,7 +89,7 @@ impl network::Network for CtbnNetwork { if let Some(network) = &mut self.adj_matrix { network[[parent, child]] = 1; - self.nodes[child].reset_params(); + self.nodes[child].params.reset_params(); } } @@ -105,8 +105,8 @@ impl network::Network for CtbnNetwork { fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { if x.1 > &0 { - acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; - acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); + acc.0 += self.nodes[x.0].params.state_to_index(¤t_state[x.0]) * acc.1; + acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent(); } acc }).0 @@ -157,7 +157,7 @@ mod tests { domain.insert(String::from("A")); domain.insert(String::from("B")); let params = params::DiscreteStatesContinousTimeParams::init(domain); - let n = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),name); + let n = node::Node::init(Box::from(params), name); return n; } diff --git a/src/node.rs b/src/node.rs index 2aecb27..61987f1 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,68 +1,19 @@ -use std::collections::BTreeSet; use crate::params::*; -/// Enumerator representing the parameters supported in this library. -pub enum NodeType { - DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams) -} pub struct Node { - pub params: NodeType, + pub params: Box, pub label: String } impl Node { - pub fn init(params: NodeType, label: String) -> Node { + pub fn init(params: Box, label: String) -> Node { Node{ params: params, label:label } } - pub fn reset_params(&mut self) { - match &mut self.params { - NodeType::DiscreteStatesContinousTime(params) => {params.reset_params();} - } - } - - pub fn get_params(&self) -> &NodeType { - &self.params - } - - pub fn get_reserved_space_as_parent(&self) -> usize { - match &self.params { - NodeType::DiscreteStatesContinousTime(params) => params.get_reserved_space_as_parent() - } - } - - pub fn state_to_index(&self,state: &StateType) -> usize{ - match &self.params { - NodeType::DiscreteStatesContinousTime(params) => params.state_to_index(state) - } - } - - - pub fn get_random_residence_time(&self, state: usize, u:usize) -> Result { - match &self.params { - NodeType::DiscreteStatesContinousTime(params) => params.get_random_residence_time(state, u) - } - } - - - pub fn get_random_state_uniform(&self) -> StateType { - match &self.params { - NodeType::DiscreteStatesContinousTime(params) => params.get_random_state_uniform() - } - } - - - pub fn get_random_state(&self, state: usize, u:usize) -> Result{ - match &self.params { - NodeType::DiscreteStatesContinousTime(params) => params.get_random_state(state, u) - } - } - - } impl PartialEq for Node { diff --git a/src/tools.rs b/src/tools.rs index 8a73a7a..921d558 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -25,7 +25,7 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 let mut time: Vec = Vec::new(); let mut events: Vec> = Vec::new(); let mut current_state: Vec = node_idx.iter().map(|x| { - net.get_node(*x).get_random_state_uniform() + net.get_node(*x).params.get_random_state_uniform() }).collect(); let mut next_transitions: Vec> = (0..node_idx.len()).map(|_| Option::None).collect(); events.push(current_state.iter().map(|x| match x { @@ -35,8 +35,8 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 while t < t_end { for (idx, val) in next_transitions.iter_mut().enumerate(){ if let None = val { - *val = Some(net.get_node(idx) - .get_random_residence_time(net.get_node(idx).state_to_index(¤t_state[idx]), + *val = Some(net.get_node(idx).params + .get_random_residence_time(net.get_node(idx).params.state_to_index(¤t_state[idx]), net.get_param_index_network(idx, ¤t_state)).unwrap() + t); } }; @@ -55,9 +55,9 @@ pub fn trajectory_generator(net: &Box, n_trajectories: u64 t = next_transitions[next_node_transition].unwrap().clone(); time.push(t.clone()); - current_state[next_node_transition] = net.get_node(next_node_transition) + current_state[next_node_transition] = net.get_node(next_node_transition).params .get_random_state( - net.get_node(next_node_transition). + net.get_node(next_node_transition).params. state_to_index( ¤t_state[next_node_transition]), net.get_param_index_network(next_node_transition, ¤t_state))