|
|
@ -1,8 +1,8 @@ |
|
|
|
use std::collections::{HashMap, BTreeSet}; |
|
|
|
|
|
|
|
use ndarray::prelude::*; |
|
|
|
use ndarray::prelude::*; |
|
|
|
use crate::node; |
|
|
|
use crate::node; |
|
|
|
use crate::params::{StateType, Params, ParamsTrait}; |
|
|
|
use crate::params::{StateType, ParamsTrait}; |
|
|
|
use crate::network; |
|
|
|
use crate::network; |
|
|
|
|
|
|
|
use std::collections::BTreeSet; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -52,7 +52,7 @@ use crate::network; |
|
|
|
///
|
|
|
|
///
|
|
|
|
/// //Get all the children of node X1
|
|
|
|
/// //Get all the children of node X1
|
|
|
|
/// let cs = net.get_children_set(X1);
|
|
|
|
/// let cs = net.get_children_set(X1);
|
|
|
|
/// assert_eq!(X2, cs[0]);
|
|
|
|
/// assert_eq!(&X2, cs.iter().next().unwrap());
|
|
|
|
/// ```
|
|
|
|
/// ```
|
|
|
|
pub struct CtbnNetwork { |
|
|
|
pub struct CtbnNetwork { |
|
|
|
adj_matrix: Option<Array2<u16>>, |
|
|
|
adj_matrix: Option<Array2<u16>>, |
|
|
@ -117,7 +117,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
}).0 |
|
|
|
}).0 |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn get_parent_set(&self, node: usize) -> Vec<usize> { |
|
|
|
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
|
|
|
self.adj_matrix.as_ref() |
|
|
|
self.adj_matrix.as_ref() |
|
|
|
.unwrap() |
|
|
|
.unwrap() |
|
|
|
.column(node) |
|
|
|
.column(node) |
|
|
@ -132,7 +132,7 @@ impl network::Network for CtbnNetwork { |
|
|
|
}).collect() |
|
|
|
}).collect() |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn get_children_set(&self, node: usize) -> Vec<usize>{ |
|
|
|
fn get_children_set(&self, node: usize) -> BTreeSet<usize>{ |
|
|
|
self.adj_matrix.as_ref() |
|
|
|
self.adj_matrix.as_ref() |
|
|
|
.unwrap() |
|
|
|
.unwrap() |
|
|
|
.row(node) |
|
|
|
.row(node) |
|
|
@ -186,7 +186,7 @@ mod tests { |
|
|
|
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); |
|
|
|
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); |
|
|
|
net.add_edge(n1, n2); |
|
|
|
net.add_edge(n1, n2); |
|
|
|
let cs = net.get_children_set(n1); |
|
|
|
let cs = net.get_children_set(n1); |
|
|
|
assert_eq!(n2, cs[0]); |
|
|
|
assert_eq!(&n2, cs.iter().next().unwrap()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
#[test] |
|
|
@ -196,9 +196,9 @@ mod tests { |
|
|
|
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); |
|
|
|
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); |
|
|
|
net.add_edge(n1, n2); |
|
|
|
net.add_edge(n1, n2); |
|
|
|
let cs = net.get_children_set(n1); |
|
|
|
let cs = net.get_children_set(n1); |
|
|
|
assert_eq!(n2, cs[0]); |
|
|
|
assert_eq!(&n2, cs.iter().next().unwrap()); |
|
|
|
let ps = net.get_parent_set(n2); |
|
|
|
let ps = net.get_parent_set(n2); |
|
|
|
assert_eq!(n1, ps[0]); |
|
|
|
assert_eq!(&n1, ps.iter().next().unwrap()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|