File size: 1,797 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from typing import Optional
import pyspark
from .. import Features, NamedSplit
from ..download import DownloadMode
from ..packaged_modules.spark.spark import Spark
from .abc import AbstractDatasetReader
class SparkDatasetReader(AbstractDatasetReader):
"""A dataset reader that reads from a Spark DataFrame.
When caching, cache materialization is parallelized over Spark; an NFS that is accessible to the driver must be
provided. Streaming is not currently supported.
"""
def __init__(
self,
df: pyspark.sql.DataFrame,
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
streaming: bool = True,
cache_dir: str = None,
keep_in_memory: bool = False,
working_dir: str = None,
load_from_cache_file: bool = True,
file_format: str = "arrow",
**kwargs,
):
super().__init__(
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
self._load_from_cache_file = load_from_cache_file
self._file_format = file_format
self.builder = Spark(
df=df,
features=features,
cache_dir=cache_dir,
working_dir=working_dir,
**kwargs,
)
def read(self):
if self.streaming:
return self.builder.as_streaming_dataset(split=self.split)
download_mode = None if self._load_from_cache_file else DownloadMode.FORCE_REDOWNLOAD
self.builder.download_and_prepare(
download_mode=download_mode,
file_format=self._file_format,
)
return self.builder.as_dataset(split=self.split)
|