Enforced correct set of parameters (cim) when inserted manually

pull/26/head
Alessandro Bregoli 3 years ago
parent 331c2006e9
commit c178862664
  1. 50
      src/params.rs
  2. 38
      tests/parameter_learning.rs
  3. 14
      tests/params.rs
  4. 4
      tests/tools.rs
  5. 2
      tests/utils.rs

@ -69,21 +69,55 @@ pub enum Params {
/// - **residence_time**: permanence time in each possible states given a specific /// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set /// realization of the parent set
pub struct DiscreteStatesContinousTimeParams { pub struct DiscreteStatesContinousTimeParams {
pub domain: BTreeSet<String>, domain: BTreeSet<String>,
pub cim: Option<Array3<f64>>, cim: Option<Array3<f64>>,
pub transitions: Option<Array3<u64>>, transitions: Option<Array3<u64>>,
pub residence_time: Option<Array2<f64>>, residence_time: Option<Array2<f64>>,
} }
impl DiscreteStatesContinousTimeParams { impl DiscreteStatesContinousTimeParams {
pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams { pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams { DiscreteStatesContinousTimeParams {
domain: domain, domain,
cim: Option::None, cim: Option::None,
transitions: Option::None, transitions: Option::None,
residence_time: Option::None, residence_time: Option::None,
} }
} }
pub fn get_cim(&self) -> &Option<Array3<f64>> {
&self.cim
}
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError>{
self.cim = Some(cim);
match self.validate_params() {
Ok(()) => Ok(()),
Err(e) => {
self.cim = None;
Err(e)
}
}
}
pub fn get_transitions(&self) -> &Option<Array3<u64>> {
&self.transitions
}
pub fn set_transitions(&mut self, transitions: Array3<u64>) {
self.transitions = Some(transitions);
}
pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
&self.residence_time
}
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
self.residence_time = Some(residence_time);
}
} }
impl ParamsTrait for DiscreteStatesContinousTimeParams { impl ParamsTrait for DiscreteStatesContinousTimeParams {
@ -192,10 +226,8 @@ impl ParamsTrait for DiscreteStatesContinousTimeParams {
} }
// Check if each row sum up to 0 // Check if each row sum up to 0
let zeros = Array::zeros(domain_size); if cim.sum_axis(Axis(2)).iter()
if cim .any(|x| f64::abs(x.clone()) > f64::EPSILON * 3.0)
.axis_iter(Axis(0))
.any(|x| !x.sum_axis(Axis(1)).abs_diff_eq(&zeros, f64::MIN_POSITIVE))
{ {
return Err(ParamsError::InvalidCIM(String::from( return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0", "The sum of each row must be 0",

@ -27,16 +27,16 @@ fn learn_binary_cim<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])); assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 3.0], [2.0, -2.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 1.0], [4.0, -4.0]], [[-1.0, 1.0], [4.0, -4.0]],
[[-6.0, 6.0], [2.0, -2.0]], [[-6.0, 6.0], [2.0, -2.0]],
])); ])));
} }
} }
@ -77,19 +77,19 @@ fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])); [0.4, 0.6, -1.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])); ])));
} }
} }
@ -132,19 +132,19 @@ fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])); [0.4, 0.6, -1.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])); ])));
} }
} }
@ -192,28 +192,28 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[[[-3.0, 2.0, 1.0], assert_eq!(Ok(()), param.set_cim(arr3(&[[[-3.0, 2.0, 1.0],
[1.5, -2.0, 0.5], [1.5, -2.0, 0.5],
[0.4, 0.6, -1.0]]])); [0.4, 0.6, -1.0]]])));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]], [[-1.0, 0.5, 0.5], [3.0, -4.0, 1.0], [0.9, 0.1, -1.0]],
[[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]], [[-6.0, 2.0, 4.0], [1.5, -2.0, 0.5], [3.0, 1.0, -4.0]],
[[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]], [[-1.0, 0.1, 0.9], [2.0, -2.5, 0.5], [0.9, 0.1, -1.0]],
])); ])));
} }
} }
match &mut net.get_node_mut(n3).params { match &mut net.get_node_mut(n3).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some(arr3(&[ assert_eq!(Ok(()), param.set_cim(arr3(&[
[[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]], [[-1.0, 0.5, 0.3, 0.2], [0.5, -4.0, 2.5, 1.0], [2.5, 0.5, -4.0, 1.0], [0.7, 0.2, 0.1, -1.0]],
[[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 , 1.7], [2.5, 0.5, 1.0, -4.0]], [[-6.0, 2.0, 3.0, 1.0], [1.5, -3.0, 0.5, 1.0], [2.0, 1.3, -5.0 ,1.7], [2.5, 0.5, 1.0, -4.0]],
[[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]], [[-1.3, 0.3, 0.1, 0.9], [1.4, -4.0, 0.5, 2.1], [1.0, 1.5, -3.0, 0.5], [0.4, 0.3, 0.1, -0.8]],
[[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]], [[-2.0, 1.0, 0.7, 0.3], [1.3, -5.9, 2.7, 1.9], [2.0, 1.5, -4.0, 0.5], [0.2, 0.7, 0.1, -1.0]],
@ -223,12 +223,12 @@ fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
[[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]], [[-2.0, 1.0, 0.6, 0.4], [2.6, -7.1, 1.4, 3.1], [5.0, 1.0, -8.0, 2.0], [1.4, 0.4, 0.2, -2.0]],
[[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]], [[-3.0, 1.0, 1.5, 0.5], [3.0, -6.0, 1.0, 2.0], [0.3, 0.5, -1.9, 1.1], [5.0, 1.0, 2.0, -8.0]],
[[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]], [[-2.6, 0.6, 0.2, 1.8], [2.0, -6.0, 3.0, 1.0], [0.1, 0.5, -1.3, 0.7], [0.8, 0.6, 0.2, -1.6]],
])); ])));
} }
} }
let data = trajectory_generator(&net, 300, 200.0); let data = trajectory_generator(&net, 300, 300.0);
let (CIM, M, T) = pl.fit(&net, &data, 2, None); let (CIM, M, T) = pl.fit(&net, &data, 2, None);
print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T); print!("CIM: {:?}\nM: {:?}\nT: {:?}\n", CIM, M, T);
assert_eq!(CIM.shape(), [9, 4, 4]); assert_eq!(CIM.shape(), [9, 4, 4]);

