|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Train PyTorch models directly from POSIX tar archive. |
|
|
|
Code works locally or over HTTP connections. |
|
""" |
|
|
|
|
|
from . import utils |
|
from .pytorch import IterableDataset |
|
from .utils import PipelineStage |
|
|
|
|
|
class MockDataset(IterableDataset): |
|
"""Create a mock dataset for performance testing and unit testing. |
|
|
|
Args: |
|
sample: The sample to be returned repeatedly. |
|
length (int): The length of the mock dataset. |
|
""" |
|
|
|
def __init__(self, sample, length): |
|
self.sample = sample |
|
self.length = length |
|
|
|
def __iter__(self): |
|
"""Yield samples from the mock dataset. |
|
|
|
Returns: |
|
Iterator: An iterator that yields the same sample repeatedly. |
|
""" |
|
for _ in range(self.length): |
|
yield self.sample |
|
|
|
|
|
class repeatedly(IterableDataset, PipelineStage): |
|
"""Repeatedly yield samples from a dataset. |
|
|
|
Args: |
|
source: The source dataset to repeat. |
|
nepochs (int, optional): Maximum number of epochs to repeat. |
|
nbatches (int, optional): Maximum number of batches to repeat. |
|
length (int, optional): Length of the repeated dataset. |
|
""" |
|
|
|
def __init__(self, source, nepochs=None, nbatches=None, length=None): |
|
self.source = source |
|
self.length = length |
|
self.nbatches = nbatches |
|
|
|
def invoke(self, source): |
|
"""Return an iterator that iterates repeatedly over a source. |
|
|
|
Args: |
|
source: The source dataset to repeat. |
|
|
|
Returns: |
|
Iterator: An iterator that repeatedly yields samples from the source. |
|
""" |
|
return utils.repeatedly( |
|
source, |
|
nepochs=self.nepochs, |
|
nbatches=self.nbatches, |
|
) |
|
|
|
|
|
class with_epoch(IterableDataset): |
|
"""Change the actual and nominal length of an IterableDataset. |
|
|
|
This will continuously iterate through the original dataset, but |
|
impose new epoch boundaries at the given length/nominal. |
|
This exists mainly as a workaround for the odd logic in DataLoader. |
|
It is also useful for choosing smaller nominal epoch sizes with |
|
very large datasets. |
|
|
|
Args: |
|
dataset: The source IterableDataset. |
|
length (int): Declared length of the dataset. |
|
""" |
|
|
|
def __init__(self, dataset, length): |
|
super().__init__() |
|
self.length = length |
|
self.source = None |
|
|
|
def __getstate__(self): |
|
"""Return the pickled state of the dataset. |
|
|
|
This resets the dataset iterator, since that can't be pickled. |
|
|
|
Returns: |
|
dict: A dictionary representing the pickled state of the dataset. |
|
""" |
|
result = dict(self.__dict__) |
|
result["source"] = None |
|
return result |
|
|
|
def invoke(self, dataset): |
|
"""Return an iterator over the dataset. |
|
|
|
This iterator returns as many samples as given by the `length` parameter. |
|
|
|
Args: |
|
dataset: The source dataset to iterate over. |
|
|
|
Yields: |
|
Sample: The next sample from the dataset. |
|
""" |
|
if self.source is None: |
|
self.source = iter(dataset) |
|
for _ in range(self.length): |
|
try: |
|
sample = next(self.source) |
|
except StopIteration: |
|
self.source = iter(dataset) |
|
try: |
|
sample = next(self.source) |
|
except StopIteration: |
|
return |
|
yield sample |
|
self.source = None |
|
|
|
|
|
class with_length(IterableDataset, PipelineStage): |
|
"""Repeatedly yield samples from a dataset with a specified length. |
|
|
|
Args: |
|
dataset: The source dataset. |
|
length (int): The stated length of the dataset. |
|
""" |
|
|
|
def __init__(self, dataset, length): |
|
super().__init__() |
|
self.dataset = dataset |
|
self.length = length |
|
|
|
def invoke(self, dataset): |
|
"""Return an iterator that iterates over the source dataset. |
|
|
|
Args: |
|
dataset: The source dataset to iterate over. |
|
|
|
Returns: |
|
Iterator: An iterator over the source dataset. |
|
""" |
|
return iter(dataset) |
|
|
|
def __len__(self): |
|
"""Return the user specified length. |
|
|
|
Returns: |
|
int: The specified length of the dataset. |
|
""" |
|
return self.length |
|
|