|
import os |
|
from typing import BinaryIO, Optional, Union |
|
|
|
import fsspec |
|
import pyarrow.parquet as pq |
|
|
|
from .. import Dataset, Features, NamedSplit, config |
|
from ..arrow_writer import get_writer_batch_size |
|
from ..formatting import query_table |
|
from ..packaged_modules import _PACKAGED_DATASETS_MODULES |
|
from ..packaged_modules.parquet.parquet import Parquet |
|
from ..utils import tqdm as hf_tqdm |
|
from ..utils.typing import NestedDataStructureLike, PathLike |
|
from .abc import AbstractDatasetReader |
|
|
|
|
|
class ParquetDatasetReader(AbstractDatasetReader): |
|
def __init__( |
|
self, |
|
path_or_paths: NestedDataStructureLike[PathLike], |
|
split: Optional[NamedSplit] = None, |
|
features: Optional[Features] = None, |
|
cache_dir: str = None, |
|
keep_in_memory: bool = False, |
|
streaming: bool = False, |
|
num_proc: Optional[int] = None, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
path_or_paths, |
|
split=split, |
|
features=features, |
|
cache_dir=cache_dir, |
|
keep_in_memory=keep_in_memory, |
|
streaming=streaming, |
|
num_proc=num_proc, |
|
**kwargs, |
|
) |
|
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} |
|
hash = _PACKAGED_DATASETS_MODULES["parquet"][1] |
|
self.builder = Parquet( |
|
cache_dir=cache_dir, |
|
data_files=path_or_paths, |
|
features=features, |
|
hash=hash, |
|
**kwargs, |
|
) |
|
|
|
def read(self): |
|
|
|
if self.streaming: |
|
dataset = self.builder.as_streaming_dataset(split=self.split) |
|
|
|
else: |
|
download_config = None |
|
download_mode = None |
|
verification_mode = None |
|
base_path = None |
|
|
|
self.builder.download_and_prepare( |
|
download_config=download_config, |
|
download_mode=download_mode, |
|
verification_mode=verification_mode, |
|
base_path=base_path, |
|
num_proc=self.num_proc, |
|
) |
|
dataset = self.builder.as_dataset( |
|
split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory |
|
) |
|
return dataset |
|
|
|
|
|
class ParquetDatasetWriter: |
|
def __init__( |
|
self, |
|
dataset: Dataset, |
|
path_or_buf: Union[PathLike, BinaryIO], |
|
batch_size: Optional[int] = None, |
|
storage_options: Optional[dict] = None, |
|
**parquet_writer_kwargs, |
|
): |
|
self.dataset = dataset |
|
self.path_or_buf = path_or_buf |
|
self.batch_size = batch_size or get_writer_batch_size(dataset.features) |
|
self.storage_options = storage_options or {} |
|
self.parquet_writer_kwargs = parquet_writer_kwargs |
|
|
|
def write(self) -> int: |
|
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE |
|
|
|
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): |
|
with fsspec.open(self.path_or_buf, "wb", **(self.storage_options or {})) as buffer: |
|
written = self._write(file_obj=buffer, batch_size=batch_size, **self.parquet_writer_kwargs) |
|
else: |
|
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.parquet_writer_kwargs) |
|
return written |
|
|
|
def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -> int: |
|
"""Writes the pyarrow table as Parquet to a binary file handle. |
|
|
|
Caller is responsible for opening and closing the handle. |
|
""" |
|
written = 0 |
|
_ = parquet_writer_kwargs.pop("path_or_buf", None) |
|
schema = self.dataset.features.arrow_schema |
|
|
|
writer = pq.ParquetWriter(file_obj, schema=schema, **parquet_writer_kwargs) |
|
|
|
for offset in hf_tqdm( |
|
range(0, len(self.dataset), batch_size), |
|
unit="ba", |
|
desc="Creating parquet from Arrow format", |
|
): |
|
batch = query_table( |
|
table=self.dataset._data, |
|
key=slice(offset, offset + batch_size), |
|
indices=self.dataset._indices, |
|
) |
|
writer.write_table(batch) |
|
written += batch.nbytes |
|
writer.close() |
|
return written |
|
|