Removed `node.rs`

pull/45/head
Alessandro Bregoli 3 years ago
parent 42c457cf32
commit 9b7e683630
  1. 32
      src/ctbn.rs
  2. 1
      src/lib.rs
  3. 7
      src/network.rs
  4. 25
      src/node.rs
  5. 2
      src/parameter_learning.rs
  6. 12
      src/params.rs
  7. 2
      src/structure_learning/score_function.rs
  8. 8
      src/tools.rs
  9. 105
      tests/ctbn.rs
  10. 282
      tests/parameter_learning.rs
  11. 31
      tests/params.rs
  12. 185
      tests/structure_learning.rs
  13. 42
      tests/tools.rs
  14. 11
      tests/utils.rs

@ -1,6 +1,5 @@
use ndarray::prelude::*;
use crate::node;
use crate::params::{StateType, ParamsTrait};
use crate::params::{StateType, Params, ParamsTrait};
use crate::network;
use std::collections::BTreeSet;
@ -19,7 +18,6 @@ use std::collections::BTreeSet;
///
/// use std::collections::BTreeSet;
/// use reCTBN::network::Network;
/// use reCTBN::node;
/// use reCTBN::params;
/// use reCTBN::ctbn::*;
///
@ -29,16 +27,16 @@ use std::collections::BTreeSet;
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::new(domain);
/// let param = params::DiscreteStatesContinousTimeParams::new("X1".to_string(), domain);
///
/// //Create the node using the parameters
/// let X1 = node::Node::new(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
/// let X1 = params::Params::DiscreteStatesContinousTime(param);
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::new(domain);
/// let X2 = node::Node::new(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
/// let param = params::DiscreteStatesContinousTimeParams::new("X2".to_string(), domain);
/// let X2 = params::Params::DiscreteStatesContinousTime(param);
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::new();
@ -56,7 +54,7 @@ use std::collections::BTreeSet;
/// ```
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<node::Node>
nodes: Vec<Params>
}
@ -75,8 +73,8 @@ impl network::Network for CtbnNetwork {
}
fn add_node(&mut self, mut n: node::Node) -> Result<usize, network::NetworkError> {
n.params.reset_params();
fn add_node(&mut self, mut n: Params) -> Result<usize, network::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() -1)
@ -89,7 +87,7 @@ impl network::Network for CtbnNetwork {
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].params.reset_params();
self.nodes[child].reset_params();
}
}
@ -101,12 +99,12 @@ impl network::Network for CtbnNetwork {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &node::Node{
fn get_node(&self, node_idx: usize) -> &Params{
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node{
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params{
&mut self.nodes[node_idx]
}
@ -114,8 +112,8 @@ impl network::Network for CtbnNetwork {
fn get_param_index_network(&self, node: usize, current_state: &Vec<StateType>) -> usize{
self.adj_matrix.as_ref().unwrap().column(node).iter().enumerate().fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].params.state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].params.get_reserved_space_as_parent();
acc.0 += self.nodes[x.0].state_to_index(&current_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
}).0
@ -124,8 +122,8 @@ impl network::Network for CtbnNetwork {
fn get_param_index_from_custom_parent_set(&self, current_state: &Vec<StateType>, parent_set: &BTreeSet<usize>) -> usize {
parent_set.iter().fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].params.state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].params.get_reserved_space_as_parent();
acc.0 += self.nodes[*x].state_to_index(&current_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
}).0
}

@ -2,7 +2,6 @@
#[macro_use]
extern crate approx;
pub mod node;
pub mod params;
pub mod network;
pub mod ctbn;

@ -1,6 +1,5 @@
use thiserror::Error;
use crate::params;
use crate::node;
use std::collections::BTreeSet;
/// Error types for trait Network
@ -15,14 +14,14 @@ pub enum NetworkError {
///The Network trait define the required methods for a structure used as pgm (such as ctbn).
pub trait Network {
fn initialize_adj_matrix(&mut self);
fn add_node(&mut self, n: node::Node) -> Result<usize, NetworkError>;
fn add_node(&mut self, n: params::Params) -> Result<usize, NetworkError>;
fn add_edge(&mut self, parent: usize, child: usize);
///Get all the indices of the nodes contained inside the network
fn get_node_indices(&self) -> std::ops::Range<usize>;
fn get_number_of_nodes(&self) -> usize;
fn get_node(&self, node_idx: usize) -> &node::Node;
fn get_node_mut(&mut self, node_idx: usize) -> &mut node::Node;
fn get_node(&self, node_idx: usize) -> &params::Params;
fn get_node_mut(&mut self, node_idx: usize) -> &mut params::Params;
///Compute the index that must be used to access the parameters of a node given a specific
///configuration of the network. Usually, the only values really used in *current_state* are

@ -1,25 +0,0 @@
use crate::params::*;
pub struct Node {
pub params: Params,
pub label: String
}
impl Node {
pub fn new(params: Params, label: String) -> Node {
Node{
params: params,
label:label
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Node) -> bool{
self.label == other.label
}
}

@ -24,7 +24,6 @@ pub fn sufficient_statistics<T:network::Network>(
//Get the number of values assumable by the node
let node_domain = net
.get_node(node.clone())
.params
.get_reserved_space_as_parent();
//Get the number of values assumable by each parent of the node
@ -32,7 +31,6 @@ pub fn sufficient_statistics<T:network::Network>(
.iter()
.map(|x| {
net.get_node(x.clone())
.params
.get_reserved_space_as_parent()
})
.collect();

@ -49,6 +49,9 @@ pub trait ParamsTrait {
/// Validate parameters against domain
fn validate_params(&self) -> Result<(), ParamsError>;
/// Return a reference to the associated label
fn get_label(&self) -> &String;
}
/// The Params enum is the core element for building different types of nodes. The goal is to
@ -70,6 +73,7 @@ pub enum Params {
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
pub struct DiscreteStatesContinousTimeParams {
label: String,
domain: BTreeSet<String>,
cim: Option<Array3<f64>>,
transitions: Option<Array3<u64>>,
@ -77,8 +81,9 @@ pub struct DiscreteStatesContinousTimeParams {
}
impl DiscreteStatesContinousTimeParams {
pub fn new(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams {
label,
domain,
cim: Option::None,
transitions: Option::None,
@ -244,4 +249,9 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
return Ok(());
}
fn get_label(&self) -> &String {
&self.label
}
}

@ -44,7 +44,7 @@ impl LogLikelihood {
T: network::Network,
{
//Identify the type of node used
match &net.get_node(node).params {
match &net.get_node(node){
params::Params::DiscreteStatesContinousTime(_params) => {
//Compute the sufficient statistics M (number of transistions) and T (residence
//time)

@ -1,5 +1,4 @@
use crate::network;
use crate::node;
use crate::params;
use crate::params::ParamsTrait;
use ndarray::prelude::*;
@ -80,7 +79,7 @@ pub fn trajectory_generator<T: network::Network>(
//Configuration of the process variables at time t initialized with an uniform
//distribution.
let mut current_state: Vec<params::StateType> = net.get_node_indices()
.map(|x| net.get_node(x).params.get_random_state_uniform(&mut rng))
.map(|x| net.get_node(x).get_random_state_uniform(&mut rng))
.collect();
//History of all the configurations of the process variables.
let mut events: Vec<Array1<usize>> = Vec::new();
@ -106,9 +105,8 @@ pub fn trajectory_generator<T: network::Network>(
if let None = val {
*val = Some(
net.get_node(idx)
.params
.get_random_residence_time(
net.get_node(idx).params.state_to_index(&current_state[idx]),
net.get_node(idx).state_to_index(&current_state[idx]),
net.get_param_index_network(idx, &current_state),
&mut rng,
)
@ -137,10 +135,8 @@ pub fn trajectory_generator<T: network::Network>(
//Compute the new state of the transitioning variable.
current_state[next_node_transition] = net
.get_node(next_node_transition)
.params
.get_random_state(
net.get_node(next_node_transition)
.params
.state_to_index(&current_state[next_node_transition]),
net.get_param_index_network(next_node_transition, &current_state),
&mut rng,

@ -1,10 +1,9 @@
mod utils;
use utils::generate_discrete_time_continous_node;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::node;
use reCTBN::params;
use reCTBN::params::{self, ParamsTrait};
use std::collections::BTreeSet;
use reCTBN::ctbn::*;
use utils::generate_discrete_time_continous_node;
#[test]
fn define_simpe_ctbn() {
@ -15,15 +14,21 @@ fn define_simpe_ctbn() {
#[test]
fn add_node_to_ctbn() {
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
assert_eq!(String::from("n1"), net.get_node(n1).label);
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
assert_eq!(&String::from("n1"), net.get_node(n1).get_label());
}
#[test]
fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
@ -32,8 +37,12 @@ fn add_edge_to_ctbn() {
#[test]
fn children_and_parents() {
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
net.add_edge(n1, n2);
let cs = net.get_children_set(n1);
assert_eq!(&n2, cs.iter().next().unwrap());
@ -41,59 +50,81 @@ fn children_and_parents() {
assert_eq!(&n1, ps.iter().next().unwrap());
}
#[test]
fn compute_index_ctbn() {
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
net.add_edge(n1, n2);
net.add_edge(n3, n2);
let idx = net.get_param_index_network(n2, &vec![
let idx = net.get_param_index_network(
n2,
&vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
params::StateType::Discrete(1),
],
);
assert_eq!(3, idx);
let idx = net.get_param_index_network(n2, &vec![
let idx = net.get_param_index_network(
n2,
&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(1),
params::StateType::Discrete(1)]);
params::StateType::Discrete(1),
],
);
assert_eq!(2, idx);
let idx = net.get_param_index_network(n2, &vec![
let idx = net.get_param_index_network(
n2,
&vec![
params::StateType::Discrete(1),
params::StateType::Discrete(1),
params::StateType::Discrete(0)]);
params::StateType::Discrete(0),
],
);
assert_eq!(1, idx);
}
#[test]
fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::new();
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();
let idx = net.get_param_index_from_custom_parent_set(&vec![
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let _n2 = net
.add_node(generate_discrete_time_continous_node(String::from("n2"), 2))
.unwrap();
let _n3 = net
.add_node(generate_discrete_time_continous_node(String::from("n3"), 2))
.unwrap();
let idx = net.get_param_index_from_custom_parent_set(
&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1]));
params::StateType::Discrete(1),
],
&BTreeSet::from([1]),
);
assert_eq!(0, idx);
let idx = net.get_param_index_from_custom_parent_set(&vec![
let idx = net.get_param_index_from_custom_parent_set(
&vec![
params::StateType::Discrete(0),
params::StateType::Discrete(0),
params::StateType::Discrete(1)],
&BTreeSet::from([1,2]));
params::StateType::Discrete(1),
],
&BTreeSet::from([1, 2]),
);
assert_eq!(2, idx);
}

@ -1,20 +1,16 @@
mod utils;
use utils::*;
use reCTBN::parameter_learning::*;
use ndarray::arr3;
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::node;
use reCTBN::params;
use reCTBN::tools::*;
use ndarray::arr3;
use reCTBN::parameter_learning::*;
use reCTBN::{params, tools::*};
use std::collections::BTreeSet;
#[macro_use]
extern crate approx;
fn learn_binary_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -25,29 +21,32 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
])));
]))
);
}
}
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [2, 2, 2]);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]],
]), 0.1));
assert!(CIM.abs_diff_eq(
&arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]),
0.1
));
}
#[test]
@ -56,12 +55,9 @@ fn learn_binary_cim_MLE() {
learn_binary_cim(mle);
}
#[test]
fn learn_binary_cim_BA() {
let ba = BayesianApproach{
alpha: 1,
tau: 1.0};
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_binary_cim(ba);
}
@ -75,48 +71,55 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])));
[0.4, 0.6, -1.0]
]]))
);
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])));
]))
);
}
}
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 1, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [3, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[
assert!(CIM.abs_diff_eq(
&arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
]), 0.1));
]),
0.1
));
}
#[test]
fn learn_ternary_cim_MLE() {
let mle = MLE {};
learn_ternary_cim(mle);
}
#[test]
fn learn_ternary_cim_BA() {
let ba = BayesianApproach{
alpha: 1,
tau: 1.0};
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim(ba);
}
@ -130,50 +133,54 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])));
[0.4, 0.6, -1.0]
]]))
);
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])));
]))
);
}
}
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 0, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [1, 3, 3]);
assert!(CIM.abs_diff_eq(&arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]]), 0.1));
assert!(CIM.abs_diff_eq(
&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]),
0.1
));
}
#[test]
fn learn_ternary_cim_no_parents_MLE() {
let mle = MLE {};
learn_ternary_cim_no_parents(mle);
}
#[test]
fn learn_ternary_cim_no_parents_BA() {
let ba = BayesianApproach{
alpha: 1,
tau: 1.0};
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_ternary_cim_no_parents(ba);
}
fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -190,61 +197,159 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
net.add_edge(n1, n3);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])));
[0.4, 0.6, -1.0]
]]))
);
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])));
]))
);
}
}
match &mut net.get_node_mut(n3).params {
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]],
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]],
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]],
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]],
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
])));
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.3, 0.2],
[0.5, -4.0, 2.5, 1.0],
[2.5, 0.5, -4.0, 1.0],
[0.7, 0.2, 0.1, -1.0]
],
[
[-6.0, 2.0, 3.0, 1.0],
[1.5, -3.0, 0.5, 1.0],
[2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]))
);
}
}
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259));
let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]);
assert!(CIM.abs_diff_eq(&arr3(&[
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]],
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]],
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]],
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]],
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
]), 0.1));
assert!(CIM.abs_diff_eq(
&arr3(&[
[
[-1.0, 0.5, 0.3, 0.2],
[0.5, -4.0, 2.5, 1.0],
[2.5, 0.5, -4.0, 1.0],
[0.7, 0.2, 0.1, -1.0]
],
[
[-6.0, 2.0, 3.0, 1.0],
[1.5, -3.0, 0.5, 1.0],
[2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]),
0.1
));
}
#[test]
@ -253,11 +358,8 @@ fn learn_mixed_discrete_cim_MLE() {
learn_mixed_discrete_cim(mle);
}
#[test]
fn learn_mixed_discrete_cim_BA() {
let ba = BayesianApproach{
alpha: 1,
tau: 1.0};
let ba = BayesianApproach { alpha: 1, tau: 1.0 };
learn_mixed_discrete_cim(ba);
}

@ -1,16 +1,15 @@
use ndarray::prelude::*;
use reCTBN::params::*;
use std::collections::BTreeSet;
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
use reCTBN::params::{ParamsTrait, *};
mod utils;
#[macro_use]
extern crate approx;
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams {
let mut params = utils::generate_discrete_time_continous_param(3);
let mut params = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
@ -18,6 +17,12 @@ fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTime
params
}
#[test]
fn test_get_label() {
let param = create_ternary_discrete_time_continous_param();
assert_eq!(&String::from("A"), param.get_label())
}
#[test]
fn test_uniform_generation() {
let param = create_ternary_discrete_time_continous_param();
@ -79,15 +84,19 @@ fn test_validate_params_valid_cim() {
#[test]
fn test_validate_params_valid_cim_with_huge_values() {
let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-2e10, 1e10, 1e10], [1.5e10, -3e10, 1.5e10], [1e10, 1e10, -2e10]]];
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[
[-2e10, 1e10, 1e10],
[1.5e10, -3e10, 1.5e10],
[1e10, 1e10, -2e10]
]];
let result = param.set_cim(cim);
assert_eq!(Ok(()), result);
}
#[test]
fn test_validate_params_cim_not_initialized() {
let param = utils::generate_discrete_time_continous_param(3);
let param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
assert_eq!(
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
@ -98,7 +107,7 @@ fn test_validate_params_cim_not_initialized() {
#[test]
fn test_validate_params_wrong_shape() {
let mut param = utils::generate_discrete_time_continous_param(4);
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 4);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
@ -111,7 +120,7 @@ fn test_validate_params_wrong_shape() {
#[test]
fn test_validate_params_positive_diag() {
let mut param = utils::generate_discrete_time_continous_param(3);
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(
@ -124,7 +133,7 @@ fn test_validate_params_positive_diag() {
#[test]
fn test_validate_params_row_not_sum_to_zero() {
let mut param = utils::generate_discrete_time_continous_param(3);
let mut param = utils::generate_discrete_time_continous_params("A".to_string(), 3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]];
let result = param.set_cim(cim);
assert_eq!(

@ -1,17 +1,14 @@
mod utils;
use utils::*;
use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*;
use reCTBN::network::Network;
use reCTBN::tools::*;
use reCTBN::params;
use reCTBN::structure_learning::score_function::*;
use reCTBN::structure_learning::score_based_algorithm::*;
use reCTBN::structure_learning::StructureLearningAlgorithm;
use ndarray::{arr1, arr2, arr3};
use reCTBN::structure_learning::{score_based_algorithm::*, StructureLearningAlgorithm};
use reCTBN::tools::*;
use std::collections::BTreeSet;
use reCTBN::params;
#[macro_use]
extern crate approx;
@ -23,19 +20,19 @@ fn simple_score_test() {
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let trj = Trajectory::new(
arr1(&[0.0,0.1,0.3]),
arr2(&[[0],[1],[1]]));
let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
let dataset = Dataset::new(vec![trj]);
let ll = LogLikelihood::new(1, 1.0);
assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
assert_abs_diff_eq!(
0.04257,
ll.call(&net, n1, &BTreeSet::new(), &dataset),
epsilon = 1e-3
);
}
#[test]
fn simple_bic() {
let mut net = CtbnNetwork::new();
@ -43,19 +40,18 @@ fn simple_bic() {
.add_node(generate_discrete_time_continous_node(String::from("n1"), 2))
.unwrap();
let trj = Trajectory::new(
arr1(&[0.0,0.1,0.3]),
arr2(&[[0],[1],[1]]));
let trj = Trajectory::new(arr1(&[0.0, 0.1, 0.3]), arr2(&[[0], [1], [1]]));
let dataset = Dataset::new(vec![trj]);
let bic = BIC::new(1, 1.0);
assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
assert_abs_diff_eq!(
-0.65058,
bic.call(&net, n1, &BTreeSet::new(), &dataset),
epsilon = 1e-3
);
}
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm>(sl: T) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -66,25 +62,33 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])));
[0.4, 0.6, -1.0]
]]))
);
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])));
]))
);
}
}
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let mut net = CtbnNetwork::new();
let _n1 = net
@ -93,7 +97,6 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
let net = sl.fit_transform(net, &data);
}
#[test]
#[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
@ -112,32 +115,39 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])));
[0.4, 0.6, -1.0]
]]))
);
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])));
]))
);
}
}
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259));
let net = sl.fit_transform(net, &data);
assert_eq!(BTreeSet::from_iter(vec![n1]), net.get_parent_set(n2));
assert_eq!(BTreeSet::new(), net.get_parent_set(n1));
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0);
@ -152,7 +162,6 @@ pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
learn_ternary_net_2_nodes(hl);
}
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::new();
let n1 = net
@ -169,45 +178,97 @@ fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
net.add_edge(n1, n3);
net.add_edge(n2, n3);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
assert_eq!(
Ok(()),
param.set_cim(arr3(&[[
[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])));
[0.4, 0.6, -1.0]
]]))
);
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])));
]))
);
}
}
match &mut net.get_node_mut(n3).params {
match &mut net.get_node_mut(n3) {
params::Params::DiscreteStatesContinousTime(param) => {
assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]],
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]],
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]],
[[-6.0, 1.0, 2.0, 3.0], [0.5, -3.0, 1.0, 1.5], [1.4, 2.1, -4.3, 0.8], [0.5, 1.0, 2.5, -4.0]],
[[-1.3, 0.9, 0.3, 0.1], [0.1, -1.3, 0.2, 1.0], [0.5, 1.0, -3.0, 1.5], [0.1, 0.4, 0.3, -0.8]],
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
])));
assert_eq!(
Ok(()),
param.set_cim(arr3(&[
[
[-1.0, 0.5, 0.3, 0.2],
[0.5, -4.0, 2.5, 1.0],
[2.5, 0.5, -4.0, 1.0],
[0.7, 0.2, 0.1, -1.0]
],
[
[-6.0, 2.0, 3.0, 1.0],
[1.5, -3.0, 0.5, 1.0],
[2.0, 1.3, -5.0, 1.7],
[2.5, 0.5, 1.0, -4.0]
],
[
[-1.3, 0.3, 0.1, 0.9],
[1.4, -4.0, 0.5, 2.1],
[1.0, 1.5, -3.0, 0.5],
[0.4, 0.3, 0.1, -0.8]
],
[
[-2.0, 1.0, 0.7, 0.3],
[1.3, -5.9, 2.7, 1.9],
[2.0, 1.5, -4.0, 0.5],
[0.2, 0.7, 0.1, -1.0]
],
[
[-6.0, 1.0, 2.0, 3.0],
[0.5, -3.0, 1.0, 1.5],
[1.4, 2.1, -4.3, 0.8],
[0.5, 1.0, 2.5, -4.0]
],
[
[-1.3, 0.9, 0.3, 0.1],
[0.1, -1.3, 0.2, 1.0],
[0.5, 1.0, -3.0, 1.5],
[0.1, 0.4, 0.3, -0.8]
],
[
[-2.0, 1.0, 0.6, 0.4],
[2.6, -7.1, 1.4, 3.1],
[5.0, 1.0, -8.0, 2.0],
[1.4, 0.4, 0.2, -2.0]
],
[
[-3.0, 1.0, 1.5, 0.5],
[3.0, -6.0, 1.0, 2.0],
[0.3, 0.5, -1.9, 1.1],
[5.0, 1.0, 2.0, -8.0]
],
[
[-2.6, 0.6, 0.2, 1.8],
[2.0, -6.0, 3.0, 1.0],
[0.1, 0.5, -1.3, 0.7],
[0.8, 0.6, 0.2, -1.6]
],
]))
);
}
}
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 300, 30.0, Some(6347747169756259));
return (net, data);
}
@ -219,7 +280,6 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
assert_eq!(BTreeSet::from_iter(vec![0, 1]), net.get_parent_set(2));
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::new(1, 1.0);
@ -234,8 +294,6 @@ pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
learn_mixed_discrete_net_3_nodes(hl);
}
fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgorithm>(sl: T) {
let (net, data) = get_mixed_discrete_net_3_nodes_with_data();
let net = sl.fit_transform(net, &data);
@ -244,7 +302,6 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgo
assert_eq!(BTreeSet::from_iter(vec![0]), net.get_parent_set(2));
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
let ll = LogLikelihood::new(1, 1.0);

@ -1,13 +1,9 @@
use reCTBN::tools::*;
use reCTBN::network::Network;
use ndarray::{arr1, arr2, arr3};
use reCTBN::ctbn::*;
use reCTBN::node;
use reCTBN::network::Network;
use reCTBN::params;
use reCTBN::tools::*;
use std::collections::BTreeSet;
use ndarray::{arr1, arr2, arr3};
#[macro_use]
extern crate approx;
@ -17,32 +13,44 @@ mod utils;
#[test]
fn run_sampling() {
let mut net = CtbnNetwork::new();
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let n1 = net
.add_node(utils::generate_discrete_time_continous_node(
String::from("n1"),
2,
))
.unwrap();
let n2 = net
.add_node(utils::generate_discrete_time_continous_node(
String::from("n2"),
2,
))
.unwrap();
net.add_edge(n1, n2);
match &mut net.get_node_mut(n1).params {
match &mut net.get_node_mut(n1) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]]));
}
}
match &mut net.get_node_mut(n2).params {
match &mut net.get_node_mut(n2) {
params::Params::DiscreteStatesContinousTime(param) => {
param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]],
[[-6.0,6.0],[2.0,-2.0]]]));
[[-6.0, 6.0], [2.0, -2.0]],
]));
}
}
let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259),);
let data = trajectory_generator(&net, 4, 1.0, Some(6347747169756259));
assert_eq!(4, data.get_trajectories().len());
assert_relative_eq!(1.0, data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len()-1]);
assert_relative_eq!(
1.0,
data.get_trajectories()[0].get_time()[data.get_trajectories()[0].get_time().len() - 1]
);
}
#[test]
#[should_panic]
fn trajectory_wrong_shape() {
@ -51,7 +59,6 @@ fn run_sampling() {
Trajectory::new(time, events);
}
#[test]
#[should_panic]
fn dataset_wrong_shape() {
@ -59,7 +66,6 @@ fn dataset_wrong_shape() {
let events = arr2(&[[0, 3], [1, 2]]);
let t1 = Trajectory::new(time, events);
let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0, 3, 3], [1, 2, 3]]);
let t2 = Trajectory::new(time, events);

@ -1,16 +1,17 @@
use reCTBN::params;
use reCTBN::node;
use std::collections::BTreeSet;
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node {
node::Node::new(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name)
pub fn generate_discrete_time_continous_node(label: String, cardinality: usize) -> params::Params {
params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_params(label, cardinality))
}
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
pub fn generate_discrete_time_continous_params(label: String, cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::new(domain)
params::DiscreteStatesContinousTimeParams::new(label, domain)
}

Loading…
Cancel
Save