|
import itertools |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import pandas as pd |
|
import pyarrow as pa |
|
|
|
import datasets |
|
import datasets.config |
|
from datasets.features.features import require_storage_cast |
|
from datasets.table import table_cast |
|
from datasets.utils.py_utils import Literal |
|
|
|
|
|
logger = datasets.utils.logging.get_logger(__name__) |
|
|
|
_PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS = ["names", "prefix"] |
|
_PANDAS_READ_CSV_DEPRECATED_PARAMETERS = ["warn_bad_lines", "error_bad_lines", "mangle_dupe_cols"] |
|
_PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS = ["encoding_errors", "on_bad_lines"] |
|
_PANDAS_READ_CSV_NEW_2_0_0_PARAMETERS = ["date_format"] |
|
_PANDAS_READ_CSV_DEPRECATED_2_2_0_PARAMETERS = ["verbose"] |
|
|
|
|
|
@dataclass |
|
class CsvConfig(datasets.BuilderConfig): |
|
"""BuilderConfig for CSV.""" |
|
|
|
sep: str = "," |
|
delimiter: Optional[str] = None |
|
header: Optional[Union[int, list[int], str]] = "infer" |
|
names: Optional[list[str]] = None |
|
column_names: Optional[list[str]] = None |
|
index_col: Optional[Union[int, str, list[int], list[str]]] = None |
|
usecols: Optional[Union[list[int], list[str]]] = None |
|
prefix: Optional[str] = None |
|
mangle_dupe_cols: bool = True |
|
engine: Optional[Literal["c", "python", "pyarrow"]] = None |
|
converters: dict[Union[int, str], Callable[[Any], Any]] = None |
|
true_values: Optional[list] = None |
|
false_values: Optional[list] = None |
|
skipinitialspace: bool = False |
|
skiprows: Optional[Union[int, list[int]]] = None |
|
nrows: Optional[int] = None |
|
na_values: Optional[Union[str, list[str]]] = None |
|
keep_default_na: bool = True |
|
na_filter: bool = True |
|
verbose: bool = False |
|
skip_blank_lines: bool = True |
|
thousands: Optional[str] = None |
|
decimal: str = "." |
|
lineterminator: Optional[str] = None |
|
quotechar: str = '"' |
|
quoting: int = 0 |
|
escapechar: Optional[str] = None |
|
comment: Optional[str] = None |
|
encoding: Optional[str] = None |
|
dialect: Optional[str] = None |
|
error_bad_lines: bool = True |
|
warn_bad_lines: bool = True |
|
skipfooter: int = 0 |
|
doublequote: bool = True |
|
memory_map: bool = False |
|
float_precision: Optional[str] = None |
|
chunksize: int = 10_000 |
|
features: Optional[datasets.Features] = None |
|
encoding_errors: Optional[str] = "strict" |
|
on_bad_lines: Literal["error", "warn", "skip"] = "error" |
|
date_format: Optional[str] = None |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
if self.delimiter is not None: |
|
self.sep = self.delimiter |
|
if self.column_names is not None: |
|
self.names = self.column_names |
|
|
|
@property |
|
def pd_read_csv_kwargs(self): |
|
pd_read_csv_kwargs = { |
|
"sep": self.sep, |
|
"header": self.header, |
|
"names": self.names, |
|
"index_col": self.index_col, |
|
"usecols": self.usecols, |
|
"prefix": self.prefix, |
|
"mangle_dupe_cols": self.mangle_dupe_cols, |
|
"engine": self.engine, |
|
"converters": self.converters, |
|
"true_values": self.true_values, |
|
"false_values": self.false_values, |
|
"skipinitialspace": self.skipinitialspace, |
|
"skiprows": self.skiprows, |
|
"nrows": self.nrows, |
|
"na_values": self.na_values, |
|
"keep_default_na": self.keep_default_na, |
|
"na_filter": self.na_filter, |
|
"verbose": self.verbose, |
|
"skip_blank_lines": self.skip_blank_lines, |
|
"thousands": self.thousands, |
|
"decimal": self.decimal, |
|
"lineterminator": self.lineterminator, |
|
"quotechar": self.quotechar, |
|
"quoting": self.quoting, |
|
"escapechar": self.escapechar, |
|
"comment": self.comment, |
|
"encoding": self.encoding, |
|
"dialect": self.dialect, |
|
"error_bad_lines": self.error_bad_lines, |
|
"warn_bad_lines": self.warn_bad_lines, |
|
"skipfooter": self.skipfooter, |
|
"doublequote": self.doublequote, |
|
"memory_map": self.memory_map, |
|
"float_precision": self.float_precision, |
|
"chunksize": self.chunksize, |
|
"encoding_errors": self.encoding_errors, |
|
"on_bad_lines": self.on_bad_lines, |
|
"date_format": self.date_format, |
|
} |
|
|
|
|
|
|
|
for pd_read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS: |
|
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter): |
|
del pd_read_csv_kwargs[pd_read_csv_parameter] |
|
|
|
|
|
if not (datasets.config.PANDAS_VERSION.major >= 1 and datasets.config.PANDAS_VERSION.minor >= 3): |
|
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS: |
|
del pd_read_csv_kwargs[pd_read_csv_parameter] |
|
|
|
|
|
if not (datasets.config.PANDAS_VERSION.major >= 2): |
|
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_2_0_0_PARAMETERS: |
|
del pd_read_csv_kwargs[pd_read_csv_parameter] |
|
|
|
|
|
if datasets.config.PANDAS_VERSION.release >= (2, 2): |
|
for pd_read_csv_parameter in _PANDAS_READ_CSV_DEPRECATED_2_2_0_PARAMETERS: |
|
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter): |
|
del pd_read_csv_kwargs[pd_read_csv_parameter] |
|
|
|
return pd_read_csv_kwargs |
|
|
|
|
|
class Csv(datasets.ArrowBasedBuilder): |
|
BUILDER_CONFIG_CLASS = CsvConfig |
|
|
|
def _info(self): |
|
return datasets.DatasetInfo(features=self.config.features) |
|
|
|
def _split_generators(self, dl_manager): |
|
"""We handle string, list and dicts in datafiles""" |
|
if not self.config.data_files: |
|
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") |
|
dl_manager.download_config.extract_on_the_fly = True |
|
data_files = dl_manager.download_and_extract(self.config.data_files) |
|
splits = [] |
|
for split_name, files in data_files.items(): |
|
if isinstance(files, str): |
|
files = [files] |
|
files = [dl_manager.iter_files(file) for file in files] |
|
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) |
|
return splits |
|
|
|
def _cast_table(self, pa_table: pa.Table) -> pa.Table: |
|
if self.config.features is not None: |
|
schema = self.config.features.arrow_schema |
|
if all(not require_storage_cast(feature) for feature in self.config.features.values()): |
|
|
|
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) |
|
else: |
|
|
|
pa_table = table_cast(pa_table, schema) |
|
return pa_table |
|
|
|
def _generate_tables(self, files): |
|
schema = self.config.features.arrow_schema if self.config.features else None |
|
|
|
dtype = ( |
|
{ |
|
name: dtype.to_pandas_dtype() if not require_storage_cast(feature) else object |
|
for name, dtype, feature in zip(schema.names, schema.types, self.config.features.values()) |
|
} |
|
if schema is not None |
|
else None |
|
) |
|
for file_idx, file in enumerate(itertools.chain.from_iterable(files)): |
|
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs) |
|
try: |
|
for batch_idx, df in enumerate(csv_file_reader): |
|
pa_table = pa.Table.from_pandas(df) |
|
|
|
|
|
|
|
yield (file_idx, batch_idx), self._cast_table(pa_table) |
|
except ValueError as e: |
|
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") |
|
raise |
|
|