|
import os |
|
import random |
|
import warnings |
|
from types import SimpleNamespace |
|
from urllib.parse import urlparse |
|
|
|
import yaml |
|
|
|
from . import autodecode, cache, filters, shardlists, utils |
|
from .filters import pipelinefilter, reraise_exception |
|
from .pipeline import DataPipeline |
|
from .pytorch import DataLoader |
|
from .tariterators import group_by_keys, tar_file_expander |
|
|
|
|
|
class FluidInterface: |
|
def batched(self, batchsize, collation_fn=filters.default_collation_fn, partial=True): |
|
"""Create batches of the given size. |
|
|
|
This method forwards to the filters.batched function. |
|
|
|
Args: |
|
batchsize (int): Target batch size. |
|
collation_fn (callable, optional): Function to collate samples into a batch. |
|
Defaults to filters.default_collation_fn. |
|
partial (bool, optional): Whether to return partial batches. Defaults to True. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with batched filter. |
|
""" |
|
return self.compose(filters.batched(batchsize, collation_fn=collation_fn, partial=partial)) |
|
|
|
def unbatched(self): |
|
"""Turn batched data back into unbatched data. |
|
|
|
This method forwards to the filters.unbatched function. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with unbatched filter. |
|
""" |
|
return self.compose(filters.unbatched()) |
|
|
|
def listed(self, batchsize, partial=True): |
|
"""Create lists of samples without collation. |
|
|
|
This method forwards to the filters.batched function with collation_fn set to None. |
|
|
|
Args: |
|
batchsize (int): Target list size. |
|
partial (bool, optional): Whether to return partial lists. Defaults to True. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with listed filter. |
|
""" |
|
return self.compose(filters.batched(batchsize=batchsize, collation_fn=None)) |
|
|
|
def unlisted(self): |
|
"""Turn listed data back into individual samples. |
|
|
|
This method forwards to the filters.unlisted function. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with unlisted filter. |
|
""" |
|
return self.compose(filters.unlisted()) |
|
|
|
def log_keys(self, logfile=None): |
|
"""Log keys of samples passing through the pipeline. |
|
|
|
This method forwards to the filters.log_keys function. |
|
|
|
Args: |
|
logfile (str, optional): Path to the log file. If None, logging is disabled. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with log_keys filter. |
|
""" |
|
return self.compose(filters.log_keys(logfile)) |
|
|
|
def shuffle(self, size, **kw): |
|
"""Shuffle the data in the stream. |
|
|
|
This method forwards to the filters.shuffle function if size > 0. |
|
|
|
Args: |
|
size (int): Buffer size for shuffling. |
|
**kw: Additional keyword arguments for filters.shuffle. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with shuffle filter, or self if size < 1. |
|
""" |
|
if size < 1: |
|
return self |
|
else: |
|
return self.compose(filters.shuffle(size, **kw)) |
|
|
|
def map(self, f, handler=reraise_exception): |
|
"""Apply a function to each sample in the stream. |
|
|
|
This method forwards to the filters.map function. |
|
|
|
Args: |
|
f (callable): Function to apply to each sample. |
|
handler (callable, optional): Exception handler. Defaults to reraise_exception. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with map filter. |
|
""" |
|
return self.compose(filters.map(f, handler=handler)) |
|
|
|
def decode( |
|
self, |
|
*args, |
|
pre=None, |
|
post=None, |
|
only=None, |
|
partial=False, |
|
handler=reraise_exception, |
|
): |
|
"""Decode data based on the decoding functions given as arguments. |
|
|
|
This method creates a decoder using autodecode.Decoder and applies it using filters.map. |
|
|
|
Args: |
|
*args: Decoding functions or strings representing image handlers. |
|
pre (callable, optional): Pre-processing function. |
|
post (callable, optional): Post-processing function. |
|
only (list, optional): List of keys to decode. |
|
partial (bool, optional): Whether to allow partial decoding. Defaults to False. |
|
handler (callable, optional): Exception handler. Defaults to reraise_exception. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with decode filter. |
|
""" |
|
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args] |
|
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial) |
|
return self.map(decoder, handler=handler) |
|
|
|
def map_dict(self, handler=reraise_exception, **kw): |
|
"""Map the entries in a dict sample with individual functions. |
|
|
|
This method forwards to the filters.map_dict function. |
|
|
|
Args: |
|
handler (callable, optional): Exception handler. Defaults to reraise_exception. |
|
**kw: Mapping of keys to functions to apply. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with map_dict filter. |
|
""" |
|
return self.compose(filters.map_dict(handler=handler, **kw)) |
|
|
|
def select(self, predicate, **kw): |
|
"""Select samples based on a predicate. |
|
|
|
This method forwards to the filters.select function. |
|
|
|
Args: |
|
predicate (callable): Function that returns True for samples to keep. |
|
**kw: Additional keyword arguments for filters.select. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with select filter. |
|
""" |
|
return self.compose(filters.select(predicate, **kw)) |
|
|
|
def to_tuple(self, *args, **kw): |
|
"""Convert dict samples to tuples. |
|
|
|
This method forwards to the filters.to_tuple function. |
|
|
|
Args: |
|
*args: Keys to extract from the dict. |
|
**kw: Additional keyword arguments for filters.to_tuple. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with to_tuple filter. |
|
""" |
|
return self.compose(filters.to_tuple(*args, **kw)) |
|
|
|
def map_tuple(self, *args, handler=reraise_exception): |
|
"""Map the entries of a tuple with individual functions. |
|
|
|
This method forwards to the filters.map_tuple function. |
|
|
|
Args: |
|
*args: Functions to apply to each element of the tuple. |
|
handler (callable, optional): Exception handler. Defaults to reraise_exception. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with map_tuple filter. |
|
""" |
|
return self.compose(filters.map_tuple(*args, handler=handler)) |
|
|
|
def slice(self, *args): |
|
"""Slice the data stream. |
|
|
|
This method forwards to the filters.slice function. |
|
|
|
Args: |
|
*args: Arguments for slicing (start, stop, step). |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with slice filter. |
|
""" |
|
return self.compose(filters.slice(*args)) |
|
|
|
def rename(self, **kw): |
|
"""Rename samples based on keyword arguments. |
|
|
|
This method forwards to the filters.rename function. |
|
|
|
Args: |
|
**kw: Mapping of old names to new names. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with rename filter. |
|
""" |
|
return self.compose(filters.rename(**kw)) |
|
|
|
def rsample(self, p=0.5): |
|
"""Randomly subsample a stream of data. |
|
|
|
This method forwards to the filters.rsample function. |
|
|
|
Args: |
|
p (float, optional): Probability of keeping each sample. Defaults to 0.5. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with rsample filter. |
|
""" |
|
return self.compose(filters.rsample(p)) |
|
|
|
def rename_keys(self, *args, **kw): |
|
"""Rename keys in samples based on patterns. |
|
|
|
This method forwards to the filters.rename_keys function. |
|
|
|
Args: |
|
*args: Positional arguments for filters.rename_keys. |
|
**kw: Keyword arguments for filters.rename_keys. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with rename_keys filter. |
|
""" |
|
return self.compose(filters.rename_keys(*args, **kw)) |
|
|
|
def extract_keys(self, *args, **kw): |
|
"""Extract specific keys from samples. |
|
|
|
This method forwards to the filters.extract_keys function. |
|
|
|
Args: |
|
*args: Keys or patterns to extract. |
|
**kw: Additional keyword arguments for filters.extract_keys. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with extract_keys filter. |
|
""" |
|
return self.compose(filters.extract_keys(*args, **kw)) |
|
|
|
def xdecode(self, *args, **kw): |
|
"""Decode data based on file extensions. |
|
|
|
This method forwards to the filters.xdecode function. |
|
|
|
Args: |
|
*args: Positional arguments for filters.xdecode. |
|
**kw: Keyword arguments for filters.xdecode. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with xdecode filter. |
|
""" |
|
return self.compose(filters.xdecode(*args, **kw)) |
|
|
|
def mcached(self): |
|
"""Cache samples in memory. |
|
|
|
This method forwards to the filters.Cached class. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with memory caching. |
|
""" |
|
return self.compose(filters.Cached()) |
|
|
|
def lmdb_cached(self, *args, **kw): |
|
"""Cache samples using LMDB. |
|
|
|
This method forwards to the filters.LMDBCached class. |
|
|
|
Args: |
|
*args: Positional arguments for filters.LMDBCached. |
|
**kw: Keyword arguments for filters.LMDBCached. |
|
|
|
Returns: |
|
FluidInterface: Updated pipeline with LMDB caching. |
|
""" |
|
return self.compose(filters.LMDBCached(*args, **kw)) |
|
|
|
|
|
def check_empty(source): |
|
"""Check if the dataset is empty and yield samples. |
|
|
|
Args: |
|
source: An iterable source of samples. |
|
|
|
Yields: |
|
The samples from the source. |
|
|
|
Raises: |
|
ValueError: If no samples are found in the dataset. |
|
""" |
|
count = 0 |
|
for sample in source: |
|
yield sample |
|
count += 1 |
|
if count == 0: |
|
raise ValueError( |
|
"No samples found in dataset; perhaps you have fewer shards than workers.\n" |
|
+ "Turn off using empty_check=False in the WebDataset constructor." |
|
) |
|
|
|
|
|
class WebDataset(DataPipeline, FluidInterface): |
|
"""Create a WebDataset pipeline for efficient data loading. |
|
|
|
This class sets up a data pipeline for loading and processing WebDataset-format data. |
|
It handles URL generation, shard shuffling, caching, and sample grouping. |
|
|
|
Args: |
|
urls: The source URLs or specifications for the dataset. |
|
handler: Function to handle exceptions. Defaults to reraise_exception. |
|
mode: The mode of operation. Defaults to None. |
|
resampled: Whether to use resampled mode. Defaults to False. |
|
repeat: Whether to repeat the dataset. Defaults to False. |
|
shardshuffle: The number of shards to shuffle, or None. Defaults to None. |
|
cache_size: The size of the cache in bytes. Defaults to -1 (unlimited). |
|
cache_dir: The directory to use for caching. Defaults to None. |
|
url_to_name: Function to convert URLs to cache names. Defaults to pipe_cleaner. |
|
detshuffle: Whether to use deterministic shuffling. Defaults to False. |
|
nodesplitter: Function to split data by node. Defaults to single_node_only. |
|
workersplitter: Function to split data by worker. Defaults to split_by_worker. |
|
select_files: Function to select files from tar archives. Defaults to None. |
|
rename_files: Function to rename files from tar archives. Defaults to None. |
|
empty_check: Whether to check for empty datasets. Defaults to True. |
|
verbose: Whether to print verbose output. Defaults to False. |
|
seed: Random seed for shuffling. Defaults to None. |
|
|
|
Raises: |
|
ValueError: If the cache directory does not exist or if the URL type is not supported. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
urls, |
|
handler=reraise_exception, |
|
mode=None, |
|
resampled=False, |
|
repeat=False, |
|
shardshuffle=None, |
|
cache_size=-1, |
|
cache_dir=None, |
|
url_to_name=cache.pipe_cleaner, |
|
detshuffle=False, |
|
nodesplitter=shardlists.single_node_only, |
|
workersplitter=shardlists.split_by_worker, |
|
select_files=None, |
|
rename_files=None, |
|
empty_check=True, |
|
verbose=False, |
|
seed=None, |
|
): |
|
super().__init__() |
|
if resampled: |
|
mode = "resampled" |
|
if mode == "resampled" and shardshuffle not in (False, None): |
|
warnings.warn("WebDataset(shardshuffle=...) is ignored for resampled datasets") |
|
elif shardshuffle is None: |
|
warnings.warn("WebDataset(shardshuffle=...) is None; set explicitly to False or a number") |
|
if shardshuffle is True: |
|
warnings.warn("set WebDataset(shardshuffle=...) to a positive integer or 0 or False") |
|
shardshuffle = 100 |
|
args = SimpleNamespace(**locals()) |
|
self.seed = os.environ.get("WDS_SEED", random.randint(0, 1000000)) if seed is None else seed |
|
self.update_cache_info(args) |
|
|
|
|
|
|
|
self.create_url_iterator(args) |
|
|
|
|
|
if nodesplitter is not None: |
|
self.append(nodesplitter) |
|
|
|
|
|
if workersplitter: |
|
self.append(workersplitter) |
|
|
|
|
|
if args.shardshuffle is not None: |
|
if args.detshuffle: |
|
self.append(filters.detshuffle(args.shardshuffle, seed=self.seed)) |
|
else: |
|
self.append(filters.shuffle(args.shardshuffle, seed=self.seed)) |
|
|
|
|
|
|
|
if cache_dir is None or cache_size == 0: |
|
opener = cache.StreamingOpen(handler=handler) |
|
else: |
|
opener = cache.FileCache(cache_dir=cache_dir, cache_size=cache_size, handler=handler) |
|
self.append(opener) |
|
|
|
|
|
|
|
expander = pipelinefilter(tar_file_expander) |
|
self.append(expander(handler=handler, select_files=select_files, rename_files=rename_files)) |
|
|
|
|
|
|
|
grouper = pipelinefilter(group_by_keys) |
|
self.append(grouper(handler=handler)) |
|
|
|
|
|
if empty_check: |
|
self.append(check_empty) |
|
|
|
def update_cache_info(self, args): |
|
"""Update cache information based on arguments and environment variables. |
|
|
|
Args: |
|
args: A SimpleNamespace object containing the arguments. |
|
|
|
Raises: |
|
ValueError: If the specified cache directory does not exist. |
|
""" |
|
args.cache_size = int(os.environ.get("WDS_CACHE_SIZE", args.cache_size)) |
|
args.cache_dir = os.environ.get("WDS_CACHE", args.cache_dir) |
|
if args.cache_dir is not None: |
|
args.cache_dir = os.path.expanduser(args.cache_dir) |
|
if not os.path.exists(args.cache_dir): |
|
raise ValueError(f"cache directory {args.cache_dir} does not exist") |
|
|
|
def create_url_iterator(self, args): |
|
"""Create an appropriate URL iterator based on the input type. |
|
|
|
This method determines the type of URL input and creates the corresponding |
|
iterator for the dataset. |
|
|
|
Args: |
|
args: A SimpleNamespace object containing the arguments. |
|
|
|
Raises: |
|
ValueError: If the URL type is not supported or implemented. |
|
""" |
|
urls = args.urls |
|
|
|
|
|
if isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")): |
|
with open(args.urls) as stream: |
|
spec = yaml.safe_load(stream) |
|
assert "datasets" in spec |
|
self.append(shardlists.MultiShardSample(spec)) |
|
return |
|
|
|
|
|
if isinstance(args.urls, dict): |
|
assert "datasets" in args.urls |
|
self.append(shardlists.MultiShardSample(args.urls)) |
|
return |
|
|
|
|
|
if isinstance(urls, str) and urls.endswith(".json"): |
|
raise ValueError("unimplemented") |
|
|
|
|
|
if isinstance(urls, str) and urlparse(urls).path.endswith("/"): |
|
self.append(shardlists.DirectoryShardList(urls, mode=args.mode)) |
|
return |
|
|
|
|
|
if isinstance(args.urls, str) or utils.is_iterable(args.urls): |
|
if args.mode == "resampled": |
|
self.append(shardlists.ResampledShardList(args.urls)) |
|
else: |
|
self.append(shardlists.SimpleShardList(args.urls)) |
|
return |
|
|
|
raise ValueError(f"cannot handle urls of type {type(args.urls)}") |
|
|
|
def __enter__(self): |
|
"""Enter the runtime context for the WebDataset. |
|
|
|
Returns: |
|
self: The WebDataset instance. |
|
""" |
|
return self |
|
|
|
def __exit__(self, *args): |
|
"""Exit the runtime context for the WebDataset. |
|
|
|
Args: |
|
*args: Exception type, value, and traceback if an exception occurred. |
|
""" |
|
self.close() |
|
|
|
|
|
class FluidWrapper(DataPipeline, FluidInterface): |
|
"""Small fluid-interface wrapper for DataPipeline.""" |
|
|
|
def __init__(self, initial): |
|
super().__init__() |
|
self.append(initial) |
|
|
|
|
|
class WebLoader(DataPipeline, FluidInterface): |
|
"""A wrapper for DataLoader that adds a fluid interface.""" |
|
|
|
def __init__(self, *args, **kw): |
|
super().__init__(DataLoader(*args, **kw)) |
|
|