File size: 18,368 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
import os
import random
import warnings
from types import SimpleNamespace
from urllib.parse import urlparse

import yaml

from . import autodecode, cache, filters, shardlists, utils
from .filters import pipelinefilter, reraise_exception
from .pipeline import DataPipeline
from .pytorch import DataLoader
from .tariterators import group_by_keys, tar_file_expander


class FluidInterface:
    def batched(self, batchsize, collation_fn=filters.default_collation_fn, partial=True):
        """Create batches of the given size.

        This method forwards to the filters.batched function.

        Args:
            batchsize (int): Target batch size.
            collation_fn (callable, optional): Function to collate samples into a batch.
                Defaults to filters.default_collation_fn.
            partial (bool, optional): Whether to return partial batches. Defaults to True.

        Returns:
            FluidInterface: Updated pipeline with batched filter.
        """
        return self.compose(filters.batched(batchsize, collation_fn=collation_fn, partial=partial))

    def unbatched(self):
        """Turn batched data back into unbatched data.

        This method forwards to the filters.unbatched function.

        Returns:
            FluidInterface: Updated pipeline with unbatched filter.
        """
        return self.compose(filters.unbatched())

    def listed(self, batchsize, partial=True):
        """Create lists of samples without collation.

        This method forwards to the filters.batched function with collation_fn set to None.

        Args:
            batchsize (int): Target list size.
            partial (bool, optional): Whether to return partial lists. Defaults to True.

        Returns:
            FluidInterface: Updated pipeline with listed filter.
        """
        return self.compose(filters.batched(batchsize=batchsize, collation_fn=None))

    def unlisted(self):
        """Turn listed data back into individual samples.

        This method forwards to the filters.unlisted function.

        Returns:
            FluidInterface: Updated pipeline with unlisted filter.
        """
        return self.compose(filters.unlisted())

    def log_keys(self, logfile=None):
        """Log keys of samples passing through the pipeline.

        This method forwards to the filters.log_keys function.

        Args:
            logfile (str, optional): Path to the log file. If None, logging is disabled.

        Returns:
            FluidInterface: Updated pipeline with log_keys filter.
        """
        return self.compose(filters.log_keys(logfile))

    def shuffle(self, size, **kw):
        """Shuffle the data in the stream.

        This method forwards to the filters.shuffle function if size > 0.

        Args:
            size (int): Buffer size for shuffling.
            **kw: Additional keyword arguments for filters.shuffle.

        Returns:
            FluidInterface: Updated pipeline with shuffle filter, or self if size < 1.
        """
        if size < 1:
            return self
        else:
            return self.compose(filters.shuffle(size, **kw))

    def map(self, f, handler=reraise_exception):
        """Apply a function to each sample in the stream.

        This method forwards to the filters.map function.

        Args:
            f (callable): Function to apply to each sample.
            handler (callable, optional): Exception handler. Defaults to reraise_exception.

        Returns:
            FluidInterface: Updated pipeline with map filter.
        """
        return self.compose(filters.map(f, handler=handler))

    def decode(
        self,
        *args,
        pre=None,
        post=None,
        only=None,
        partial=False,
        handler=reraise_exception,
    ):
        """Decode data based on the decoding functions given as arguments.

        This method creates a decoder using autodecode.Decoder and applies it using filters.map.

        Args:
            *args: Decoding functions or strings representing image handlers.
            pre (callable, optional): Pre-processing function.
            post (callable, optional): Post-processing function.
            only (list, optional): List of keys to decode.
            partial (bool, optional): Whether to allow partial decoding. Defaults to False.
            handler (callable, optional): Exception handler. Defaults to reraise_exception.

        Returns:
            FluidInterface: Updated pipeline with decode filter.
        """
        handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
        decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
        return self.map(decoder, handler=handler)

    def map_dict(self, handler=reraise_exception, **kw):
        """Map the entries in a dict sample with individual functions.

        This method forwards to the filters.map_dict function.

        Args:
            handler (callable, optional): Exception handler. Defaults to reraise_exception.
            **kw: Mapping of keys to functions to apply.

        Returns:
            FluidInterface: Updated pipeline with map_dict filter.
        """
        return self.compose(filters.map_dict(handler=handler, **kw))

    def select(self, predicate, **kw):
        """Select samples based on a predicate.

        This method forwards to the filters.select function.

        Args:
            predicate (callable): Function that returns True for samples to keep.
            **kw: Additional keyword arguments for filters.select.

        Returns:
            FluidInterface: Updated pipeline with select filter.
        """
        return self.compose(filters.select(predicate, **kw))

    def to_tuple(self, *args, **kw):
        """Convert dict samples to tuples.

        This method forwards to the filters.to_tuple function.

        Args:
            *args: Keys to extract from the dict.
            **kw: Additional keyword arguments for filters.to_tuple.

        Returns:
            FluidInterface: Updated pipeline with to_tuple filter.
        """
        return self.compose(filters.to_tuple(*args, **kw))

    def map_tuple(self, *args, handler=reraise_exception):
        """Map the entries of a tuple with individual functions.

        This method forwards to the filters.map_tuple function.

        Args:
            *args: Functions to apply to each element of the tuple.
            handler (callable, optional): Exception handler. Defaults to reraise_exception.

        Returns:
            FluidInterface: Updated pipeline with map_tuple filter.
        """
        return self.compose(filters.map_tuple(*args, handler=handler))

    def slice(self, *args):
        """Slice the data stream.

        This method forwards to the filters.slice function.

        Args:
            *args: Arguments for slicing (start, stop, step).

        Returns:
            FluidInterface: Updated pipeline with slice filter.
        """
        return self.compose(filters.slice(*args))

    def rename(self, **kw):
        """Rename samples based on keyword arguments.

        This method forwards to the filters.rename function.

        Args:
            **kw: Mapping of old names to new names.

        Returns:
            FluidInterface: Updated pipeline with rename filter.
        """
        return self.compose(filters.rename(**kw))

    def rsample(self, p=0.5):
        """Randomly subsample a stream of data.

        This method forwards to the filters.rsample function.

        Args:
            p (float, optional): Probability of keeping each sample. Defaults to 0.5.

        Returns:
            FluidInterface: Updated pipeline with rsample filter.
        """
        return self.compose(filters.rsample(p))

    def rename_keys(self, *args, **kw):
        """Rename keys in samples based on patterns.

        This method forwards to the filters.rename_keys function.

        Args:
            *args: Positional arguments for filters.rename_keys.
            **kw: Keyword arguments for filters.rename_keys.

        Returns:
            FluidInterface: Updated pipeline with rename_keys filter.
        """
        return self.compose(filters.rename_keys(*args, **kw))

    def extract_keys(self, *args, **kw):
        """Extract specific keys from samples.

        This method forwards to the filters.extract_keys function.

        Args:
            *args: Keys or patterns to extract.
            **kw: Additional keyword arguments for filters.extract_keys.

        Returns:
            FluidInterface: Updated pipeline with extract_keys filter.
        """
        return self.compose(filters.extract_keys(*args, **kw))

    def xdecode(self, *args, **kw):
        """Decode data based on file extensions.

        This method forwards to the filters.xdecode function.

        Args:
            *args: Positional arguments for filters.xdecode.
            **kw: Keyword arguments for filters.xdecode.

        Returns:
            FluidInterface: Updated pipeline with xdecode filter.
        """
        return self.compose(filters.xdecode(*args, **kw))

    def mcached(self):
        """Cache samples in memory.

        This method forwards to the filters.Cached class.

        Returns:
            FluidInterface: Updated pipeline with memory caching.
        """
        return self.compose(filters.Cached())

    def lmdb_cached(self, *args, **kw):
        """Cache samples using LMDB.

        This method forwards to the filters.LMDBCached class.

        Args:
            *args: Positional arguments for filters.LMDBCached.
            **kw: Keyword arguments for filters.LMDBCached.

        Returns:
            FluidInterface: Updated pipeline with LMDB caching.
        """
        return self.compose(filters.LMDBCached(*args, **kw))


