parent
dc53e5167e
commit
4adfbfa4e4
@ -0,0 +1,99 @@ |
||||
mod utils; |
||||
use utils::generate_discrete_time_continous_node; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::node; |
||||
use rustyCTBN::params; |
||||
use std::collections::BTreeSet; |
||||
use rustyCTBN::ctbn::*; |
||||
|
||||
#[test] |
||||
fn define_simpe_ctbn() { |
||||
let _ = CtbnNetwork::init(); |
||||
assert!(true); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_node_to_ctbn() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
assert_eq!(String::from("n1"), net.get_node(n1).label); |
||||
} |
||||
|
||||
#[test] |
||||
fn add_edge_to_ctbn() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
let cs = net.get_children_set(n1); |
||||
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||
} |
||||
|
||||
#[test] |
||||
fn children_and_parents() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
let cs = net.get_children_set(n1); |
||||
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||
let ps = net.get_parent_set(n2); |
||||
assert_eq!(&n1, ps.iter().next().unwrap()); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn compute_index_ctbn() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
net.add_edge(n3, n2); |
||||
let idx = net.get_param_index_network(n2, &vec![ |
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1)]); |
||||
assert_eq!(3, idx); |
||||
|
||||
|
||||
let idx = net.get_param_index_network(n2, &vec![ |
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1)]); |
||||
assert_eq!(2, idx); |
||||
|
||||
|
||||
let idx = net.get_param_index_network(n2, &vec![ |
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(1),
|
||||
params::StateType::Discrete(0)]); |
||||
assert_eq!(1, idx); |
||||
|
||||
} |
||||
|
||||
|
||||
|
||||
#[test] |
||||
fn compute_index_from_custom_parent_set() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); |
||||
|
||||
|
||||
let idx = net.get_param_index_from_custom_parent_set(&vec![ |
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(1)], |
||||
&BTreeSet::from([1])); |
||||
assert_eq!(0, idx); |
||||
|
||||
|
||||
let idx = net.get_param_index_from_custom_parent_set(&vec![ |
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(0),
|
||||
params::StateType::Discrete(1)], |
||||
&BTreeSet::from([1,2])); |
||||
assert_eq!(2, idx); |
||||
} |
@ -0,0 +1,95 @@ |
||||
mod utils; |
||||
use utils::*; |
||||
|
||||
use rustyCTBN::parameter_learning::*; |
||||
use rustyCTBN::ctbn::*; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::node; |
||||
use rustyCTBN::params; |
||||
use rustyCTBN::tools::*; |
||||
use ndarray::arr3; |
||||
use std::collections::BTreeSet; |
||||
|
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
|
||||
#[test] |
||||
fn learn_binary_cim_MLE() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),2)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[ |
||||
[[-1.0, 1.0], [4.0, -4.0]], |
||||
[[-6.0, 6.0], [2.0, -2.0]], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(Box::new(&net), 100, 100.0); |
||||
|
||||
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); |
||||
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||
assert_eq!(CIM.shape(), [2, 2, 2]); |
||||
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); |
||||
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); |
||||
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2); |
||||
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn learn_ternary_cim_MLE() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||
.unwrap(); |
||||
let n2 = net |
||||
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||
.unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
|
||||
[1.5, -2.0, 0.5], |
||||
[0.4, 0.6, -1.0]]])); |
||||
} |
||||
} |
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some(arr3(&[ |
||||
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||
])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(Box::new(&net), 100, 200.0); |
||||
|
||||
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); |
||||
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||
assert_eq!(CIM.shape(), [3, 3, 3]); |
||||
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); |
||||
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); |
||||
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2); |
||||
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2); |
||||
} |
||||
|
@ -0,0 +1,64 @@ |
||||
use rustyCTBN::params::*; |
||||
use ndarray::prelude::*; |
||||
use std::collections::BTreeSet; |
||||
|
||||
mod utils; |
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
|
||||
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { |
||||
let mut params = utils::generate_discrete_time_continous_param(3); |
||||
|
||||
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; |
||||
|
||||
params.cim = Some(cim); |
||||
params |
||||
} |
||||
|
||||
#[test] |
||||
fn test_uniform_generation() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<usize>::zeros(10000); |
||||
|
||||
states.mapv_inplace(|_| { |
||||
if let StateType::Discrete(val) = param.get_random_state_uniform() { |
||||
val |
||||
} else { |
||||
panic!() |
||||
} |
||||
}); |
||||
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||
|
||||
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_random_generation_state() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<usize>::zeros(10000); |
||||
|
||||
states.mapv_inplace(|_| { |
||||
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { |
||||
val |
||||
} else { |
||||
panic!() |
||||
} |
||||
}); |
||||
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; |
||||
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||
|
||||
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01); |
||||
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01); |
||||
} |
||||
|
||||
#[test] |
||||
fn test_random_generation_residence_time() { |
||||
let param = create_ternary_discrete_time_continous_param(); |
||||
let mut states = Array1::<f64>::zeros(10000); |
||||
|
||||
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); |
||||
|
||||
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); |
||||
} |
@ -0,0 +1,43 @@ |
||||
|
||||
use rustyCTBN::tools::*; |
||||
use rustyCTBN::network::Network; |
||||
use rustyCTBN::ctbn::*; |
||||
use rustyCTBN::node; |
||||
use rustyCTBN::params; |
||||
use std::collections::BTreeSet; |
||||
use ndarray::arr3; |
||||
|
||||
|
||||
|
||||
#[macro_use] |
||||
extern crate approx; |
||||
|
||||
mod utils; |
||||
|
||||
#[test] |
||||
fn run_sampling() { |
||||
let mut net = CtbnNetwork::init(); |
||||
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||
net.add_edge(n1, n2); |
||||
|
||||
match &mut net.get_node_mut(n1).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); |
||||
} |
||||
} |
||||
|
||||
|
||||
match &mut net.get_node_mut(n2).params { |
||||
params::Params::DiscreteStatesContinousTime(param) => { |
||||
param.cim = Some (arr3(&[ |
||||
[[-1.0,1.0],[4.0,-4.0]], |
||||
[[-6.0,6.0],[2.0,-2.0]]])); |
||||
} |
||||
} |
||||
|
||||
let data = trajectory_generator(Box::new(&net), 4, 1.0); |
||||
|
||||
assert_eq!(4, data.trajectories.len()); |
||||
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); |
||||
} |
@ -0,0 +1,16 @@ |
||||
use rustyCTBN::params; |
||||
use rustyCTBN::node; |
||||
use std::collections::BTreeSet; |
||||
|
||||
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { |
||||
node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) |
||||
} |
||||
|
||||
|
||||
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ |
||||
let mut domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); |
||||
params::DiscreteStatesContinousTimeParams::init(domain) |
||||
} |
||||
|
||||
|
||||
|
Loading…
Reference in new issue