Removed `node.rs`

pull/45/head
Alessandro Bregoli 3 years ago
parent 42c457cf32
commit 9b7e683630
  1. 32
      src/ctbn.rs
  2. 1
      src/lib.rs
  3. 7
      src/network.rs
  4. 25
      src/node.rs
  5. 2
      src/parameter_learning.rs
  6. 12
      src/params.rs
  7. 2
      src/structure_learning/score_function.rs
  8. 8
      src/tools.rs
  9. 105
      tests/ctbn.rs
  10. 282
      tests/parameter_learning.rs
  11. 31
      tests/params.rs
  12. 185
      tests/structure_learning.rs
  13. 42
      tests/tools.rs
  14. 11
      tests/utils.rs

@ -1,6 +1,5 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::node; use crate::params::{StateType, Params, ParamsTrait};
use crate::params::{StateType, ParamsTrait};
use crate::network; use crate::network;
use std::collections::BTreeSet; use std::collections::BTreeSet;
@ -19,7 +18,6 @@ use std::collections::BTreeSet;
/// ///
/// use std::collections::BTreeSet; /// use std::collections::BTreeSet;
/// use reCTBN::network::Network; /// use reCTBN::network::Network;
/// use reCTBN::node;
/// use reCTBN::params; /// use reCTBN::params;
/// use reCTBN::ctbn::*; /// use reCTBN::ctbn::*;
/// ///
@ -29,16 +27,16 @@ use std::collections::BTreeSet;
/// domain.insert(String::from("B")); /// domain.insert(String::from("B"));
/// ///
/// //Create the parameters for a discrete node using the domain /// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new(domain); /// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
/// ///
/// //Create the node using the parameters /// //Create the node using the parameters
/// let X1 = node::Node::new(params::Params::DiscreteStatesContinousTime(param),String::from("X1")); /// let X1 = params::Params::DiscreteStatesContinousTime(param);
/// ///
/// let mut domain = BTreeSet::new(); /// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A")); /// domain.insert(String::from("A"));
/// domain.insert(String::from("B")); /// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new(domain); /// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = node::Node::new(params::Params::DiscreteStatesContinousTime(param), String::from("X2")); /// let X2 = params::Params::DiscreteStatesContinousTime(param);
/// ///
/// //Initialize a ctbn /// //Initialize a ctbn
/// let mut net = CtbnNetwork::new(); /// let mut net = CtbnNetwork::new();
@ -56,7 +54,7 @@ use std::collections::BTreeSet;
/// ``` /// ```
pub struct CtbnNetwork { pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>, adj_matrix: Option<Array2<u16>>,
nodes: Vec<node::Node> nodes: Vec<Params>
} }
@ -75,8 +73,8 @@ impl network::Network for CtbnNetwork {
} }
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> { fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.params.reset_params(); n.reset_params();
self.adj_matrix = Option::None; self.adj_matrix = Option::None;
self.nodes.push(n); self.nodes.push(n);
Ok(self.nodes.len() -1) Ok(self.nodes.len() -1)
@ -89,7 +87,7 @@ impl network::Network for CtbnNetwork {
if let Some(network) = &mut self.adj_matrix { if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1; network[[parent, child]] = 1;
self.nodes[child].params.reset_params(); self.nodes[child].reset_params();
} }
} }
@ -101,12 +99,12 @@ impl network::Network for CtbnNetwork {
self.nodes.len() self.nodes.len()
} }
fn get_node(&self, node_idx: usize) -> &node::Node{ fn get_node(&self, node_idx: usize) -> &Params{
&self.nodes[node_idx] &self.nodes[node_idx]
} }
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{ fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{
&mut self.nodes[node_idx] &mut self.nodes[node_idx]
} }
@ -114,8 +112,8 @@ impl network::Network for CtbnNetwork {
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{ fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| { self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
if x.1 > &0 { if x.1 > &0 {
acc.0 += self.nodes[x.0].params.state_to_index(&current_state[x.0]) * acc.1; acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent(); acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
} }
acc acc
}).0 }).0
@ -124,8 +122,8 @@ impl network::Network for CtbnNetwork {
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize { fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize {
parent_set.iter().fold((0, 1), |mut acc, x| { parent_set.iter().fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].params.state_to_index(&current_state[*x]) * acc.1; acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent(); acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc acc
}).0 }).0
} }

