|
|
|
@ -1,6 +1,5 @@ |
|
|
|
|
use std::collections::{HashMap, BTreeSet}; |
|
|
|
|
use petgraph::prelude::*; |
|
|
|
|
|
|
|
|
|
use ndarray::prelude::*; |
|
|
|
|
use crate::node; |
|
|
|
|
use crate::params::StateType; |
|
|
|
|
use crate::network; |
|
|
|
@ -9,50 +8,80 @@ use crate::network; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub struct CtbnNetwork { |
|
|
|
|
network: petgraph::stable_graph::StableGraph<node::Node, ()>, |
|
|
|
|
adj_matrix: Option<Array2<u16>>, |
|
|
|
|
nodes: Vec<node::Node> |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl network::Network for CtbnNetwork { |
|
|
|
|
fn add_node(&mut self, mut n: node::Node) -> Result<petgraph::graph::NodeIndex, network::NetworkError> { |
|
|
|
|
fn initialize_adj_matrix(&mut self) { |
|
|
|
|
self.adj_matrix = Some(Array2::<u16>::zeros((self.nodes.len(), self.nodes.len()).f())); |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> { |
|
|
|
|
n.reset_params(); |
|
|
|
|
Ok(self.network.add_node(n))
|
|
|
|
|
self.adj_matrix = Option::None; |
|
|
|
|
self.nodes.push(n); |
|
|
|
|
Ok(self.nodes.len() -1)
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn add_edge(&mut self, parent: &petgraph::stable_graph::NodeIndex, child: &petgraph::graph::NodeIndex) { |
|
|
|
|
self.network.add_edge(parent.clone(), child.clone(), {}); |
|
|
|
|
let mut p = self.network.node_weight_mut(child.clone()).unwrap(); |
|
|
|
|
p.reset_params(); |
|
|
|
|
fn add_edge(&mut self, parent: usize, child: usize) { |
|
|
|
|
if let None = self.adj_matrix { |
|
|
|
|
self.initialize_adj_matrix(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn get_node_indices(&self) -> petgraph::stable_graph::NodeIndices<node::Node>{ |
|
|
|
|
self.network.node_indices()
|
|
|
|
|
if let Some(network) = &mut self.adj_matrix { |
|
|
|
|
network[[parent, child]] = 1; |
|
|
|
|
self.nodes[child].reset_params(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn get_node(&self, node_idx: &petgraph::stable_graph::NodeIndex) -> &node::Node{ |
|
|
|
|
self.network.node_weight(node_idx.clone()).unwrap() |
|
|
|
|
fn get_node_indices(&self) -> std::ops::Range<usize>{ |
|
|
|
|
0..self.nodes.len() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn get_param_index_parents(&self, node: &petgraph::stable_graph::NodeIndex, u: &Vec<StateType>) -> usize{ |
|
|
|
|
self.network.neighbors_directed(node.clone(), Direction::Incoming).zip(u).fold((0, 1), |mut acc, x| { |
|
|
|
|
let n = self.get_node(node); |
|
|
|
|
acc.0 += n.state_to_index(&x.1) * acc.1; |
|
|
|
|
acc.1 *= n.get_reserved_space_as_parent(); |
|
|
|
|
acc |
|
|
|
|
fn get_node(&self, node_idx: usize) -> &node::Node{ |
|
|
|
|
&self.nodes[node_idx] |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{ |
|
|
|
|
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { |
|
|
|
|
if x.1 > &0 { |
|
|
|
|
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; |
|
|
|
|
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); |
|
|
|
|
} |
|
|
|
|
acc |
|
|
|
|
}).0 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn get_param_index_network(&self, node: &petgraph::stable_graph::NodeIndex, current_state: &Vec<StateType>) -> usize{ |
|
|
|
|
self.get_param_index_parents(node, ¤t_state.iter() |
|
|
|
|
.zip(self.get_node_indices()) |
|
|
|
|
.filter_map(|x| { |
|
|
|
|
match self.network.find_edge(x.1, node.clone()) { |
|
|
|
|
Some(_) => Some(x.0.clone()), |
|
|
|
|
None => None |
|
|
|
|
fn get_parent_set(&self, node: usize) -> Vec<usize> { |
|
|
|
|
self.adj_matrix.as_ref() |
|
|
|
|
.unwrap() |
|
|
|
|
.column(node) |
|
|
|
|
.iter() |
|
|
|
|
.enumerate() |
|
|
|
|
.filter_map(|(idx, x)| { |
|
|
|
|
if x > &0 { |
|
|
|
|
Some(idx) |
|
|
|
|
} else { |
|
|
|
|
None |
|
|
|
|
} |
|
|
|
|
}).collect() |
|
|
|
|
} |
|
|
|
|
fn get_children_set(&self, node: usize) -> Vec<usize>{ |
|
|
|
|
self.adj_matrix.as_ref() |
|
|
|
|
.unwrap() |
|
|
|
|
.row(node) |
|
|
|
|
.iter() |
|
|
|
|
.enumerate() |
|
|
|
|
.filter_map(|(idx, x)| { |
|
|
|
|
if x > &0 { |
|
|
|
|
Some(idx) |
|
|
|
|
} else { |
|
|
|
|
None |
|
|
|
|
} |
|
|
|
|
}).collect() |
|
|
|
|
) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|