File size: 14,689 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
import os
import posixpath
import uuid
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import pyarrow as pa
import datasets
from datasets.arrow_writer import ArrowWriter, ParquetWriter
from datasets.config import MAX_SHARD_SIZE
from datasets.filesystems import (
is_remote_filesystem,
rename,
)
from datasets.iterable_dataset import _BaseExamplesIterable
from datasets.utils import experimental
from datasets.utils.py_utils import convert_file_size_to_int
logger = datasets.utils.logging.get_logger(__name__)
if TYPE_CHECKING:
import pyspark
import pyspark.sql
@dataclass
class SparkConfig(datasets.BuilderConfig):
"""BuilderConfig for Spark."""
features: Optional[datasets.Features] = None
def __post_init__(self):
super().__post_init__()
def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: list[int]):
df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}")
for partition_id in new_partition_order[1:]:
partition_df = df.select("*").where(f"part_id = {partition_id}")
df_combined = df_combined.union(partition_df)
return df_combined
def _generate_iterable_examples(
df: "pyspark.sql.DataFrame",
partition_order: list[int],
state_dict: Optional[dict] = None,
):
import pyspark
df_with_partition_id = df.select("*", pyspark.sql.functions.spark_partition_id().alias("part_id"))
partition_idx_start = state_dict["partition_idx"] if state_dict else 0
partition_df = _reorder_dataframe_by_partition(df_with_partition_id, partition_order[partition_idx_start:])
# pipeline next partition in parallel to hide latency
rows = partition_df.toLocalIterator(prefetchPartitions=True)
curr_partition = None
row_id = state_dict["partition_example_idx"] if state_dict else 0
for row in islice(rows, row_id, None):
row_as_dict = row.asDict()
part_id = row_as_dict["part_id"]
row_as_dict.pop("part_id")
if curr_partition != part_id:
if state_dict and curr_partition is not None:
state_dict["partition_idx"] += 1
curr_partition = part_id
row_id = 0
if state_dict:
state_dict["partition_example_idx"] = row_id + 1
yield f"{part_id}_{row_id}", row_as_dict
row_id += 1
class SparkExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
df: "pyspark.sql.DataFrame",
partition_order=None,
):
super().__init__()
self.df = df
self.partition_order = partition_order or range(self.df.rdd.getNumPartitions())
def _init_state_dict(self) -> dict:
self._state_dict = {"partition_idx": 0, "partition_example_idx": 0}
return self._state_dict
@experimental
def load_state_dict(self, state_dict: dict) -> dict:
return super().load_state_dict(state_dict)
def __iter__(self):
yield from _generate_iterable_examples(self.df, self.partition_order, self._state_dict)
def shuffle_data_sources(self, generator: np.random.Generator) -> "SparkExamplesIterable":
partition_order = list(range(self.df.rdd.getNumPartitions()))
generator.shuffle(partition_order)
return SparkExamplesIterable(self.df, partition_order=partition_order)
def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SparkExamplesIterable":
partition_order = self.split_shard_indices_by_worker(num_shards=num_shards, index=index, contiguous=contiguous)
return SparkExamplesIterable(self.df, partition_order=partition_order)
@property
def num_shards(self) -> int:
return len(self.partition_order)
class Spark(datasets.DatasetBuilder):
BUILDER_CONFIG_CLASS = SparkConfig
def __init__(
self,
df: "pyspark.sql.DataFrame",
cache_dir: str = None,
working_dir: str = None,
**config_kwargs,
):
import pyspark
self._spark = pyspark.sql.SparkSession.builder.getOrCreate()
self.df = df
self._working_dir = working_dir
super().__init__(
cache_dir=cache_dir,
config_name=str(self.df.semanticHash()),
**config_kwargs,
)
def _validate_cache_dir(self):
# Define this so that we don't reference self in create_cache_and_write_probe, which will result in a pickling
# error due to pickling the SparkContext.
cache_dir = self._cache_dir
# Returns the path of the created file.
def create_cache_and_write_probe(context):
# makedirs with exist_ok will recursively create the directory. It will not throw an error if directories
# already exist.
os.makedirs(cache_dir, exist_ok=True)
probe_file = os.path.join(cache_dir, "fs_test" + uuid.uuid4().hex)
# Opening the file in append mode will create a new file unless it already exists, in which case it will not
# change the file contents.
open(probe_file, "a")
return [probe_file]
if self._spark.conf.get("spark.master", "").startswith("local"):
return
# If the cluster is multi-node, make sure that the user provided a cache_dir and that it is on an NFS
# accessible to the driver.
# TODO: Stream batches to the driver using ArrowCollectSerializer instead of throwing an error.
if self._cache_dir:
probe = (
self._spark.sparkContext.parallelize(range(1), 1).mapPartitions(create_cache_and_write_probe).collect()
)
if os.path.isfile(probe[0]):
return
raise ValueError(
"When using Dataset.from_spark on a multi-node cluster, the driver and all workers should be able to access cache_dir"
)
def _info(self):
return datasets.DatasetInfo(features=self.config.features)
def _split_generators(self, dl_manager: datasets.download.download_manager.DownloadManager):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN)]
def _repartition_df_if_needed(self, max_shard_size):
import pyspark
def get_arrow_batch_size(it):
for batch in it:
yield pa.RecordBatch.from_pydict({"batch_bytes": [batch.nbytes]})
df_num_rows = self.df.count()
sample_num_rows = df_num_rows if df_num_rows <= 100 else 100
# Approximate the size of each row (in Arrow format) by averaging over a max-100-row sample.
approx_bytes_per_row = (
self.df.limit(sample_num_rows)
.repartition(1)
.mapInArrow(get_arrow_batch_size, "batch_bytes: long")
.agg(pyspark.sql.functions.sum("batch_bytes").alias("sample_bytes"))
.collect()[0]
.sample_bytes
/ sample_num_rows
)
approx_total_size = approx_bytes_per_row * df_num_rows
if approx_total_size > max_shard_size:
# Make sure there is at least one row per partition.
new_num_partitions = min(df_num_rows, int(approx_total_size / max_shard_size))
self.df = self.df.repartition(new_num_partitions)
def _prepare_split_single(
self,
fpath: str,
file_format: str,
max_shard_size: int,
) -> Iterable[tuple[int, bool, Union[int, tuple]]]:
import pyspark
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
working_fpath = os.path.join(self._working_dir, os.path.basename(fpath)) if self._working_dir else fpath
embed_local_files = file_format == "parquet"
# Define these so that we don't reference self in write_arrow, which will result in a pickling error due to
# pickling the SparkContext.
features = self.config.features
writer_batch_size = self._writer_batch_size
storage_options = self._fs.storage_options
def write_arrow(it):
# Within the same SparkContext, no two task attempts will share the same attempt ID.
task_id = pyspark.TaskContext().taskAttemptId()
first_batch = next(it, None)
if first_batch is None:
# Some partitions might not receive any data.
return pa.RecordBatch.from_arrays(
[[task_id], [0], [0]],
names=["task_id", "num_examples", "num_bytes"],
)
shard_id = 0
writer = writer_class(
features=features,
path=working_fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"),
writer_batch_size=writer_batch_size,
storage_options=storage_options,
embed_local_files=embed_local_files,
)
table = pa.Table.from_batches([first_batch])
writer.write_table(table)
for batch in it:
if max_shard_size is not None and writer._num_bytes >= max_shard_size:
num_examples, num_bytes = writer.finalize()
writer.close()
yield pa.RecordBatch.from_arrays(
[[task_id], [num_examples], [num_bytes]],
names=["task_id", "num_examples", "num_bytes"],
)
shard_id += 1
writer = writer_class(
features=writer._features,
path=working_fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"),
writer_batch_size=writer_batch_size,
storage_options=storage_options,
embed_local_files=embed_local_files,
)
table = pa.Table.from_batches([batch])
writer.write_table(table)
if writer._num_bytes > 0:
num_examples, num_bytes = writer.finalize()
writer.close()
yield pa.RecordBatch.from_arrays(
[[task_id], [num_examples], [num_bytes]],
names=["task_id", "num_examples", "num_bytes"],
)
if working_fpath != fpath:
for file in os.listdir(os.path.dirname(working_fpath)):
dest = os.path.join(os.path.dirname(fpath), os.path.basename(file))
shutil.move(file, dest)
stats = (
self.df.mapInArrow(write_arrow, "task_id: long, num_examples: long, num_bytes: long")
.groupBy("task_id")
.agg(
pyspark.sql.functions.sum("num_examples").alias("total_num_examples"),
pyspark.sql.functions.sum("num_bytes").alias("total_num_bytes"),
pyspark.sql.functions.count("num_bytes").alias("num_shards"),
pyspark.sql.functions.collect_list("num_examples").alias("shard_lengths"),
)
.collect()
)
for row in stats:
yield row.task_id, (row.total_num_examples, row.total_num_bytes, row.num_shards, row.shard_lengths)
def _prepare_split(
self,
split_generator: "datasets.SplitGenerator",
file_format: str = "arrow",
max_shard_size: Optional[Union[str, int]] = None,
num_proc: Optional[int] = None,
**kwargs,
):
self._validate_cache_dir()
max_shard_size = convert_file_size_to_int(max_shard_size or MAX_SHARD_SIZE)
self._repartition_df_if_needed(max_shard_size)
is_local = not is_remote_filesystem(self._fs)
path_join = os.path.join if is_local else posixpath.join
SUFFIX = "-TTTTT-SSSSS-of-NNNNN"
fname = f"{self.name}-{split_generator.name}{SUFFIX}.{file_format}"
fpath = path_join(self._output_dir, fname)
total_num_examples = 0
total_num_bytes = 0
total_shards = 0
task_id_and_num_shards = []
all_shard_lengths = []
for task_id, content in self._prepare_split_single(fpath, file_format, max_shard_size):
(
num_examples,
num_bytes,
num_shards,
shard_lengths,
) = content
if num_bytes > 0:
total_num_examples += num_examples
total_num_bytes += num_bytes
total_shards += num_shards
task_id_and_num_shards.append((task_id, num_shards))
all_shard_lengths.extend(shard_lengths)
split_generator.split_info.num_examples = total_num_examples
split_generator.split_info.num_bytes = total_num_bytes
# should rename everything at the end
logger.debug(f"Renaming {total_shards} shards.")
if total_shards > 1:
split_generator.split_info.shard_lengths = all_shard_lengths
# Define fs outside of _rename_shard so that we don't reference self in the function, which will result in a
# pickling error due to pickling the SparkContext.
fs = self._fs
# use the -SSSSS-of-NNNNN pattern
def _rename_shard(
task_id: int,
shard_id: int,
global_shard_id: int,
):
rename(
fs,
fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"),
fpath.replace("TTTTT-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"),
)
args = []
global_shard_id = 0
for i in range(len(task_id_and_num_shards)):
task_id, num_shards = task_id_and_num_shards[i]
for shard_id in range(num_shards):
args.append([task_id, shard_id, global_shard_id])
global_shard_id += 1
self._spark.sparkContext.parallelize(args, len(args)).map(lambda args: _rename_shard(*args)).collect()
else:
# don't use any pattern
shard_id = 0
task_id = task_id_and_num_shards[0][0]
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"),
fpath.replace(SUFFIX, ""),
)
def _get_examples_iterable_for_split(
self,
split_generator: "datasets.SplitGenerator",
) -> SparkExamplesIterable:
return SparkExamplesIterable(self.df)
|