|
import os |
|
import posixpath |
|
import uuid |
|
from collections.abc import Iterable |
|
from dataclasses import dataclass |
|
from itertools import islice |
|
from typing import TYPE_CHECKING, Optional, Union |
|
|
|
import numpy as np |
|
import pyarrow as pa |
|
|
|
import datasets |
|
from datasets.arrow_writer import ArrowWriter, ParquetWriter |
|
from datasets.config import MAX_SHARD_SIZE |
|
from datasets.filesystems import ( |
|
is_remote_filesystem, |
|
rename, |
|
) |
|
from datasets.iterable_dataset import _BaseExamplesIterable |
|
from datasets.utils import experimental |
|
from datasets.utils.py_utils import convert_file_size_to_int |
|
|
|
|
|
logger = datasets.utils.logging.get_logger(__name__) |
|
|
|
if TYPE_CHECKING: |
|
import pyspark |
|
import pyspark.sql |
|
|
|
|
|
@dataclass |
|
class SparkConfig(datasets.BuilderConfig): |
|
"""BuilderConfig for Spark.""" |
|
|
|
features: Optional[datasets.Features] = None |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
|
|
def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: list[int]): |
|
df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}") |
|
for partition_id in new_partition_order[1:]: |
|
partition_df = df.select("*").where(f"part_id = {partition_id}") |
|
df_combined = df_combined.union(partition_df) |
|
return df_combined |
|
|
|
|
|
def _generate_iterable_examples( |
|
df: "pyspark.sql.DataFrame", |
|
partition_order: list[int], |
|
state_dict: Optional[dict] = None, |
|
): |
|
import pyspark |
|
|
|
df_with_partition_id = df.select("*", pyspark.sql.functions.spark_partition_id().alias("part_id")) |
|
partition_idx_start = state_dict["partition_idx"] if state_dict else 0 |
|
partition_df = _reorder_dataframe_by_partition(df_with_partition_id, partition_order[partition_idx_start:]) |
|
|
|
rows = partition_df.toLocalIterator(prefetchPartitions=True) |
|
curr_partition = None |
|
row_id = state_dict["partition_example_idx"] if state_dict else 0 |
|
for row in islice(rows, row_id, None): |
|
row_as_dict = row.asDict() |
|
part_id = row_as_dict["part_id"] |
|
row_as_dict.pop("part_id") |
|
if curr_partition != part_id: |
|
if state_dict and curr_partition is not None: |
|
state_dict["partition_idx"] += 1 |
|
curr_partition = part_id |
|
row_id = 0 |
|
if state_dict: |
|
state_dict["partition_example_idx"] = row_id + 1 |
|
yield f"{part_id}_{row_id}", row_as_dict |
|
row_id += 1 |
|
|
|
|
|
class SparkExamplesIterable(_BaseExamplesIterable): |
|
def __init__( |
|
self, |
|
df: "pyspark.sql.DataFrame", |
|
partition_order=None, |
|
): |
|
super().__init__() |
|
self.df = df |
|
self.partition_order = partition_order or range(self.df.rdd.getNumPartitions()) |
|
|
|
def _init_state_dict(self) -> dict: |
|
self._state_dict = {"partition_idx": 0, "partition_example_idx": 0} |
|
return self._state_dict |
|
|
|
@experimental |
|
def load_state_dict(self, state_dict: dict) -> dict: |
|
return super().load_state_dict(state_dict) |
|
|
|
def __iter__(self): |
|
yield from _generate_iterable_examples(self.df, self.partition_order, self._state_dict) |
|
|
|
def shuffle_data_sources(self, generator: np.random.Generator) -> "SparkExamplesIterable": |
|
partition_order = list(range(self.df.rdd.getNumPartitions())) |
|
generator.shuffle(partition_order) |
|
return SparkExamplesIterable(self.df, partition_order=partition_order) |
|
|
|
def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SparkExamplesIterable": |
|
partition_order = self.split_shard_indices_by_worker(num_shards=num_shards, index=index, contiguous=contiguous) |
|
return SparkExamplesIterable(self.df, partition_order=partition_order) |
|
|
|
@property |
|
def num_shards(self) -> int: |
|
return len(self.partition_order) |
|
|
|
|
|
class Spark(datasets.DatasetBuilder): |
|
BUILDER_CONFIG_CLASS = SparkConfig |
|
|
|
def __init__( |
|
self, |
|
df: "pyspark.sql.DataFrame", |
|
cache_dir: str = None, |
|
working_dir: str = None, |
|
**config_kwargs, |
|
): |
|
import pyspark |
|
|
|
self._spark = pyspark.sql.SparkSession.builder.getOrCreate() |
|
self.df = df |
|
self._working_dir = working_dir |
|
|
|
super().__init__( |
|
cache_dir=cache_dir, |
|
config_name=str(self.df.semanticHash()), |
|
**config_kwargs, |
|
) |
|
|
|
def _validate_cache_dir(self): |
|
|
|
|
|
cache_dir = self._cache_dir |
|
|
|
|
|
def create_cache_and_write_probe(context): |
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
probe_file = os.path.join(cache_dir, "fs_test" + uuid.uuid4().hex) |
|
|
|
|
|
open(probe_file, "a") |
|
return [probe_file] |
|
|
|
if self._spark.conf.get("spark.master", "").startswith("local"): |
|
return |
|
|
|
|
|
|
|
|
|
if self._cache_dir: |
|
probe = ( |
|
self._spark.sparkContext.parallelize(range(1), 1).mapPartitions(create_cache_and_write_probe).collect() |
|
) |
|
if os.path.isfile(probe[0]): |
|
return |
|
|
|
raise ValueError( |
|
"When using Dataset.from_spark on a multi-node cluster, the driver and all workers should be able to access cache_dir" |
|
) |
|
|
|
def _info(self): |
|
return datasets.DatasetInfo(features=self.config.features) |
|
|
|
def _split_generators(self, dl_manager: datasets.download.download_manager.DownloadManager): |
|
return [datasets.SplitGenerator(name=datasets.Split.TRAIN)] |
|
|
|
def _repartition_df_if_needed(self, max_shard_size): |
|
import pyspark |
|
|
|
def get_arrow_batch_size(it): |
|
for batch in it: |
|
yield pa.RecordBatch.from_pydict({"batch_bytes": [batch.nbytes]}) |
|
|
|
df_num_rows = self.df.count() |
|
sample_num_rows = df_num_rows if df_num_rows <= 100 else 100 |
|
|
|
approx_bytes_per_row = ( |
|
self.df.limit(sample_num_rows) |
|
.repartition(1) |
|
.mapInArrow(get_arrow_batch_size, "batch_bytes: long") |
|
.agg(pyspark.sql.functions.sum("batch_bytes").alias("sample_bytes")) |
|
.collect()[0] |
|
.sample_bytes |
|
/ sample_num_rows |
|
) |
|
approx_total_size = approx_bytes_per_row * df_num_rows |
|
if approx_total_size > max_shard_size: |
|
|
|
new_num_partitions = min(df_num_rows, int(approx_total_size / max_shard_size)) |
|
self.df = self.df.repartition(new_num_partitions) |
|
|
|
def _prepare_split_single( |
|
self, |
|
fpath: str, |
|
file_format: str, |
|
max_shard_size: int, |
|
) -> Iterable[tuple[int, bool, Union[int, tuple]]]: |
|
import pyspark |
|
|
|
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter |
|
working_fpath = os.path.join(self._working_dir, os.path.basename(fpath)) if self._working_dir else fpath |
|
embed_local_files = file_format == "parquet" |
|
|
|
|
|
|
|
features = self.config.features |
|
writer_batch_size = self._writer_batch_size |
|
storage_options = self._fs.storage_options |
|
|
|
def write_arrow(it): |
|
|
|
task_id = pyspark.TaskContext().taskAttemptId() |
|
first_batch = next(it, None) |
|
if first_batch is None: |
|
|
|
return pa.RecordBatch.from_arrays( |
|
[[task_id], [0], [0]], |
|
names=["task_id", "num_examples", "num_bytes"], |
|
) |
|
shard_id = 0 |
|
writer = writer_class( |
|
features=features, |
|
path=working_fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"), |
|
writer_batch_size=writer_batch_size, |
|
storage_options=storage_options, |
|
embed_local_files=embed_local_files, |
|
) |
|
table = pa.Table.from_batches([first_batch]) |
|
writer.write_table(table) |
|
for batch in it: |
|
if max_shard_size is not None and writer._num_bytes >= max_shard_size: |
|
num_examples, num_bytes = writer.finalize() |
|
writer.close() |
|
yield pa.RecordBatch.from_arrays( |
|
[[task_id], [num_examples], [num_bytes]], |
|
names=["task_id", "num_examples", "num_bytes"], |
|
) |
|
shard_id += 1 |
|
writer = writer_class( |
|
features=writer._features, |
|
path=working_fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"), |
|
writer_batch_size=writer_batch_size, |
|
storage_options=storage_options, |
|
embed_local_files=embed_local_files, |
|
) |
|
table = pa.Table.from_batches([batch]) |
|
writer.write_table(table) |
|
|
|
if writer._num_bytes > 0: |
|
num_examples, num_bytes = writer.finalize() |
|
writer.close() |
|
yield pa.RecordBatch.from_arrays( |
|
[[task_id], [num_examples], [num_bytes]], |
|
names=["task_id", "num_examples", "num_bytes"], |
|
) |
|
|
|
if working_fpath != fpath: |
|
for file in os.listdir(os.path.dirname(working_fpath)): |
|
dest = os.path.join(os.path.dirname(fpath), os.path.basename(file)) |
|
shutil.move(file, dest) |
|
|
|
stats = ( |
|
self.df.mapInArrow(write_arrow, "task_id: long, num_examples: long, num_bytes: long") |
|
.groupBy("task_id") |
|
.agg( |
|
pyspark.sql.functions.sum("num_examples").alias("total_num_examples"), |
|
pyspark.sql.functions.sum("num_bytes").alias("total_num_bytes"), |
|
pyspark.sql.functions.count("num_bytes").alias("num_shards"), |
|
pyspark.sql.functions.collect_list("num_examples").alias("shard_lengths"), |
|
) |
|
.collect() |
|
) |
|
for row in stats: |
|
yield row.task_id, (row.total_num_examples, row.total_num_bytes, row.num_shards, row.shard_lengths) |
|
|
|
def _prepare_split( |
|
self, |
|
split_generator: "datasets.SplitGenerator", |
|
file_format: str = "arrow", |
|
max_shard_size: Optional[Union[str, int]] = None, |
|
num_proc: Optional[int] = None, |
|
**kwargs, |
|
): |
|
self._validate_cache_dir() |
|
|
|
max_shard_size = convert_file_size_to_int(max_shard_size or MAX_SHARD_SIZE) |
|
self._repartition_df_if_needed(max_shard_size) |
|
is_local = not is_remote_filesystem(self._fs) |
|
path_join = os.path.join if is_local else posixpath.join |
|
|
|
SUFFIX = "-TTTTT-SSSSS-of-NNNNN" |
|
fname = f"{self.name}-{split_generator.name}{SUFFIX}.{file_format}" |
|
fpath = path_join(self._output_dir, fname) |
|
|
|
total_num_examples = 0 |
|
total_num_bytes = 0 |
|
total_shards = 0 |
|
task_id_and_num_shards = [] |
|
all_shard_lengths = [] |
|
|
|
for task_id, content in self._prepare_split_single(fpath, file_format, max_shard_size): |
|
( |
|
num_examples, |
|
num_bytes, |
|
num_shards, |
|
shard_lengths, |
|
) = content |
|
if num_bytes > 0: |
|
total_num_examples += num_examples |
|
total_num_bytes += num_bytes |
|
total_shards += num_shards |
|
task_id_and_num_shards.append((task_id, num_shards)) |
|
all_shard_lengths.extend(shard_lengths) |
|
|
|
split_generator.split_info.num_examples = total_num_examples |
|
split_generator.split_info.num_bytes = total_num_bytes |
|
|
|
|
|
logger.debug(f"Renaming {total_shards} shards.") |
|
if total_shards > 1: |
|
split_generator.split_info.shard_lengths = all_shard_lengths |
|
|
|
|
|
|
|
fs = self._fs |
|
|
|
|
|
def _rename_shard( |
|
task_id: int, |
|
shard_id: int, |
|
global_shard_id: int, |
|
): |
|
rename( |
|
fs, |
|
fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"), |
|
fpath.replace("TTTTT-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"), |
|
) |
|
|
|
args = [] |
|
global_shard_id = 0 |
|
for i in range(len(task_id_and_num_shards)): |
|
task_id, num_shards = task_id_and_num_shards[i] |
|
for shard_id in range(num_shards): |
|
args.append([task_id, shard_id, global_shard_id]) |
|
global_shard_id += 1 |
|
self._spark.sparkContext.parallelize(args, len(args)).map(lambda args: _rename_shard(*args)).collect() |
|
else: |
|
|
|
shard_id = 0 |
|
task_id = task_id_and_num_shards[0][0] |
|
self._rename( |
|
fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"), |
|
fpath.replace(SUFFIX, ""), |
|
) |
|
|
|
def _get_examples_iterable_for_split( |
|
self, |
|
split_generator: "datasets.SplitGenerator", |
|
) -> SparkExamplesIterable: |
|
return SparkExamplesIterable(self.df) |
|
|