def check_empty(source):
    """Check if the dataset is empty and yield samples.

    Args:
        source: An iterable source of samples.

    Yields:
        The samples from the source.

    Raises:
        ValueError: If no samples are found in the dataset.
    """
    count = 0
    for sample in source:
        yield sample
        count += 1
    if count == 0:
        raise ValueError(
            "No samples found in dataset; perhaps you have fewer shards than workers.\n"
            + "Turn off using empty_check=False in the WebDataset constructor."
        )


class WebDataset(DataPipeline, FluidInterface):
    """Create a WebDataset pipeline for efficient data loading.

    This class sets up a data pipeline for loading and processing WebDataset-format data.
    It handles URL generation, shard shuffling, caching, and sample grouping.

    Args:
        urls: The source URLs or specifications for the dataset.
        handler: Function to handle exceptions. Defaults to reraise_exception.
        mode: The mode of operation. Defaults to None.
        resampled: Whether to use resampled mode. Defaults to False.
        repeat: Whether to repeat the dataset. Defaults to False.
        shardshuffle: The number of shards to shuffle, or None. Defaults to None.
        cache_size: The size of the cache in bytes. Defaults to -1 (unlimited).
        cache_dir: The directory to use for caching. Defaults to None.
        url_to_name: Function to convert URLs to cache names. Defaults to pipe_cleaner.
        detshuffle: Whether to use deterministic shuffling. Defaults to False.
        nodesplitter: Function to split data by node. Defaults to single_node_only.
        workersplitter: Function to split data by worker. Defaults to split_by_worker.
        select_files: Function to select files from tar archives. Defaults to None.
        rename_files: Function to rename files from tar archives. Defaults to None.
        empty_check: Whether to check for empty datasets. Defaults to True.
        verbose: Whether to print verbose output. Defaults to False.
        seed: Random seed for shuffling. Defaults to None.

    Raises:
        ValueError: If the cache directory does not exist or if the URL type is not supported.
    """

    def __init__(
        self,
        urls,
        handler=reraise_exception,
        mode=None,
        resampled=False,
        repeat=False,
        shardshuffle=None,
        cache_size=-1,
        cache_dir=None,
        url_to_name=cache.pipe_cleaner,
        detshuffle=False,
        nodesplitter=shardlists.single_node_only,
        workersplitter=shardlists.split_by_worker,
        select_files=None,
        rename_files=None,
        empty_check=True,
        verbose=False,
        seed=None,
    ):
        super().__init__()
        if resampled:
            mode = "resampled"
        if mode == "resampled" and shardshuffle not in (False, None):
            warnings.warn("WebDataset(shardshuffle=...) is ignored for resampled datasets")
        elif shardshuffle is None:
            warnings.warn("WebDataset(shardshuffle=...) is None; set explicitly to False or a number")
        if shardshuffle is True:
            warnings.warn("set WebDataset(shardshuffle=...) to a positive integer or 0 or False")
            shardshuffle = 100
        args = SimpleNamespace(**locals())
        self.seed = os.environ.get("WDS_SEED", random.randint(0, 1000000)) if seed is None else seed
        self.update_cache_info(args)

        # first, we add a generator for the urls to used
        # this generates a stream of dict(url=...)
        self.create_url_iterator(args)

        # split by node (for distributed processing)
        if nodesplitter is not None:
            self.append(nodesplitter)

        # split by worker (for DataLoader)
        if workersplitter:
            self.append(workersplitter)

        # add a shard shuffler
        if args.shardshuffle is not None:
            if args.detshuffle:
                self.append(filters.detshuffle(args.shardshuffle, seed=self.seed))
            else:
                self.append(filters.shuffle(args.shardshuffle, seed=self.seed))

        # next, we select a URL opener, either with or without caching
        # this generates a stream of dict(url=..., stream=...)
        if cache_dir is None or cache_size == 0:
            opener = cache.StreamingOpen(handler=handler)
        else:
            opener = cache.FileCache(cache_dir=cache_dir, cache_size=cache_size, handler=handler)
        self.append(opener)

        # now we need to open each stream and read the tar files contained in it
        # this generates a stream of dict(fname=..., data=...) objects
        expander = pipelinefilter(tar_file_expander)
        self.append(expander(handler=handler, select_files=select_files, rename_files=rename_files))

        # finally, the files need to be groups into samples
        # this generates a stream of dict(__key__=..., ...=...) objects
        grouper = pipelinefilter(group_by_keys)
        self.append(grouper(handler=handler))

        # check for empty datasets
        if empty_check:
            self.append(check_empty)

    def update_cache_info(self, args):
        """Update cache information based on arguments and environment variables.

        Args:
            args: A SimpleNamespace object containing the arguments.

        Raises:
            ValueError: If the specified cache directory does not exist.
        """
        args.cache_size = int(os.environ.get("WDS_CACHE_SIZE", args.cache_size))
        args.cache_dir = os.environ.get("WDS_CACHE", args.cache_dir)
        if args.cache_dir is not None:
            args.cache_dir = os.path.expanduser(args.cache_dir)
            if not os.path.exists(args.cache_dir):
                raise ValueError(f"cache directory {args.cache_dir} does not exist")

    def create_url_iterator(self, args):
        """Create an appropriate URL iterator based on the input type.

        This method determines the type of URL input and creates the corresponding
        iterator for the dataset.

        Args:
            args: A SimpleNamespace object containing the arguments.

        Raises:
            ValueError: If the URL type is not supported or implemented.
        """
        urls = args.urls

        # .yaml specification files
        if isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
            with open(args.urls) as stream:
                spec = yaml.safe_load(stream)
            assert "datasets" in spec
            self.append(shardlists.MultiShardSample(spec))
            return

        # .yaml specifications already loaded as dictionaries
        if isinstance(args.urls, dict):
            assert "datasets" in args.urls
            self.append(shardlists.MultiShardSample(args.urls))
            return

        # .json specification files (from wids)
        if isinstance(urls, str) and urls.endswith(".json"):
            raise ValueError("unimplemented")

        # any URL ending in "/" is assumed to be a directory
        if isinstance(urls, str) and urlparse(urls).path.endswith("/"):
            self.append(shardlists.DirectoryShardList(urls, mode=args.mode))
            return

        # the rest is either a shard list or a resampled shard list
        if isinstance(args.urls, str) or utils.is_iterable(args.urls):
            if args.mode == "resampled":
                self.append(shardlists.ResampledShardList(args.urls))
            else:
                self.append(shardlists.SimpleShardList(args.urls))
            return

        raise ValueError(f"cannot handle urls of type {type(args.urls)}")

    def __enter__(self):
        """Enter the runtime context for the WebDataset.

        Returns:
            self: The WebDataset instance.
        """
        return self

    def __exit__(self, *args):
        """Exit the runtime context for the WebDataset.

        Args:
            *args: Exception type, value, and traceback if an exception occurred.
        """
        self.close()


class FluidWrapper(DataPipeline, FluidInterface):
    """Small fluid-interface wrapper for DataPipeline."""

    def __init__(self, initial):
        super().__init__()
        self.append(initial)


class WebLoader(DataPipeline, FluidInterface):
    """A wrapper for DataLoader that adds a fluid interface."""

    def __init__(self, *args, **kw):
        super().__init__(DataLoader(*args, **kw))