# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from collections import UserList import datetime import io import pathlib import pytest import random import socket import threading import weakref try: import numpy as np except ImportError: np = None import pyarrow as pa from pyarrow.tests.util import changed_environ, invoke_script try: from pandas.testing import assert_frame_equal import pandas as pd except ImportError: pass class IpcFixture: write_stats = None def __init__(self, sink_factory=lambda: io.BytesIO()): self._sink_factory = sink_factory self.sink = self.get_sink() def get_sink(self): return self._sink_factory() def get_source(self): return self.sink.getvalue() def write_batches(self, num_batches=5, as_table=False): nrows = 5 schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())]) writer = self._get_writer(self.sink, schema) batches = [] for i in range(num_batches): batch = pa.record_batch( [[random.random() for _ in range(nrows)], ['foo', None, 'bar', 'bazbaz', 'qux']], schema=schema) batches.append(batch) if as_table: table = pa.Table.from_batches(batches) writer.write_table(table) else: for batch in batches: writer.write_batch(batch) self.write_stats = writer.stats writer.close() return batches class FileFormatFixture(IpcFixture): is_file = True options = None def _get_writer(self, sink, schema): return pa.ipc.new_file(sink, schema, options=self.options) def _check_roundtrip(self, as_table=False): batches = self.write_batches(as_table=as_table) file_contents = pa.BufferReader(self.get_source()) reader = pa.ipc.open_file(file_contents) assert reader.num_record_batches == len(batches) for i, batch in enumerate(batches): # it works. Must convert back to DataFrame batch = reader.get_batch(i) assert batches[i].equals(batch) assert reader.schema.equals(batches[0].schema) assert isinstance(reader.stats, pa.ipc.ReadStats) assert isinstance(self.write_stats, pa.ipc.WriteStats) assert tuple(reader.stats) == tuple(self.write_stats) class StreamFormatFixture(IpcFixture): # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix use_legacy_ipc_format = False # ARROW-9395, for testing writing old metadata version options = None is_file = False def _get_writer(self, sink, schema): return pa.ipc.new_stream( sink, schema, use_legacy_format=self.use_legacy_ipc_format, options=self.options, ) class MessageFixture(IpcFixture): def _get_writer(self, sink, schema): return pa.RecordBatchStreamWriter(sink, schema) @pytest.fixture def ipc_fixture(): return IpcFixture() @pytest.fixture def file_fixture(): return FileFormatFixture() @pytest.fixture def stream_fixture(): return StreamFormatFixture() @pytest.fixture(params=[ pytest.param( 'file_fixture', id='File Format' ), pytest.param( 'stream_fixture', id='Stream Format' ) ]) def format_fixture(request): return request.getfixturevalue(request.param) def test_empty_file(): buf = b'' with pytest.raises(pa.ArrowInvalid): pa.ipc.open_file(pa.BufferReader(buf)) def test_file_simple_roundtrip(file_fixture): file_fixture._check_roundtrip(as_table=False) def test_file_write_table(file_fixture): file_fixture._check_roundtrip(as_table=True) @pytest.mark.parametrize("sink_factory", [ lambda: io.BytesIO(), lambda: pa.BufferOutputStream() ]) def test_file_read_all(sink_factory): fixture = FileFormatFixture(sink_factory) batches = fixture.write_batches() file_contents = pa.BufferReader(fixture.get_source()) reader = pa.ipc.open_file(file_contents) result = reader.read_all() expected = pa.Table.from_batches(batches) assert result.equals(expected) def test_open_file_from_buffer(file_fixture): # ARROW-2859; APIs accept the buffer protocol file_fixture.write_batches() source = file_fixture.get_source() reader1 = pa.ipc.open_file(source) reader2 = pa.ipc.open_file(pa.BufferReader(source)) reader3 = pa.RecordBatchFileReader(source) result1 = reader1.read_all() result2 = reader2.read_all() result3 = reader3.read_all() assert result1.equals(result2) assert result1.equals(result3) st1 = reader1.stats assert st1.num_messages == 6 assert st1.num_record_batches == 5 assert reader2.stats == st1 assert reader3.stats == st1 @pytest.mark.pandas def test_file_read_pandas(file_fixture): frames = [batch.to_pandas() for batch in file_fixture.write_batches()] file_contents = pa.BufferReader(file_fixture.get_source()) reader = pa.ipc.open_file(file_contents) result = reader.read_pandas() expected = pd.concat(frames).reset_index(drop=True) assert_frame_equal(result, expected) def test_file_pathlib(file_fixture, tmpdir): file_fixture.write_batches() source = file_fixture.get_source() path = tmpdir.join('file.arrow').strpath with open(path, 'wb') as f: f.write(source) t1 = pa.ipc.open_file(pathlib.Path(path)).read_all() t2 = pa.ipc.open_file(pa.OSFile(path)).read_all() assert t1.equals(t2) def test_empty_stream(): buf = io.BytesIO(b'') with pytest.raises(pa.ArrowInvalid): pa.ipc.open_stream(buf) @pytest.mark.pandas @pytest.mark.processes def test_read_year_month_nano_interval(tmpdir): """ARROW-15783: Verify to_pandas works for interval types. Interval types require static structures to be enabled. This test verifies that they are when no other library functions are invoked. """ mdn_interval_type = pa.month_day_nano_interval() schema = pa.schema([pa.field('nums', mdn_interval_type)]) path = tmpdir.join('file.arrow').strpath with pa.OSFile(path, 'wb') as sink: with pa.ipc.new_file(sink, schema) as writer: interval_array = pa.array([(1, 2, 3)], type=mdn_interval_type) batch = pa.record_batch([interval_array], schema) writer.write(batch) invoke_script('read_record_batch.py', path) @pytest.mark.pandas def test_stream_categorical_roundtrip(stream_fixture): df = pd.DataFrame({ 'one': np.random.randn(5), 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'], categories=['foo', 'bar'], ordered=True) }) batch = pa.RecordBatch.from_pandas(df) with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr: wr.write_batch(batch) table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source())) .read_all()) assert_frame_equal(table.to_pandas(), df) def test_open_stream_from_buffer(stream_fixture): # ARROW-2859 stream_fixture.write_batches() source = stream_fixture.get_source() reader1 = pa.ipc.open_stream(source) reader2 = pa.ipc.open_stream(pa.BufferReader(source)) reader3 = pa.RecordBatchStreamReader(source) result1 = reader1.read_all() result2 = reader2.read_all() result3 = reader3.read_all() assert result1.equals(result2) assert result1.equals(result3) st1 = reader1.stats assert st1.num_messages == 6 assert st1.num_record_batches == 5 assert reader2.stats == st1 assert reader3.stats == st1 assert tuple(st1) == tuple(stream_fixture.write_stats) @pytest.mark.parametrize('options', [ pa.ipc.IpcReadOptions(), pa.ipc.IpcReadOptions(use_threads=False), ]) def test_open_stream_options(stream_fixture, options): stream_fixture.write_batches() source = stream_fixture.get_source() reader = pa.ipc.open_stream(source, options=options) reader.read_all() st = reader.stats assert st.num_messages == 6 assert st.num_record_batches == 5 assert tuple(st) == tuple(stream_fixture.write_stats) def test_open_stream_with_wrong_options(stream_fixture): stream_fixture.write_batches() source = stream_fixture.get_source() with pytest.raises(TypeError): pa.ipc.open_stream(source, options=True) @pytest.mark.parametrize('options', [ pa.ipc.IpcReadOptions(), pa.ipc.IpcReadOptions(use_threads=False), ]) def test_open_file_options(file_fixture, options): file_fixture.write_batches() source = file_fixture.get_source() reader = pa.ipc.open_file(source, options=options) reader.read_all() st = reader.stats assert st.num_messages == 6 assert st.num_record_batches == 5 def test_open_file_with_wrong_options(file_fixture): file_fixture.write_batches() source = file_fixture.get_source() with pytest.raises(TypeError): pa.ipc.open_file(source, options=True) @pytest.mark.pandas def test_stream_write_dispatch(stream_fixture): # ARROW-1616 df = pd.DataFrame({ 'one': np.random.randn(5), 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'], categories=['foo', 'bar'], ordered=True) }) table = pa.Table.from_pandas(df, preserve_index=False) batch = pa.RecordBatch.from_pandas(df, preserve_index=False) with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr: wr.write(table) wr.write(batch) table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source())) .read_all()) assert_frame_equal(table.to_pandas(), pd.concat([df, df], ignore_index=True)) @pytest.mark.pandas def test_stream_write_table_batches(stream_fixture): # ARROW-504 df = pd.DataFrame({ 'one': np.random.randn(20), }) b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False) b2 = pa.RecordBatch.from_pandas(df, preserve_index=False) table = pa.Table.from_batches([b1, b2, b1]) with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr: wr.write_table(table, max_chunksize=15) batches = list(pa.ipc.open_stream(stream_fixture.get_source())) assert list(map(len, batches)) == [10, 15, 5, 10] result_table = pa.Table.from_batches(batches) assert_frame_equal(result_table.to_pandas(), pd.concat([df[:10], df, df[:10]], ignore_index=True)) @pytest.mark.parametrize('use_legacy_ipc_format', [False, True]) def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format): stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format batches = stream_fixture.write_batches() file_contents = pa.BufferReader(stream_fixture.get_source()) reader = pa.ipc.open_stream(file_contents) assert reader.schema.equals(batches[0].schema) total = 0 for i, next_batch in enumerate(reader): assert next_batch.equals(batches[i]) total += 1 assert total == len(batches) with pytest.raises(StopIteration): reader.read_next_batch() @pytest.mark.zstd def test_compression_roundtrip(): sink = io.BytesIO() values = [random.randint(0, 3) for _ in range(10000)] table = pa.Table.from_arrays([values], names=["values"]) options = pa.ipc.IpcWriteOptions(compression='zstd') with pa.ipc.RecordBatchFileWriter( sink, table.schema, options=options) as writer: writer.write_table(table) len1 = len(sink.getvalue()) sink2 = io.BytesIO() codec = pa.Codec('zstd', compression_level=5) options = pa.ipc.IpcWriteOptions(compression=codec) with pa.ipc.RecordBatchFileWriter( sink2, table.schema, options=options) as writer: writer.write_table(table) len2 = len(sink2.getvalue()) # In theory len2 should be less than len1 but for this test we just want # to ensure compression_level is being correctly passed down to the C++ # layer so we don't really care if it makes it worse or better assert len2 != len1 t1 = pa.ipc.open_file(sink).read_all() t2 = pa.ipc.open_file(sink2).read_all() assert t1 == t2 def test_write_options(): options = pa.ipc.IpcWriteOptions() assert options.allow_64bit is False assert options.use_legacy_format is False assert options.metadata_version == pa.ipc.MetadataVersion.V5 options.allow_64bit = True assert options.allow_64bit is True options.use_legacy_format = True assert options.use_legacy_format is True options.metadata_version = pa.ipc.MetadataVersion.V4 assert options.metadata_version == pa.ipc.MetadataVersion.V4 for value in ('V5', 42): with pytest.raises((TypeError, ValueError)): options.metadata_version = value assert options.compression is None for value in ['lz4', 'zstd']: if pa.Codec.is_available(value): options.compression = value assert options.compression == value options.compression = value.upper() assert options.compression == value options.compression = None assert options.compression is None with pytest.raises(TypeError): options.compression = 0 assert options.use_threads is True options.use_threads = False assert options.use_threads is False if pa.Codec.is_available('lz4'): options = pa.ipc.IpcWriteOptions( metadata_version=pa.ipc.MetadataVersion.V4, allow_64bit=True, use_legacy_format=True, compression='lz4', use_threads=False) assert options.metadata_version == pa.ipc.MetadataVersion.V4 assert options.allow_64bit is True assert options.use_legacy_format is True assert options.compression == 'lz4' assert options.use_threads is False def test_write_options_legacy_exclusive(stream_fixture): with pytest.raises( ValueError, match="provide at most one of options and use_legacy_format"): stream_fixture.use_legacy_ipc_format = True stream_fixture.options = pa.ipc.IpcWriteOptions() stream_fixture.write_batches() @pytest.mark.parametrize('options', [ pa.ipc.IpcWriteOptions(), pa.ipc.IpcWriteOptions(allow_64bit=True), pa.ipc.IpcWriteOptions(use_legacy_format=True), pa.ipc.IpcWriteOptions(metadata_version=pa.ipc.MetadataVersion.V4), pa.ipc.IpcWriteOptions(use_legacy_format=True, metadata_version=pa.ipc.MetadataVersion.V4), ]) def test_stream_options_roundtrip(stream_fixture, options): stream_fixture.use_legacy_ipc_format = None stream_fixture.options = options batches = stream_fixture.write_batches() file_contents = pa.BufferReader(stream_fixture.get_source()) message = pa.ipc.read_message(stream_fixture.get_source()) assert message.metadata_version == options.metadata_version reader = pa.ipc.open_stream(file_contents) assert reader.schema.equals(batches[0].schema) total = 0 for i, next_batch in enumerate(reader): assert next_batch.equals(batches[i]) total += 1 assert total == len(batches) with pytest.raises(StopIteration): reader.read_next_batch() def test_read_options(): options = pa.ipc.IpcReadOptions() assert options.use_threads is True assert options.ensure_native_endian is True assert options.included_fields == [] options.ensure_native_endian = False assert options.ensure_native_endian is False options.use_threads = False assert options.use_threads is False options.included_fields = [0, 1] assert options.included_fields == [0, 1] with pytest.raises(TypeError): options.included_fields = None options = pa.ipc.IpcReadOptions( use_threads=False, ensure_native_endian=False, included_fields=[1] ) assert options.use_threads is False assert options.ensure_native_endian is False assert options.included_fields == [1] def test_read_options_included_fields(stream_fixture): options1 = pa.ipc.IpcReadOptions() options2 = pa.ipc.IpcReadOptions(included_fields=[1]) table = pa.Table.from_arrays([pa.array(['foo', 'bar', 'baz', 'qux']), pa.array([1, 2, 3, 4])], names=['a', 'b']) with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr: wr.write_table(table) source = stream_fixture.get_source() reader1 = pa.ipc.open_stream(source, options=options1) reader2 = pa.ipc.open_stream( source, options=options2, memory_pool=pa.system_memory_pool()) result1 = reader1.read_all() result2 = reader2.read_all() assert result1.num_columns == 2 assert result2.num_columns == 1 expected = pa.Table.from_arrays([pa.array([1, 2, 3, 4])], names=["b"]) assert result2 == expected assert result1 == table def test_dictionary_delta(format_fixture): ty = pa.dictionary(pa.int8(), pa.utf8()) data = [["foo", "foo", None], ["foo", "bar", "foo"], # potential delta ["foo", "bar"], # nothing new ["foo", None, "bar", "quux"], # potential delta ["bar", "quux"], # replacement ] batches = [ pa.RecordBatch.from_arrays([pa.array(v, type=ty)], names=['dicts']) for v in data] batches_delta_only = batches[:4] schema = batches[0].schema def write_batches(batches, as_table=False): with format_fixture._get_writer(pa.MockOutputStream(), schema) as writer: if as_table: table = pa.Table.from_batches(batches) writer.write_table(table) else: for batch in batches: writer.write_batch(batch) return writer.stats if format_fixture.is_file: # File format cannot handle replacement with pytest.raises(pa.ArrowInvalid): write_batches(batches) # File format cannot handle delta if emit_deltas # is not provided with pytest.raises(pa.ArrowInvalid): write_batches(batches_delta_only) else: st = write_batches(batches) assert st.num_record_batches == 5 assert st.num_dictionary_batches == 4 assert st.num_replaced_dictionaries == 3 assert st.num_dictionary_deltas == 0 format_fixture.use_legacy_ipc_format = None format_fixture.options = pa.ipc.IpcWriteOptions( emit_dictionary_deltas=True) if format_fixture.is_file: # File format cannot handle replacement with pytest.raises(pa.ArrowInvalid): write_batches(batches) else: st = write_batches(batches) assert st.num_record_batches == 5 assert st.num_dictionary_batches == 4 assert st.num_replaced_dictionaries == 1 assert st.num_dictionary_deltas == 2 st = write_batches(batches_delta_only) assert st.num_record_batches == 4 assert st.num_dictionary_batches == 3 assert st.num_replaced_dictionaries == 0 assert st.num_dictionary_deltas == 2 format_fixture.options = pa.ipc.IpcWriteOptions( unify_dictionaries=True ) st = write_batches(batches, as_table=True) assert st.num_record_batches == 5 if format_fixture.is_file: assert st.num_dictionary_batches == 1 assert st.num_replaced_dictionaries == 0 assert st.num_dictionary_deltas == 0 else: assert st.num_dictionary_batches == 4 assert st.num_replaced_dictionaries == 3 assert st.num_dictionary_deltas == 0 def test_envvar_set_legacy_ipc_format(): schema = pa.schema([pa.field('foo', pa.int32())]) writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) assert not writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V5 writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) assert not writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V5 with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'): writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) assert writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V5 writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) assert writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V5 with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'): writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) assert not writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V4 writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) assert not writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V4 with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'): with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'): writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) assert writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V4 writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) assert writer._use_legacy_format assert writer._metadata_version == pa.ipc.MetadataVersion.V4 def test_stream_read_all(stream_fixture): batches = stream_fixture.write_batches() file_contents = pa.BufferReader(stream_fixture.get_source()) reader = pa.ipc.open_stream(file_contents) result = reader.read_all() expected = pa.Table.from_batches(batches) assert result.equals(expected) @pytest.mark.pandas def test_stream_read_pandas(stream_fixture): frames = [batch.to_pandas() for batch in stream_fixture.write_batches()] file_contents = stream_fixture.get_source() reader = pa.ipc.open_stream(file_contents) result = reader.read_pandas() expected = pd.concat(frames).reset_index(drop=True) assert_frame_equal(result, expected) @pytest.fixture def example_messages(stream_fixture): batches = stream_fixture.write_batches() file_contents = stream_fixture.get_source() buf_reader = pa.BufferReader(file_contents) reader = pa.MessageReader.open_stream(buf_reader) return batches, list(reader) def test_message_ctors_no_segfault(): with pytest.raises(TypeError): repr(pa.Message()) with pytest.raises(TypeError): repr(pa.MessageReader()) def test_message_reader(example_messages): _, messages = example_messages assert len(messages) == 6 assert messages[0].type == 'schema' assert isinstance(messages[0].metadata, pa.Buffer) assert isinstance(messages[0].body, pa.Buffer) assert messages[0].metadata_version == pa.MetadataVersion.V5 for msg in messages[1:]: assert msg.type == 'record batch' assert isinstance(msg.metadata, pa.Buffer) assert isinstance(msg.body, pa.Buffer) assert msg.metadata_version == pa.MetadataVersion.V5 def test_message_serialize_read_message(example_messages): _, messages = example_messages msg = messages[0] buf = msg.serialize() reader = pa.BufferReader(buf.to_pybytes() * 2) restored = pa.ipc.read_message(buf) restored2 = pa.ipc.read_message(reader) restored3 = pa.ipc.read_message(buf.to_pybytes()) restored4 = pa.ipc.read_message(reader) assert msg.equals(restored) assert msg.equals(restored2) assert msg.equals(restored3) assert msg.equals(restored4) with pytest.raises(pa.ArrowInvalid, match="Corrupted message"): pa.ipc.read_message(pa.BufferReader(b'ab')) with pytest.raises(EOFError): pa.ipc.read_message(reader) @pytest.mark.gzip def test_message_read_from_compressed(example_messages): # Part of ARROW-5910 _, messages = example_messages for message in messages: raw_out = pa.BufferOutputStream() with pa.output_stream(raw_out, compression='gzip') as compressed_out: message.serialize_to(compressed_out) compressed_buf = raw_out.getvalue() result = pa.ipc.read_message(pa.input_stream(compressed_buf, compression='gzip')) assert result.equals(message) def test_message_read_schema(example_messages): batches, messages = example_messages schema = pa.ipc.read_schema(messages[0]) assert schema.equals(batches[1].schema) def test_message_read_record_batch(example_messages): batches, messages = example_messages for batch, message in zip(batches, messages[1:]): read_batch = pa.ipc.read_record_batch(message, batch.schema) assert read_batch.equals(batch) def test_read_record_batch_on_stream_error_message(): # ARROW-5374 batch = pa.record_batch([pa.array([b"foo"], type=pa.utf8())], names=['strs']) stream = pa.BufferOutputStream() with pa.ipc.new_stream(stream, batch.schema) as writer: writer.write_batch(batch) buf = stream.getvalue() with pytest.raises(IOError, match="type record batch but got schema"): pa.ipc.read_record_batch(buf, batch.schema) # ---------------------------------------------------------------------- # Socket streaming testa class StreamReaderServer(threading.Thread): def init(self, do_read_all): self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._sock.bind(('127.0.0.1', 0)) self._sock.listen(1) host, port = self._sock.getsockname() self._do_read_all = do_read_all self._schema = None self._batches = [] self._table = None return port def run(self): connection, client_address = self._sock.accept() try: source = connection.makefile(mode='rb') reader = pa.ipc.open_stream(source) self._schema = reader.schema if self._do_read_all: self._table = reader.read_all() else: for i, batch in enumerate(reader): self._batches.append(batch) finally: connection.close() self._sock.close() def get_result(self): return (self._schema, self._table if self._do_read_all else self._batches) class SocketStreamFixture(IpcFixture): def __init__(self): # XXX(wesm): test will decide when to start socket server. This should # probably be refactored pass def start_server(self, do_read_all): self._server = StreamReaderServer() port = self._server.init(do_read_all) self._server.start() self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._sock.connect(('127.0.0.1', port)) self.sink = self.get_sink() def stop_and_get_result(self): import struct self.sink.write(struct.pack('Q', 0)) self.sink.flush() self._sock.close() self._server.join() return self._server.get_result() def get_sink(self): return self._sock.makefile(mode='wb') def _get_writer(self, sink, schema): return pa.RecordBatchStreamWriter(sink, schema) @pytest.fixture def socket_fixture(): return SocketStreamFixture() @pytest.mark.sockets def test_socket_simple_roundtrip(socket_fixture): socket_fixture.start_server(do_read_all=False) writer_batches = socket_fixture.write_batches() reader_schema, reader_batches = socket_fixture.stop_and_get_result() assert reader_schema.equals(writer_batches[0].schema) assert len(reader_batches) == len(writer_batches) for i, batch in enumerate(writer_batches): assert reader_batches[i].equals(batch) @pytest.mark.sockets def test_socket_read_all(socket_fixture): socket_fixture.start_server(do_read_all=True) writer_batches = socket_fixture.write_batches() _, result = socket_fixture.stop_and_get_result() expected = pa.Table.from_batches(writer_batches) assert result.equals(expected) # ---------------------------------------------------------------------- # Miscellaneous IPC tests @pytest.mark.pandas def test_ipc_file_stream_has_eos(): # ARROW-5395 df = pd.DataFrame({'foo': [1.5]}) batch = pa.RecordBatch.from_pandas(df) sink = pa.BufferOutputStream() write_file(batch, sink) buffer = sink.getvalue() # skip the file magic reader = pa.ipc.open_stream(buffer[8:]) # will fail if encounters footer data instead of eos rdf = reader.read_pandas() assert_frame_equal(df, rdf) @pytest.mark.pandas def test_ipc_zero_copy_numpy(): df = pd.DataFrame({'foo': [1.5]}) batch = pa.RecordBatch.from_pandas(df) sink = pa.BufferOutputStream() write_file(batch, sink) buffer = sink.getvalue() reader = pa.BufferReader(buffer) batches = read_file(reader) data = batches[0].to_pandas() rdf = pd.DataFrame(data) assert_frame_equal(df, rdf) @pytest.mark.pandas @pytest.mark.parametrize("ipc_type", ["stream", "file"]) def test_batches_with_custom_metadata_roundtrip(ipc_type): df = pd.DataFrame({'foo': [1.5]}) batch = pa.RecordBatch.from_pandas(df) sink = pa.BufferOutputStream() batch_count = 2 file_factory = {"stream": pa.ipc.new_stream, "file": pa.ipc.new_file}[ipc_type] with file_factory(sink, batch.schema) as writer: for i in range(batch_count): writer.write_batch(batch, custom_metadata={"batch_id": str(i)}) # write a batch without custom metadata writer.write_batch(batch) buffer = sink.getvalue() if ipc_type == "stream": with pa.ipc.open_stream(buffer) as reader: batch_with_metas = list(reader.iter_batches_with_custom_metadata()) else: with pa.ipc.open_file(buffer) as reader: batch_with_metas = [reader.get_batch_with_custom_metadata(i) for i in range(reader.num_record_batches)] for i in range(batch_count): assert batch_with_metas[i].batch.num_rows == 1 assert isinstance( batch_with_metas[i].custom_metadata, pa.KeyValueMetadata) assert batch_with_metas[i].custom_metadata == {"batch_id": str(i)} # the last batch has no custom metadata assert batch_with_metas[batch_count].batch.num_rows == 1 assert batch_with_metas[batch_count].custom_metadata is None def test_ipc_stream_no_batches(): # ARROW-2307 table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]), pa.array(['foo', 'bar', 'baz', 'qux'])], names=['a', 'b']) sink = pa.BufferOutputStream() with pa.ipc.new_stream(sink, table.schema): pass source = sink.getvalue() with pa.ipc.open_stream(source) as reader: result = reader.read_all() assert result.schema.equals(table.schema) assert len(result) == 0 @pytest.mark.pandas def test_get_record_batch_size(): N = 10 itemsize = 8 df = pd.DataFrame({'foo': np.random.randn(N)}) batch = pa.RecordBatch.from_pandas(df) assert pa.ipc.get_record_batch_size(batch) > (N * itemsize) @pytest.mark.pandas def _check_serialize_pandas_round_trip(df, use_threads=False): buf = pa.serialize_pandas(df, nthreads=2 if use_threads else 1) result = pa.deserialize_pandas(buf, use_threads=use_threads) assert_frame_equal(result, df) @pytest.mark.pandas def test_pandas_serialize_round_trip(): index = pd.Index([1, 2, 3], name='my_index') columns = ['foo', 'bar'] df = pd.DataFrame( {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, index=index, columns=columns ) _check_serialize_pandas_round_trip(df) @pytest.mark.pandas def test_pandas_serialize_round_trip_nthreads(): index = pd.Index([1, 2, 3], name='my_index') columns = ['foo', 'bar'] df = pd.DataFrame( {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, index=index, columns=columns ) _check_serialize_pandas_round_trip(df, use_threads=True) @pytest.mark.pandas def test_pandas_serialize_round_trip_multi_index(): index1 = pd.Index([1, 2, 3], name='level_1') index2 = pd.Index(list('def'), name=None) index = pd.MultiIndex.from_arrays([index1, index2]) columns = ['foo', 'bar'] df = pd.DataFrame( {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, index=index, columns=columns, ) _check_serialize_pandas_round_trip(df) @pytest.mark.pandas def test_serialize_pandas_empty_dataframe(): df = pd.DataFrame() _check_serialize_pandas_round_trip(df) @pytest.mark.pandas def test_pandas_serialize_round_trip_not_string_columns(): df = pd.DataFrame(list(zip([1.5, 1.6, 1.7], 'abc'))) buf = pa.serialize_pandas(df) result = pa.deserialize_pandas(buf) assert_frame_equal(result, df) @pytest.mark.pandas def test_serialize_pandas_no_preserve_index(): df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3]) expected = pd.DataFrame({'a': [1, 2, 3]}) buf = pa.serialize_pandas(df, preserve_index=False) result = pa.deserialize_pandas(buf) assert_frame_equal(result, expected) buf = pa.serialize_pandas(df, preserve_index=True) result = pa.deserialize_pandas(buf) assert_frame_equal(result, df) @pytest.mark.pandas def test_schema_batch_serialize_methods(): nrows = 5 df = pd.DataFrame({ 'one': np.random.randn(nrows), 'two': ['foo', np.nan, 'bar', 'bazbaz', 'qux']}) batch = pa.RecordBatch.from_pandas(df) s_schema = batch.schema.serialize() s_batch = batch.serialize() recons_schema = pa.ipc.read_schema(s_schema) recons_batch = pa.ipc.read_record_batch(s_batch, recons_schema) assert recons_batch.equals(batch) def test_schema_serialization_with_metadata(): field_metadata = {b'foo': b'bar', b'kind': b'field'} schema_metadata = {b'foo': b'bar', b'kind': b'schema'} f0 = pa.field('a', pa.int8()) f1 = pa.field('b', pa.string(), metadata=field_metadata) schema = pa.schema([f0, f1], metadata=schema_metadata) s_schema = schema.serialize() recons_schema = pa.ipc.read_schema(s_schema) assert recons_schema.equals(schema) assert recons_schema.metadata == schema_metadata assert recons_schema[0].metadata is None assert recons_schema[1].metadata == field_metadata def write_file(batch, sink): with pa.ipc.new_file(sink, batch.schema) as writer: writer.write_batch(batch) def read_file(source): with pa.ipc.open_file(source) as reader: return [reader.get_batch(i) for i in range(reader.num_record_batches)] def test_write_empty_ipc_file(): # ARROW-3894: IPC file was not being properly initialized when no record # batches are being written schema = pa.schema([('field', pa.int64())]) sink = pa.BufferOutputStream() with pa.ipc.new_file(sink, schema): pass buf = sink.getvalue() with pa.RecordBatchFileReader(pa.BufferReader(buf)) as reader: table = reader.read_all() assert len(table) == 0 assert table.schema.equals(schema) def test_py_record_batch_reader(): def make_schema(): return pa.schema([('field', pa.int64())]) def make_batches(): schema = make_schema() batch1 = pa.record_batch([[1, 2, 3]], schema=schema) batch2 = pa.record_batch([[4, 5]], schema=schema) return [batch1, batch2] # With iterable batches = UserList(make_batches()) # weakrefable wr = weakref.ref(batches) with pa.RecordBatchReader.from_batches(make_schema(), batches) as reader: batches = None assert wr() is not None assert list(reader) == make_batches() assert wr() is None # With iterator batches = iter(UserList(make_batches())) # weakrefable wr = weakref.ref(batches) with pa.RecordBatchReader.from_batches(make_schema(), batches) as reader: batches = None assert wr() is not None assert list(reader) == make_batches() assert wr() is None # ensure we get proper error when not passing a schema # (https://issues.apache.org/jira/browse/ARROW-18229) batches = make_batches() with pytest.raises(TypeError): reader = pa.RecordBatchReader.from_batches( [('field', pa.int64())], batches) pass with pytest.raises(TypeError): reader = pa.RecordBatchReader.from_batches(None, batches) pass def test_record_batch_reader_from_arrow_stream(): class StreamWrapper: def __init__(self, batches): self.batches = batches def __arrow_c_stream__(self, requested_schema=None): reader = pa.RecordBatchReader.from_batches( self.batches[0].schema, self.batches) return reader.__arrow_c_stream__(requested_schema) data = [ pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']), pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']) ] wrapper = StreamWrapper(data) # Can roundtrip a pyarrow stream-like object expected = pa.Table.from_batches(data) reader = pa.RecordBatchReader.from_stream(expected) assert reader.read_all() == expected # Can roundtrip through the wrapper. reader = pa.RecordBatchReader.from_stream(wrapper) assert reader.read_all() == expected # Passing schema works if already that schema reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema) assert reader.read_all() == expected # Passing a different but castable schema works good_schema = pa.schema([pa.field("a", pa.int32())]) reader = pa.RecordBatchReader.from_stream(wrapper, schema=good_schema) assert reader.read_all() == expected.cast(good_schema) # If schema doesn't match, raises TypeError with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'): pa.RecordBatchReader.from_stream( wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))]) ) # Proper type errors for wrong input with pytest.raises(TypeError): pa.RecordBatchReader.from_stream(data[0]['a']) with pytest.raises(TypeError): pa.RecordBatchReader.from_stream(expected, schema=data[0]) def test_record_batch_reader_cast(): schema_src = pa.schema([pa.field('a', pa.int64())]) data = [ pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']), pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']), ] table_src = pa.Table.from_batches(data) # Cast to same type should always work reader = pa.RecordBatchReader.from_batches(schema_src, data) assert reader.cast(schema_src).read_all() == table_src # Check non-trivial cast schema_dst = pa.schema([pa.field('a', pa.int32())]) reader = pa.RecordBatchReader.from_batches(schema_src, data) assert reader.cast(schema_dst).read_all() == table_src.cast(schema_dst) # Check error for field name/length mismatch reader = pa.RecordBatchReader.from_batches(schema_src, data) with pytest.raises(ValueError, match="Target schema's field names"): reader.cast(pa.schema([])) # Check error for impossible cast in call to .cast() reader = pa.RecordBatchReader.from_batches(schema_src, data) with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'): reader.cast(pa.schema([pa.field('a', pa.list_(pa.int32()))])) # Cast to same type should always work (also for types without a T->T cast function) # (https://github.com/apache/arrow/issues/41884) schema_src = pa.schema([pa.field('a', pa.date32())]) arr = pa.array([datetime.date(2024, 6, 11)], type=pa.date32()) data = [pa.record_batch([arr], names=['a']), pa.record_batch([arr], names=['a'])] table_src = pa.Table.from_batches(data) reader = pa.RecordBatchReader.from_batches(schema_src, data) assert reader.cast(schema_src).read_all() == table_src def test_record_batch_reader_cast_nulls(): schema_src = pa.schema([pa.field('a', pa.int64())]) data_with_nulls = [ pa.record_batch([pa.array([1, 2, None], type=pa.int64())], names=['a']), ] data_without_nulls = [ pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']), ] table_with_nulls = pa.Table.from_batches(data_with_nulls) table_without_nulls = pa.Table.from_batches(data_without_nulls) # Cast to nullable destination should work reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls) schema_dst = pa.schema([pa.field('a', pa.int32())]) assert reader.cast(schema_dst).read_all() == table_with_nulls.cast(schema_dst) # Cast to non-nullable destination should work if there are no nulls reader = pa.RecordBatchReader.from_batches(schema_src, data_without_nulls) schema_dst = pa.schema([pa.field('a', pa.int32(), nullable=False)]) assert reader.cast(schema_dst).read_all() == table_without_nulls.cast(schema_dst) # Cast to non-nullable destination should error if there are nulls # when the batch is pulled reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls) casted_reader = reader.cast(schema_dst) with pytest.raises(pa.lib.ArrowInvalid, match="Can't cast array"): casted_reader.read_all()