parent
dc53e5167e
commit
4adfbfa4e4
@ -0,0 +1,99 @@ |
|||||||
|
mod utils; |
||||||
|
use utils::generate_discrete_time_continous_node; |
||||||
|
use rustyCTBN::network::Network; |
||||||
|
use rustyCTBN::node; |
||||||
|
use rustyCTBN::params; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
use rustyCTBN::ctbn::*; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn define_simpe_ctbn() { |
||||||
|
let _ = CtbnNetwork::init(); |
||||||
|
assert!(true); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn add_node_to_ctbn() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||||
|
assert_eq!(String::from("n1"), net.get_node(n1).label); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn add_edge_to_ctbn() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||||
|
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
let cs = net.get_children_set(n1); |
||||||
|
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn children_and_parents() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||||
|
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
let cs = net.get_children_set(n1); |
||||||
|
assert_eq!(&n2, cs.iter().next().unwrap()); |
||||||
|
let ps = net.get_parent_set(n2); |
||||||
|
assert_eq!(&n1, ps.iter().next().unwrap()); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn compute_index_ctbn() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||||
|
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||||
|
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
net.add_edge(n3, n2); |
||||||
|
let idx = net.get_param_index_network(n2, &vec![ |
||||||
|
params::StateType::Discrete(1),
|
||||||
|
params::StateType::Discrete(1),
|
||||||
|
params::StateType::Discrete(1)]); |
||||||
|
assert_eq!(3, idx); |
||||||
|
|
||||||
|
|
||||||
|
let idx = net.get_param_index_network(n2, &vec![ |
||||||
|
params::StateType::Discrete(0),
|
||||||
|
params::StateType::Discrete(1),
|
||||||
|
params::StateType::Discrete(1)]); |
||||||
|
assert_eq!(2, idx); |
||||||
|
|
||||||
|
|
||||||
|
let idx = net.get_param_index_network(n2, &vec![ |
||||||
|
params::StateType::Discrete(1),
|
||||||
|
params::StateType::Discrete(1),
|
||||||
|
params::StateType::Discrete(0)]); |
||||||
|
assert_eq!(1, idx); |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn compute_index_from_custom_parent_set() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||||
|
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||||
|
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap(); |
||||||
|
|
||||||
|
|
||||||
|
let idx = net.get_param_index_from_custom_parent_set(&vec![ |
||||||
|
params::StateType::Discrete(0),
|
||||||
|
params::StateType::Discrete(0),
|
||||||
|
params::StateType::Discrete(1)], |
||||||
|
&BTreeSet::from([1])); |
||||||
|
assert_eq!(0, idx); |
||||||
|
|
||||||
|
|
||||||
|
let idx = net.get_param_index_from_custom_parent_set(&vec![ |
||||||
|
params::StateType::Discrete(0),
|
||||||
|
params::StateType::Discrete(0),
|
||||||
|
params::StateType::Discrete(1)], |
||||||
|
&BTreeSet::from([1,2])); |
||||||
|
assert_eq!(2, idx); |
||||||
|
} |
@ -0,0 +1,95 @@ |
|||||||
|
mod utils; |
||||||
|
use utils::*; |
||||||
|
|
||||||
|
use rustyCTBN::parameter_learning::*; |
||||||
|
use rustyCTBN::ctbn::*; |
||||||
|
use rustyCTBN::network::Network; |
||||||
|
use rustyCTBN::node; |
||||||
|
use rustyCTBN::params; |
||||||
|
use rustyCTBN::tools::*; |
||||||
|
use ndarray::arr3; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
|
||||||
|
#[macro_use] |
||||||
|
extern crate approx; |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn learn_binary_cim_MLE() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),2)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"),2)) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[ |
||||||
|
[[-1.0, 1.0], [4.0, -4.0]], |
||||||
|
[[-6.0, 6.0], [2.0, -2.0]], |
||||||
|
])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let data = trajectory_generator(Box::new(&net), 100, 100.0); |
||||||
|
|
||||||
|
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); |
||||||
|
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||||
|
assert_eq!(CIM.shape(), [2, 2, 2]); |
||||||
|
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-6.0, CIM[[1, 0, 0]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-2.0, CIM[[1, 1, 1]], epsilon=0.2); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#[test] |
||||||
|
fn learn_ternary_cim_MLE() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n1"),3)) |
||||||
|
.unwrap(); |
||||||
|
let n2 = net |
||||||
|
.add_node(generate_discrete_time_continous_node(String::from("n2"),3)) |
||||||
|
.unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0],
|
||||||
|
[1.5, -2.0, 0.5], |
||||||
|
[0.4, 0.6, -1.0]]])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some(arr3(&[ |
||||||
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
||||||
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
||||||
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
||||||
|
])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let data = trajectory_generator(Box::new(&net), 100, 200.0); |
||||||
|
|
||||||
|
let (CIM, M, T) = MLE(Box::new(&net), &data, 1, None); |
||||||
|
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); |
||||||
|
assert_eq!(CIM.shape(), [3, 3, 3]); |
||||||
|
assert_relative_eq!(-1.0, CIM[[0, 0, 0]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-4.0, CIM[[0, 1, 1]], epsilon=0.2); |
||||||
|
assert_relative_eq!(-1.0, CIM[[0, 2, 2]], epsilon=0.2); |
||||||
|
assert_relative_eq!(0.5, CIM[[0, 0, 1]], epsilon=0.2); |
||||||
|
} |
||||||
|
|
@ -0,0 +1,64 @@ |
|||||||
|
use rustyCTBN::params::*; |
||||||
|
use ndarray::prelude::*; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
mod utils; |
||||||
|
|
||||||
|
#[macro_use] |
||||||
|
extern crate approx; |
||||||
|
|
||||||
|
|
||||||
|
fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTimeParams { |
||||||
|
let mut params = utils::generate_discrete_time_continous_param(3); |
||||||
|
|
||||||
|
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [3.2, 1.7, -4.0]]]; |
||||||
|
|
||||||
|
params.cim = Some(cim); |
||||||
|
params |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn test_uniform_generation() { |
||||||
|
let param = create_ternary_discrete_time_continous_param(); |
||||||
|
let mut states = Array1::<usize>::zeros(10000); |
||||||
|
|
||||||
|
states.mapv_inplace(|_| { |
||||||
|
if let StateType::Discrete(val) = param.get_random_state_uniform() { |
||||||
|
val |
||||||
|
} else { |
||||||
|
panic!() |
||||||
|
} |
||||||
|
}); |
||||||
|
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||||
|
|
||||||
|
assert_relative_eq!(1.0 / 3.0, zero_freq, epsilon = 0.01); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn test_random_generation_state() { |
||||||
|
let param = create_ternary_discrete_time_continous_param(); |
||||||
|
let mut states = Array1::<usize>::zeros(10000); |
||||||
|
|
||||||
|
states.mapv_inplace(|_| { |
||||||
|
if let StateType::Discrete(val) = param.get_random_state(1, 0).unwrap() { |
||||||
|
val |
||||||
|
} else { |
||||||
|
panic!() |
||||||
|
} |
||||||
|
}); |
||||||
|
let two_freq = states.mapv(|a| (a == 2) as u64).sum() as f64 / 10000.0; |
||||||
|
let zero_freq = states.mapv(|a| (a == 0) as u64).sum() as f64 / 10000.0; |
||||||
|
|
||||||
|
assert_relative_eq!(4.0 / 5.0, two_freq, epsilon = 0.01); |
||||||
|
assert_relative_eq!(1.0 / 5.0, zero_freq, epsilon = 0.01); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn test_random_generation_residence_time() { |
||||||
|
let param = create_ternary_discrete_time_continous_param(); |
||||||
|
let mut states = Array1::<f64>::zeros(10000); |
||||||
|
|
||||||
|
states.mapv_inplace(|_| param.get_random_residence_time(1, 0).unwrap()); |
||||||
|
|
||||||
|
assert_relative_eq!(1.0 / 5.0, states.mean().unwrap(), epsilon = 0.01); |
||||||
|
} |
@ -0,0 +1,43 @@ |
|||||||
|
|
||||||
|
use rustyCTBN::tools::*; |
||||||
|
use rustyCTBN::network::Network; |
||||||
|
use rustyCTBN::ctbn::*; |
||||||
|
use rustyCTBN::node; |
||||||
|
use rustyCTBN::params; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
use ndarray::arr3; |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#[macro_use] |
||||||
|
extern crate approx; |
||||||
|
|
||||||
|
mod utils; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn run_sampling() { |
||||||
|
let mut net = CtbnNetwork::init(); |
||||||
|
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap(); |
||||||
|
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap(); |
||||||
|
net.add_edge(n1, n2); |
||||||
|
|
||||||
|
match &mut net.get_node_mut(n1).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
match &mut net.get_node_mut(n2).params { |
||||||
|
params::Params::DiscreteStatesContinousTime(param) => { |
||||||
|
param.cim = Some (arr3(&[ |
||||||
|
[[-1.0,1.0],[4.0,-4.0]], |
||||||
|
[[-6.0,6.0],[2.0,-2.0]]])); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
let data = trajectory_generator(Box::new(&net), 4, 1.0); |
||||||
|
|
||||||
|
assert_eq!(4, data.trajectories.len()); |
||||||
|
assert_relative_eq!(1.0, data.trajectories[0].time[data.trajectories[0].time.len()-1]); |
||||||
|
} |
@ -0,0 +1,16 @@ |
|||||||
|
use rustyCTBN::params; |
||||||
|
use rustyCTBN::node; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node { |
||||||
|
node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name) |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ |
||||||
|
let mut domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); |
||||||
|
params::DiscreteStatesContinousTimeParams::init(domain) |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in new issue