Old engine for Continuous Time Bayesian Networks. Superseded by reCTBN. 🐍
https://github.com/madlabunimib/PyCTBN
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
6.7 KiB
233 lines
6.7 KiB
from contextlib import ExitStack as does_not_raise
|
|
from datetime import datetime
|
|
import os
|
|
import platform
|
|
import random
|
|
import string
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import pytz
|
|
|
|
import pandas as pd
|
|
from pandas import DataFrame
|
|
|
|
api_exceptions = pytest.importorskip("google.api_core.exceptions")
|
|
bigquery = pytest.importorskip("google.cloud.bigquery")
|
|
service_account = pytest.importorskip("google.oauth2.service_account")
|
|
pandas_gbq = pytest.importorskip("pandas_gbq")
|
|
|
|
PROJECT_ID = None
|
|
PRIVATE_KEY_JSON_PATH = None
|
|
PRIVATE_KEY_JSON_CONTENTS = None
|
|
|
|
VERSION = platform.python_version()
|
|
|
|
|
|
def _skip_if_no_project_id():
|
|
if not _get_project_id():
|
|
pytest.skip("Cannot run integration tests without a project id")
|
|
|
|
|
|
def _skip_if_no_private_key_path():
|
|
if not _get_private_key_path():
|
|
pytest.skip("Cannot run integration tests without a private key json file path")
|
|
|
|
|
|
def _in_travis_environment():
|
|
return "TRAVIS_BUILD_DIR" in os.environ and "GBQ_PROJECT_ID" in os.environ
|
|
|
|
|
|
def _get_project_id():
|
|
if _in_travis_environment():
|
|
return os.environ.get("GBQ_PROJECT_ID")
|
|
return PROJECT_ID or os.environ.get("GBQ_PROJECT_ID")
|
|
|
|
|
|
def _get_private_key_path():
|
|
if _in_travis_environment():
|
|
return os.path.join(
|
|
*[os.environ.get("TRAVIS_BUILD_DIR"), "ci", "travis_gbq.json"]
|
|
)
|
|
|
|
private_key_path = PRIVATE_KEY_JSON_PATH
|
|
if not private_key_path:
|
|
private_key_path = os.environ.get("GBQ_GOOGLE_APPLICATION_CREDENTIALS")
|
|
return private_key_path
|
|
|
|
|
|
def _get_credentials():
|
|
private_key_path = _get_private_key_path()
|
|
if private_key_path:
|
|
return service_account.Credentials.from_service_account_file(private_key_path)
|
|
|
|
|
|
def _get_client():
|
|
project_id = _get_project_id()
|
|
credentials = _get_credentials()
|
|
return bigquery.Client(project=project_id, credentials=credentials)
|
|
|
|
|
|
def generate_rand_str(length: int = 10) -> str:
|
|
return "".join(random.choices(string.ascii_lowercase, k=length))
|
|
|
|
|
|
def make_mixed_dataframe_v2(test_size):
|
|
# create df to test for all BQ datatypes except RECORD
|
|
bools = np.random.randint(2, size=(1, test_size)).astype(bool)
|
|
flts = np.random.randn(1, test_size)
|
|
ints = np.random.randint(1, 10, size=(1, test_size))
|
|
strs = np.random.randint(1, 10, size=(1, test_size)).astype(str)
|
|
times = [datetime.now(pytz.timezone("US/Arizona")) for t in range(test_size)]
|
|
return DataFrame(
|
|
{
|
|
"bools": bools[0],
|
|
"flts": flts[0],
|
|
"ints": ints[0],
|
|
"strs": strs[0],
|
|
"times": times[0],
|
|
},
|
|
index=range(test_size),
|
|
)
|
|
|
|
|
|
def test_read_gbq_without_deprecated_kwargs(monkeypatch):
|
|
captured_kwargs = {}
|
|
|
|
def mock_read_gbq(sql, **kwargs):
|
|
captured_kwargs.update(kwargs)
|
|
return DataFrame([[1.0]])
|
|
|
|
monkeypatch.setattr("pandas_gbq.read_gbq", mock_read_gbq)
|
|
pd.read_gbq("SELECT 1")
|
|
|
|
assert "verbose" not in captured_kwargs
|
|
assert "private_key" not in captured_kwargs
|
|
|
|
|
|
def test_read_gbq_with_new_kwargs(monkeypatch):
|
|
captured_kwargs = {}
|
|
|
|
def mock_read_gbq(sql, **kwargs):
|
|
captured_kwargs.update(kwargs)
|
|
return DataFrame([[1.0]])
|
|
|
|
monkeypatch.setattr("pandas_gbq.read_gbq", mock_read_gbq)
|
|
pd.read_gbq("SELECT 1", use_bqstorage_api=True, max_results=1)
|
|
|
|
assert captured_kwargs["use_bqstorage_api"]
|
|
assert captured_kwargs["max_results"]
|
|
|
|
|
|
def test_read_gbq_without_new_kwargs(monkeypatch):
|
|
captured_kwargs = {}
|
|
|
|
def mock_read_gbq(sql, **kwargs):
|
|
captured_kwargs.update(kwargs)
|
|
return DataFrame([[1.0]])
|
|
|
|
monkeypatch.setattr("pandas_gbq.read_gbq", mock_read_gbq)
|
|
pd.read_gbq("SELECT 1")
|
|
|
|
assert "use_bqstorage_api" not in captured_kwargs
|
|
assert "max_results" not in captured_kwargs
|
|
|
|
|
|
@pytest.mark.parametrize("progress_bar", [None, "foo"])
|
|
def test_read_gbq_progress_bar_type_kwarg(monkeypatch, progress_bar):
|
|
# GH 29857
|
|
captured_kwargs = {}
|
|
|
|
def mock_read_gbq(sql, **kwargs):
|
|
captured_kwargs.update(kwargs)
|
|
return DataFrame([[1.0]])
|
|
|
|
monkeypatch.setattr("pandas_gbq.read_gbq", mock_read_gbq)
|
|
pd.read_gbq("SELECT 1", progress_bar_type=progress_bar)
|
|
assert "progress_bar_type" in captured_kwargs
|
|
|
|
|
|
@pytest.mark.single
|
|
class TestToGBQIntegrationWithServiceAccountKeyPath:
|
|
@pytest.fixture()
|
|
def gbq_dataset(self):
|
|
# Setup Dataset
|
|
_skip_if_no_project_id()
|
|
_skip_if_no_private_key_path()
|
|
|
|
dataset_id = "pydata_pandas_bq_testing_" + generate_rand_str()
|
|
|
|
self.client = _get_client()
|
|
self.dataset = self.client.dataset(dataset_id)
|
|
|
|
# Create the dataset
|
|
self.client.create_dataset(bigquery.Dataset(self.dataset))
|
|
|
|
table_name = generate_rand_str()
|
|
destination_table = f"{dataset_id}.{table_name}"
|
|
yield destination_table
|
|
|
|
# Teardown Dataset
|
|
self.client.delete_dataset(self.dataset, delete_contents=True)
|
|
|
|
def test_roundtrip(self, gbq_dataset):
|
|
destination_table = gbq_dataset
|
|
|
|
test_size = 20001
|
|
df = make_mixed_dataframe_v2(test_size)
|
|
|
|
df.to_gbq(
|
|
destination_table,
|
|
_get_project_id(),
|
|
chunksize=None,
|
|
credentials=_get_credentials(),
|
|
)
|
|
|
|
result = pd.read_gbq(
|
|
f"SELECT COUNT(*) AS num_rows FROM {destination_table}",
|
|
project_id=_get_project_id(),
|
|
credentials=_get_credentials(),
|
|
dialect="standard",
|
|
)
|
|
assert result["num_rows"][0] == test_size
|
|
|
|
@pytest.mark.parametrize(
|
|
"if_exists, expected_num_rows, expectation",
|
|
[
|
|
("append", 300, does_not_raise()),
|
|
("fail", 200, pytest.raises(pandas_gbq.gbq.TableCreationError)),
|
|
("replace", 100, does_not_raise()),
|
|
],
|
|
)
|
|
def test_gbq_if_exists(
|
|
self, if_exists, expected_num_rows, expectation, gbq_dataset
|
|
):
|
|
# GH 29598
|
|
destination_table = gbq_dataset
|
|
|
|
test_size = 200
|
|
df = make_mixed_dataframe_v2(test_size)
|
|
|
|
df.to_gbq(
|
|
destination_table,
|
|
_get_project_id(),
|
|
chunksize=None,
|
|
credentials=_get_credentials(),
|
|
)
|
|
|
|
with expectation:
|
|
df.iloc[:100].to_gbq(
|
|
destination_table,
|
|
_get_project_id(),
|
|
if_exists=if_exists,
|
|
chunksize=None,
|
|
credentials=_get_credentials(),
|
|
)
|
|
|
|
result = pd.read_gbq(
|
|
f"SELECT COUNT(*) AS num_rows FROM {destination_table}",
|
|
project_id=_get_project_id(),
|
|
credentials=_get_credentials(),
|
|
dialect="standard",
|
|
)
|
|
assert result["num_rows"][0] == expected_num_rows
|
|
|