jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
from . import utils
from .pytorch import IterableDataset
from .utils import PipelineStage
class MockDataset(IterableDataset):
"""Create a mock dataset for performance testing and unit testing.
Args:
sample: The sample to be returned repeatedly.
length (int): The length of the mock dataset.
"""
def __init__(self, sample, length):
self.sample = sample
self.length = length
def __iter__(self):
"""Yield samples from the mock dataset.
Returns:
Iterator: An iterator that yields the same sample repeatedly.
"""
for _ in range(self.length):
yield self.sample
class repeatedly(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset.
Args:
source: The source dataset to repeat.
nepochs (int, optional): Maximum number of epochs to repeat.
nbatches (int, optional): Maximum number of batches to repeat.
length (int, optional): Length of the repeated dataset.
"""
def __init__(self, source, nepochs=None, nbatches=None, length=None):
self.source = source
self.length = length
self.nbatches = nbatches
def invoke(self, source):
"""Return an iterator that iterates repeatedly over a source.
Args:
source: The source dataset to repeat.
Returns:
Iterator: An iterator that repeatedly yields samples from the source.
"""
return utils.repeatedly(
source,
nepochs=self.nepochs,
nbatches=self.nbatches,
)
class with_epoch(IterableDataset):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
Args:
dataset: The source IterableDataset.
length (int): Declared length of the dataset.
"""
def __init__(self, dataset, length):
super().__init__()
self.length = length
self.source = None
def __getstate__(self):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
Returns:
dict: A dictionary representing the pickled state of the dataset.
"""
result = dict(self.__dict__)
result["source"] = None
return result
def invoke(self, dataset):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length` parameter.
Args:
dataset: The source dataset to iterate over.
Yields:
Sample: The next sample from the dataset.
"""
if self.source is None:
self.source = iter(dataset)
for _ in range(self.length):
try:
sample = next(self.source)
except StopIteration:
self.source = iter(dataset)
try:
sample = next(self.source)
except StopIteration:
return
yield sample
self.source = None
class with_length(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset with a specified length.
Args:
dataset: The source dataset.
length (int): The stated length of the dataset.
"""
def __init__(self, dataset, length):
super().__init__()
self.dataset = dataset
self.length = length
def invoke(self, dataset):
"""Return an iterator that iterates over the source dataset.
Args:
dataset: The source dataset to iterate over.
Returns:
Iterator: An iterator over the source dataset.
"""
return iter(dataset)
def __len__(self):
"""Return the user specified length.
Returns:
int: The specified length of the dataset.
"""
return self.length