@ -2,7 +2,6 @@
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
pub mod node;
pub mod params; pub mod params;
pub mod network; pub mod network;
pub mod ctbn; pub mod ctbn;

@ -1,6 +1,5 @@
use thiserror::Error; use thiserror::Error;
use crate::params; use crate::params;
use crate::node;
use std::collections::BTreeSet; use std::collections::BTreeSet;
/// Error types for trait Network /// Error types for trait Network
@ -15,14 +14,14 @@ pub enum NetworkError {
///The Network trait define the required methods for a structure used as pgm (such as ctbn). ///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network { pub trait Network {
fn initialize_adj_matrix(&mut self); fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: node::Node) -> Result<usize, NetworkError>; fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize); fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network ///Get all the indices of the nodes contained inside the network
fn get_node_indices(&self) -> std::ops::Range<usize>; fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_number_of_nodes(&self) -> usize; fn get_number_of_nodes(&self) -> usize;
fn get_node(&self, node_idx: usize) -> &node::Node; fn get_node(&self, node_idx: usize) -> &params::Params;
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node; fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
///Compute the index that must be used to access the parameters of a node given a specific ///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are ///configuration of the network. Usually, the only values really used in *current_state* are

@ -1,25 +0,0 @@
use crate::params::*;
pub struct Node {
pub params: Params,
pub label: String
}
impl Node {
pub fn new(params: Params, label: String) -> Node {
Node{
params: params,
label:label
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Node) -> bool{
self.label == other.label
}
}

@ -24,7 +24,6 @@ pub fn sufficient_statistics<T:network::Network>(
//Get the number of values assumable by the node //Get the number of values assumable by the node
let node_domain = net let node_domain = net
.get_node(node.clone()) .get_node(node.clone())
.params
.get_reserved_space_as_parent(); .get_reserved_space_as_parent();
//Get the number of values assumable by each parent of the node //Get the number of values assumable by each parent of the node
@ -32,7 +31,6 @@ pub fn sufficient_statistics<T:network::Network>(
.iter() .iter()
.map(|x| { .map(|x| {
net.get_node(x.clone()) net.get_node(x.clone())
.params
.get_reserved_space_as_parent() .get_reserved_space_as_parent()
}) })
.collect(); .collect();

@ -49,6 +49,9 @@ pub trait ParamsTrait {
/// Validate parameters against domain /// Validate parameters against domain
fn validate_params(&self) -> Result<(), ParamsError>; fn validate_params(&self) -> Result<(), ParamsError>;
/// Return a reference to the associated label
fn get_label(&self) -> &String;
} }
/// The Params enum is the core element for building different types of nodes. The goal is to /// The Params enum is the core element for building different types of nodes. The goal is to
@ -70,6 +73,7 @@ pub enum Params {
/// - **residence_time**: permanence time in each possible states given a specific /// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set /// realization of the parent set
pub struct DiscreteStatesContinousTimeParams { pub struct DiscreteStatesContinousTimeParams {
label: String,
domain: BTreeSet<String>, domain: BTreeSet<String>,
cim: Option<Array3<f64>>, cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>, transitions: Option<Array3<u64>>,
@ -77,8 +81,9 @@ pub struct DiscreteStatesContinousTimeParams {
} }
impl DiscreteStatesContinousTimeParams { impl DiscreteStatesContinousTimeParams {
pub fn new(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams { pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams {
label,
domain, domain,
cim: Option::None, cim: Option::None,
transitions: Option::None, transitions: Option::None,
@ -244,4 +249,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
return Ok(()); return Ok(());
} }
fn get_label(&self) -> &String {
&self.label
}
} }

@ -44,7 +44,7 @@ impl LogLikelihood {
T: network::Network, T: network::Network,
{ {
//Identify the type of node used //Identify the type of node used
match &net.get_node(node).params { match &net.get_node(node){
params::Params::DiscreteStatesContinousTime(_params) => { params::Params::DiscreteStatesContinousTime(_params) => {
//Compute the sufficient statistics M (number of transistions) and T (residence //Compute the sufficient statistics M (number of transistions) and T (residence
//time) //time)

@ -1,5 +1,4 @@
use crate::network; use crate::network;
use crate::node;
use crate::params; use crate::params;
use crate::params::ParamsTrait; use crate::params::ParamsTrait;
use ndarray::prelude::*; use ndarray::prelude::*;
@ -80,7 +79,7 @@ pub fn trajectory_generator<T: network::Network>(
//Configuration of the process variables at time t initialized with an uniform //Configuration of the process variables at time t initialized with an uniform
//distribution. //distribution.
let mut current_state: Vec<params::StateType> = net.get_node_indices() let mut current_state: Vec<params::StateType> = net.get_node_indices()
.map(|x| net.get_node(x).params.get_random_state_uniform(&mut rng)) .map(|x| net.get_node(x).get_random_state_uniform(&mut rng))
.collect(); .collect();
//History of all the configurations of the process variables. //History of all the configurations of the process variables.
let mut events: Vec<Array1<usize>> = Vec::new(); let mut events: Vec<Array1<usize>> = Vec::new();
@ -106,9 +105,8 @@ pub fn trajectory_generator<T: network::Network>(
if let None = val { if let None = val {
*val = Some( *val = Some(
net.get_node(idx) net.get_node(idx)
.params
.get_random_residence_time( .get_random_residence_time(
net.get_node(idx).params.state_to_index(&current_state[idx]), net.get_node(idx).state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state), net.get_param_index_network(idx, &current_state),
&mut rng, &mut rng,
) )
@ -137,10 +135,8 @@ pub fn trajectory_generator<T: network::Network>(
//Compute the new state of the transitioning variable. //Compute the new state of the transitioning variable.
current_state[next_node_transition] = net current_state[next_node_transition] = net
.get_node(next_node_transition) .get_node(next_node_transition)
.params
.get_random_state( .get_random_state(
net.get_node(next_node_transition) net.get_node(next_node_transition)
.params
.state_to_index(&current_state[next_node_transition]), .state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state), net.get_param_index_network(next_node_transition, &current_state),
&mut rng, &mut rng,

@ -1,10 +1,9 @@
mod utils; mod utils;
use utils::generate_discrete_time_continous_node; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::node; use reCTBN::params::{self, ParamsTrait};
use reCTBN::params;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use reCTBN::ctbn::*; use utils::generate_discrete_time_continous_node;
#[test] #[test]
fn define_simpe_ctbn() { fn define_simpe_ctbn() {
@ -15,15 +14,21 @@ fn define_simpe_ctbn() {
#[test] #[test]
fn add_node_to_ctbn() { fn add_node_to_ctbn() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n1 = net
assert_eq!(String::from("n1"), net.get_node(n1).label); .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
} }
#[test] #[test]
fn add_edge_to_ctbn() { fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n1 = net
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
let cs = net.get_children_set(n1); let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap()); assert_eq!(&n2, cs.iter().next().unwrap());
@ -32,8 +37,12 @@ fn add_edge_to_ctbn() {
#[test] #[test]
fn children_and_parents() { fn children_and_parents() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n1 = net
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
let cs = net.get_children_set(n1); let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap()); assert_eq!(&n2, cs.iter().next().unwrap());
@ -41,59 +50,81 @@ fn children_and_parents() {
assert_eq!(&n1, ps.iter().next().unwrap()); assert_eq!(&n1, ps.iter().next().unwrap());
} }
#[test] #[test]
fn compute_index_ctbn() { fn compute_index_ctbn() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n1 = net
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); .unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
net.add_edge(n3, n2); net.add_edge(n3, n2);
let idx = net.get_param_index_network(n2, &vec![ let idx = net.get_param_index_network(
n2,
&vec![
params::StateType::Discrete(1), params::StateType::Discrete(1),
params::StateType::Discrete(1), params::StateType::Discrete(1),
params::StateType::Discrete(1)]); params::StateType::Discrete(1),
],
);
assert_eq!(3, idx); assert_eq!(3, idx);
let idx = net.get_param_index_network(
let idx = net.get_param_index_network(n2, &vec![ n2,
&vec![
params::StateType::Discrete(0), params::StateType::Discrete(0),
params::StateType::Discrete(1), params::StateType::Discrete(1),
params::StateType::Discrete(1)]); params::StateType::Discrete(1),
],
);
assert_eq!(2, idx); assert_eq!(2, idx);
let idx = net.get_param_index_network(
let idx = net.get_param_index_network(n2, &vec![ n2,
&vec![
params::StateType::Discrete(1), params::StateType::Discrete(1),
params::StateType::Discrete(1), params::StateType::Discrete(1),
params::StateType::Discrete(0)]); params::StateType::Discrete(0),
],
);
assert_eq!(1, idx); assert_eq!(1, idx);
} }
#[test] #[test]
fn compute_index_from_custom_parent_set() { fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let _n1 = net
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); .unwrap();
let _n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
let idx = net.get_param_index_from_custom_parent_set(&vec![ .unwrap();
let _n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let idx = net.get_param_index_from_custom_parent_set(
&vec![
params::StateType::Discrete(0), params::StateType::Discrete(0),
params::StateType::Discrete(0), params::StateType::Discrete(0),
params::StateType::Discrete(1)], params::StateType::Discrete(1),
&BTreeSet::from([1])); ],
&BTreeSet::from([1]),
);
assert_eq!(0, idx); assert_eq!(0, idx);
let idx = net.get_param_index_from_custom_parent_set(
let idx = net.get_param_index_from_custom_parent_set(&vec![ &vec![
params::StateType::Discrete(0), params::StateType::Discrete(0),
params::StateType::Discrete(0), params::StateType::Discrete(0),
params::StateType::Discrete(1)], params::StateType::Discrete(1),
&BTreeSet::from([1,2])); ],
&BTreeSet::from([1, 2]),
);
assert_eq!(2, idx); assert_eq!(2, idx);
} }

@ -1,20 +1,16 @@
mod utils; mod utils;
use utils::*; use utils::*;
use reCTBN::parameter_learning::*; use ndarray::arr3;
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::node; use reCTBN::parameter_learning::*;
use reCTBN::params; use reCTBN::{params, tools::*};
use reCTBN::tools::*;
use ndarray::arr3;
use std::collections::BTreeSet; use std::collections::BTreeSet;
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
fn learn_binary_cim<T: ParameterLearning>(pl: T) { fn learn_binary_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -25,29 +21,32 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
.unwrap(); .unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]))); assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]], [[-6.0, 6.0], [2.0, -2.0]],
]))); ]))
);
} }
} }
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),); let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]); assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(&arr3(&[ assert!(CIM.abs_diff_eq(
[[-1.0, 1.0], [4.0, -4.0]], &arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]),
[[-6.0, 6.0], [2.0, -2.0]], 0.1
]), 0.1)); ));
} }
#[test] #[test]
@ -56,12 +55,9 @@ fn learn_binary_cim_MLE() {
learn_binary_cim(mle); learn_binary_cim(mle);
} }
#[test] #[test]
fn learn_binary_cim_BA() { fn learn_binary_cim_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
alpha: 1,
tau: 1.0};
learn_binary_cim(ba); learn_binary_cim(ba);
} }
@ -75,48 +71,55 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
.unwrap(); .unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]))); [0.4, 0.6, -1.0]
]]))
);
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]))); ]))
);
} }
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None); let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]); assert_eq!(CIM.shape(), [3, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[ assert!(CIM.abs_diff_eq(
&arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.1)); ]),
0.1
));
} }
#[test] #[test]
fn learn_ternary_cim_MLE() { fn learn_ternary_cim_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_ternary_cim(mle); learn_ternary_cim(mle);
} }
#[test] #[test]
fn learn_ternary_cim_BA() { fn learn_ternary_cim_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
alpha: 1,
tau: 1.0};
learn_ternary_cim(ba); learn_ternary_cim(ba);
} }
@ -130,50 +133,54 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
.unwrap(); .unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]))); [0.4, 0.6, -1.0]
]]))
);
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]))); ]))
);
} }
} }
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),); let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 0, None); let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]); assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0], assert!(CIM.abs_diff_eq(
[1.5, -2.0, 0.5], &arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]),
[0.4, 0.6, -1.0]]]), 0.1)); 0.1
));
} }
#[test] #[test]
fn learn_ternary_cim_no_parents_MLE() { fn learn_ternary_cim_no_parents_MLE() {
let mle = MLE {}; let mle = MLE {};
learn_ternary_cim_no_parents(mle); learn_ternary_cim_no_parents(mle);
} }
#[test] #[test]
fn learn_ternary_cim_no_parents_BA() { fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
alpha: 1,
tau: 1.0};
learn_ternary_cim_no_parents(ba); learn_ternary_cim_no_parents(ba);
} }
fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) { fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -190,61 +197,159 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
net.add_edge(n1, n3); net.add_edge(n1, n3);
net.add_edge(n2, n3); net.add_edge(n2, n3);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]))); [0.4, 0.6, -1.0]
]]))
);
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]))); ]))
);
} }
} }
match &mut net.get_node_mut(n3) {
match &mut net.get_node_mut(n3).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], Ok(()),
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], param.set_cim(arr3(&[
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [
[-1.0, 0.5, 0.3, 0.2],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], [0.5, -4.0, 2.5, 1.0],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], [2.5, 0.5, -4.0, 1.0],
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], [0.7, 0.2, 0.1, -1.0]
],
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [-6.0, 2.0, 3.0, 1.0],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], [1.5, -3.0, 0.5, 1.0],
]))); [2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]))
);
} }
} }
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),);
let (CIM, M, T) = pl.fit(&net, &data, 2, None); let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]); assert_eq!(CIM.shape(), [9, 4, 4]);
assert!(CIM.abs_diff_eq(&arr3(&[ assert!(CIM.abs_diff_eq(
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], &arr3(&[
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], [
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [-1.0, 0.5, 0.3, 0.2],
[0.5, -4.0, 2.5, 1.0],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], [2.5, 0.5, -4.0, 1.0],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], [0.7, 0.2, 0.1, -1.0]
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], ],
[
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [-6.0, 2.0, 3.0, 1.0],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [1.5, -3.0, 0.5, 1.0],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], [2.0, 1.3, -5.0, 1.7],
]), 0.1)); [2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]),
0.1
));
} }
#[test] #[test]
@ -253,11 +358,8 @@ fn learn_mixed_discrete_cim_MLE() {
learn_mixed_discrete_cim(mle); learn_mixed_discrete_cim(mle);
} }
#[test] #[test]
fn learn_mixed_discrete_cim_BA() { fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach{ let ba = BayesianApproach { alpha: 1, tau: 1.0 };
alpha: 1,
tau: 1.0};
learn_mixed_discrete_cim(ba); learn_mixed_discrete_cim(ba);
} }

