Old engine for Continuous Time Bayesian Networks. Superseded by reCTBN. 🐍
https://github.com/madlabunimib/PyCTBN
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
2.8 KiB
103 lines
2.8 KiB
"""
|
|
Assertion helpers for arithmetic tests.
|
|
"""
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from pandas import DataFrame, Index, Series
|
|
import pandas._testing as tm
|
|
|
|
|
|
def assert_invalid_addsub_type(left, right, msg=None):
|
|
"""
|
|
Helper to assert that left and right can be neither added nor subtracted.
|
|
|
|
Parameters
|
|
----------
|
|
left : object
|
|
right : object
|
|
msg : str or None, default None
|
|
"""
|
|
with pytest.raises(TypeError, match=msg):
|
|
left + right
|
|
with pytest.raises(TypeError, match=msg):
|
|
right + left
|
|
with pytest.raises(TypeError, match=msg):
|
|
left - right
|
|
with pytest.raises(TypeError, match=msg):
|
|
right - left
|
|
|
|
|
|
def get_upcast_box(box, vector):
|
|
"""
|
|
Given two box-types, find the one that takes priority
|
|
"""
|
|
if box is DataFrame or isinstance(vector, DataFrame):
|
|
return DataFrame
|
|
if box is Series or isinstance(vector, Series):
|
|
return Series
|
|
if box is Index or isinstance(vector, Index):
|
|
return Index
|
|
return box
|
|
|
|
|
|
def assert_invalid_comparison(left, right, box):
|
|
"""
|
|
Assert that comparison operations with mismatched types behave correctly.
|
|
|
|
Parameters
|
|
----------
|
|
left : np.ndarray, ExtensionArray, Index, or Series
|
|
right : object
|
|
box : {pd.DataFrame, pd.Series, pd.Index, tm.to_array}
|
|
"""
|
|
# Not for tznaive-tzaware comparison
|
|
|
|
# Note: not quite the same as how we do this for tm.box_expected
|
|
xbox = box if box is not Index else np.array
|
|
|
|
result = left == right
|
|
expected = xbox(np.zeros(result.shape, dtype=np.bool_))
|
|
|
|
tm.assert_equal(result, expected)
|
|
|
|
result = right == left
|
|
tm.assert_equal(result, expected)
|
|
|
|
result = left != right
|
|
tm.assert_equal(result, ~expected)
|
|
|
|
result = right != left
|
|
tm.assert_equal(result, ~expected)
|
|
|
|
msg = "|".join(
|
|
[
|
|
"Invalid comparison between",
|
|
"Cannot compare type",
|
|
"not supported between",
|
|
"invalid type promotion",
|
|
(
|
|
# GH#36706 npdev 1.20.0 2020-09-28
|
|
r"The DTypes <class 'numpy.dtype\[datetime64\]'> and "
|
|
r"<class 'numpy.dtype\[int64\]'> do not have a common DType. "
|
|
"For example they cannot be stored in a single array unless the "
|
|
"dtype is `object`."
|
|
),
|
|
]
|
|
)
|
|
with pytest.raises(TypeError, match=msg):
|
|
left < right
|
|
with pytest.raises(TypeError, match=msg):
|
|
left <= right
|
|
with pytest.raises(TypeError, match=msg):
|
|
left > right
|
|
with pytest.raises(TypeError, match=msg):
|
|
left >= right
|
|
with pytest.raises(TypeError, match=msg):
|
|
right < left
|
|
with pytest.raises(TypeError, match=msg):
|
|
right <= left
|
|
with pytest.raises(TypeError, match=msg):
|
|
right > left
|
|
with pytest.raises(TypeError, match=msg):
|
|
right >= left
|
|
|