MLE - Partial commit

pull/19/head
AlessandroBregoli 3 years ago
parent 1aaa252653
commit aa9ff5e05d
  1. 16
      src/ctbn.rs
  2. 1
      src/lib.rs
  3. 5
      src/network.rs
  4. 21
      src/parameter_learning.rs

@ -1,8 +1,8 @@
use std::collections::{HashMap, BTreeSet};
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::node; use crate::node;
use crate::params::{StateType, Params, ParamsTrait}; use crate::params::{StateType, ParamsTrait};
use crate::network; use crate::network;
use std::collections::BTreeSet;
@ -52,7 +52,7 @@ use crate::network;
/// ///
/// //Get all the children of node X1 /// //Get all the children of node X1
/// let cs = net.get_children_set(X1); /// let cs = net.get_children_set(X1);
/// assert_eq!(X2, cs[0]); /// assert_eq!(&X2, cs.iter().next().unwrap());
/// ``` /// ```
pub struct CtbnNetwork { pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>, adj_matrix: Option<Array2<u16>>,
@ -117,7 +117,7 @@ impl network::Network for CtbnNetwork {
}).0 }).0
} }
fn get_parent_set(&self, node: usize) -> Vec<usize> { fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix.as_ref() self.adj_matrix.as_ref()
.unwrap() .unwrap()
.column(node) .column(node)
@ -132,7 +132,7 @@ impl network::Network for CtbnNetwork {
}).collect() }).collect()
} }
fn get_children_set(&self, node: usize) -> Vec<usize>{ fn get_children_set(&self, node: usize) -> BTreeSet<usize>{
self.adj_matrix.as_ref() self.adj_matrix.as_ref()
.unwrap() .unwrap()
.row(node) .row(node)
@ -186,7 +186,7 @@ mod tests {
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
let cs = net.get_children_set(n1); let cs = net.get_children_set(n1);
assert_eq!(n2, cs[0]); assert_eq!(&n2, cs.iter().next().unwrap());
} }
#[test] #[test]
@ -196,9 +196,9 @@ mod tests {
let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
let cs = net.get_children_set(n1); let cs = net.get_children_set(n1);
assert_eq!(n2, cs[0]); assert_eq!(&n2, cs.iter().next().unwrap());
let ps = net.get_parent_set(n2); let ps = net.get_parent_set(n2);
assert_eq!(n1, ps[0]); assert_eq!(&n1, ps.iter().next().unwrap());
} }

@ -7,4 +7,5 @@ pub mod params;
pub mod network; pub mod network;
pub mod ctbn; pub mod ctbn;
pub mod tools; pub mod tools;
pub mod parameter_learning;

@ -1,6 +1,7 @@
use thiserror::Error; use thiserror::Error;
use crate::params; use crate::params;
use crate::node; use crate::node;
use std::collections::BTreeSet;
/// Error types for trait Network /// Error types for trait Network
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -26,6 +27,6 @@ pub trait Network {
///configuration of the network. Usually, the only values really used in *current_state* are ///configuration of the network. Usually, the only values really used in *current_state* are
///the ones in the parent set of the *node*. ///the ones in the parent set of the *node*.
fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize; fn get_param_index_network(&self, node: usize, current_state: &Vec<params::StateType>) -> usize;
fn get_parent_set(&self, node: usize) -> Vec<usize>; fn get_parent_set(&self, node: usize) -> BTreeSet<usize>;
fn get_children_set(&self, node: usize) -> Vec<usize>; fn get_children_set(&self, node: usize) -> BTreeSet<usize>;
} }

@ -0,0 +1,21 @@
use crate::params::*;
use crate::network;
use crate::tools;
use ndarray::prelude::*;
use std::collections::BTreeSet;
pub fn MLE(net: Box<dyn network::Network>,
dataset: &tools::Dataset,
node: usize,
parent_set: Option<BTreeSet<usize>>) {
let parent_set = match parent_set {
Some(p) => p,
None => net.get_parent_set(node)
};
}
Loading…
Cancel
Save