jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import os
import random
import warnings
from types import SimpleNamespace
from urllib.parse import urlparse
import yaml
from . import autodecode, cache, filters, shardlists, utils
from .filters import pipelinefilter, reraise_exception
from .pipeline import DataPipeline
from .pytorch import DataLoader
from .tariterators import group_by_keys, tar_file_expander
class FluidInterface:
def batched(self, batchsize, collation_fn=filters.default_collation_fn, partial=True):
"""Create batches of the given size.
This method forwards to the filters.batched function.
Args:
batchsize (int): Target batch size.
collation_fn (callable, optional): Function to collate samples into a batch.
Defaults to filters.default_collation_fn.
partial (bool, optional): Whether to return partial batches. Defaults to True.
Returns:
FluidInterface: Updated pipeline with batched filter.
"""
return self.compose(filters.batched(batchsize, collation_fn=collation_fn, partial=partial))
def unbatched(self):
"""Turn batched data back into unbatched data.
This method forwards to the filters.unbatched function.
Returns:
FluidInterface: Updated pipeline with unbatched filter.
"""
return self.compose(filters.unbatched())
def listed(self, batchsize, partial=True):
"""Create lists of samples without collation.
This method forwards to the filters.batched function with collation_fn set to None.
Args:
batchsize (int): Target list size.
partial (bool, optional): Whether to return partial lists. Defaults to True.
Returns:
FluidInterface: Updated pipeline with listed filter.
"""
return self.compose(filters.batched(batchsize=batchsize, collation_fn=None))
def unlisted(self):
"""Turn listed data back into individual samples.
This method forwards to the filters.unlisted function.
Returns:
FluidInterface: Updated pipeline with unlisted filter.
"""
return self.compose(filters.unlisted())
def log_keys(self, logfile=None):
"""Log keys of samples passing through the pipeline.
This method forwards to the filters.log_keys function.
Args:
logfile (str, optional): Path to the log file. If None, logging is disabled.
Returns:
FluidInterface: Updated pipeline with log_keys filter.
"""
return self.compose(filters.log_keys(logfile))
def shuffle(self, size, **kw):
"""Shuffle the data in the stream.
This method forwards to the filters.shuffle function if size > 0.
Args:
size (int): Buffer size for shuffling.
**kw: Additional keyword arguments for filters.shuffle.
Returns:
FluidInterface: Updated pipeline with shuffle filter, or self if size < 1.
"""
if size < 1:
return self
else:
return self.compose(filters.shuffle(size, **kw))
def map(self, f, handler=reraise_exception):
"""Apply a function to each sample in the stream.
This method forwards to the filters.map function.
Args:
f (callable): Function to apply to each sample.
handler (callable, optional): Exception handler. Defaults to reraise_exception.
Returns:
FluidInterface: Updated pipeline with map filter.
"""
return self.compose(filters.map(f, handler=handler))
def decode(
self,
*args,
pre=None,
post=None,
only=None,
partial=False,
handler=reraise_exception,
):
"""Decode data based on the decoding functions given as arguments.
This method creates a decoder using autodecode.Decoder and applies it using filters.map.
Args:
*args: Decoding functions or strings representing image handlers.
pre (callable, optional): Pre-processing function.
post (callable, optional): Post-processing function.
only (list, optional): List of keys to decode.
partial (bool, optional): Whether to allow partial decoding. Defaults to False.
handler (callable, optional): Exception handler. Defaults to reraise_exception.
Returns:
FluidInterface: Updated pipeline with decode filter.
"""
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
return self.map(decoder, handler=handler)
def map_dict(self, handler=reraise_exception, **kw):
"""Map the entries in a dict sample with individual functions.
This method forwards to the filters.map_dict function.
Args:
handler (callable, optional): Exception handler. Defaults to reraise_exception.
**kw: Mapping of keys to functions to apply.
Returns:
FluidInterface: Updated pipeline with map_dict filter.
"""
return self.compose(filters.map_dict(handler=handler, **kw))
def select(self, predicate, **kw):
"""Select samples based on a predicate.
This method forwards to the filters.select function.
Args:
predicate (callable): Function that returns True for samples to keep.
**kw: Additional keyword arguments for filters.select.
Returns:
FluidInterface: Updated pipeline with select filter.
"""
return self.compose(filters.select(predicate, **kw))
def to_tuple(self, *args, **kw):
"""Convert dict samples to tuples.
This method forwards to the filters.to_tuple function.
Args:
*args: Keys to extract from the dict.
**kw: Additional keyword arguments for filters.to_tuple.
Returns:
FluidInterface: Updated pipeline with to_tuple filter.
"""
return self.compose(filters.to_tuple(*args, **kw))
def map_tuple(self, *args, handler=reraise_exception):
"""Map the entries of a tuple with individual functions.
This method forwards to the filters.map_tuple function.
Args:
*args: Functions to apply to each element of the tuple.
handler (callable, optional): Exception handler. Defaults to reraise_exception.
Returns:
FluidInterface: Updated pipeline with map_tuple filter.
"""
return self.compose(filters.map_tuple(*args, handler=handler))
def slice(self, *args):
"""Slice the data stream.
This method forwards to the filters.slice function.
Args:
*args: Arguments for slicing (start, stop, step).
Returns:
FluidInterface: Updated pipeline with slice filter.
"""
return self.compose(filters.slice(*args))
def rename(self, **kw):
"""Rename samples based on keyword arguments.
This method forwards to the filters.rename function.
Args:
**kw: Mapping of old names to new names.
Returns:
FluidInterface: Updated pipeline with rename filter.
"""
return self.compose(filters.rename(**kw))
def rsample(self, p=0.5):
"""Randomly subsample a stream of data.
This method forwards to the filters.rsample function.
Args:
p (float, optional): Probability of keeping each sample. Defaults to 0.5.
Returns:
FluidInterface: Updated pipeline with rsample filter.
"""
return self.compose(filters.rsample(p))
def rename_keys(self, *args, **kw):
"""Rename keys in samples based on patterns.
This method forwards to the filters.rename_keys function.
Args:
*args: Positional arguments for filters.rename_keys.
**kw: Keyword arguments for filters.rename_keys.
Returns:
FluidInterface: Updated pipeline with rename_keys filter.
"""
return self.compose(filters.rename_keys(*args, **kw))
def extract_keys(self, *args, **kw):
"""Extract specific keys from samples.
This method forwards to the filters.extract_keys function.
Args:
*args: Keys or patterns to extract.
**kw: Additional keyword arguments for filters.extract_keys.
Returns:
FluidInterface: Updated pipeline with extract_keys filter.
"""
return self.compose(filters.extract_keys(*args, **kw))
def xdecode(self, *args, **kw):
"""Decode data based on file extensions.
This method forwards to the filters.xdecode function.
Args:
*args: Positional arguments for filters.xdecode.
**kw: Keyword arguments for filters.xdecode.
Returns:
FluidInterface: Updated pipeline with xdecode filter.
"""
return self.compose(filters.xdecode(*args, **kw))
def mcached(self):
"""Cache samples in memory.
This method forwards to the filters.Cached class.
Returns:
FluidInterface: Updated pipeline with memory caching.
"""
return self.compose(filters.Cached())
def lmdb_cached(self, *args, **kw):
"""Cache samples using LMDB.
This method forwards to the filters.LMDBCached class.
Args:
*args: Positional arguments for filters.LMDBCached.
**kw: Keyword arguments for filters.LMDBCached.
Returns:
FluidInterface: Updated pipeline with LMDB caching.
"""
return self.compose(filters.LMDBCached(*args, **kw))
def check_empty(source):
"""Check if the dataset is empty and yield samples.
Args:
source: An iterable source of samples.
Yields:
The samples from the source.
Raises:
ValueError: If no samples are found in the dataset.
"""
count = 0
for sample in source:
yield sample
count += 1
if count == 0:
raise ValueError(
"No samples found in dataset; perhaps you have fewer shards than workers.\n"
+ "Turn off using empty_check=False in the WebDataset constructor."
)
class WebDataset(DataPipeline, FluidInterface):
"""Create a WebDataset pipeline for efficient data loading.
This class sets up a data pipeline for loading and processing WebDataset-format data.
It handles URL generation, shard shuffling, caching, and sample grouping.
Args:
urls: The source URLs or specifications for the dataset.
handler: Function to handle exceptions. Defaults to reraise_exception.
mode: The mode of operation. Defaults to None.
resampled: Whether to use resampled mode. Defaults to False.
repeat: Whether to repeat the dataset. Defaults to False.
shardshuffle: The number of shards to shuffle, or None. Defaults to None.
cache_size: The size of the cache in bytes. Defaults to -1 (unlimited).
cache_dir: The directory to use for caching. Defaults to None.
url_to_name: Function to convert URLs to cache names. Defaults to pipe_cleaner.
detshuffle: Whether to use deterministic shuffling. Defaults to False.
nodesplitter: Function to split data by node. Defaults to single_node_only.
workersplitter: Function to split data by worker. Defaults to split_by_worker.
select_files: Function to select files from tar archives. Defaults to None.
rename_files: Function to rename files from tar archives. Defaults to None.
empty_check: Whether to check for empty datasets. Defaults to True.
verbose: Whether to print verbose output. Defaults to False.
seed: Random seed for shuffling. Defaults to None.
Raises:
ValueError: If the cache directory does not exist or if the URL type is not supported.
"""
def __init__(
self,
urls,
handler=reraise_exception,
mode=None,
resampled=False,
repeat=False,
shardshuffle=None,
cache_size=-1,
cache_dir=None,
url_to_name=cache.pipe_cleaner,
detshuffle=False,
nodesplitter=shardlists.single_node_only,
workersplitter=shardlists.split_by_worker,
select_files=None,
rename_files=None,
empty_check=True,
verbose=False,
seed=None,
):
super().__init__()
if resampled:
mode = "resampled"
if mode == "resampled" and shardshuffle not in (False, None):
warnings.warn("WebDataset(shardshuffle=...) is ignored for resampled datasets")
elif shardshuffle is None:
warnings.warn("WebDataset(shardshuffle=...) is None; set explicitly to False or a number")
if shardshuffle is True:
warnings.warn("set WebDataset(shardshuffle=...) to a positive integer or 0 or False")
shardshuffle = 100
args = SimpleNamespace(**locals())
self.seed = os.environ.get("WDS_SEED", random.randint(0, 1000000)) if seed is None else seed
self.update_cache_info(args)
# first, we add a generator for the urls to used
# this generates a stream of dict(url=...)
self.create_url_iterator(args)
# split by node (for distributed processing)
if nodesplitter is not None:
self.append(nodesplitter)
# split by worker (for DataLoader)
if workersplitter:
self.append(workersplitter)
# add a shard shuffler
if args.shardshuffle is not None:
if args.detshuffle:
self.append(filters.detshuffle(args.shardshuffle, seed=self.seed))
else:
self.append(filters.shuffle(args.shardshuffle, seed=self.seed))
# next, we select a URL opener, either with or without caching
# this generates a stream of dict(url=..., stream=...)
if cache_dir is None or cache_size == 0:
opener = cache.StreamingOpen(handler=handler)
else:
opener = cache.FileCache(cache_dir=cache_dir, cache_size=cache_size, handler=handler)
self.append(opener)
# now we need to open each stream and read the tar files contained in it
# this generates a stream of dict(fname=..., data=...) objects
expander = pipelinefilter(tar_file_expander)
self.append(expander(handler=handler, select_files=select_files, rename_files=rename_files))
# finally, the files need to be groups into samples
# this generates a stream of dict(__key__=..., ...=...) objects
grouper = pipelinefilter(group_by_keys)
self.append(grouper(handler=handler))
# check for empty datasets
if empty_check:
self.append(check_empty)
def update_cache_info(self, args):
"""Update cache information based on arguments and environment variables.
Args:
args: A SimpleNamespace object containing the arguments.
Raises:
ValueError: If the specified cache directory does not exist.
"""
args.cache_size = int(os.environ.get("WDS_CACHE_SIZE", args.cache_size))
args.cache_dir = os.environ.get("WDS_CACHE", args.cache_dir)
if args.cache_dir is not None:
args.cache_dir = os.path.expanduser(args.cache_dir)
if not os.path.exists(args.cache_dir):
raise ValueError(f"cache directory {args.cache_dir} does not exist")
def create_url_iterator(self, args):
"""Create an appropriate URL iterator based on the input type.
This method determines the type of URL input and creates the corresponding
iterator for the dataset.
Args:
args: A SimpleNamespace object containing the arguments.
Raises:
ValueError: If the URL type is not supported or implemented.
"""
urls = args.urls
# .yaml specification files
if isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
with open(args.urls) as stream:
spec = yaml.safe_load(stream)
assert "datasets" in spec
self.append(shardlists.MultiShardSample(spec))
return
# .yaml specifications already loaded as dictionaries
if isinstance(args.urls, dict):
assert "datasets" in args.urls
self.append(shardlists.MultiShardSample(args.urls))
return
# .json specification files (from wids)
if isinstance(urls, str) and urls.endswith(".json"):
raise ValueError("unimplemented")
# any URL ending in "/" is assumed to be a directory
if isinstance(urls, str) and urlparse(urls).path.endswith("/"):
self.append(shardlists.DirectoryShardList(urls, mode=args.mode))
return
# the rest is either a shard list or a resampled shard list
if isinstance(args.urls, str) or utils.is_iterable(args.urls):
if args.mode == "resampled":
self.append(shardlists.ResampledShardList(args.urls))
else:
self.append(shardlists.SimpleShardList(args.urls))
return
raise ValueError(f"cannot handle urls of type {type(args.urls)}")
def __enter__(self):
"""Enter the runtime context for the WebDataset.
Returns:
self: The WebDataset instance.
"""
return self
def __exit__(self, *args):
"""Exit the runtime context for the WebDataset.
Args:
*args: Exception type, value, and traceback if an exception occurred.
"""
self.close()
class FluidWrapper(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(self, initial):
super().__init__()
self.append(initial)
class WebLoader(DataPipeline, FluidInterface):
"""A wrapper for DataLoader that adds a fluid interface."""
def __init__(self, *args, **kw):
super().__init__(DataLoader(*args, **kw))