|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from datetime import datetime as dt |
|
import pyarrow as pa |
|
from pyarrow.vendored.version import Version |
|
import pytest |
|
|
|
try: |
|
import numpy as np |
|
except ImportError: |
|
np = None |
|
|
|
import pyarrow.interchange as pi |
|
from pyarrow.interchange.column import ( |
|
_PyArrowColumn, |
|
ColumnNullType, |
|
DtypeKind, |
|
) |
|
from pyarrow.interchange.from_dataframe import _from_dataframe |
|
|
|
try: |
|
import pandas as pd |
|
|
|
except ImportError: |
|
pass |
|
|
|
|
|
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns']) |
|
@pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30']) |
|
def test_datetime(unit, tz): |
|
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), None] |
|
table = pa.table({"A": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))}) |
|
col = table.__dataframe__().get_column_by_name("A") |
|
|
|
assert col.size() == 3 |
|
assert col.offset == 0 |
|
assert col.null_count == 1 |
|
assert col.dtype[0] == DtypeKind.DATETIME |
|
assert col.describe_null == (ColumnNullType.USE_BITMASK, 0) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
["test_data", "kind"], |
|
[ |
|
(["foo", "bar"], 21), |
|
([1.5, 2.5, 3.5], 2), |
|
([1, 2, 3, 4], 0), |
|
], |
|
) |
|
def test_array_to_pyarrowcolumn(test_data, kind): |
|
arr = pa.array(test_data) |
|
arr_column = _PyArrowColumn(arr) |
|
|
|
assert arr_column._col == arr |
|
assert arr_column.size() == len(test_data) |
|
assert arr_column.dtype[0] == kind |
|
assert arr_column.num_chunks() == 1 |
|
assert arr_column.null_count == 0 |
|
assert arr_column.get_buffers()["validity"] is None |
|
assert len(list(arr_column.get_chunks())) == 1 |
|
|
|
for chunk in arr_column.get_chunks(): |
|
assert chunk == arr_column |
|
|
|
|
|
def test_offset_of_sliced_array(): |
|
arr = pa.array([1, 2, 3, 4]) |
|
arr_sliced = arr.slice(2, 2) |
|
|
|
table = pa.table([arr], names=["arr"]) |
|
table_sliced = pa.table([arr_sliced], names=["arr_sliced"]) |
|
|
|
col = table_sliced.__dataframe__().get_column(0) |
|
assert col.offset == 2 |
|
|
|
result = _from_dataframe(table_sliced.__dataframe__()) |
|
assert table_sliced.equals(result) |
|
assert not table.equals(result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.pandas |
|
@pytest.mark.parametrize( |
|
"uint", [pa.uint8(), pa.uint16(), pa.uint32()] |
|
) |
|
@pytest.mark.parametrize( |
|
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()] |
|
) |
|
@pytest.mark.parametrize( |
|
"float, np_float_str", [ |
|
|
|
(pa.float32(), "float32"), |
|
(pa.float64(), "float64") |
|
] |
|
) |
|
def test_pandas_roundtrip(uint, int, float, np_float_str): |
|
if Version(pd.__version__) < Version("1.5.0"): |
|
pytest.skip("__dataframe__ added to pandas in 1.5.0") |
|
|
|
arr = [1, 2, 3] |
|
table = pa.table( |
|
{ |
|
"a": pa.array(arr, type=uint), |
|
"b": pa.array(arr, type=int), |
|
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)), type=float), |
|
"d": [True, False, True], |
|
} |
|
) |
|
from pandas.api.interchange import ( |
|
from_dataframe as pandas_from_dataframe |
|
) |
|
pandas_df = pandas_from_dataframe(table) |
|
result = pi.from_dataframe(pandas_df) |
|
assert table.equals(result) |
|
|
|
table_protocol = table.__dataframe__() |
|
result_protocol = result.__dataframe__() |
|
|
|
assert table_protocol.num_columns() == result_protocol.num_columns() |
|
assert table_protocol.num_rows() == result_protocol.num_rows() |
|
assert table_protocol.num_chunks() == result_protocol.num_chunks() |
|
assert table_protocol.column_names() == result_protocol.column_names() |
|
|
|
|
|
@pytest.mark.pandas |
|
def test_pandas_roundtrip_string(): |
|
|
|
if Version(pd.__version__) < Version("1.6"): |
|
pytest.skip("Column.size() bug in pandas") |
|
|
|
arr = ["a", "", "c"] |
|
table = pa.table({"a": pa.array(arr)}) |
|
|
|
from pandas.api.interchange import ( |
|
from_dataframe as pandas_from_dataframe |
|
) |
|
|
|
pandas_df = pandas_from_dataframe(table) |
|
result = pi.from_dataframe(pandas_df) |
|
|
|
assert result["a"].to_pylist() == table["a"].to_pylist() |
|
assert pa.types.is_string(table["a"].type) |
|
assert pa.types.is_large_string(result["a"].type) |
|
|
|
table_protocol = table.__dataframe__() |
|
result_protocol = result.__dataframe__() |
|
|
|
assert table_protocol.num_columns() == result_protocol.num_columns() |
|
assert table_protocol.num_rows() == result_protocol.num_rows() |
|
assert table_protocol.num_chunks() == result_protocol.num_chunks() |
|
assert table_protocol.column_names() == result_protocol.column_names() |
|
|
|
|
|
@pytest.mark.pandas |
|
def test_pandas_roundtrip_large_string(): |
|
|
|
if Version(pd.__version__) < Version("1.6"): |
|
pytest.skip("Column.size() bug in pandas") |
|
|
|
arr = ["a", "", "c"] |
|
table = pa.table({"a_large": pa.array(arr, type=pa.large_string())}) |
|
|
|
from pandas.api.interchange import ( |
|
from_dataframe as pandas_from_dataframe |
|
) |
|
|
|
if Version(pd.__version__) >= Version("2.0.1"): |
|
pandas_df = pandas_from_dataframe(table) |
|
result = pi.from_dataframe(pandas_df) |
|
|
|
assert result["a_large"].to_pylist() == table["a_large"].to_pylist() |
|
assert pa.types.is_large_string(table["a_large"].type) |
|
assert pa.types.is_large_string(result["a_large"].type) |
|
|
|
table_protocol = table.__dataframe__() |
|
result_protocol = result.__dataframe__() |
|
|
|
assert table_protocol.num_columns() == result_protocol.num_columns() |
|
assert table_protocol.num_rows() == result_protocol.num_rows() |
|
assert table_protocol.num_chunks() == result_protocol.num_chunks() |
|
assert table_protocol.column_names() == result_protocol.column_names() |
|
|
|
else: |
|
|
|
|
|
|
|
with pytest.raises(AssertionError): |
|
pandas_from_dataframe(table) |
|
|
|
|
|
@pytest.mark.pandas |
|
def test_pandas_roundtrip_string_with_missing(): |
|
|
|
if Version(pd.__version__) < Version("1.6"): |
|
pytest.skip("Column.size() bug in pandas") |
|
|
|
arr = ["a", "", "c", None] |
|
table = pa.table({"a": pa.array(arr), |
|
"a_large": pa.array(arr, type=pa.large_string())}) |
|
|
|
from pandas.api.interchange import ( |
|
from_dataframe as pandas_from_dataframe |
|
) |
|
|
|
if Version(pd.__version__) >= Version("2.0.2"): |
|
pandas_df = pandas_from_dataframe(table) |
|
result = pi.from_dataframe(pandas_df) |
|
|
|
assert result["a"].to_pylist() == table["a"].to_pylist() |
|
assert pa.types.is_string(table["a"].type) |
|
assert pa.types.is_large_string(result["a"].type) |
|
|
|
assert result["a_large"].to_pylist() == table["a_large"].to_pylist() |
|
assert pa.types.is_large_string(table["a_large"].type) |
|
assert pa.types.is_large_string(result["a_large"].type) |
|
else: |
|
|
|
|
|
with pytest.raises(NotImplementedError): |
|
pandas_from_dataframe(table) |
|
|
|
|
|
@pytest.mark.pandas |
|
def test_pandas_roundtrip_categorical(): |
|
if Version(pd.__version__) < Version("2.0.2"): |
|
pytest.skip("Bitmasks not supported in pandas interchange implementation") |
|
|
|
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None] |
|
table = pa.table( |
|
{"weekday": pa.array(arr).dictionary_encode()} |
|
) |
|
|
|
from pandas.api.interchange import ( |
|
from_dataframe as pandas_from_dataframe |
|
) |
|
pandas_df = pandas_from_dataframe(table) |
|
result = pi.from_dataframe(pandas_df) |
|
|
|
assert result["weekday"].to_pylist() == table["weekday"].to_pylist() |
|
assert pa.types.is_dictionary(table["weekday"].type) |
|
assert pa.types.is_dictionary(result["weekday"].type) |
|
assert pa.types.is_string(table["weekday"].chunk(0).dictionary.type) |
|
assert pa.types.is_large_string(result["weekday"].chunk(0).dictionary.type) |
|
assert pa.types.is_int32(table["weekday"].chunk(0).indices.type) |
|
assert pa.types.is_int8(result["weekday"].chunk(0).indices.type) |
|
|
|
table_protocol = table.__dataframe__() |
|
result_protocol = result.__dataframe__() |
|
|
|
assert table_protocol.num_columns() == result_protocol.num_columns() |
|
assert table_protocol.num_rows() == result_protocol.num_rows() |
|
assert table_protocol.num_chunks() == result_protocol.num_chunks() |
|
assert table_protocol.column_names() == result_protocol.column_names() |
|
|
|
col_table = table_protocol.get_column(0) |
|
col_result = result_protocol.get_column(0) |
|
|
|
assert col_result.dtype[0] == DtypeKind.CATEGORICAL |
|
assert col_result.dtype[0] == col_table.dtype[0] |
|
assert col_result.size() == col_table.size() |
|
assert col_result.offset == col_table.offset |
|
|
|
desc_cat_table = col_result.describe_categorical |
|
desc_cat_result = col_result.describe_categorical |
|
|
|
assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] |
|
assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] |
|
assert isinstance(desc_cat_result["categories"]._col, pa.Array) |
|
|
|
|
|
@pytest.mark.pandas |
|
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns']) |
|
def test_pandas_roundtrip_datetime(unit): |
|
if Version(pd.__version__) < Version("1.5.0"): |
|
pytest.skip("__dataframe__ added to pandas in 1.5.0") |
|
from datetime import datetime as dt |
|
|
|
|
|
|
|
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)] |
|
table = pa.table({"a": pa.array(dt_arr, type=pa.timestamp(unit))}) |
|
|
|
if Version(pd.__version__) < Version("1.6"): |
|
|
|
|
|
expected = pa.table({"a": pa.array(dt_arr, type=pa.timestamp('ns'))}) |
|
else: |
|
expected = table |
|
|
|
from pandas.api.interchange import ( |
|
from_dataframe as pandas_from_dataframe |
|
) |
|
pandas_df = pandas_from_dataframe(table) |
|
result = pi.from_dataframe(pandas_df) |
|
|
|
assert expected.equals(result) |
|
|
|
expected_protocol = expected.__dataframe__() |
|
result_protocol = result.__dataframe__() |
|
|
|
assert expected_protocol.num_columns() == result_protocol.num_columns() |
|
assert expected_protocol.num_rows() == result_protocol.num_rows() |
|
assert expected_protocol.num_chunks() == result_protocol.num_chunks() |
|
assert expected_protocol.column_names() == result_protocol.column_names() |
|
|
|
|
|
@pytest.mark.pandas |
|
@pytest.mark.parametrize( |
|
"np_float_str", ["float32", "float64"] |
|
) |
|
def test_pandas_to_pyarrow_with_missing(np_float_str): |
|
if Version(pd.__version__) < Version("1.5.0"): |
|
pytest.skip("__dataframe__ added to pandas in 1.5.0") |
|
|
|
np_array = np.array([0, np.nan, 2], dtype=np.dtype(np_float_str)) |
|
datetime_array = [None, dt(2007, 7, 14), dt(2007, 7, 15)] |
|
df = pd.DataFrame({ |
|
|
|
"a": np_array, |
|
|
|
"dt": np.array(datetime_array, dtype="datetime64[ns]") |
|
}) |
|
expected = pa.table({ |
|
"a": pa.array(np_array, from_pandas=True), |
|
"dt": pa.array(datetime_array, type=pa.timestamp("ns")) |
|
}) |
|
result = pi.from_dataframe(df) |
|
|
|
assert result.equals(expected) |
|
|
|
|
|
@pytest.mark.pandas |
|
def test_pandas_to_pyarrow_float16_with_missing(): |
|
if Version(pd.__version__) < Version("1.5.0"): |
|
pytest.skip("__dataframe__ added to pandas in 1.5.0") |
|
|
|
|
|
|
|
|
|
np_array = np.array([0, np.nan, 2], dtype=np.float16) |
|
df = pd.DataFrame({"a": np_array}) |
|
|
|
with pytest.raises(NotImplementedError): |
|
pi.from_dataframe(df) |
|
|
|
|
|
@pytest.mark.numpy |
|
@pytest.mark.parametrize( |
|
"uint", [pa.uint8(), pa.uint16(), pa.uint32()] |
|
) |
|
@pytest.mark.parametrize( |
|
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()] |
|
) |
|
@pytest.mark.parametrize( |
|
"float, np_float_str", [ |
|
(pa.float16(), "float16"), |
|
(pa.float32(), "float32"), |
|
(pa.float64(), "float64") |
|
] |
|
) |
|
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns']) |
|
@pytest.mark.parametrize("tz", ['America/New_York', '+07:30', '-04:30']) |
|
@pytest.mark.parametrize("offset, length", [(0, 3), (0, 2), (1, 2), (2, 1)]) |
|
def test_pyarrow_roundtrip(uint, int, float, np_float_str, |
|
unit, tz, offset, length): |
|
|
|
from datetime import datetime as dt |
|
arr = [1, 2, None] |
|
dt_arr = [dt(2007, 7, 13), None, dt(2007, 7, 15)] |
|
|
|
table = pa.table( |
|
{ |
|
"a": pa.array(arr, type=uint), |
|
"b": pa.array(arr, type=int), |
|
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)), |
|
type=float, from_pandas=True), |
|
"d": [True, False, True], |
|
"e": [True, False, None], |
|
"f": ["a", None, "c"], |
|
"g": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz)) |
|
} |
|
) |
|
table = table.slice(offset, length) |
|
result = _from_dataframe(table.__dataframe__()) |
|
|
|
assert table.equals(result) |
|
|
|
table_protocol = table.__dataframe__() |
|
result_protocol = result.__dataframe__() |
|
|
|
assert table_protocol.num_columns() == result_protocol.num_columns() |
|
assert table_protocol.num_rows() == result_protocol.num_rows() |
|
assert table_protocol.num_chunks() == result_protocol.num_chunks() |
|
assert table_protocol.column_names() == result_protocol.column_names() |
|
|
|
|
|
@pytest.mark.parametrize("offset, length", [(0, 10), (0, 2), (7, 3), (2, 1)]) |
|
def test_pyarrow_roundtrip_categorical(offset, length): |
|
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", None, "Sun"] |
|
table = pa.table( |
|
{"weekday": pa.array(arr).dictionary_encode()} |
|
) |
|
table = table.slice(offset, length) |
|
result = _from_dataframe(table.__dataframe__()) |
|
|
|
assert table.equals(result) |
|
|
|
table_protocol = table.__dataframe__() |
|
result_protocol = result.__dataframe__() |
|
|
|
assert table_protocol.num_columns() == result_protocol.num_columns() |
|
assert table_protocol.num_rows() == result_protocol.num_rows() |
|
assert table_protocol.num_chunks() == result_protocol.num_chunks() |
|
assert table_protocol.column_names() == result_protocol.column_names() |
|
|
|
col_table = table_protocol.get_column(0) |
|
col_result = result_protocol.get_column(0) |
|
|
|
assert col_result.dtype[0] == DtypeKind.CATEGORICAL |
|
assert col_result.dtype[0] == col_table.dtype[0] |
|
assert col_result.size() == col_table.size() |
|
assert col_result.offset == col_table.offset |
|
|
|
desc_cat_table = col_table.describe_categorical |
|
desc_cat_result = col_result.describe_categorical |
|
|
|
assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] |
|
assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] |
|
assert isinstance(desc_cat_result["categories"]._col, pa.Array) |
|
|
|
|
|
@pytest.mark.large_memory |
|
def test_pyarrow_roundtrip_large_string(): |
|
|
|
data = np.array([b'x'*1024]*(3*1024**2), dtype='object') |
|
arr = pa.array(data, type=pa.large_string()) |
|
table = pa.table([arr], names=["large_string"]) |
|
|
|
result = _from_dataframe(table.__dataframe__()) |
|
col = result.__dataframe__().get_column(0) |
|
|
|
assert col.size() == 3*1024**2 |
|
assert pa.types.is_large_string(table[0].type) |
|
assert pa.types.is_large_string(result[0].type) |
|
|
|
assert table.equals(result) |
|
|
|
|
|
def test_nan_as_null(): |
|
table = pa.table({"a": [1, 2, 3, 4]}) |
|
with pytest.raises(RuntimeError): |
|
table.__dataframe__(nan_as_null=True) |
|
|
|
|
|
@pytest.mark.pandas |
|
def test_allow_copy_false(): |
|
if Version(pd.__version__) < Version("1.5.0"): |
|
pytest.skip("__dataframe__ added to pandas in 1.5.0") |
|
|
|
|
|
|
|
|
|
df = pd.DataFrame({"a": [0, 1.0, 2.0]}) |
|
with pytest.raises(RuntimeError): |
|
pi.from_dataframe(df, allow_copy=False) |
|
|
|
df = pd.DataFrame({ |
|
"dt": [None, dt(2007, 7, 14), dt(2007, 7, 15)] |
|
}) |
|
with pytest.raises(RuntimeError): |
|
pi.from_dataframe(df, allow_copy=False) |
|
|
|
|
|
@pytest.mark.pandas |
|
def test_allow_copy_false_bool_categorical(): |
|
if Version(pd.__version__) < Version("1.5.0"): |
|
pytest.skip("__dataframe__ added to pandas in 1.5.0") |
|
|
|
|
|
|
|
|
|
df = pd.DataFrame({"a": [None, False, True]}) |
|
with pytest.raises(RuntimeError): |
|
pi.from_dataframe(df, allow_copy=False) |
|
|
|
df = pd.DataFrame({"a": [True, False, True]}) |
|
with pytest.raises(RuntimeError): |
|
pi.from_dataframe(df, allow_copy=False) |
|
|
|
df = pd.DataFrame({"weekday": ["a", "b", None]}) |
|
df = df.astype("category") |
|
with pytest.raises(RuntimeError): |
|
pi.from_dataframe(df, allow_copy=False) |
|
|
|
df = pd.DataFrame({"weekday": ["a", "b", "c"]}) |
|
df = df.astype("category") |
|
with pytest.raises(RuntimeError): |
|
pi.from_dataframe(df, allow_copy=False) |
|
|
|
|
|
def test_empty_dataframe(): |
|
schema = pa.schema([('col1', pa.int8())]) |
|
df = pa.table([[]], schema=schema) |
|
dfi = df.__dataframe__() |
|
assert pi.from_dataframe(dfi) == df |
|
|