File size: 4,482 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
from . import utils
from .pytorch import IterableDataset
from .utils import PipelineStage
class MockDataset(IterableDataset):
"""Create a mock dataset for performance testing and unit testing.
Args:
sample: The sample to be returned repeatedly.
length (int): The length of the mock dataset.
"""
def __init__(self, sample, length):
self.sample = sample
self.length = length
def __iter__(self):
"""Yield samples from the mock dataset.
Returns:
Iterator: An iterator that yields the same sample repeatedly.
"""
for _ in range(self.length):
yield self.sample
class repeatedly(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset.
Args:
source: The source dataset to repeat.
nepochs (int, optional): Maximum number of epochs to repeat.
nbatches (int, optional): Maximum number of batches to repeat.
length (int, optional): Length of the repeated dataset.
"""
def __init__(self, source, nepochs=None, nbatches=None, length=None):
self.source = source
self.length = length
self.nbatches = nbatches
def invoke(self, source):
"""Return an iterator that iterates repeatedly over a source.
Args:
source: The source dataset to repeat.
Returns:
Iterator: An iterator that repeatedly yields samples from the source.
"""
return utils.repeatedly(
source,
nepochs=self.nepochs,
nbatches=self.nbatches,
)
class with_epoch(IterableDataset):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
Args:
dataset: The source IterableDataset.
length (int): Declared length of the dataset.
"""
def __init__(self, dataset, length):
super().__init__()
self.length = length
self.source = None
def __getstate__(self):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
Returns:
dict: A dictionary representing the pickled state of the dataset.
"""
result = dict(self.__dict__)
result["source"] = None
return result
def invoke(self, dataset):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length` parameter.
Args:
dataset: The source dataset to iterate over.
Yields:
Sample: The next sample from the dataset.
"""
if self.source is None:
self.source = iter(dataset)
for _ in range(self.length):
try:
sample = next(self.source)
except StopIteration:
self.source = iter(dataset)
try:
sample = next(self.source)
except StopIteration:
return
yield sample
self.source = None
class with_length(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset with a specified length.
Args:
dataset: The source dataset.
length (int): The stated length of the dataset.
"""
def __init__(self, dataset, length):
super().__init__()
self.dataset = dataset
self.length = length
def invoke(self, dataset):
"""Return an iterator that iterates over the source dataset.
Args:
dataset: The source dataset to iterate over.
Returns:
Iterator: An iterator over the source dataset.
"""
return iter(dataset)
def __len__(self):
"""Return the user specified length.
Returns:
int: The specified length of the dataset.
"""
return self.length
|