jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import os
from typing import BinaryIO, Optional, Union
import fsspec
import pyarrow.parquet as pq
from .. import Dataset, Features, NamedSplit, config
from ..arrow_writer import get_writer_batch_size
from ..formatting import query_table
from ..packaged_modules import _PACKAGED_DATASETS_MODULES
from ..packaged_modules.parquet.parquet import Parquet
from ..utils import tqdm as hf_tqdm
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader
class ParquetDatasetReader(AbstractDatasetReader):
def __init__(
self,
path_or_paths: NestedDataStructureLike[PathLike],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
num_proc: Optional[int] = None,
**kwargs,
):
super().__init__(
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
num_proc=num_proc,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
hash = _PACKAGED_DATASETS_MODULES["parquet"][1]
self.builder = Parquet(
cache_dir=cache_dir,
data_files=path_or_paths,
features=features,
hash=hash,
**kwargs,
)
def read(self):
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
verification_mode = None
base_path = None
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
verification_mode=verification_mode,
base_path=base_path,
num_proc=self.num_proc,
)
dataset = self.builder.as_dataset(
split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
)
return dataset
class ParquetDatasetWriter:
def __init__(
self,
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
storage_options: Optional[dict] = None,
**parquet_writer_kwargs,
):
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size or get_writer_batch_size(dataset.features)
self.storage_options = storage_options or {}
self.parquet_writer_kwargs = parquet_writer_kwargs
def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with fsspec.open(self.path_or_buf, "wb", **(self.storage_options or {})) as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.parquet_writer_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.parquet_writer_kwargs)
return written
def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -> int:
"""Writes the pyarrow table as Parquet to a binary file handle.
Caller is responsible for opening and closing the handle.
"""
written = 0
_ = parquet_writer_kwargs.pop("path_or_buf", None)
schema = self.dataset.features.arrow_schema
writer = pq.ParquetWriter(file_obj, schema=schema, **parquet_writer_kwargs)
for offset in hf_tqdm(
range(0, len(self.dataset), batch_size),
unit="ba",
desc="Creating parquet from Arrow format",
):
batch = query_table(
table=self.dataset._data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices,
)
writer.write_table(batch)
written += batch.nbytes
writer.close()
return written