File size: 5,265 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import multiprocessing
import os
from typing import BinaryIO, Optional, Union
import fsspec
from .. import Dataset, Features, NamedSplit, config
from ..formatting import query_table
from ..packaged_modules.csv.csv import Csv
from ..utils import tqdm as hf_tqdm
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader
class CsvDatasetReader(AbstractDatasetReader):
def __init__(
self,
path_or_paths: NestedDataStructureLike[PathLike],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
num_proc: Optional[int] = None,
**kwargs,
):
super().__init__(
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
num_proc=num_proc,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
self.builder = Csv(
cache_dir=cache_dir,
data_files=path_or_paths,
features=features,
**kwargs,
)
def read(self):
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
verification_mode = None
base_path = None
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
verification_mode=verification_mode,
base_path=base_path,
num_proc=self.num_proc,
)
dataset = self.builder.as_dataset(
split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
)
return dataset
class CsvDatasetWriter:
def __init__(
self,
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
storage_options: Optional[dict] = None,
**to_csv_kwargs,
):
if num_proc is not None and num_proc <= 0:
raise ValueError(f"num_proc {num_proc} must be an integer > 0.")
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.encoding = "utf-8"
self.storage_options = storage_options or {}
self.to_csv_kwargs = to_csv_kwargs
def write(self) -> int:
_ = self.to_csv_kwargs.pop("path_or_buf", None)
header = self.to_csv_kwargs.pop("header", True)
index = self.to_csv_kwargs.pop("index", False)
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with fsspec.open(self.path_or_buf, "wb", **(self.storage_options or {})) as buffer:
written = self._write(file_obj=buffer, header=header, index=index, **self.to_csv_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, header=header, index=index, **self.to_csv_kwargs)
return written
def _batch_csv(self, args):
offset, header, index, to_csv_kwargs = args
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
csv_str = batch.to_pandas().to_csv(
path_or_buf=None, header=header if (offset == 0) else False, index=index, **to_csv_kwargs
)
return csv_str.encode(self.encoding)
def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int:
"""Writes the pyarrow table as CSV to a binary file handle.
Caller is responsible for opening and closing the handle.
"""
written = 0
if self.num_proc is None or self.num_proc == 1:
for offset in hf_tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
desc="Creating CSV from Arrow format",
):
csv_str = self._batch_csv((offset, header, index, to_csv_kwargs))
written += file_obj.write(csv_str)
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for csv_str in hf_tqdm(
pool.imap(
self._batch_csv,
[(offset, header, index, to_csv_kwargs) for offset in range(0, num_rows, batch_size)],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
desc="Creating CSV from Arrow format",
):
written += file_obj.write(csv_str)
return written
|