@ -12,7 +12,7 @@ fn create_ternary_discrete_time_continous_param() -> DiscreteStatesContinousTime
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
params.cim = Some(cim); params.set_cim(cim);
params params
} }
@ -85,12 +85,12 @@ fn test_validate_params_cim_not_initialized() {
fn test_validate_params_wrong_shape() { fn test_validate_params_wrong_shape() {
let mut param = utils::generate_discrete_time_continous_param(4); let mut param = utils::generate_discrete_time_continous_param(4);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
param.cim = Some(cim); let result = param.set_cim(cim);
assert_eq!( assert_eq!(
Err(ParamsError::InvalidCIM(String::from( Err(ParamsError::InvalidCIM(String::from(
"Incompatible shape [1, 3, 3] with domain 4" "Incompatible shape [1, 3, 3] with domain 4"
))), ))),
param.validate_params() result
); );
} }
@ -99,12 +99,12 @@ fn test_validate_params_wrong_shape() {
fn test_validate_params_positive_diag() { fn test_validate_params_positive_diag() {
let mut param = utils::generate_discrete_time_continous_param(3); let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]]; let cim = array![[[2.0, -3.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.7, -4.0]]];
param.cim = Some(cim); let result = param.set_cim(cim);
assert_eq!( assert_eq!(
Err(ParamsError::InvalidCIM(String::from( Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive", "The diagonal of each cim must be non-positive",
))), ))),
param.validate_params() result
); );
} }
@ -113,11 +113,11 @@ fn test_validate_params_positive_diag() {
fn test_validate_params_row_not_sum_to_zero() { fn test_validate_params_row_not_sum_to_zero() {
let mut param = utils::generate_discrete_time_continous_param(3); let mut param = utils::generate_discrete_time_continous_param(3);
let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]]; let cim = array![[[-3.0, 2.0, 1.0], [1.0, -5.0, 4.0], [2.3, 1.701, -4.0]]];
param.cim = Some(cim); let result = param.set_cim(cim);
assert_eq!( assert_eq!(
Err(ParamsError::InvalidCIM(String::from( Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0" "The sum of each row must be 0"
))), ))),
param.validate_params() result
); );
} }

@ -23,14 +23,14 @@ fn run_sampling() {
match &mut net.get_node_mut(n1).params { match &mut net.get_node_mut(n1).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[[[-3.0,3.0],[2.0,-2.0]]])); param.set_cim(arr3(&[[[-3.0,3.0],[2.0,-2.0]]]));
} }
} }
match &mut net.get_node_mut(n2).params { match &mut net.get_node_mut(n2).params {
params::Params::DiscreteStatesContinousTime(param) => { params::Params::DiscreteStatesContinousTime(param) => {
param.cim = Some (arr3(&[ param.set_cim(arr3(&[
[[-1.0,1.0],[4.0,-4.0]], [[-1.0,1.0],[4.0,-4.0]],
[[-6.0,6.0],[2.0,-2.0]]])); [[-6.0,6.0],[2.0,-2.0]]]));
} }

@ -8,7 +8,7 @@ pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{ pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
let mut domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect(); let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::init(domain) params::DiscreteStatesContinousTimeParams::init(domain)
} }

Loading…
Cancel
Save