|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytest |
|
|
|
import pyarrow as pa |
|
import pyarrow.compute as pc |
|
from pyarrow.compute import field |
|
|
|
try: |
|
from pyarrow.acero import ( |
|
Declaration, |
|
TableSourceNodeOptions, |
|
FilterNodeOptions, |
|
ProjectNodeOptions, |
|
AggregateNodeOptions, |
|
OrderByNodeOptions, |
|
HashJoinNodeOptions, |
|
AsofJoinNodeOptions, |
|
) |
|
except ImportError: |
|
pass |
|
|
|
try: |
|
import pyarrow.dataset as ds |
|
from pyarrow.acero import ScanNodeOptions |
|
except ImportError: |
|
ds = None |
|
|
|
pytestmark = pytest.mark.acero |
|
|
|
|
|
@pytest.fixture |
|
def table_source(): |
|
table = pa.table({'a': [1, 2, 3], 'b': [4, 5, 6]}) |
|
table_opts = TableSourceNodeOptions(table) |
|
table_source = Declaration("table_source", options=table_opts) |
|
return table_source |
|
|
|
|
|
def test_declaration(): |
|
|
|
table = pa.table({'a': [1, 2, 3], 'b': [4, 5, 6]}) |
|
table_opts = TableSourceNodeOptions(table) |
|
filter_opts = FilterNodeOptions(field('a') > 1) |
|
|
|
|
|
decl = Declaration.from_sequence([ |
|
Declaration("table_source", options=table_opts), |
|
Declaration("filter", options=filter_opts) |
|
]) |
|
result = decl.to_table() |
|
assert result.equals(table.slice(1, 2)) |
|
|
|
|
|
table_source = Declaration("table_source", options=table_opts) |
|
filtered = Declaration("filter", options=filter_opts, inputs=[table_source]) |
|
result = filtered.to_table() |
|
assert result.equals(table.slice(1, 2)) |
|
|
|
|
|
def test_declaration_repr(table_source): |
|
|
|
assert "TableSourceNode" in str(table_source) |
|
assert "TableSourceNode" in repr(table_source) |
|
|
|
|
|
def test_declaration_to_reader(table_source): |
|
with table_source.to_reader() as reader: |
|
assert reader.schema == pa.schema([("a", pa.int64()), ("b", pa.int64())]) |
|
result = reader.read_all() |
|
expected = pa.table({'a': [1, 2, 3], 'b': [4, 5, 6]}) |
|
assert result.equals(expected) |
|
|
|
|
|
def test_table_source(): |
|
with pytest.raises(TypeError): |
|
TableSourceNodeOptions(pa.record_batch([pa.array([1, 2, 3])], ["a"])) |
|
|
|
table_source = TableSourceNodeOptions(None) |
|
decl = Declaration("table_source", table_source) |
|
with pytest.raises( |
|
ValueError, match="TableSourceNode requires table which is not null" |
|
): |
|
_ = decl.to_table() |
|
|
|
|
|
def test_filter(table_source): |
|
|
|
decl = Declaration.from_sequence([ |
|
table_source, |
|
Declaration("filter", options=FilterNodeOptions(field("c") > 1)) |
|
]) |
|
with pytest.raises(ValueError, match=r"No match for FieldRef.Name\(c\)"): |
|
_ = decl.to_table() |
|
|
|
|
|
with pytest.raises(TypeError): |
|
FilterNodeOptions(pa.array([True, False, True])) |
|
with pytest.raises(TypeError): |
|
FilterNodeOptions(None) |
|
|
|
|
|
def test_project(table_source): |
|
|
|
decl = Declaration.from_sequence([ |
|
table_source, |
|
Declaration("project", ProjectNodeOptions([pc.multiply(field("a"), 2)])) |
|
]) |
|
result = decl.to_table() |
|
assert result.schema.names == ["multiply(a, 2)"] |
|
assert result[0].to_pylist() == [2, 4, 6] |
|
|
|
|
|
decl = Declaration.from_sequence([ |
|
table_source, |
|
Declaration("project", ProjectNodeOptions([pc.multiply(field("a"), 2)], ["a2"])) |
|
]) |
|
result = decl.to_table() |
|
assert result.schema.names == ["a2"] |
|
assert result["a2"].to_pylist() == [2, 4, 6] |
|
|
|
|
|
with pytest.raises(ValueError): |
|
ProjectNodeOptions([pc.multiply(field("a"), 2)], ["a2", "b2"]) |
|
|
|
|
|
decl = Declaration.from_sequence([ |
|
table_source, |
|
Declaration("project", ProjectNodeOptions([pc.sum(field("a"))])) |
|
]) |
|
with pytest.raises(ValueError, match="cannot Execute non-scalar expression"): |
|
_ = decl.to_table() |
|
|
|
|
|
def test_aggregate_scalar(table_source): |
|
decl = Declaration.from_sequence([ |
|
table_source, |
|
Declaration("aggregate", AggregateNodeOptions([("a", "sum", None, "a_sum")])) |
|
]) |
|
result = decl.to_table() |
|
assert result.schema.names == ["a_sum"] |
|
assert result["a_sum"].to_pylist() == [6] |
|
|
|
|
|
table = pa.table({'a': [1, 2, None]}) |
|
aggr_opts = AggregateNodeOptions( |
|
[("a", "sum", pc.ScalarAggregateOptions(skip_nulls=False), "a_sum")] |
|
) |
|
decl = Declaration.from_sequence([ |
|
Declaration("table_source", TableSourceNodeOptions(table)), |
|
Declaration("aggregate", aggr_opts), |
|
]) |
|
result = decl.to_table() |
|
assert result.schema.names == ["a_sum"] |
|
assert result["a_sum"].to_pylist() == [None] |
|
|
|
|
|
for target in ["a", field("a"), 0, field(0), ["a"], [field("a")], [0]]: |
|
aggr_opts = AggregateNodeOptions([(target, "sum", None, "a_sum")]) |
|
decl = Declaration.from_sequence( |
|
[table_source, Declaration("aggregate", aggr_opts)] |
|
) |
|
result = decl.to_table() |
|
assert result.schema.names == ["a_sum"] |
|
assert result["a_sum"].to_pylist() == [6] |
|
|
|
|
|
aggr_opts = AggregateNodeOptions([(["a", "b"], "sum", None, "a_sum")]) |
|
decl = Declaration.from_sequence( |
|
[table_source, Declaration("aggregate", aggr_opts)] |
|
) |
|
with pytest.raises( |
|
ValueError, match="Function 'sum' accepts 1 arguments but 2 passed" |
|
): |
|
_ = decl.to_table() |
|
|
|
|
|
aggr_opts = AggregateNodeOptions([("a", "hash_sum", None, "a_sum")]) |
|
decl = Declaration.from_sequence( |
|
[table_source, Declaration("aggregate", aggr_opts)] |
|
) |
|
with pytest.raises(ValueError, match="is a hash aggregate function"): |
|
_ = decl.to_table() |
|
|
|
|
|
def test_aggregate_hash(): |
|
table = pa.table({'a': [1, 2, None], 'b': ["foo", "bar", "foo"]}) |
|
table_opts = TableSourceNodeOptions(table) |
|
table_source = Declaration("table_source", options=table_opts) |
|
|
|
|
|
aggr_opts = AggregateNodeOptions( |
|
[("a", "hash_count", None, "count(a)")], keys=["b"]) |
|
decl = Declaration.from_sequence([ |
|
table_source, Declaration("aggregate", aggr_opts) |
|
]) |
|
result = decl.to_table() |
|
expected = pa.table({"b": ["foo", "bar"], "count(a)": [1, 1]}) |
|
assert result.equals(expected) |
|
|
|
|
|
aggr_opts = AggregateNodeOptions( |
|
[("a", "hash_count", pc.CountOptions("all"), "count(a)")], keys=["b"] |
|
) |
|
decl = Declaration.from_sequence([ |
|
table_source, Declaration("aggregate", aggr_opts) |
|
]) |
|
result = decl.to_table() |
|
expected_all = pa.table({"b": ["foo", "bar"], "count(a)": [2, 1]}) |
|
assert result.equals(expected_all) |
|
|
|
|
|
aggr_opts = AggregateNodeOptions( |
|
[("a", "hash_count", None, "count(a)")], keys=[field("b")] |
|
) |
|
decl = Declaration.from_sequence([ |
|
table_source, Declaration("aggregate", aggr_opts) |
|
]) |
|
result = decl.to_table() |
|
assert result.equals(expected) |
|
|
|
|
|
|
|
aggr_opts = AggregateNodeOptions([("a", "sum", None, "a_sum")], keys=["b"]) |
|
decl = Declaration.from_sequence([ |
|
table_source, Declaration("aggregate", aggr_opts) |
|
]) |
|
with pytest.raises(ValueError): |
|
_ = decl.to_table() |
|
|
|
|
|
def test_order_by(): |
|
table = pa.table({'a': [1, 2, 3, 4], 'b': [1, 3, None, 2]}) |
|
table_source = Declaration("table_source", TableSourceNodeOptions(table)) |
|
|
|
ord_opts = OrderByNodeOptions([("b", "ascending")]) |
|
decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) |
|
result = decl.to_table() |
|
expected = pa.table({"a": [1, 4, 2, 3], "b": [1, 2, 3, None]}) |
|
assert result.equals(expected) |
|
|
|
ord_opts = OrderByNodeOptions([(field("b"), "descending")]) |
|
decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) |
|
result = decl.to_table() |
|
expected = pa.table({"a": [2, 4, 1, 3], "b": [3, 2, 1, None]}) |
|
assert result.equals(expected) |
|
|
|
ord_opts = OrderByNodeOptions([(1, "descending")], null_placement="at_start") |
|
decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) |
|
result = decl.to_table() |
|
expected = pa.table({"a": [3, 2, 4, 1], "b": [None, 3, 2, 1]}) |
|
assert result.equals(expected) |
|
|
|
|
|
ord_opts = OrderByNodeOptions([]) |
|
decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) |
|
with pytest.raises( |
|
ValueError, match="`ordering` must be an explicit non-empty ordering" |
|
): |
|
_ = decl.to_table() |
|
|
|
with pytest.raises(ValueError, match="\"decreasing\" is not a valid sort order"): |
|
_ = OrderByNodeOptions([("b", "decreasing")]) |
|
|
|
with pytest.raises(ValueError, match="\"start\" is not a valid null placement"): |
|
_ = OrderByNodeOptions([("b", "ascending")], null_placement="start") |
|
|
|
|
|
def test_hash_join(): |
|
left = pa.table({'key': [1, 2, 3], 'a': [4, 5, 6]}) |
|
left_source = Declaration("table_source", options=TableSourceNodeOptions(left)) |
|
right = pa.table({'key': [2, 3, 4], 'b': [4, 5, 6]}) |
|
right_source = Declaration("table_source", options=TableSourceNodeOptions(right)) |
|
|
|
|
|
join_opts = HashJoinNodeOptions("inner", left_keys="key", right_keys="key") |
|
joined = Declaration( |
|
"hashjoin", options=join_opts, inputs=[left_source, right_source]) |
|
result = joined.to_table() |
|
expected = pa.table( |
|
[[2, 3], [5, 6], [2, 3], [4, 5]], |
|
names=["key", "a", "key", "b"]) |
|
assert result.equals(expected) |
|
|
|
for keys in [field("key"), ["key"], [field("key")]]: |
|
join_opts = HashJoinNodeOptions("inner", left_keys=keys, right_keys=keys) |
|
joined = Declaration( |
|
"hashjoin", options=join_opts, inputs=[left_source, right_source]) |
|
result = joined.to_table() |
|
assert result.equals(expected) |
|
|
|
|
|
join_opts = HashJoinNodeOptions( |
|
"left outer", left_keys="key", right_keys="key") |
|
joined = Declaration( |
|
"hashjoin", options=join_opts, inputs=[left_source, right_source]) |
|
result = joined.to_table() |
|
expected = pa.table( |
|
[[1, 2, 3], [4, 5, 6], [None, 2, 3], [None, 4, 5]], |
|
names=["key", "a", "key", "b"] |
|
) |
|
assert result.sort_by("a").equals(expected) |
|
|
|
|
|
join_opts = HashJoinNodeOptions( |
|
"left outer", left_keys="key", right_keys="key", |
|
output_suffix_for_left="_left", output_suffix_for_right="_right") |
|
joined = Declaration( |
|
"hashjoin", options=join_opts, inputs=[left_source, right_source]) |
|
result = joined.to_table() |
|
expected = pa.table( |
|
[[1, 2, 3], [4, 5, 6], [None, 2, 3], [None, 4, 5]], |
|
names=["key_left", "a", "key_right", "b"] |
|
) |
|
assert result.sort_by("a").equals(expected) |
|
|
|
|
|
join_opts = HashJoinNodeOptions( |
|
"left outer", left_keys="key", right_keys="key", |
|
left_output=["key", "a"], right_output=[field("b")]) |
|
joined = Declaration( |
|
"hashjoin", options=join_opts, inputs=[left_source, right_source]) |
|
result = joined.to_table() |
|
expected = pa.table( |
|
[[1, 2, 3], [4, 5, 6], [None, 4, 5]], |
|
names=["key", "a", "b"] |
|
) |
|
assert result.sort_by("a").equals(expected) |
|
|
|
|
|
def test_asof_join(): |
|
left = pa.table({'key': [1, 2, 3], 'ts': [1, 1, 1], 'a': [4, 5, 6]}) |
|
left_source = Declaration("table_source", options=TableSourceNodeOptions(left)) |
|
right = pa.table({'key': [2, 3, 4], 'ts': [2, 5, 2], 'b': [4, 5, 6]}) |
|
right_source = Declaration("table_source", options=TableSourceNodeOptions(right)) |
|
|
|
|
|
join_opts = AsofJoinNodeOptions( |
|
left_on="ts", left_by=["key"], |
|
right_on="ts", right_by=["key"], |
|
tolerance=1, |
|
) |
|
joined = Declaration( |
|
"asofjoin", options=join_opts, inputs=[left_source, right_source] |
|
) |
|
result = joined.to_table() |
|
expected = pa.table( |
|
[[1, 2, 3], [1, 1, 1], [4, 5, 6], [None, 4, None]], |
|
names=["key", "ts", "a", "b"]) |
|
assert result == expected |
|
|
|
for by in [field("key"), ["key"], [field("key")]]: |
|
for on in [field("ts"), "ts"]: |
|
join_opts = AsofJoinNodeOptions( |
|
left_on=on, left_by=by, |
|
right_on=on, right_by=by, |
|
tolerance=1, |
|
) |
|
joined = Declaration( |
|
"asofjoin", options=join_opts, inputs=[left_source, right_source]) |
|
result = joined.to_table() |
|
assert result == expected |
|
|
|
|
|
@pytest.mark.dataset |
|
def test_scan(tempdir): |
|
table = pa.table({'a': [1, 2, 3], 'b': [4, 5, 6]}) |
|
ds.write_dataset(table, tempdir / "dataset", format="parquet") |
|
dataset = ds.dataset(tempdir / "dataset", format="parquet") |
|
decl = Declaration("scan", ScanNodeOptions(dataset)) |
|
result = decl.to_table() |
|
assert result.schema.names == [ |
|
"a", "b", "__fragment_index", "__batch_index", |
|
"__last_in_fragment", "__filename" |
|
] |
|
assert result.select(["a", "b"]).equals(table) |
|
|
|
|
|
|
|
scan_opts = ScanNodeOptions(dataset, filter=field('a') > 1) |
|
decl = Declaration("scan", scan_opts) |
|
|
|
assert decl.to_table().num_rows == 3 |
|
|
|
scan_opts = ScanNodeOptions(dataset, filter=field('a') > 4) |
|
decl = Declaration("scan", scan_opts) |
|
|
|
assert decl.to_table().num_rows == 0 |
|
|
|
|
|
|
|
scan_opts = ScanNodeOptions(dataset, columns={"a2": pc.multiply(field("a"), 2)}) |
|
decl = Declaration("scan", scan_opts) |
|
result = decl.to_table() |
|
|
|
assert result["a"].to_pylist() == [1, 2, 3] |
|
|
|
assert pc.all(result["b"].is_null()).as_py() |
|
|