MLE - Partial commit

pull/19/head
AlessandroBregoli 3 years ago
parent 1aaa252653
commit aa9ff5e05d
  1. 16
      src/ctbn.rs
  2. 1
      src/lib.rs
  3. 5
      src/network.rs
  4. 21
      src/parameter_learning.rs

@ -1,8 +1,8 @@
use std::collections::{HashMap, BTreeSet};
use ndarray::prelude::*;
use crate::node;
use crate::params::{StateType, Params, ParamsTrait};
use crate::params::{StateType, ParamsTrait};
use crate::network;
use std::collections::BTreeSet;
@ -52,7 +52,7 @@ use crate::network;
///
/// //Get all the children of node X1
/// let cs = net.get_children_set(X1);
/// assert_eq!(X2, cs[0]);
/// assert_eq!(&X2, cs.iter().next().unwrap());
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
@ -117,7 +117,7 @@ impl network::Network for CtbnNetwork {
}).0
}
fn get_parent_set(&self, node: usize) -> Vec<usize> {
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref()
.unwrap()
.column(node)
@ -132,7 +132,7 @@ impl network::Network for CtbnNetwork {
}).collect()
}
fn get_children_set(&self, node: usize) -> Vec<usize>{
fn get_children_set(&self, node: usize) -> BTreeSet<usize>{
self.adj_matrix.as_ref()
.unwrap()
.row(node)
@ -186,7 +186,7 @@ mod tests {
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(n2, cs[0]);
assert_eq!(&n2, cs.iter().next().unwrap());
}
#[test]
@ -196,9 +196,9 @@ mod tests {
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(n2, cs[0]);
assert_eq!(&n2, cs.iter().next().unwrap());
let ps = net.get_parent_set(n2);
assert_eq!(n1, ps[0]);
assert_eq!(&n1, ps.iter().next().unwrap());
}

@ -7,4 +7,5 @@ pub mod params;
pub mod network;
pub mod ctbn;
pub mod tools;
pub mod parameter_learning;

@ -1,6 +1,7 @@
use thiserror::Error;
use crate::params;
use crate::node;
use std::collections::BTreeSet;
/// Error types for trait Network
#[derive(Error, Debug)]
@ -26,6 +27,6 @@ pub trait Network {
///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize;
fn get_parent_set(&self, node: usize) -> Vec<usize>;
fn get_children_set(&self, node: usize) -> Vec<usize>;
fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
}

@ -0,0 +1,21 @@
use crate::params::*;
use crate::network;
use crate::tools;
use ndarray::prelude::*;
use std::collections::BTreeSet;
pub fn MLE(net: Box<dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>) {
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node)
};
}
Loading…
Cancel
Save