File size: 9,452 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#

"""Miscellaneous utility functions."""

import fnmatch
import functools
import glob
import importlib
import itertools as itt
import os
import re
import sys
import warnings
from typing import Any, Callable, Iterable, Iterator, Union

import braceexpand
import numpy as np

enforce_security = bool(int(os.environ.get("WDS_SECURE", "0")))


def glob_with_braces(pattern):
    """Apply glob to patterns with braces by pre-expanding the braces.

    Args:
        pattern (str): The glob pattern with braces.

    Returns:
        list: A list of file paths matching the expanded pattern.
    """
    expanded = braceexpand.braceexpand(pattern)
    return [f for pat in expanded for f in glob.glob(pat)]


def fnmatch_with_braces(filename, pattern):
    """Apply fnmatch to patterns with braces by pre-expanding the braces.

    Args:
        filename (str): The filename to match against.
        pattern (str): The pattern with braces to match.

    Returns:
        bool: True if the filename matches any of the expanded patterns, False otherwise.
    """
    expanded = braceexpand.braceexpand(pattern)
    for pat in expanded:
        if fnmatch.fnmatch(filename, pat):
            return True
    return any(fnmatch.fnmatch(filename, pat) for pat in expanded)


def make_seed(*args):
    """Generate a seed value from the given arguments.

    Args:
        *args: Variable length argument list to generate the seed from.

    Returns:
        int: A 31-bit positive integer seed value.
    """
    seed = 0
    for arg in args:
        seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
    return seed


def is_iterable(obj):
    """Check if an object is iterable (excluding strings and bytes).

    Args:
        obj: The object to check for iterability.

    Returns:
        bool: True if the object is iterable (excluding strings and bytes), False otherwise.
    """
    if isinstance(obj, str):
        return False
    if isinstance(obj, bytes):
        return False
    if isinstance(obj, list):
        return True
    if isinstance(obj, Iterable):
        return True
    if isinstance(obj, Iterator):
        return True
    return False


class PipelineStage:
    """Base class for pipeline stages."""

    def invoke(self, *args, **kw):
        """Invoke the pipeline stage.

        Args:
            *args: Variable length argument list.
            **kw: Arbitrary keyword arguments.

        Raises:
            NotImplementedError: This method should be implemented by subclasses.
        """
        raise NotImplementedError


def identity(x: Any) -> Any:
    """Return the argument as is.

    Args:
        x (Any): The input value.

    Returns:
        Any: The input value unchanged.
    """
    return x


def safe_eval(s: str, expr: str = "{}"):
    """Evaluate the given expression more safely.

    Args:
        s (str): The string to evaluate.
        expr (str, optional): The expression format. Defaults to "{}".

    Returns:
        Any: The result of the evaluation.

    Raises:
        ValueError: If the input string contains illegal characters.
    """
    if re.sub("[^A-Za-z0-9_]", "", s) != s:
        raise ValueError(f"safe_eval: illegal characters in: '{s}'")
    return eval(expr.format(s))


def lookup_sym(sym: str, modules: list):
    """Look up a symbol in a list of modules.

    Args:
        sym (str): The symbol to look up.
        modules (list): A list of module names to search in.

    Returns:
        Any: The found symbol, or None if not found.
    """
    for mname in modules:
        module = importlib.import_module(mname, package="webdataset")
        result = getattr(module, sym, None)
        if result is not None:
            return result
    return None


def repeatedly0(loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize):
    """Repeatedly returns batches from a DataLoader.

    Args:
        loader (Iterator): The data loader to yield batches from.
        nepochs (int, optional): Number of epochs to repeat. Defaults to sys.maxsize.
        nbatches (int, optional): Number of batches per epoch. Defaults to sys.maxsize.

    Yields:
        Any: Batches from the data loader.
    """
    for _ in range(nepochs):
        yield from itt.islice(loader, nbatches)


def guess_batchsize(batch: Union[tuple, list]):
    """Guess the batch size by looking at the length of the first element in a tuple.

    Args:
        batch (Union[tuple, list]): The batch to guess the size of.

    Returns:
        int: The guessed batch size.
    """
    return len(batch[0])


