File size: 6,534 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import importlib
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Optional

from .download.download_config import DownloadConfig
from .utils.file_utils import (
    xbasename,
    xdirname,
    xet_parse,
    xexists,
    xgetsize,
    xglob,
    xgzip_open,
    xisdir,
    xisfile,
    xjoin,
    xlistdir,
    xnumpy_load,
    xopen,
    xpandas_read_csv,
    xpandas_read_excel,
    xPath,
    xpyarrow_parquet_read_table,
    xrelpath,
    xsio_loadmat,
    xsplit,
    xsplitext,
    xwalk,
    xxml_dom_minidom_parse,
)
from .utils.logging import get_logger
from .utils.patching import patch_submodule
from .utils.py_utils import get_imports, lock_importable_file


logger = get_logger(__name__)


if TYPE_CHECKING:
    from .builder import DatasetBuilder


def extend_module_for_streaming(module_path, download_config: Optional[DownloadConfig] = None):
    """Extend the module to support streaming.

    We patch some functions in the module to use `fsspec` to support data streaming:
    - We use `fsspec.open` to open and read remote files. We patch the module function:
      - `open`
    - We use the "::" hop separator to join paths and navigate remote compressed/archive files. We patch the module
      functions:
      - `os.path.join`
      - `pathlib.Path.joinpath` and `pathlib.Path.__truediv__` (called when using the "/" operator)

    The patched functions are replaced with custom functions defined to work with the
    :class:`~download.streaming_download_manager.StreamingDownloadManager`.

    Args:
        module_path: Path to the module to be extended.
        download_config: Mainly use `token` or `storage_options` to support different platforms and auth types.
    """

    module = importlib.import_module(module_path)

    # TODO(QL): always update the module to add subsequent new authentication without removing old ones
    if hasattr(module, "_patched_for_streaming") and module._patched_for_streaming:
        if isinstance(module._patched_for_streaming, DownloadConfig):
            module._patched_for_streaming.token = download_config.token
            module._patched_for_streaming.storage_options = download_config.storage_options
        return

    def wrap_auth(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            return function(*args, download_config=download_config, **kwargs)

        wrapper._decorator_name_ = "wrap_auth"
        return wrapper

    # open files in a streaming fashion
    patch_submodule(module, "open", wrap_auth(xopen)).start()
    patch_submodule(module, "os.listdir", wrap_auth(xlistdir)).start()
    patch_submodule(module, "os.walk", wrap_auth(xwalk)).start()
    patch_submodule(module, "glob.glob", wrap_auth(xglob)).start()
    # allow to navigate in remote zip files
    patch_submodule(module, "os.path.join", xjoin).start()
    patch_submodule(module, "os.path.dirname", xdirname).start()
    patch_submodule(module, "os.path.basename", xbasename).start()
    patch_submodule(module, "os.path.relpath", xrelpath).start()
    patch_submodule(module, "os.path.split", xsplit).start()
    patch_submodule(module, "os.path.splitext", xsplitext).start()
    # allow checks on paths
    patch_submodule(module, "os.path.exists", wrap_auth(xexists)).start()
    patch_submodule(module, "os.path.isdir", wrap_auth(xisdir)).start()
    patch_submodule(module, "os.path.isfile", wrap_auth(xisfile)).start()
    patch_submodule(module, "os.path.getsize", wrap_auth(xgetsize)).start()
    patch_submodule(module, "pathlib.Path", xPath).start()
    # file readers
    patch_submodule(module, "gzip.open", wrap_auth(xgzip_open)).start()
    patch_submodule(module, "numpy.load", wrap_auth(xnumpy_load)).start()
    patch_submodule(module, "pandas.read_csv", wrap_auth(xpandas_read_csv), attrs=["__version__"]).start()
    patch_submodule(module, "pandas.read_excel", wrap_auth(xpandas_read_excel), attrs=["__version__"]).start()
    patch_submodule(module, "scipy.io.loadmat", wrap_auth(xsio_loadmat), attrs=["__version__"]).start()
    patch_submodule(module, "xml.etree.ElementTree.parse", wrap_auth(xet_parse)).start()
    patch_submodule(module, "xml.dom.minidom.parse", wrap_auth(xxml_dom_minidom_parse)).start()
    # pyarrow: do not patch pyarrow attribute in packaged modules
    if not module.__name__.startswith("datasets.packaged_modules."):
        patch_submodule(module, "pyarrow.parquet.read_table", wrap_auth(xpyarrow_parquet_read_table)).start()
    module._patched_for_streaming = download_config


def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"):
    """Extend the dataset builder module and the modules imported by it to support streaming.

    Args:
        builder (:class:`DatasetBuilder`): Dataset builder instance.
    """
    # this extends the open and os.path.join functions for data streaming
    download_config = DownloadConfig(storage_options=builder.storage_options, token=builder.token)
    extend_module_for_streaming(builder.__module__, download_config=download_config)
    # if needed, we also have to extend additional internal imports (like wmt14 -> wmt_utils)
    if not builder.__module__.startswith("datasets."):  # check that it's not a packaged builder like csv
        importable_file = inspect.getfile(builder.__class__)
        with lock_importable_file(importable_file):
            for imports in get_imports(importable_file):
                if imports[0] == "internal":
                    internal_import_name = imports[1]
                    internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name])
                    extend_module_for_streaming(internal_module_name, download_config=download_config)

    # builders can inherit from other builders that might use streaming functionality
    # (for example, ImageFolder and AudioFolder inherit from FolderBuilder which implements examples generation)
    # but these parents builders are not patched automatically as they are not instantiated, so we patch them here
    from .builder import DatasetBuilder

    parent_builder_modules = [
        cls.__module__
        for cls in type(builder).__mro__[1:]  # make sure it's not the same module we've already patched
        if issubclass(cls, DatasetBuilder) and cls.__module__ != DatasetBuilder.__module__
    ]  # check it's not a standard builder from datasets.builder
    for module in parent_builder_modules:
        extend_module_for_streaming(module, download_config=download_config)