From aa9ff5e05d61659acf9186233465df0a8e090f50 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Thu, 3 Mar 2022 20:53:29 +0100 Subject: [PATCH] MLE - Partial commit --- src/ctbn.rs | 16 ++++++++-------- src/lib.rs | 1 + src/network.rs | 5 +++-- src/parameter_learning.rs | 21 +++++++++++++++++++++ 4 files changed, 33 insertions(+), 10 deletions(-) create mode 100644 src/parameter_learning.rs diff --git a/src/ctbn.rs b/src/ctbn.rs index 9153cb5..6cfb9ba 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -1,8 +1,8 @@ -use std::collections::{HashMap, BTreeSet}; use ndarray::prelude::*; use crate::node; -use crate::params::{StateType, Params, ParamsTrait}; +use crate::params::{StateType, ParamsTrait}; use crate::network; +use std::collections::BTreeSet; @@ -52,7 +52,7 @@ use crate::network; /// /// //Get all the children of node X1 /// let cs = net.get_children_set(X1); -/// assert_eq!(X2, cs[0]); +/// assert_eq!(&X2, cs.iter().next().unwrap()); /// ``` pub struct CtbnNetwork { adj_matrix: Option>, @@ -117,7 +117,7 @@ impl network::Network for CtbnNetwork { }).0 } - fn get_parent_set(&self, node: usize) -> Vec { + fn get_parent_set(&self, node: usize) -> BTreeSet { self.adj_matrix.as_ref() .unwrap() .column(node) @@ -132,7 +132,7 @@ impl network::Network for CtbnNetwork { }).collect() } - fn get_children_set(&self, node: usize) -> Vec{ + fn get_children_set(&self, node: usize) -> BTreeSet{ self.adj_matrix.as_ref() .unwrap() .row(node) @@ -186,7 +186,7 @@ mod tests { let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); net.add_edge(n1, n2); let cs = net.get_children_set(n1); - assert_eq!(n2, cs[0]); + assert_eq!(&n2, cs.iter().next().unwrap()); } #[test] @@ -196,9 +196,9 @@ mod tests { let n2 = net.add_node(define_binary_node(String::from("n2"))).unwrap(); net.add_edge(n1, n2); let cs = net.get_children_set(n1); - assert_eq!(n2, cs[0]); + assert_eq!(&n2, cs.iter().next().unwrap()); let ps = net.get_parent_set(n2); - assert_eq!(n1, ps[0]); + assert_eq!(&n1, ps.iter().next().unwrap()); } diff --git a/src/lib.rs b/src/lib.rs index b2e9365..65e4b11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,4 +7,5 @@ pub mod params; pub mod network; pub mod ctbn; pub mod tools; +pub mod parameter_learning; diff --git a/src/network.rs b/src/network.rs index 4fb7e63..2f36738 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,6 +1,7 @@ use thiserror::Error; use crate::params; use crate::node; +use std::collections::BTreeSet; /// Error types for trait Network #[derive(Error, Debug)] @@ -26,6 +27,6 @@ pub trait Network { ///configuration of the network. Usually, the only values really used in *current_state* are ///the ones in the parent set of the *node*. fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize; - fn get_parent_set(&self, node: usize) -> Vec; - fn get_children_set(&self, node: usize) -> Vec; + fn get_parent_set(&self, node: usize) -> BTreeSet; + fn get_children_set(&self, node: usize) -> BTreeSet; } diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs new file mode 100644 index 0000000..16c0b8f --- /dev/null +++ b/src/parameter_learning.rs @@ -0,0 +1,21 @@ +use crate::params::*; +use crate::network; +use crate::tools; +use ndarray::prelude::*; +use std::collections::BTreeSet; + +pub fn MLE(net: Box, + dataset: &tools::Dataset, + node: usize, + parent_set: Option>) { + + let parent_set = match parent_set { + Some(p) => p, + None => net.get_parent_set(node) + }; + + + + + +}