@ -1,16 +1,15 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use reCTBN::params::*; use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
use std::collections::BTreeSet; use reCTBN::params::{ParamsTrait, *};
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
mod utils; mod utils;
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut params = utils::generate_discrete_time_continous_param(3); let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
@ -18,6 +17,12 @@ fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTime
params params
} }
#[test]
fn test_get_label() {
let param = create_ternary_discrete_time_continous_param();
assert_eq!(&String::from("A"), param.get_label())
}
#[test] #[test]
fn test_uniform_generation() { fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param(); let param = create_ternary_discrete_time_continous_param();
@ -79,15 +84,19 @@ fn test_validate_params_valid_cim() {
#[test] #[test]
fn test_validate_params_valid_cim_with_huge_values() { fn test_validate_params_valid_cim_with_huge_values() {
let mut param = utils::generate_discrete_time_continous_param(3); let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]]; let cim = array![[
[-2e10, 1e10, 1e10],
[1.5e10, -3e10, 1.5e10],
[1e10, 1e10, -2e10]
]];
let result = param.set_cim(cim); let result = param.set_cim(cim);
assert_eq!(Ok(()), result); assert_eq!(Ok(()), result);
} }
#[test] #[test]
fn test_validate_params_cim_not_initialized() { fn test_validate_params_cim_not_initialized() {
let param = utils::generate_discrete_time_continous_param(3); let param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
assert_eq!( assert_eq!(
Err(ParamsError::ParametersNotInitialized(String::from( Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized", "CIM not initialized",
@ -98,7 +107,7 @@ fn test_validate_params_cim_not_initialized() {
#[test] #[test]
fn test_validate_params_wrong_shape() { fn test_validate_params_wrong_shape() {
let mut param = utils::generate_discrete_time_continous_param(4); let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim); let result = param.set_cim(cim);
assert_eq!( assert_eq!(
@ -111,7 +120,7 @@ fn test_validate_params_wrong_shape() {
#[test] #[test]
fn test_validate_params_positive_diag() { fn test_validate_params_positive_diag() {
let mut param = utils::generate_discrete_time_continous_param(3); let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim); let result = param.set_cim(cim);
assert_eq!( assert_eq!(
@ -124,7 +133,7 @@ fn test_validate_params_positive_diag() {
#[test] #[test]
fn test_validate_params_row_not_sum_to_zero() { fn test_validate_params_row_not_sum_to_zero() {
let mut param = utils::generate_discrete_time_continous_param(3); let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]];
let result = param.set_cim(cim); let result = param.set_cim(cim);
assert_eq!( assert_eq!(

@ -1,17 +1,14 @@
mod utils; mod utils;
use utils::*; use utils::*;
use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::network::Network; use reCTBN::network::Network;
use reCTBN::tools::*; use reCTBN::params;
use reCTBN::structure_learning::score_function::*; use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::score_based_algorithm::*; use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm};
use reCTBN::structure_learning::StructureLearningAlgorithm; use reCTBN::tools::*;
use ndarray::{arr1, arr2, arr3};
use std::collections::BTreeSet; use std::collections::BTreeSet;
use reCTBN::params;
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
@ -23,19 +20,19 @@ fn simple_score_test() {
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap(); .unwrap();
let trj = Trajectory::new( let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
arr1(&[0.0,0.1,0.3]),
arr2(&[[0],[1],[1]]));
let dataset = Dataset::new(vec![trj]); let dataset = Dataset::new(vec![trj]);
let ll = LogLikelihood::new(1, 1.0); let ll = LogLikelihood::new(1, 1.0);
assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); assert_abs_diff_eq!(
0.04257,
ll.call(&net, n1, &BTreeSet::new(), &dataset),
epsilon = 1e-3
);
} }
#[test] #[test]
fn simple_bic() { fn simple_bic() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
@ -43,19 +40,18 @@ fn simple_bic() {
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2)) .add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap(); .unwrap();
let trj = Trajectory::new( let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
arr1(&[0.0,0.1,0.3]),
arr2(&[[0],[1],[1]]));
let dataset = Dataset::new(vec![trj]); let dataset = Dataset::new(vec![trj]);
let bic = BIC::new(1, 1.0); let bic = BIC::new(1, 1.0);
assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3); assert_abs_diff_eq!(
-0.65058,
bic.call(&net, n1, &BTreeSet::new(), &dataset),
epsilon = 1e-3
);
} }
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) { fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -66,25 +62,33 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
.unwrap(); .unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]))); [0.4, 0.6, -1.0]
]]))
);
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]))); ]))
);
} }
} }
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let _n1 = net let _n1 = net
@ -93,7 +97,6 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
} }
#[test] #[test]
#[should_panic] #[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing() { pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
@ -112,32 +115,39 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
.unwrap(); .unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]))); [0.4, 0.6, -1.0]
]]))
);
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]))); ]))
);
} }
} }
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),); let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2)); assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1)); assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
} }
#[test] #[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() { pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0); let ll = LogLikelihood::new(1, 1.0);
@ -152,7 +162,6 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
learn_ternary_net_2_nodes(hl); learn_ternary_net_2_nodes(hl);
} }
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) { fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net let n1 = net
@ -169,45 +178,97 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
net.add_edge(n1, n3); net.add_edge(n1, n3);
net.add_edge(n2, n3); net.add_edge(n2, n3);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]))); [0.4, 0.6, -1.0]
]]))
);
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]))); ]))
);
} }
} }
match &mut net.get_node_mut(n3) {
match &mut net.get_node_mut(n3).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[ assert_eq!(
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], Ok(()),
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]], param.set_cim(arr3(&[
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [
[-1.0, 0.5, 0.3, 0.2],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], [0.5, -4.0, 2.5, 1.0],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]], [2.5, 0.5, -4.0, 1.0],
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]], [0.7, 0.2, 0.1, -1.0]
],
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [-6.0, 2.0, 3.0, 1.0],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], [1.5, -3.0, 0.5, 1.0],
]))); [2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]))
);
} }
} }
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),);
return (net, data); return (net, data);
} }
@ -219,7 +280,6 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2)); assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
} }
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0); let ll = LogLikelihood::new(1, 1.0);
@ -234,8 +294,6 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
learn_mixed_discrete_net_3_nodes(hl); learn_mixed_discrete_net_3_nodes(hl);
} }
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) { fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data(); let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data); let net = sl.fit_transform(net, &data);
@ -244,7 +302,6 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgo
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2)); assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
} }
#[test] #[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() { pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
let ll = LogLikelihood::new(1, 1.0); let ll = LogLikelihood::new(1, 1.0);

