import copy import sys import warnings from itertools import islice from .pytorch import DataLoader, IterableDataset from .utils import PipelineStage def add_length_method(obj): """Add a length method to the given object. Args: obj: The object to which the length method will be added. Returns: The modified object with a new length method. """ def length(self): return self.size Combined = type( obj.__class__.__name__ + "_Length", (obj.__class__, IterableDataset), {"__len__": length}, ) obj.__class__ = Combined return obj class DataPipeline(IterableDataset, PipelineStage): """A pipeline starting with an IterableDataset and a series of filters. Args: *args: Variable length argument list of pipeline stages. **kwargs: Arbitrary keyword arguments. """ def __init__(self, *args, **kwargs): super().__init__() self.pipeline = [] self.length = -1 self.repetitions = 1 self.nsamples = -1 for arg in args: if arg is None: continue if isinstance(arg, list): self.pipeline.extend(arg) else: self.pipeline.append(arg) def close(self): """Close the pipeline and release resources.""" for step in self.pipeline: if hasattr(step, "close"): step.close() del self.pipeline def invoke(self, f, *args, **kwargs): """Apply a pipeline stage, possibly to the output of a previous stage. Args: f: The pipeline stage to invoke. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: The result of invoking the pipeline stage. Raises: ValueError: If the pipeline stage is not valid. """ if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0: return iter(f) if isinstance(f, PipelineStage): return f.run(*args, **kwargs) if isinstance(f, list): return iter(f) if callable(f): result = f(*args, **kwargs) return result raise ValueError(f"{f}: not a valid pipeline stage") def iterator1(self): """Create an iterator through one epoch in the pipeline. Returns: An iterator for one epoch of the pipeline. """ source = self.invoke(self.pipeline[0]) for step in self.pipeline[1:]: source = self.invoke(step, source) return source def iterator(self): """Create an iterator through the entire dataset, using the given number of repetitions. Yields: Samples from the dataset. """ for _ in range(self.repetitions): count = 0 for sample in self.iterator1(): yield sample count += 1 if count == 0: # if the dataset is empty, don't keep looping break def __iter__(self): """Create an iterator through the pipeline, repeating and slicing as requested. Returns: An iterator through the pipeline. """ if self.repetitions != 1: if self.nsamples > 0: return islice(self.iterator(), self.nsamples) else: return self.iterator() else: return self.iterator() def stage(self, i): """Return pipeline stage i. Args: i: The index of the pipeline stage to return. Returns: The pipeline stage at index i. """ return self.pipeline[i] def append(self, f): """Append a pipeline stage (modifies the object). Args: f: The pipeline stage to append. """ self.pipeline.append(f) def compose(self, *args): """Append pipeline stages to a copy of the pipeline and return the copy. Args: *args: Variable length argument list of pipeline stages to append. Returns: A new DataPipeline object with the appended stages. """ result = copy.copy(self) result.pipeline = copy.copy(result.pipeline) for arg in args: result.append(arg) return result def with_length(self, n, silent=False): """Add a __len__ method returning the desired value. This does not change the actual number of samples in an epoch. PyTorch IterableDataset should not have a __len__ method. This is provided only as a workaround for some broken training environments that require a __len__ method. Args: n: The length value to set. silent: If True, suppress the warning message. Returns: The modified DataPipeline object with a __len__ method. """ if not silent: warnings.warn( ".with_length() only sets the value of __len__ for compatibility " + "with some training environments. It does not change the number of " + "samples in an epoch." ) self.size = n return add_length_method(self) def with_epoch(self, nsamples=-1, nbatches=-1): """Change the epoch to return the given number of samples/batches. Args: nsamples: The number of samples per epoch. nbatches: The number of batches per epoch. Returns: The modified DataPipeline object. """ self.repetitions = sys.maxsize self.nsamples = max(nsamples, nbatches) return self def repeat(self, nepochs=-1, nbatches=-1): """Repeat iterating through the dataset for the given number of epochs up to the given number of samples. Args: nepochs: The number of epochs to repeat. nbatches: The number of batches to limit per repetition. Returns: The modified DataPipeline object. """ if nepochs > 0: self.repetitions = nepochs self.nsamples = nbatches else: self.repetitions = sys.maxsize self.nsamples = nbatches return self