|
|
@ -1,13 +1,13 @@ |
|
|
|
#![allow(non_snake_case)] |
|
|
|
#![allow(non_snake_case)] |
|
|
|
|
|
|
|
|
|
|
|
mod utils; |
|
|
|
mod utils; |
|
|
|
use utils::*; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use ndarray::arr3; |
|
|
|
use ndarray::arr3; |
|
|
|
use reCTBN::ctbn::*; |
|
|
|
use reCTBN::ctbn::*; |
|
|
|
use reCTBN::network::Network; |
|
|
|
use reCTBN::network::Network; |
|
|
|
use reCTBN::parameter_learning::*; |
|
|
|
use reCTBN::parameter_learning::*; |
|
|
|
use reCTBN::{params, tools::*}; |
|
|
|
use reCTBN::params; |
|
|
|
|
|
|
|
use reCTBN::tools::*; |
|
|
|
|
|
|
|
use utils::*; |
|
|
|
|
|
|
|
|
|
|
|
extern crate approx; |
|
|
|
extern crate approx; |
|
|
|
|
|
|
|
|
|
|
@ -32,8 +32,14 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) { |
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
Ok(()), |
|
|
|
Ok(()), |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
[[-1.0, 1.0], [4.0, -4.0]], |
|
|
|
[ |
|
|
|
[[-6.0, 6.0], [2.0, -2.0]], |
|
|
|
[-1.0, 1.0], |
|
|
|
|
|
|
|
[4.0, -4.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-6.0, 6.0], |
|
|
|
|
|
|
|
[2.0, -2.0] |
|
|
|
|
|
|
|
], |
|
|
|
])) |
|
|
|
])) |
|
|
|
); |
|
|
|
); |
|
|
|
} |
|
|
|
} |
|
|
@ -41,11 +47,20 @@ fn learn_binary_cim<T: ParameterLearning>(pl: T) { |
|
|
|
|
|
|
|
|
|
|
|
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); |
|
|
|
let data = trajectory_generator(&net, 100, 100.0, Some(6347747169756259)); |
|
|
|
let p = match pl.fit(&net, &data, 1, None) { |
|
|
|
let p = match pl.fit(&net, &data, 1, None) { |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p, |
|
|
|
}; |
|
|
|
}; |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [2, 2, 2]); |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|
&arr3(&[[[-1.0, 1.0], [4.0, -4.0]], [[-6.0, 6.0], [2.0, -2.0]],]), |
|
|
|
&arr3(&[ |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-1.0, 1.0], |
|
|
|
|
|
|
|
[4.0, -4.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-6.0, 6.0], |
|
|
|
|
|
|
|
[2.0, -2.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
]), |
|
|
|
0.1 |
|
|
|
0.1 |
|
|
|
)); |
|
|
|
)); |
|
|
|
} |
|
|
|
} |
|
|
@ -76,11 +91,13 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) { |
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
Ok(()), |
|
|
|
Ok(()), |
|
|
|
param.set_cim(arr3(&[[ |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
|
|
|
|
[ |
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
[0.4, 0.6, -1.0] |
|
|
|
[0.4, 0.6, -1.0] |
|
|
|
]])) |
|
|
|
], |
|
|
|
|
|
|
|
])) |
|
|
|
); |
|
|
|
); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -90,24 +107,48 @@ fn learn_ternary_cim<T: ParameterLearning>(pl: T) { |
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
Ok(()), |
|
|
|
Ok(()), |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
|
|
|
[ |
|
|
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
|
|
|
[-1.0, 0.5, 0.5], |
|
|
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
|
|
|
[3.0, -4.0, 1.0], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-6.0, 2.0, 4.0], |
|
|
|
|
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
|
|
|
|
[3.0, 1.0, -4.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-1.0, 0.1, 0.9], |
|
|
|
|
|
|
|
[2.0, -2.5, 0.5], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
])) |
|
|
|
])) |
|
|
|
); |
|
|
|
); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
|
|
|
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
|
|
|
let p = match pl.fit(&net, &data, 1, None){ |
|
|
|
let p = match pl.fit(&net, &data, 1, None) { |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p, |
|
|
|
}; |
|
|
|
}; |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [3, 3, 3]); |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|
&arr3(&[ |
|
|
|
&arr3(&[ |
|
|
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
|
|
|
[ |
|
|
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
|
|
|
[-1.0, 0.5, 0.5], |
|
|
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
|
|
|
[3.0, -4.0, 1.0], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-6.0, 2.0, 4.0], |
|
|
|
|
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
|
|
|
|
[3.0, 1.0, -4.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-1.0, 0.1, 0.9], |
|
|
|
|
|
|
|
[2.0, -2.5, 0.5], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
]), |
|
|
|
]), |
|
|
|
0.1 |
|
|
|
0.1 |
|
|
|
)); |
|
|
|
)); |
|
|
@ -139,11 +180,13 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) { |
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
Ok(()), |
|
|
|
Ok(()), |
|
|
|
param.set_cim(arr3(&[[ |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
|
|
|
|
[ |
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
[0.4, 0.6, -1.0] |
|
|
|
[0.4, 0.6, -1.0] |
|
|
|
]])) |
|
|
|
] |
|
|
|
|
|
|
|
])) |
|
|
|
); |
|
|
|
); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -153,21 +196,39 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning>(pl: T) { |
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
Ok(()), |
|
|
|
Ok(()), |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
|
|
|
[ |
|
|
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
|
|
|
[-1.0, 0.5, 0.5], |
|
|
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
|
|
|
[3.0, -4.0, 1.0], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-6.0, 2.0, 4.0], |
|
|
|
|
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
|
|
|
|
[3.0, 1.0, -4.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-1.0, 0.1, 0.9], |
|
|
|
|
|
|
|
[2.0, -2.5, 0.5], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
])) |
|
|
|
])) |
|
|
|
); |
|
|
|
); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
|
|
|
let data = trajectory_generator(&net, 100, 200.0, Some(6347747169756259)); |
|
|
|
let p = match pl.fit(&net, &data, 0, None){ |
|
|
|
let p = match pl.fit(&net, &data, 0, None) { |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p, |
|
|
|
}; |
|
|
|
}; |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [1, 3, 3]); |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|
&arr3(&[[[-3.0, 2.0, 1.0], [1.5, -2.0, 0.5], [0.4, 0.6, -1.0]]]), |
|
|
|
&arr3(&[ |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
|
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
|
|
|
|
[0.4, 0.6, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
]), |
|
|
|
0.1 |
|
|
|
0.1 |
|
|
|
)); |
|
|
|
)); |
|
|
|
} |
|
|
|
} |
|
|
@ -204,11 +265,13 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) { |
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
Ok(()), |
|
|
|
Ok(()), |
|
|
|
param.set_cim(arr3(&[[ |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
|
|
|
|
[ |
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
[-3.0, 2.0, 1.0], |
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
[0.4, 0.6, -1.0] |
|
|
|
[0.4, 0.6, -1.0] |
|
|
|
]])) |
|
|
|
], |
|
|
|
|
|
|
|
])) |
|
|
|
); |
|
|
|
); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -218,9 +281,21 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) { |
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
Ok(()), |
|
|
|
Ok(()), |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
param.set_cim(arr3(&[ |
|
|
|
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], |
|
|
|
[ |
|
|
|
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], |
|
|
|
[-1.0, 0.5, 0.5], |
|
|
|
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], |
|
|
|
[3.0, -4.0, 1.0], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-6.0, 2.0, 4.0], |
|
|
|
|
|
|
|
[1.5, -2.0, 0.5], |
|
|
|
|
|
|
|
[3.0, 1.0, -4.0] |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
[-1.0, 0.1, 0.9], |
|
|
|
|
|
|
|
[2.0, -2.5, 0.5], |
|
|
|
|
|
|
|
[0.9, 0.1, -1.0] |
|
|
|
|
|
|
|
], |
|
|
|
])) |
|
|
|
])) |
|
|
|
); |
|
|
|
); |
|
|
|
} |
|
|
|
} |
|
|
@ -291,8 +366,8 @@ fn learn_mixed_discrete_cim<T: ParameterLearning>(pl: T) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); |
|
|
|
let data = trajectory_generator(&net, 300, 300.0, Some(6347747169756259)); |
|
|
|
let p = match pl.fit(&net, &data, 2, None){ |
|
|
|
let p = match pl.fit(&net, &data, 2, None) { |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p |
|
|
|
params::Params::DiscreteStatesContinousTime(p) => p, |
|
|
|
}; |
|
|
|
}; |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); |
|
|
|
assert_eq!(p.get_cim().as_ref().unwrap().shape(), [9, 4, 4]); |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|
assert!(p.get_cim().as_ref().unwrap().abs_diff_eq( |
|
|
|