File size: 8,568 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import itertools
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import pandas as pd
import pyarrow as pa
import datasets
import datasets.config
from datasets.features.features import require_storage_cast
from datasets.table import table_cast
from datasets.utils.py_utils import Literal
logger = datasets.utils.logging.get_logger(__name__)
_PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS = ["names", "prefix"]
_PANDAS_READ_CSV_DEPRECATED_PARAMETERS = ["warn_bad_lines", "error_bad_lines", "mangle_dupe_cols"]
_PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS = ["encoding_errors", "on_bad_lines"]
_PANDAS_READ_CSV_NEW_2_0_0_PARAMETERS = ["date_format"]
_PANDAS_READ_CSV_DEPRECATED_2_2_0_PARAMETERS = ["verbose"]
@dataclass
class CsvConfig(datasets.BuilderConfig):
"""BuilderConfig for CSV."""
sep: str = ","
delimiter: Optional[str] = None
header: Optional[Union[int, list[int], str]] = "infer"
names: Optional[list[str]] = None
column_names: Optional[list[str]] = None
index_col: Optional[Union[int, str, list[int], list[str]]] = None
usecols: Optional[Union[list[int], list[str]]] = None
prefix: Optional[str] = None
mangle_dupe_cols: bool = True
engine: Optional[Literal["c", "python", "pyarrow"]] = None
converters: dict[Union[int, str], Callable[[Any], Any]] = None
true_values: Optional[list] = None
false_values: Optional[list] = None
skipinitialspace: bool = False
skiprows: Optional[Union[int, list[int]]] = None
nrows: Optional[int] = None
na_values: Optional[Union[str, list[str]]] = None
keep_default_na: bool = True
na_filter: bool = True
verbose: bool = False
skip_blank_lines: bool = True
thousands: Optional[str] = None
decimal: str = "."
lineterminator: Optional[str] = None
quotechar: str = '"'
quoting: int = 0
escapechar: Optional[str] = None
comment: Optional[str] = None
encoding: Optional[str] = None
dialect: Optional[str] = None
error_bad_lines: bool = True
warn_bad_lines: bool = True
skipfooter: int = 0
doublequote: bool = True
memory_map: bool = False
float_precision: Optional[str] = None
chunksize: int = 10_000
features: Optional[datasets.Features] = None
encoding_errors: Optional[str] = "strict"
on_bad_lines: Literal["error", "warn", "skip"] = "error"
date_format: Optional[str] = None
def __post_init__(self):
super().__post_init__()
if self.delimiter is not None:
self.sep = self.delimiter
if self.column_names is not None:
self.names = self.column_names
@property
def pd_read_csv_kwargs(self):
pd_read_csv_kwargs = {
"sep": self.sep,
"header": self.header,
"names": self.names,
"index_col": self.index_col,
"usecols": self.usecols,
"prefix": self.prefix,
"mangle_dupe_cols": self.mangle_dupe_cols,
"engine": self.engine,
"converters": self.converters,
"true_values": self.true_values,
"false_values": self.false_values,
"skipinitialspace": self.skipinitialspace,
"skiprows": self.skiprows,
"nrows": self.nrows,
"na_values": self.na_values,
"keep_default_na": self.keep_default_na,
"na_filter": self.na_filter,
"verbose": self.verbose,
"skip_blank_lines": self.skip_blank_lines,
"thousands": self.thousands,
"decimal": self.decimal,
"lineterminator": self.lineterminator,
"quotechar": self.quotechar,
"quoting": self.quoting,
"escapechar": self.escapechar,
"comment": self.comment,
"encoding": self.encoding,
"dialect": self.dialect,
"error_bad_lines": self.error_bad_lines,
"warn_bad_lines": self.warn_bad_lines,
"skipfooter": self.skipfooter,
"doublequote": self.doublequote,
"memory_map": self.memory_map,
"float_precision": self.float_precision,
"chunksize": self.chunksize,
"encoding_errors": self.encoding_errors,
"on_bad_lines": self.on_bad_lines,
"date_format": self.date_format,
}
# some kwargs must not be passed if they don't have a default value
# some others are deprecated and we can also not pass them if they are the default value
for pd_read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS:
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter):
del pd_read_csv_kwargs[pd_read_csv_parameter]
# Remove 1.3 new arguments
if not (datasets.config.PANDAS_VERSION.major >= 1 and datasets.config.PANDAS_VERSION.minor >= 3):
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS:
del pd_read_csv_kwargs[pd_read_csv_parameter]
# Remove 2.0 new arguments
if not (datasets.config.PANDAS_VERSION.major >= 2):
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_2_0_0_PARAMETERS:
del pd_read_csv_kwargs[pd_read_csv_parameter]
# Remove 2.2 deprecated arguments
if datasets.config.PANDAS_VERSION.release >= (2, 2):
for pd_read_csv_parameter in _PANDAS_READ_CSV_DEPRECATED_2_2_0_PARAMETERS:
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter):
del pd_read_csv_kwargs[pd_read_csv_parameter]
return pd_read_csv_kwargs
class Csv(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = CsvConfig
def _info(self):
return datasets.DatasetInfo(features=self.config.features)
def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
schema = self.config.features.arrow_schema
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
# cheaper cast
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
else:
# more expensive cast; allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, schema)
return pa_table
def _generate_tables(self, files):
schema = self.config.features.arrow_schema if self.config.features else None
# dtype allows reading an int column as str
dtype = (
{
name: dtype.to_pandas_dtype() if not require_storage_cast(feature) else object
for name, dtype, feature in zip(schema.names, schema.types, self.config.features.values())
}
if schema is not None
else None
)
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
try:
for batch_idx, df in enumerate(csv_file_reader):
pa_table = pa.Table.from_pandas(df)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
|