def repeatedly(
    source: Iterator,
    nepochs: int = None,
    nbatches: int = None,
    nsamples: int = None,
    batchsize: Callable[..., int] = guess_batchsize,
):
    """Repeatedly yield samples from an iterator.

    Args:
        source (Iterator): The source iterator to yield samples from.
        nepochs (int, optional): Number of epochs to repeat. Defaults to None.
        nbatches (int, optional): Number of batches to yield. Defaults to None.
        nsamples (int, optional): Number of samples to yield. Defaults to None.
        batchsize (Callable[..., int], optional): Function to guess batch size. Defaults to guess_batchsize.

    Yields:
        Any: Samples from the source iterator.
    """
    epoch = 0
    batch = 0
    total = 0
    while True:
        for sample in source:
            yield sample
            batch += 1
            if nbatches is not None and batch >= nbatches:
                return
            if nsamples is not None:
                total += guess_batchsize(sample)
                if total >= nsamples:
                    return
        epoch += 1
        if nepochs is not None and epoch >= nepochs:
            return


def pytorch_worker_info(group=None):  # sourcery skip: use-contextlib-suppress
    """Return node and worker info for PyTorch and some distributed environments.

    Args:
        group (optional): The process group for distributed environments. Defaults to None.

    Returns:
        tuple: A tuple containing (rank, world_size, worker, num_workers).
    """
    rank = 0
    world_size = 1
    worker = 0
    num_workers = 1
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
    else:
        try:
            import torch.distributed

            if torch.distributed.is_available() and torch.distributed.is_initialized():
                group = group or torch.distributed.group.WORLD
                rank = torch.distributed.get_rank(group=group)
                world_size = torch.distributed.get_world_size(group=group)
        except ModuleNotFoundError:
            pass
    if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
        worker = int(os.environ["WORKER"])
        num_workers = int(os.environ["NUM_WORKERS"])
    else:
        try:
            import torch.utils.data

            worker_info = torch.utils.data.get_worker_info()
            if worker_info is not None:
                worker = worker_info.id
                num_workers = worker_info.num_workers
        except ModuleNotFoundError:
            pass

    return rank, world_size, worker, num_workers


def pytorch_worker_seed(group=None):
    """Compute a distinct, deterministic RNG seed for each worker and node.

    Args:
        group (optional): The process group for distributed environments. Defaults to None.

    Returns:
        int: A deterministic RNG seed.
    """
    rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
    return rank * 1000 + worker


def deprecated(arg=None):
    if callable(arg):
        # The decorator was used without arguments
        func = arg
        reason = None
    else:
        # The decorator was used with arguments
        func = None
        reason = arg

    def decorator(func):
        @functools.wraps(func)
        def new_func(*args, **kwargs):
            msg = f"Call to deprecated function {func.__name__}."
            if reason is not None:
                msg += " Reason: " + reason
            warnings.warn(
                msg,
                category=DeprecationWarning,
                stacklevel=2,
            )
            return func(*args, **kwargs)

        return new_func

    if func is None:
        # The decorator was used with arguments
        return decorator
    else:
        # The decorator was used without arguments
        return decorator(func)


class ObsoleteException(Exception):
    pass


def obsolete(func=None, *, reason=None):
    if func is None:
        return functools.partial(obsolete, reason=reason)

    @functools.wraps(func)
    def new_func(*args, **kwargs):
        if not int(os.environ.get("ALLOW_OBSOLETE", "0")):
            msg = f"Call to obsolete function {func.__name__}. Set env ALLOW_OBSOLETE=1 to permit."
            if reason is not None:
                msg += " Reason: " + reason
            raise ObsoleteException(msg)
        return func(*args, **kwargs)

    return new_func


def compute_sample_weights(n_w_pairs):
    ns = np.array([p[0] for p in n_w_pairs])
    ws = np.array([p[1] for p in n_w_pairs])
    weighted = ns * ws
    ps = weighted / np.amax(weighted)
    return ps