diff --git a/src/ctbn.rs b/src/ctbn.rs index 69196f8..2cede4a 100644 --- a/src/ctbn.rs +++ b/src/ctbn.rs @@ -1,6 +1,5 @@ use ndarray::prelude::*; -use crate::node; -use crate::params::{StateType, ParamsTrait}; +use crate::params::{StateType, Params, ParamsTrait}; use crate::network; use std::collections::BTreeSet; @@ -19,7 +18,6 @@ use std::collections::BTreeSet; /// /// use std::collections::BTreeSet; /// use reCTBN::network::Network; -/// use reCTBN::node; /// use reCTBN::params; /// use reCTBN::ctbn::*; /// @@ -29,16 +27,16 @@ use std::collections::BTreeSet; /// domain.insert(String::from("B")); /// /// //Create the parameters for a discrete node using the domain -/// let param = params::DiscreteStatesContinousTimeParams::new(domain); +/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain); /// /// //Create the node using the parameters -/// let X1 = node::Node::new(params::Params::DiscreteStatesContinousTime(param),String::from("X1")); +/// let X1 = params::Params::DiscreteStatesContinousTime(param); /// /// let mut domain = BTreeSet::new(); /// domain.insert(String::from("A")); /// domain.insert(String::from("B")); -/// let param = params::DiscreteStatesContinousTimeParams::new(domain); -/// let X2 = node::Node::new(params::Params::DiscreteStatesContinousTime(param), String::from("X2")); +/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain); +/// let X2 = params::Params::DiscreteStatesContinousTime(param); /// /// //Initialize a ctbn /// let mut net = CtbnNetwork::new(); @@ -56,7 +54,7 @@ use std::collections::BTreeSet; /// ``` pub struct CtbnNetwork { adj_matrix: Option>, - nodes: Vec + nodes: Vec } @@ -75,8 +73,8 @@ impl network::Network for CtbnNetwork { } - fn add_node(&mut self, mut n: node::Node) -> Result { - n.params.reset_params(); + fn add_node(&mut self, mut n: Params) -> Result { + n.reset_params(); self.adj_matrix = Option::None; self.nodes.push(n); Ok(self.nodes.len() -1) @@ -89,7 +87,7 @@ impl network::Network for CtbnNetwork { if let Some(network) = &mut self.adj_matrix { network[[parent, child]] = 1; - self.nodes[child].params.reset_params(); + self.nodes[child].reset_params(); } } @@ -101,12 +99,12 @@ impl network::Network for CtbnNetwork { self.nodes.len() } - fn get_node(&self, node_idx: usize) -> &node::Node{ + fn get_node(&self, node_idx: usize) -> &Params{ &self.nodes[node_idx] } - fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{ + fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{ &mut self.nodes[node_idx] } @@ -114,8 +112,8 @@ impl network::Network for CtbnNetwork { fn get_param_index_network(&self, node: usize, current_state: &Vec) -> usize{ self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { if x.1 > &0 { - acc.0 += self.nodes[x.0].params.state_to_index(¤t_state[x.0]) * acc.1; - acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent(); + acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1; + acc.1 *= self.nodes[x.0].get_reserved_space_as_parent(); } acc }).0 @@ -124,8 +122,8 @@ impl network::Network for CtbnNetwork { fn get_param_index_from_custom_parent_set(&self, current_state: &Vec, parent_set: &BTreeSet) -> usize { parent_set.iter().fold((0, 1), |mut acc, x| { - acc.0 += self.nodes[*x].params.state_to_index(¤t_state[*x]) * acc.1; - acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent(); + acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1; + acc.1 *= self.nodes[*x].get_reserved_space_as_parent(); acc }).0 } diff --git a/src/lib.rs b/src/lib.rs index ec12261..bcbde3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ #[macro_use] extern crate approx; -pub mod node; pub mod params; pub mod network; pub mod ctbn; diff --git a/src/network.rs b/src/network.rs index 3b6ce06..1c962b0 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,6 +1,5 @@ use thiserror::Error; use crate::params; -use crate::node; use std::collections::BTreeSet; /// Error types for trait Network @@ -15,14 +14,14 @@ pub enum NetworkError { ///The Network trait define the required methods for a structure used as pgm (such as ctbn). pub trait Network { fn initialize_adj_matrix(&mut self); - fn add_node(&mut self, n: node::Node) -> Result; + fn add_node(&mut self, n: params::Params) -> Result; fn add_edge(&mut self, parent: usize, child: usize); ///Get all the indices of the nodes contained inside the network fn get_node_indices(&self) -> std::ops::Range; fn get_number_of_nodes(&self) -> usize; - fn get_node(&self, node_idx: usize) -> &node::Node; - fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node; + fn get_node(&self, node_idx: usize) -> ¶ms::Params; + fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params; ///Compute the index that must be used to access the parameters of a node given a specific ///configuration of the network. Usually, the only values really used in *current_state* are diff --git a/src/node.rs b/src/node.rs deleted file mode 100644 index 3d8815f..0000000 --- a/src/node.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::params::*; - - -pub struct Node { - pub params: Params, - pub label: String -} - -impl Node { - pub fn new(params: Params, label: String) -> Node { - Node{ - params: params, - label:label - } - } - -} - -impl PartialEq for Node { - fn eq(&self, other: &Node) -> bool{ - self.label == other.label - } -} - - diff --git a/src/parameter_learning.rs b/src/parameter_learning.rs index c4221cb..5270d9e 100644 --- a/src/parameter_learning.rs +++ b/src/parameter_learning.rs @@ -24,7 +24,6 @@ pub fn sufficient_statistics( //Get the number of values assumable by the node let node_domain = net .get_node(node.clone()) - .params .get_reserved_space_as_parent(); //Get the number of values assumable by each parent of the node @@ -32,7 +31,6 @@ pub fn sufficient_statistics( .iter() .map(|x| { net.get_node(x.clone()) - .params .get_reserved_space_as_parent() }) .collect(); diff --git a/src/params.rs b/src/params.rs index d80fb43..e632b1b 100644 --- a/src/params.rs +++ b/src/params.rs @@ -49,6 +49,9 @@ pub trait ParamsTrait { /// Validate parameters against domain fn validate_params(&self) -> Result<(), ParamsError>; + + /// Return a reference to the associated label + fn get_label(&self) -> &String; } /// The Params enum is the core element for building different types of nodes. The goal is to @@ -70,6 +73,7 @@ pub enum Params { /// - **residence_time**: permanence time in each possible states given a specific /// realization of the parent set pub struct DiscreteStatesContinousTimeParams { + label: String, domain: BTreeSet, cim: Option>, transitions: Option>, @@ -77,8 +81,9 @@ pub struct DiscreteStatesContinousTimeParams { } impl DiscreteStatesContinousTimeParams { - pub fn new(domain: BTreeSet) -> DiscreteStatesContinousTimeParams { + pub fn new(label: String, domain: BTreeSet) -> DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams { + label, domain, cim: Option::None, transitions: Option::None, @@ -244,4 +249,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams { return Ok(()); } + + fn get_label(&self) -> &String { + &self.label + } + } diff --git a/src/structure_learning/score_function.rs b/src/structure_learning/score_function.rs index ad66b08..ea53db5 100644 --- a/src/structure_learning/score_function.rs +++ b/src/structure_learning/score_function.rs @@ -44,7 +44,7 @@ impl LogLikelihood { T: network::Network, { //Identify the type of node used - match &net.get_node(node).params { + match &net.get_node(node){ params::Params::DiscreteStatesContinousTime(_params) => { //Compute the sufficient statistics M (number of transistions) and T (residence //time) diff --git a/src/tools.rs b/src/tools.rs index b981f69..115fd67 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,5 +1,4 @@ use crate::network; -use crate::node; use crate::params; use crate::params::ParamsTrait; use ndarray::prelude::*; @@ -80,7 +79,7 @@ pub fn trajectory_generator( //Configuration of the process variables at time t initialized with an uniform //distribution. let mut current_state: Vec = net.get_node_indices() - .map(|x| net.get_node(x).params.get_random_state_uniform(&mut rng)) + .map(|x| net.get_node(x).get_random_state_uniform(&mut rng)) .collect(); //History of all the configurations of the process variables. let mut events: Vec> = Vec::new(); @@ -106,9 +105,8 @@ pub fn trajectory_generator( if let None = val { *val = Some( net.get_node(idx) - .params .get_random_residence_time( - net.get_node(idx).params.state_to_index(¤t_state[idx]), + net.get_node(idx).state_to_index(¤t_state[idx]), net.get_param_index_network(idx, ¤t_state), &mut rng, ) @@ -137,10 +135,8 @@ pub fn trajectory_generator( //Compute the new state of the transitioning variable. current_state[next_node_transition] = net .get_node(next_node_transition) - .params .get_random_state( net.get_node(next_node_transition) - .params .state_to_index(¤t_state[next_node_transition]), net.get_param_index_network(next_node_transition, ¤t_state), &mut rng, diff --git a/tests/ctbn.rs b/tests/ctbn.rs index 1458637..e5cad1e 100644 --- a/tests/ctbn.rs +++ b/tests/ctbn.rs @@ -1,10 +1,9 @@ mod utils; -use utils::generate_discrete_time_continous_node; +use reCTBN::ctbn::*; use reCTBN::network::Network; -use reCTBN::node; -use reCTBN::params; +use reCTBN::params::{self, ParamsTrait}; use std::collections::BTreeSet; -use reCTBN::ctbn::*; +use utils::generate_discrete_time_continous_node; #[test] fn define_simpe_ctbn() { @@ -15,15 +14,21 @@ fn define_simpe_ctbn() { #[test] fn add_node_to_ctbn() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - assert_eq!(String::from("n1"), net.get_node(n1).label); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + assert_eq!(&String::from("n1"), net.get_node(n1).get_label()); } #[test] fn add_edge_to_ctbn() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); net.add_edge(n1, n2); let cs = net.get_children_set(n1); assert_eq!(&n2, cs.iter().next().unwrap()); @@ -32,8 +37,12 @@ fn add_edge_to_ctbn() { #[test] fn children_and_parents() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); net.add_edge(n1, n2); let cs = net.get_children_set(n1); assert_eq!(&n2, cs.iter().next().unwrap()); @@ -41,59 +50,81 @@ fn children_and_parents() { assert_eq!(&n1, ps.iter().next().unwrap()); } - #[test] fn compute_index_ctbn() { let mut net = CtbnNetwork::new(); - let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); + let n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); net.add_edge(n1, n2); net.add_edge(n3, n2); - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(1), - params::StateType::Discrete(1), - params::StateType::Discrete(1)]); + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(1), + params::StateType::Discrete(1), + ], + ); assert_eq!(3, idx); - - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(1), - params::StateType::Discrete(1)]); + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(1), + params::StateType::Discrete(1), + ], + ); assert_eq!(2, idx); - - let idx = net.get_param_index_network(n2, &vec![ - params::StateType::Discrete(1), - params::StateType::Discrete(1), - params::StateType::Discrete(0)]); + let idx = net.get_param_index_network( + n2, + &vec![ + params::StateType::Discrete(1), + params::StateType::Discrete(1), + params::StateType::Discrete(0), + ], + ); assert_eq!(1, idx); - } - - #[test] fn compute_index_from_custom_parent_set() { let mut net = CtbnNetwork::new(); - let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); - let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); - let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); - - - let idx = net.get_param_index_from_custom_parent_set(&vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(0), - params::StateType::Discrete(1)], - &BTreeSet::from([1])); + let _n1 = net + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) + .unwrap(); + let _n2 = net + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) + .unwrap(); + let _n3 = net + .add_node(generate_discrete_time_continous_node(String::from("n3"), 2)) + .unwrap(); + + let idx = net.get_param_index_from_custom_parent_set( + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1), + ], + &BTreeSet::from([1]), + ); assert_eq!(0, idx); - - let idx = net.get_param_index_from_custom_parent_set(&vec![ - params::StateType::Discrete(0), - params::StateType::Discrete(0), - params::StateType::Discrete(1)], - &BTreeSet::from([1,2])); + let idx = net.get_param_index_from_custom_parent_set( + &vec![ + params::StateType::Discrete(0), + params::StateType::Discrete(0), + params::StateType::Discrete(1), + ], + &BTreeSet::from([1, 2]), + ); assert_eq!(2, idx); } diff --git a/tests/parameter_learning.rs b/tests/parameter_learning.rs index 4e22c14..cd980d0 100644 --- a/tests/parameter_learning.rs +++ b/tests/parameter_learning.rs @@ -1,263 +1,365 @@ mod utils; use utils::*; -use reCTBN::parameter_learning::*; +use ndarray::arr3; use reCTBN::ctbn::*; use reCTBN::network::Network; -use reCTBN::node; -use reCTBN::params; -use reCTBN::tools::*; -use ndarray::arr3; +use reCTBN::parameter_learning::*; +use reCTBN::{params, tools::*}; use std::collections::BTreeSet; - #[macro_use] extern crate approx; - -fn learn_binary_cim (pl: T) { +fn learn_binary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 2)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 1.0], [4.0, -4.0]], + [[-6.0, 6.0], [2.0, -2.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [2, 2, 2]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 1.0], [4.0, -4.0]], - [[-6.0, 6.0], [2.0, -2.0]], - ]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), + 0.1 + )); } #[test] fn learn_binary_cim_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_binary_cim(mle); } - #[test] fn learn_binary_cim_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_binary_cim(ba); } -fn learn_ternary_cim (pl: T) { +fn learn_ternary_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 1, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [3, 3, 3]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ]), + 0.1 + )); } - #[test] fn learn_ternary_cim_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_ternary_cim(mle); } - #[test] fn learn_ternary_cim_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim(ba); } -fn learn_ternary_cim_no_parents (pl: T) { +fn learn_ternary_cim_no_parents(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 0, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [1, 3, 3]); - assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), + 0.1 + )); } - #[test] fn learn_ternary_cim_no_parents_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_ternary_cim_no_parents(mle); } - #[test] fn learn_ternary_cim_no_parents_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_ternary_cim_no_parents(ba); } - -fn learn_mixed_discrete_cim (pl: T) { +fn learn_mixed_discrete_cim(pl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); let n3 = net - .add_node(generate_discrete_time_continous_node(String::from("n3"),4)) + .add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) .unwrap(); net.add_edge(n1, n2); net.add_edge(n1, n3); net.add_edge(n2, n3); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - - match &mut net.get_node_mut(n3).params { + match &mut net.get_node_mut(n3) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ])) + ); } } - - let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); let (CIM, M, T) = pl.fit(&net, &data, 2, None); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); assert_eq!(CIM.shape(), [9, 4, 4]); - assert!(CIM.abs_diff_eq(&arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]), 0.1)); + assert!(CIM.abs_diff_eq( + &arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ]), + 0.1 + )); } #[test] fn learn_mixed_discrete_cim_MLE() { - let mle = MLE{}; + let mle = MLE {}; learn_mixed_discrete_cim(mle); } - #[test] fn learn_mixed_discrete_cim_BA() { - let ba = BayesianApproach{ - alpha: 1, - tau: 1.0}; + let ba = BayesianApproach { alpha: 1, tau: 1.0 }; learn_mixed_discrete_cim(ba); } diff --git a/tests/params.rs b/tests/params.rs index fab150b..c002d7b 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,16 +1,15 @@ use ndarray::prelude::*; -use reCTBN::params::*; -use std::collections::BTreeSet; -use rand_chacha::ChaCha8Rng; -use rand_chacha::rand_core::SeedableRng; +use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; +use reCTBN::params::{ParamsTrait, *}; mod utils; #[macro_use] extern crate approx; + fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { - let mut params = utils::generate_discrete_time_continous_param(3); + let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; @@ -18,6 +17,12 @@ fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTime params } +#[test] +fn test_get_label() { + let param = create_ternary_discrete_time_continous_param(); + assert_eq!(&String::from("A"), param.get_label()) +} + #[test] fn test_uniform_generation() { let param = create_ternary_discrete_time_continous_param(); @@ -79,15 +84,19 @@ fn test_validate_params_valid_cim() { #[test] fn test_validate_params_valid_cim_with_huge_values() { - let mut param = utils::generate_discrete_time_continous_param(3); - let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]]; + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); + let cim = array![[ + [-2e10, 1e10, 1e10], + [1.5e10, -3e10, 1.5e10], + [1e10, 1e10, -2e10] + ]]; let result = param.set_cim(cim); assert_eq!(Ok(()), result); } #[test] fn test_validate_params_cim_not_initialized() { - let param = utils::generate_discrete_time_continous_param(3); + let param = utils::generate_discrete_time_continous_params("A".to_string(), 3); assert_eq!( Err(ParamsError::ParametersNotInitialized(String::from( "CIM not initialized", @@ -98,7 +107,7 @@ fn test_validate_params_cim_not_initialized() { #[test] fn test_validate_params_wrong_shape() { - let mut param = utils::generate_discrete_time_continous_param(4); + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let result = param.set_cim(cim); assert_eq!( @@ -111,7 +120,7 @@ fn test_validate_params_wrong_shape() { #[test] fn test_validate_params_positive_diag() { - let mut param = utils::generate_discrete_time_continous_param(3); + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let result = param.set_cim(cim); assert_eq!( @@ -124,7 +133,7 @@ fn test_validate_params_positive_diag() { #[test] fn test_validate_params_row_not_sum_to_zero() { - let mut param = utils::generate_discrete_time_continous_param(3); + let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3); let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; let result = param.set_cim(cim); assert_eq!( diff --git a/tests/structure_learning.rs b/tests/structure_learning.rs index 42f948a..c91f508 100644 --- a/tests/structure_learning.rs +++ b/tests/structure_learning.rs @@ -1,17 +1,14 @@ - mod utils; use utils::*; +use ndarray::{arr1, arr2, arr3}; use reCTBN::ctbn::*; use reCTBN::network::Network; -use reCTBN::tools::*; +use reCTBN::params; use reCTBN::structure_learning::score_function::*; -use reCTBN::structure_learning::score_based_algorithm::*; -use reCTBN::structure_learning::StructureLearningAlgorithm; -use ndarray::{arr1, arr2, arr3}; +use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm}; +use reCTBN::tools::*; use std::collections::BTreeSet; -use reCTBN::params; - #[macro_use] extern crate approx; @@ -20,80 +17,86 @@ extern crate approx; fn simple_score_test() { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); - let trj = Trajectory::new( - arr1(&[0.0,0.1,0.3]), - arr2(&[[0],[1],[1]])); + let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); let dataset = Dataset::new(vec![trj]); let ll = LogLikelihood::new(1, 1.0); - assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); - + assert_abs_diff_eq!( + 0.04257, + ll.call(&net, n1, &BTreeSet::new(), &dataset), + epsilon = 1e-3 + ); } - #[test] fn simple_bic() { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),2)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .unwrap(); - let trj = Trajectory::new( - arr1(&[0.0,0.1,0.3]), - arr2(&[[0],[1],[1]])); + let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]])); let dataset = Dataset::new(vec![trj]); let bic = BIC::new(1, 1.0); - assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); - + assert_abs_diff_eq!( + -0.65058, + bic.call(&net, n1, &BTreeSet::new(), &dataset), + epsilon = 1e-3 + ); } - - -fn check_compatibility_between_dataset_and_network (sl: T) { +fn check_compatibility_between_dataset_and_network(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); let mut net = CtbnNetwork::new(); let _n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let net = sl.fit_transform(net, &data); } - #[test] #[should_panic] pub fn check_compatibility_between_dataset_and_network_hill_climbing() { @@ -102,42 +105,49 @@ pub fn check_compatibility_between_dataset_and_network_hill_climbing() { check_compatibility_between_dataset_and_network(hl); } -fn learn_ternary_net_2_nodes (sl: T) { +fn learn_ternary_net_2_nodes(sl: T) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); net.add_edge(n1, n2); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259)); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); } - #[test] pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -152,66 +162,117 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() { learn_ternary_net_2_nodes(hl); } - fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { let mut net = CtbnNetwork::new(); let n1 = net - .add_node(generate_discrete_time_continous_node(String::from("n1"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n1"), 3)) .unwrap(); let n2 = net - .add_node(generate_discrete_time_continous_node(String::from("n2"),3)) + .add_node(generate_discrete_time_continous_node(String::from("n2"), 3)) .unwrap(); let n3 = net - .add_node(generate_discrete_time_continous_node(String::from("n3"),4)) + .add_node(generate_discrete_time_continous_node(String::from("n3"), 4)) .unwrap(); net.add_edge(n1, n2); net.add_edge(n1, n3); net.add_edge(n2, n3); - match &mut net.get_node_mut(n1).params { + match &mut net.get_node_mut(n1) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], - [1.5, -2.0, 0.5], - [0.4, 0.6, -1.0]]]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[[ + [-3.0, 2.0, 1.0], + [1.5, -2.0, 0.5], + [0.4, 0.6, -1.0] + ]])) + ); } } - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], - [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], - [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], + [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], + [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], + ])) + ); } } - - match &mut net.get_node_mut(n3).params { + match &mut net.get_node_mut(n3) { params::Params::DiscreteStatesContinousTime(param) => { - assert_eq!(Ok(()), param.set_cim(arr3(&[ - [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], - [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], - [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], - - [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], - [[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], - [[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], - - [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], - [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], - [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], - ]))); + assert_eq!( + Ok(()), + param.set_cim(arr3(&[ + [ + [-1.0, 0.5, 0.3, 0.2], + [0.5, -4.0, 2.5, 1.0], + [2.5, 0.5, -4.0, 1.0], + [0.7, 0.2, 0.1, -1.0] + ], + [ + [-6.0, 2.0, 3.0, 1.0], + [1.5, -3.0, 0.5, 1.0], + [2.0, 1.3, -5.0, 1.7], + [2.5, 0.5, 1.0, -4.0] + ], + [ + [-1.3, 0.3, 0.1, 0.9], + [1.4, -4.0, 0.5, 2.1], + [1.0, 1.5, -3.0, 0.5], + [0.4, 0.3, 0.1, -0.8] + ], + [ + [-2.0, 1.0, 0.7, 0.3], + [1.3, -5.9, 2.7, 1.9], + [2.0, 1.5, -4.0, 0.5], + [0.2, 0.7, 0.1, -1.0] + ], + [ + [-6.0, 1.0, 2.0, 3.0], + [0.5, -3.0, 1.0, 1.5], + [1.4, 2.1, -4.3, 0.8], + [0.5, 1.0, 2.5, -4.0] + ], + [ + [-1.3, 0.9, 0.3, 0.1], + [0.1, -1.3, 0.2, 1.0], + [0.5, 1.0, -3.0, 1.5], + [0.1, 0.4, 0.3, -0.8] + ], + [ + [-2.0, 1.0, 0.6, 0.4], + [2.6, -7.1, 1.4, 3.1], + [5.0, 1.0, -8.0, 2.0], + [1.4, 0.4, 0.2, -2.0] + ], + [ + [-3.0, 1.0, 1.5, 0.5], + [3.0, -6.0, 1.0, 2.0], + [0.3, 0.5, -1.9, 1.1], + [5.0, 1.0, 2.0, -8.0] + ], + [ + [-2.6, 0.6, 0.2, 1.8], + [2.0, -6.0, 3.0, 1.0], + [0.1, 0.5, -1.3, 0.7], + [0.8, 0.6, 0.2, -1.6] + ], + ])) + ); } } - - let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259)); return (net, data); } -fn learn_mixed_discrete_net_3_nodes (sl: T) { +fn learn_mixed_discrete_net_3_nodes(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -219,7 +280,6 @@ fn learn_mixed_discrete_net_3_nodes (sl: T) { assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); } - #[test] pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { let ll = LogLikelihood::new(1, 1.0); @@ -234,9 +294,7 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() { learn_mixed_discrete_net_3_nodes(hl); } - - -fn learn_mixed_discrete_net_3_nodes_1_parent_constraint (sl: T) { +fn learn_mixed_discrete_net_3_nodes_1_parent_constraint(sl: T) { let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let net = sl.fit_transform(net, &data); assert_eq!(BTreeSet::new(), net.get_parent_set(0)); @@ -244,7 +302,6 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint { - param.set_cim(arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); + param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); } } - - match &mut net.get_node_mut(n2).params { + match &mut net.get_node_mut(n2) { params::Params::DiscreteStatesContinousTime(param) => { param.set_cim(arr3(&[ - [[-1.0,1.0],[4.0,-4.0]], - [[-6.0,6.0],[2.0,-2.0]]])); + [[-1.0, 1.0], [4.0, -4.0]], + [[-6.0, 6.0], [2.0, -2.0]], + ])); } } - let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); + let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259)); assert_eq!(4, data.get_trajectories().len()); - assert_relative_eq!(1.0, data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len()-1]); + assert_relative_eq!( + 1.0, + data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1] + ); } - #[test] #[should_panic] - fn trajectory_wrong_shape() { +fn trajectory_wrong_shape() { let time = arr1(&[0.0, 0.2]); - let events = arr2(&[[0,3]]); + let events = arr2(&[[0, 3]]); Trajectory::new(time, events); } - #[test] #[should_panic] fn dataset_wrong_shape() { let time = arr1(&[0.0, 0.2]); - let events = arr2(&[[0,3], [1,2]]); + let events = arr2(&[[0, 3], [1, 2]]); let t1 = Trajectory::new(time, events); - let time = arr1(&[0.0, 0.2]); - let events = arr2(&[[0,3,3], [1,2,3]]); + let events = arr2(&[[0, 3, 3], [1, 2, 3]]); let t2 = Trajectory::new(time, events); Dataset::new(vec![t1, t2]); } diff --git a/tests/utils.rs b/tests/utils.rs index e9e5176..8648c46 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -1,16 +1,17 @@ use reCTBN::params; -use reCTBN::node; use std::collections::BTreeSet; -pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { - node::Node::new(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) +pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params { + params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality)) } -pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ +pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ let domain: BTreeSet = (0..cardinality).map(|x| x.to_string()).collect(); - params::DiscreteStatesContinousTimeParams::new(domain) + params::DiscreteStatesContinousTimeParams::new(label, domain) } + +