File size: 6,372 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import copy
import sys
import warnings
from itertools import islice
from .pytorch import DataLoader, IterableDataset
from .utils import PipelineStage
def add_length_method(obj):
"""Add a length method to the given object.
Args:
obj: The object to which the length method will be added.
Returns:
The modified object with a new length method.
"""
def length(self):
return self.size
Combined = type(
obj.__class__.__name__ + "_Length",
(obj.__class__, IterableDataset),
{"__len__": length},
)
obj.__class__ = Combined
return obj
class DataPipeline(IterableDataset, PipelineStage):
"""A pipeline starting with an IterableDataset and a series of filters.
Args:
*args: Variable length argument list of pipeline stages.
**kwargs: Arbitrary keyword arguments.
"""
def __init__(self, *args, **kwargs):
super().__init__()
self.pipeline = []
self.length = -1
self.repetitions = 1
self.nsamples = -1
for arg in args:
if arg is None:
continue
if isinstance(arg, list):
self.pipeline.extend(arg)
else:
self.pipeline.append(arg)
def close(self):
"""Close the pipeline and release resources."""
for step in self.pipeline:
if hasattr(step, "close"):
step.close()
del self.pipeline
def invoke(self, f, *args, **kwargs):
"""Apply a pipeline stage, possibly to the output of a previous stage.
Args:
f: The pipeline stage to invoke.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
The result of invoking the pipeline stage.
Raises:
ValueError: If the pipeline stage is not valid.
"""
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
return iter(f)
if isinstance(f, PipelineStage):
return f.run(*args, **kwargs)
if isinstance(f, list):
return iter(f)
if callable(f):
result = f(*args, **kwargs)
return result
raise ValueError(f"{f}: not a valid pipeline stage")
def iterator1(self):
"""Create an iterator through one epoch in the pipeline.
Returns:
An iterator for one epoch of the pipeline.
"""
source = self.invoke(self.pipeline[0])
for step in self.pipeline[1:]:
source = self.invoke(step, source)
return source
def iterator(self):
"""Create an iterator through the entire dataset, using the given number of repetitions.
Yields:
Samples from the dataset.
"""
for _ in range(self.repetitions):
count = 0
for sample in self.iterator1():
yield sample
count += 1
if count == 0:
# if the dataset is empty, don't keep looping
break
def __iter__(self):
"""Create an iterator through the pipeline, repeating and slicing as requested.
Returns:
An iterator through the pipeline.
"""
if self.repetitions != 1:
if self.nsamples > 0:
return islice(self.iterator(), self.nsamples)
else:
return self.iterator()
else:
return self.iterator()
def stage(self, i):
"""Return pipeline stage i.
Args:
i: The index of the pipeline stage to return.
Returns:
The pipeline stage at index i.
"""
return self.pipeline[i]
def append(self, f):
"""Append a pipeline stage (modifies the object).
Args:
f: The pipeline stage to append.
"""
self.pipeline.append(f)
def compose(self, *args):
"""Append pipeline stages to a copy of the pipeline and return the copy.
Args:
*args: Variable length argument list of pipeline stages to append.
Returns:
A new DataPipeline object with the appended stages.
"""
result = copy.copy(self)
result.pipeline = copy.copy(result.pipeline)
for arg in args:
result.append(arg)
return result
def with_length(self, n, silent=False):
"""Add a __len__ method returning the desired value.
This does not change the actual number of samples in an epoch.
PyTorch IterableDataset should not have a __len__ method.
This is provided only as a workaround for some broken training environments
that require a __len__ method.
Args:
n: The length value to set.
silent: If True, suppress the warning message.
Returns:
The modified DataPipeline object with a __len__ method.
"""
if not silent:
warnings.warn(
".with_length() only sets the value of __len__ for compatibility "
+ "with some training environments. It does not change the number of "
+ "samples in an epoch."
)
self.size = n
return add_length_method(self)
def with_epoch(self, nsamples=-1, nbatches=-1):
"""Change the epoch to return the given number of samples/batches.
Args:
nsamples: The number of samples per epoch.
nbatches: The number of batches per epoch.
Returns:
The modified DataPipeline object.
"""
self.repetitions = sys.maxsize
self.nsamples = max(nsamples, nbatches)
return self
def repeat(self, nepochs=-1, nbatches=-1):
"""Repeat iterating through the dataset for the given number of epochs up to the given number of samples.
Args:
nepochs: The number of epochs to repeat.
nbatches: The number of batches to limit per repetition.
Returns:
The modified DataPipeline object.
"""
if nepochs > 0:
self.repetitions = nepochs
self.nsamples = nbatches
else:
self.repetitions = sys.maxsize
self.nsamples = nbatches
return self
|