|
import multiprocessing |
|
import os |
|
from typing import BinaryIO, Optional, Union |
|
|
|
import fsspec |
|
|
|
from .. import Dataset, Features, NamedSplit, config |
|
from ..formatting import query_table |
|
from ..packaged_modules.json.json import Json |
|
from ..utils import tqdm as hf_tqdm |
|
from ..utils.typing import NestedDataStructureLike, PathLike |
|
from .abc import AbstractDatasetReader |
|
|
|
|
|
class JsonDatasetReader(AbstractDatasetReader): |
|
def __init__( |
|
self, |
|
path_or_paths: NestedDataStructureLike[PathLike], |
|
split: Optional[NamedSplit] = None, |
|
features: Optional[Features] = None, |
|
cache_dir: str = None, |
|
keep_in_memory: bool = False, |
|
streaming: bool = False, |
|
field: Optional[str] = None, |
|
num_proc: Optional[int] = None, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
path_or_paths, |
|
split=split, |
|
features=features, |
|
cache_dir=cache_dir, |
|
keep_in_memory=keep_in_memory, |
|
streaming=streaming, |
|
num_proc=num_proc, |
|
**kwargs, |
|
) |
|
self.field = field |
|
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} |
|
self.builder = Json( |
|
cache_dir=cache_dir, |
|
data_files=path_or_paths, |
|
features=features, |
|
field=field, |
|
**kwargs, |
|
) |
|
|
|
def read(self): |
|
|
|
if self.streaming: |
|
dataset = self.builder.as_streaming_dataset(split=self.split) |
|
|
|
else: |
|
download_config = None |
|
download_mode = None |
|
verification_mode = None |
|
base_path = None |
|
|
|
self.builder.download_and_prepare( |
|
download_config=download_config, |
|
download_mode=download_mode, |
|
verification_mode=verification_mode, |
|
base_path=base_path, |
|
num_proc=self.num_proc, |
|
) |
|
dataset = self.builder.as_dataset( |
|
split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory |
|
) |
|
return dataset |
|
|
|
|
|
class JsonDatasetWriter: |
|
def __init__( |
|
self, |
|
dataset: Dataset, |
|
path_or_buf: Union[PathLike, BinaryIO], |
|
batch_size: Optional[int] = None, |
|
num_proc: Optional[int] = None, |
|
storage_options: Optional[dict] = None, |
|
**to_json_kwargs, |
|
): |
|
if num_proc is not None and num_proc <= 0: |
|
raise ValueError(f"num_proc {num_proc} must be an integer > 0.") |
|
|
|
self.dataset = dataset |
|
self.path_or_buf = path_or_buf |
|
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE |
|
self.num_proc = num_proc |
|
self.encoding = "utf-8" |
|
self.storage_options = storage_options or {} |
|
self.to_json_kwargs = to_json_kwargs |
|
|
|
def write(self) -> int: |
|
_ = self.to_json_kwargs.pop("path_or_buf", None) |
|
orient = self.to_json_kwargs.pop("orient", "records") |
|
lines = self.to_json_kwargs.pop("lines", True if orient == "records" else False) |
|
if "index" not in self.to_json_kwargs and orient in ["split", "table"]: |
|
self.to_json_kwargs["index"] = False |
|
|
|
|
|
default_compression = "infer" if isinstance(self.path_or_buf, (str, bytes, os.PathLike)) else None |
|
compression = self.to_json_kwargs.pop("compression", default_compression) |
|
|
|
if compression not in [None, "infer", "gzip", "bz2", "xz"]: |
|
raise NotImplementedError(f"`datasets` currently does not support {compression} compression") |
|
|
|
if not lines and self.batch_size < self.dataset.num_rows: |
|
raise NotImplementedError( |
|
"Output JSON will not be formatted correctly when lines = False and batch_size < number of rows in the dataset. Use pandas.DataFrame.to_json() instead." |
|
) |
|
|
|
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): |
|
with fsspec.open( |
|
self.path_or_buf, "wb", compression=compression, **(self.storage_options or {}) |
|
) as buffer: |
|
written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs) |
|
else: |
|
if compression: |
|
raise NotImplementedError( |
|
f"The compression parameter is not supported when writing to a buffer, but compression={compression}" |
|
" was passed. Please provide a local path instead." |
|
) |
|
written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs) |
|
return written |
|
|
|
def _batch_json(self, args): |
|
offset, orient, lines, to_json_kwargs = args |
|
|
|
batch = query_table( |
|
table=self.dataset.data, |
|
key=slice(offset, offset + self.batch_size), |
|
indices=self.dataset._indices, |
|
) |
|
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs) |
|
if not json_str.endswith("\n"): |
|
json_str += "\n" |
|
return json_str.encode(self.encoding) |
|
|
|
def _write( |
|
self, |
|
file_obj: BinaryIO, |
|
orient, |
|
lines, |
|
**to_json_kwargs, |
|
) -> int: |
|
"""Writes the pyarrow table as JSON lines to a binary file handle. |
|
|
|
Caller is responsible for opening and closing the handle. |
|
""" |
|
written = 0 |
|
|
|
if self.num_proc is None or self.num_proc == 1: |
|
for offset in hf_tqdm( |
|
range(0, len(self.dataset), self.batch_size), |
|
unit="ba", |
|
desc="Creating json from Arrow format", |
|
): |
|
json_str = self._batch_json((offset, orient, lines, to_json_kwargs)) |
|
written += file_obj.write(json_str) |
|
else: |
|
num_rows, batch_size = len(self.dataset), self.batch_size |
|
with multiprocessing.Pool(self.num_proc) as pool: |
|
for json_str in hf_tqdm( |
|
pool.imap( |
|
self._batch_json, |
|
[(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)], |
|
), |
|
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, |
|
unit="ba", |
|
desc="Creating json from Arrow format", |
|
): |
|
written += file_obj.write(json_str) |
|
|
|
return written |
|
|