import os import random import warnings from types import SimpleNamespace from urllib.parse import urlparse import yaml from . import autodecode, cache, filters, shardlists, utils from .filters import pipelinefilter, reraise_exception from .pipeline import DataPipeline from .pytorch import DataLoader from .tariterators import group_by_keys, tar_file_expander class FluidInterface: def batched(self, batchsize, collation_fn=filters.default_collation_fn, partial=True): """Create batches of the given size. This method forwards to the filters.batched function. Args: batchsize (int): Target batch size. collation_fn (callable, optional): Function to collate samples into a batch. Defaults to filters.default_collation_fn. partial (bool, optional): Whether to return partial batches. Defaults to True. Returns: FluidInterface: Updated pipeline with batched filter. """ return self.compose(filters.batched(batchsize, collation_fn=collation_fn, partial=partial)) def unbatched(self): """Turn batched data back into unbatched data. This method forwards to the filters.unbatched function. Returns: FluidInterface: Updated pipeline with unbatched filter. """ return self.compose(filters.unbatched()) def listed(self, batchsize, partial=True): """Create lists of samples without collation. This method forwards to the filters.batched function with collation_fn set to None. Args: batchsize (int): Target list size. partial (bool, optional): Whether to return partial lists. Defaults to True. Returns: FluidInterface: Updated pipeline with listed filter. """ return self.compose(filters.batched(batchsize=batchsize, collation_fn=None)) def unlisted(self): """Turn listed data back into individual samples. This method forwards to the filters.unlisted function. Returns: FluidInterface: Updated pipeline with unlisted filter. """ return self.compose(filters.unlisted()) def log_keys(self, logfile=None): """Log keys of samples passing through the pipeline. This method forwards to the filters.log_keys function. Args: logfile (str, optional): Path to the log file. If None, logging is disabled. Returns: FluidInterface: Updated pipeline with log_keys filter. """ return self.compose(filters.log_keys(logfile)) def shuffle(self, size, **kw): """Shuffle the data in the stream. This method forwards to the filters.shuffle function if size > 0. Args: size (int): Buffer size for shuffling. **kw: Additional keyword arguments for filters.shuffle. Returns: FluidInterface: Updated pipeline with shuffle filter, or self if size < 1. """ if size < 1: return self else: return self.compose(filters.shuffle(size, **kw)) def map(self, f, handler=reraise_exception): """Apply a function to each sample in the stream. This method forwards to the filters.map function. Args: f (callable): Function to apply to each sample. handler (callable, optional): Exception handler. Defaults to reraise_exception. Returns: FluidInterface: Updated pipeline with map filter. """ return self.compose(filters.map(f, handler=handler)) def decode( self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception, ): """Decode data based on the decoding functions given as arguments. This method creates a decoder using autodecode.Decoder and applies it using filters.map. Args: *args: Decoding functions or strings representing image handlers. pre (callable, optional): Pre-processing function. post (callable, optional): Post-processing function. only (list, optional): List of keys to decode. partial (bool, optional): Whether to allow partial decoding. Defaults to False. handler (callable, optional): Exception handler. Defaults to reraise_exception. Returns: FluidInterface: Updated pipeline with decode filter. """ handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args] decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial) return self.map(decoder, handler=handler) def map_dict(self, handler=reraise_exception, **kw): """Map the entries in a dict sample with individual functions. This method forwards to the filters.map_dict function. Args: handler (callable, optional): Exception handler. Defaults to reraise_exception. **kw: Mapping of keys to functions to apply. Returns: FluidInterface: Updated pipeline with map_dict filter. """ return self.compose(filters.map_dict(handler=handler, **kw)) def select(self, predicate, **kw): """Select samples based on a predicate. This method forwards to the filters.select function. Args: predicate (callable): Function that returns True for samples to keep. **kw: Additional keyword arguments for filters.select. Returns: FluidInterface: Updated pipeline with select filter. """ return self.compose(filters.select(predicate, **kw)) def to_tuple(self, *args, **kw): """Convert dict samples to tuples. This method forwards to the filters.to_tuple function. Args: *args: Keys to extract from the dict. **kw: Additional keyword arguments for filters.to_tuple. Returns: FluidInterface: Updated pipeline with to_tuple filter. """ return self.compose(filters.to_tuple(*args, **kw)) def map_tuple(self, *args, handler=reraise_exception): """Map the entries of a tuple with individual functions. This method forwards to the filters.map_tuple function. Args: *args: Functions to apply to each element of the tuple. handler (callable, optional): Exception handler. Defaults to reraise_exception. Returns: FluidInterface: Updated pipeline with map_tuple filter. """ return self.compose(filters.map_tuple(*args, handler=handler)) def slice(self, *args): """Slice the data stream. This method forwards to the filters.slice function. Args: *args: Arguments for slicing (start, stop, step). Returns: FluidInterface: Updated pipeline with slice filter. """ return self.compose(filters.slice(*args)) def rename(self, **kw): """Rename samples based on keyword arguments. This method forwards to the filters.rename function. Args: **kw: Mapping of old names to new names. Returns: FluidInterface: Updated pipeline with rename filter. """ return self.compose(filters.rename(**kw)) def rsample(self, p=0.5): """Randomly subsample a stream of data. This method forwards to the filters.rsample function. Args: p (float, optional): Probability of keeping each sample. Defaults to 0.5. Returns: FluidInterface: Updated pipeline with rsample filter. """ return self.compose(filters.rsample(p)) def rename_keys(self, *args, **kw): """Rename keys in samples based on patterns. This method forwards to the filters.rename_keys function. Args: *args: Positional arguments for filters.rename_keys. **kw: Keyword arguments for filters.rename_keys. Returns: FluidInterface: Updated pipeline with rename_keys filter. """ return self.compose(filters.rename_keys(*args, **kw)) def extract_keys(self, *args, **kw): """Extract specific keys from samples. This method forwards to the filters.extract_keys function. Args: *args: Keys or patterns to extract. **kw: Additional keyword arguments for filters.extract_keys. Returns: FluidInterface: Updated pipeline with extract_keys filter. """ return self.compose(filters.extract_keys(*args, **kw)) def xdecode(self, *args, **kw): """Decode data based on file extensions. This method forwards to the filters.xdecode function. Args: *args: Positional arguments for filters.xdecode. **kw: Keyword arguments for filters.xdecode. Returns: FluidInterface: Updated pipeline with xdecode filter. """ return self.compose(filters.xdecode(*args, **kw)) def mcached(self): """Cache samples in memory. This method forwards to the filters.Cached class. Returns: FluidInterface: Updated pipeline with memory caching. """ return self.compose(filters.Cached()) def lmdb_cached(self, *args, **kw): """Cache samples using LMDB. This method forwards to the filters.LMDBCached class. Args: *args: Positional arguments for filters.LMDBCached. **kw: Keyword arguments for filters.LMDBCached. Returns: FluidInterface: Updated pipeline with LMDB caching. """ return self.compose(filters.LMDBCached(*args, **kw)) def check_empty(source): """Check if the dataset is empty and yield samples. Args: source: An iterable source of samples. Yields: The samples from the source. Raises: ValueError: If no samples are found in the dataset. """ count = 0 for sample in source: yield sample count += 1 if count == 0: raise ValueError( "No samples found in dataset; perhaps you have fewer shards than workers.\n" + "Turn off using empty_check=False in the WebDataset constructor." ) class WebDataset(DataPipeline, FluidInterface): """Create a WebDataset pipeline for efficient data loading. This class sets up a data pipeline for loading and processing WebDataset-format data. It handles URL generation, shard shuffling, caching, and sample grouping. Args: urls: The source URLs or specifications for the dataset. handler: Function to handle exceptions. Defaults to reraise_exception. mode: The mode of operation. Defaults to None. resampled: Whether to use resampled mode. Defaults to False. repeat: Whether to repeat the dataset. Defaults to False. shardshuffle: The number of shards to shuffle, or None. Defaults to None. cache_size: The size of the cache in bytes. Defaults to -1 (unlimited). cache_dir: The directory to use for caching. Defaults to None. url_to_name: Function to convert URLs to cache names. Defaults to pipe_cleaner. detshuffle: Whether to use deterministic shuffling. Defaults to False. nodesplitter: Function to split data by node. Defaults to single_node_only. workersplitter: Function to split data by worker. Defaults to split_by_worker. select_files: Function to select files from tar archives. Defaults to None. rename_files: Function to rename files from tar archives. Defaults to None. empty_check: Whether to check for empty datasets. Defaults to True. verbose: Whether to print verbose output. Defaults to False. seed: Random seed for shuffling. Defaults to None. Raises: ValueError: If the cache directory does not exist or if the URL type is not supported. """ def __init__( self, urls, handler=reraise_exception, mode=None, resampled=False, repeat=False, shardshuffle=None, cache_size=-1, cache_dir=None, url_to_name=cache.pipe_cleaner, detshuffle=False, nodesplitter=shardlists.single_node_only, workersplitter=shardlists.split_by_worker, select_files=None, rename_files=None, empty_check=True, verbose=False, seed=None, ): super().__init__() if resampled: mode = "resampled" if mode == "resampled" and shardshuffle not in (False, None): warnings.warn("WebDataset(shardshuffle=...) is ignored for resampled datasets") elif shardshuffle is None: warnings.warn("WebDataset(shardshuffle=...) is None; set explicitly to False or a number") if shardshuffle is True: warnings.warn("set WebDataset(shardshuffle=...) to a positive integer or 0 or False") shardshuffle = 100 args = SimpleNamespace(**locals()) self.seed = os.environ.get("WDS_SEED", random.randint(0, 1000000)) if seed is None else seed self.update_cache_info(args) # first, we add a generator for the urls to used # this generates a stream of dict(url=...) self.create_url_iterator(args) # split by node (for distributed processing) if nodesplitter is not None: self.append(nodesplitter) # split by worker (for DataLoader) if workersplitter: self.append(workersplitter) # add a shard shuffler if args.shardshuffle is not None: if args.detshuffle: self.append(filters.detshuffle(args.shardshuffle, seed=self.seed)) else: self.append(filters.shuffle(args.shardshuffle, seed=self.seed)) # next, we select a URL opener, either with or without caching # this generates a stream of dict(url=..., stream=...) if cache_dir is None or cache_size == 0: opener = cache.StreamingOpen(handler=handler) else: opener = cache.FileCache(cache_dir=cache_dir, cache_size=cache_size, handler=handler) self.append(opener) # now we need to open each stream and read the tar files contained in it # this generates a stream of dict(fname=..., data=...) objects expander = pipelinefilter(tar_file_expander) self.append(expander(handler=handler, select_files=select_files, rename_files=rename_files)) # finally, the files need to be groups into samples # this generates a stream of dict(__key__=..., ...=...) objects grouper = pipelinefilter(group_by_keys) self.append(grouper(handler=handler)) # check for empty datasets if empty_check: self.append(check_empty) def update_cache_info(self, args): """Update cache information based on arguments and environment variables. Args: args: A SimpleNamespace object containing the arguments. Raises: ValueError: If the specified cache directory does not exist. """ args.cache_size = int(os.environ.get("WDS_CACHE_SIZE", args.cache_size)) args.cache_dir = os.environ.get("WDS_CACHE", args.cache_dir) if args.cache_dir is not None: args.cache_dir = os.path.expanduser(args.cache_dir) if not os.path.exists(args.cache_dir): raise ValueError(f"cache directory {args.cache_dir} does not exist") def create_url_iterator(self, args): """Create an appropriate URL iterator based on the input type. This method determines the type of URL input and creates the corresponding iterator for the dataset. Args: args: A SimpleNamespace object containing the arguments. Raises: ValueError: If the URL type is not supported or implemented. """ urls = args.urls # .yaml specification files if isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")): with open(args.urls) as stream: spec = yaml.safe_load(stream) assert "datasets" in spec self.append(shardlists.MultiShardSample(spec)) return # .yaml specifications already loaded as dictionaries if isinstance(args.urls, dict): assert "datasets" in args.urls self.append(shardlists.MultiShardSample(args.urls)) return # .json specification files (from wids) if isinstance(urls, str) and urls.endswith(".json"): raise ValueError("unimplemented") # any URL ending in "/" is assumed to be a directory if isinstance(urls, str) and urlparse(urls).path.endswith("/"): self.append(shardlists.DirectoryShardList(urls, mode=args.mode)) return # the rest is either a shard list or a resampled shard list if isinstance(args.urls, str) or utils.is_iterable(args.urls): if args.mode == "resampled": self.append(shardlists.ResampledShardList(args.urls)) else: self.append(shardlists.SimpleShardList(args.urls)) return raise ValueError(f"cannot handle urls of type {type(args.urls)}") def __enter__(self): """Enter the runtime context for the WebDataset. Returns: self: The WebDataset instance. """ return self def __exit__(self, *args): """Exit the runtime context for the WebDataset. Args: *args: Exception type, value, and traceback if an exception occurred. """ self.close() class FluidWrapper(DataPipeline, FluidInterface): """Small fluid-interface wrapper for DataPipeline.""" def __init__(self, initial): super().__init__() self.append(initial) class WebLoader(DataPipeline, FluidInterface): """A wrapper for DataLoader that adds a fluid interface.""" def __init__(self, *args, **kw): super().__init__(DataLoader(*args, **kw))