|
|
|
@ -32,13 +32,13 @@ use crate::network; |
|
|
|
|
/// let params = params::DiscreteStatesContinousTimeParams::init(domain);
|
|
|
|
|
///
|
|
|
|
|
/// //Create the node using the parameters
|
|
|
|
|
/// let X1 = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),String::from("X1"));
|
|
|
|
|
/// let X1 = node::Node::init(Box::from(params),String::from("X1"));
|
|
|
|
|
///
|
|
|
|
|
/// let mut domain = BTreeSet::new();
|
|
|
|
|
/// domain.insert(String::from("A"));
|
|
|
|
|
/// domain.insert(String::from("B"));
|
|
|
|
|
/// let params = params::DiscreteStatesContinousTimeParams::init(domain);
|
|
|
|
|
/// let X2 = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),String::from("X2"));
|
|
|
|
|
/// let X2 = node::Node::init(Box::from(params), String::from("X2"));
|
|
|
|
|
///
|
|
|
|
|
/// //Initialize a ctbn
|
|
|
|
|
/// let mut net = CtbnNetwork::init();
|
|
|
|
@ -76,7 +76,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> { |
|
|
|
|
n.reset_params(); |
|
|
|
|
n.params.reset_params(); |
|
|
|
|
self.adj_matrix = Option::None; |
|
|
|
|
self.nodes.push(n); |
|
|
|
|
Ok(self.nodes.len() -1)
|
|
|
|
@ -89,7 +89,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
|
|
|
|
|
|
if let Some(network) = &mut self.adj_matrix { |
|
|
|
|
network[[parent, child]] = 1; |
|
|
|
|
self.nodes[child].reset_params(); |
|
|
|
|
self.nodes[child].params.reset_params(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -105,8 +105,8 @@ impl network::Network for CtbnNetwork { |
|
|
|
|
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{ |
|
|
|
|
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { |
|
|
|
|
if x.1 > &0 { |
|
|
|
|
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
|
|
|
|
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
|
|
|
|
acc.0 += self.nodes[x.0].params.state_to_index(¤t_state[x.0]) * acc.1; |
|
|
|
|
acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent(); |
|
|
|
|
} |
|
|
|
|
acc |
|
|
|
|
}).0 |
|
|
|
@ -157,7 +157,7 @@ mod tests { |
|
|
|
|
domain.insert(String::from("A")); |
|
|
|
|
domain.insert(String::from("B")); |
|
|
|
|
let params = params::DiscreteStatesContinousTimeParams::init(domain); |
|
|
|
|
let n = node::Node::init(node::NodeType::DiscreteStatesContinousTime(params),name); |
|
|
|
|
let n = node::Node::init(Box::from(params), name); |
|
|
|
|
return n; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|