File size: 6,372 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import copy
import sys
import warnings
from itertools import islice

from .pytorch import DataLoader, IterableDataset
from .utils import PipelineStage


def add_length_method(obj):
    """Add a length method to the given object.

    Args:
        obj: The object to which the length method will be added.

    Returns:
        The modified object with a new length method.
    """

    def length(self):
        return self.size

    Combined = type(
        obj.__class__.__name__ + "_Length",
        (obj.__class__, IterableDataset),
        {"__len__": length},
    )
    obj.__class__ = Combined
    return obj


class DataPipeline(IterableDataset, PipelineStage):
    """A pipeline starting with an IterableDataset and a series of filters.

    Args:
        *args: Variable length argument list of pipeline stages.
        **kwargs: Arbitrary keyword arguments.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.pipeline = []
        self.length = -1
        self.repetitions = 1
        self.nsamples = -1
        for arg in args:
            if arg is None:
                continue
            if isinstance(arg, list):
                self.pipeline.extend(arg)
            else:
                self.pipeline.append(arg)

    def close(self):
        """Close the pipeline and release resources."""
        for step in self.pipeline:
            if hasattr(step, "close"):
                step.close()
        del self.pipeline

    def invoke(self, f, *args, **kwargs):
        """Apply a pipeline stage, possibly to the output of a previous stage.

        Args:
            f: The pipeline stage to invoke.
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            The result of invoking the pipeline stage.

        Raises:
            ValueError: If the pipeline stage is not valid.
        """
        if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
            return iter(f)
        if isinstance(f, PipelineStage):
            return f.run(*args, **kwargs)
        if isinstance(f, list):
            return iter(f)
        if callable(f):
            result = f(*args, **kwargs)
            return result
        raise ValueError(f"{f}: not a valid pipeline stage")

    def iterator1(self):
        """Create an iterator through one epoch in the pipeline.

        Returns:
            An iterator for one epoch of the pipeline.
        """
        source = self.invoke(self.pipeline[0])
        for step in self.pipeline[1:]:
            source = self.invoke(step, source)
        return source

    def iterator(self):
        """Create an iterator through the entire dataset, using the given number of repetitions.

        Yields:
            Samples from the dataset.
        """
        for _ in range(self.repetitions):
            count = 0
            for sample in self.iterator1():
                yield sample
                count += 1
            if count == 0:
                # if the dataset is empty, don't keep looping
                break

    def __iter__(self):
        """Create an iterator through the pipeline, repeating and slicing as requested.

        Returns:
            An iterator through the pipeline.
        """
        if self.repetitions != 1:
            if self.nsamples > 0:
                return islice(self.iterator(), self.nsamples)
            else:
                return self.iterator()
        else:
            return self.iterator()

    def stage(self, i):
        """Return pipeline stage i.

        Args:
            i: The index of the pipeline stage to return.

        Returns:
            The pipeline stage at index i.
        """
        return self.pipeline[i]

    def append(self, f):
        """Append a pipeline stage (modifies the object).

        Args:
            f: The pipeline stage to append.
        """
        self.pipeline.append(f)

    def compose(self, *args):
        """Append pipeline stages to a copy of the pipeline and return the copy.

        Args:
            *args: Variable length argument list of pipeline stages to append.

        Returns:
            A new DataPipeline object with the appended stages.
        """
        result = copy.copy(self)
        result.pipeline = copy.copy(result.pipeline)
        for arg in args:
            result.append(arg)
        return result

    def with_length(self, n, silent=False):
        """Add a __len__ method returning the desired value.

        This does not change the actual number of samples in an epoch.
        PyTorch IterableDataset should not have a __len__ method.
        This is provided only as a workaround for some broken training environments
        that require a __len__ method.

        Args:
            n: The length value to set.
            silent: If True, suppress the warning message.

        Returns:
            The modified DataPipeline object with a __len__ method.
        """
        if not silent:
            warnings.warn(
                ".with_length() only sets the value of __len__ for compatibility "
                + "with some training environments. It does not change the number of "
                + "samples in an epoch."
            )
        self.size = n
        return add_length_method(self)

    def with_epoch(self, nsamples=-1, nbatches=-1):
        """Change the epoch to return the given number of samples/batches.

        Args:
            nsamples: The number of samples per epoch.
            nbatches: The number of batches per epoch.

        Returns:
            The modified DataPipeline object.
        """
        self.repetitions = sys.maxsize
        self.nsamples = max(nsamples, nbatches)
        return self

    def repeat(self, nepochs=-1, nbatches=-1):
        """Repeat iterating through the dataset for the given number of epochs up to the given number of samples.

        Args:
            nepochs: The number of epochs to repeat.
            nbatches: The number of batches to limit per repetition.

        Returns:
            The modified DataPipeline object.
        """
        if nepochs > 0:
            self.repetitions = nepochs
            self.nsamples = nbatches
        else:
            self.repetitions = sys.maxsize
            self.nsamples = nbatches
        return self