File size: 6,534 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import importlib
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Optional
from .download.download_config import DownloadConfig
from .utils.file_utils import (
xbasename,
xdirname,
xet_parse,
xexists,
xgetsize,
xglob,
xgzip_open,
xisdir,
xisfile,
xjoin,
xlistdir,
xnumpy_load,
xopen,
xpandas_read_csv,
xpandas_read_excel,
xPath,
xpyarrow_parquet_read_table,
xrelpath,
xsio_loadmat,
xsplit,
xsplitext,
xwalk,
xxml_dom_minidom_parse,
)
from .utils.logging import get_logger
from .utils.patching import patch_submodule
from .utils.py_utils import get_imports, lock_importable_file
logger = get_logger(__name__)
if TYPE_CHECKING:
from .builder import DatasetBuilder
def extend_module_for_streaming(module_path, download_config: Optional[DownloadConfig] = None):
"""Extend the module to support streaming.
We patch some functions in the module to use `fsspec` to support data streaming:
- We use `fsspec.open` to open and read remote files. We patch the module function:
- `open`
- We use the "::" hop separator to join paths and navigate remote compressed/archive files. We patch the module
functions:
- `os.path.join`
- `pathlib.Path.joinpath` and `pathlib.Path.__truediv__` (called when using the "/" operator)
The patched functions are replaced with custom functions defined to work with the
:class:`~download.streaming_download_manager.StreamingDownloadManager`.
Args:
module_path: Path to the module to be extended.
download_config: Mainly use `token` or `storage_options` to support different platforms and auth types.
"""
module = importlib.import_module(module_path)
# TODO(QL): always update the module to add subsequent new authentication without removing old ones
if hasattr(module, "_patched_for_streaming") and module._patched_for_streaming:
if isinstance(module._patched_for_streaming, DownloadConfig):
module._patched_for_streaming.token = download_config.token
module._patched_for_streaming.storage_options = download_config.storage_options
return
def wrap_auth(function):
@wraps(function)
def wrapper(*args, **kwargs):
return function(*args, download_config=download_config, **kwargs)
wrapper._decorator_name_ = "wrap_auth"
return wrapper
# open files in a streaming fashion
patch_submodule(module, "open", wrap_auth(xopen)).start()
patch_submodule(module, "os.listdir", wrap_auth(xlistdir)).start()
patch_submodule(module, "os.walk", wrap_auth(xwalk)).start()
patch_submodule(module, "glob.glob", wrap_auth(xglob)).start()
# allow to navigate in remote zip files
patch_submodule(module, "os.path.join", xjoin).start()
patch_submodule(module, "os.path.dirname", xdirname).start()
patch_submodule(module, "os.path.basename", xbasename).start()
patch_submodule(module, "os.path.relpath", xrelpath).start()
patch_submodule(module, "os.path.split", xsplit).start()
patch_submodule(module, "os.path.splitext", xsplitext).start()
# allow checks on paths
patch_submodule(module, "os.path.exists", wrap_auth(xexists)).start()
patch_submodule(module, "os.path.isdir", wrap_auth(xisdir)).start()
patch_submodule(module, "os.path.isfile", wrap_auth(xisfile)).start()
patch_submodule(module, "os.path.getsize", wrap_auth(xgetsize)).start()
patch_submodule(module, "pathlib.Path", xPath).start()
# file readers
patch_submodule(module, "gzip.open", wrap_auth(xgzip_open)).start()
patch_submodule(module, "numpy.load", wrap_auth(xnumpy_load)).start()
patch_submodule(module, "pandas.read_csv", wrap_auth(xpandas_read_csv), attrs=["__version__"]).start()
patch_submodule(module, "pandas.read_excel", wrap_auth(xpandas_read_excel), attrs=["__version__"]).start()
patch_submodule(module, "scipy.io.loadmat", wrap_auth(xsio_loadmat), attrs=["__version__"]).start()
patch_submodule(module, "xml.etree.ElementTree.parse", wrap_auth(xet_parse)).start()
patch_submodule(module, "xml.dom.minidom.parse", wrap_auth(xxml_dom_minidom_parse)).start()
# pyarrow: do not patch pyarrow attribute in packaged modules
if not module.__name__.startswith("datasets.packaged_modules."):
patch_submodule(module, "pyarrow.parquet.read_table", wrap_auth(xpyarrow_parquet_read_table)).start()
module._patched_for_streaming = download_config
def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"):
"""Extend the dataset builder module and the modules imported by it to support streaming.
Args:
builder (:class:`DatasetBuilder`): Dataset builder instance.
"""
# this extends the open and os.path.join functions for data streaming
download_config = DownloadConfig(storage_options=builder.storage_options, token=builder.token)
extend_module_for_streaming(builder.__module__, download_config=download_config)
# if needed, we also have to extend additional internal imports (like wmt14 -> wmt_utils)
if not builder.__module__.startswith("datasets."): # check that it's not a packaged builder like csv
importable_file = inspect.getfile(builder.__class__)
with lock_importable_file(importable_file):
for imports in get_imports(importable_file):
if imports[0] == "internal":
internal_import_name = imports[1]
internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name])
extend_module_for_streaming(internal_module_name, download_config=download_config)
# builders can inherit from other builders that might use streaming functionality
# (for example, ImageFolder and AudioFolder inherit from FolderBuilder which implements examples generation)
# but these parents builders are not patched automatically as they are not instantiated, so we patch them here
from .builder import DatasetBuilder
parent_builder_modules = [
cls.__module__
for cls in type(builder).__mro__[1:] # make sure it's not the same module we've already patched
if issubclass(cls, DatasetBuilder) and cls.__module__ != DatasetBuilder.__module__
] # check it's not a standard builder from datasets.builder
for module in parent_builder_modules:
extend_module_for_streaming(module, download_config=download_config)
|