@ -1,13 +1,9 @@
use ndarray::{arr1, arr2, arr3};
use reCTBN::tools::*;
use reCTBN::network::Network;
use reCTBN::ctbn::*; use reCTBN::ctbn::*;
use reCTBN::node; use reCTBN::network::Network;
use reCTBN::params; use reCTBN::params;
use reCTBN::tools::*;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3};
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
@ -17,32 +13,44 @@ mod utils;
#[test] #[test]
fn run_sampling() { fn run_sampling() {
let mut net = CtbnNetwork::new(); let mut net = CtbnNetwork::new();
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); let n1 = net
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); .add_node(utils::generate_discrete_time_continous_node(
String::from("n1"),
2,
))
.unwrap();
let n2 = net
.add_node(utils::generate_discrete_time_continous_node(
String::from("n2"),
2,
))
.unwrap();
net.add_edge(n1, n2); net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]));
} }
} }
match &mut net.get_node_mut(n2) {
match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[ param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [[-1.0, 1.0], [4.0, -4.0]],
[[-6.0,6.0],[2.0,-2.0]]])); [[-6.0, 6.0], [2.0, -2.0]],
]));
} }
} }
let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),); let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259));
assert_eq!(4, data.get_trajectories().len()); assert_eq!(4, data.get_trajectories().len());
assert_relative_eq!(1.0, data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len()-1]); assert_relative_eq!(
1.0,
data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1]
);
} }
#[test] #[test]
#[should_panic] #[should_panic]
fn trajectory_wrong_shape() { fn trajectory_wrong_shape() {
@ -51,7 +59,6 @@ fn run_sampling() {
Trajectory::new(time, events); Trajectory::new(time, events);
} }
#[test] #[test]
#[should_panic] #[should_panic]
fn dataset_wrong_shape() { fn dataset_wrong_shape() {
@ -59,7 +66,6 @@ fn dataset_wrong_shape() {
let events = arr2(&[[0, 3], [1, 2]]); let events = arr2(&[[0, 3], [1, 2]]);
let t1 = Trajectory::new(time, events); let t1 = Trajectory::new(time, events);
let time = arr1(&[0.0, 0.2]); let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0, 3, 3], [1, 2, 3]]); let events = arr2(&[[0, 3, 3], [1, 2, 3]]);
let t2 = Trajectory::new(time, events); let t2 = Trajectory::new(time, events);

@ -1,16 +1,17 @@
use reCTBN::params; use reCTBN::params;
use reCTBN::node;
use std::collections::BTreeSet; use std::collections::BTreeSet;
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params {
node::Node::new(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality))
} }
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::new(domain) params::DiscreteStatesContinousTimeParams::new(label, domain)
} }

Loading…
Cancel
Save