File size: 4,482 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#


"""Train PyTorch models directly from POSIX tar archive.

Code works locally or over HTTP connections.
"""


from . import utils
from .pytorch import IterableDataset
from .utils import PipelineStage


class MockDataset(IterableDataset):
    """Create a mock dataset for performance testing and unit testing.

    Args:
        sample: The sample to be returned repeatedly.
        length (int): The length of the mock dataset.
    """

    def __init__(self, sample, length):
        self.sample = sample
        self.length = length

    def __iter__(self):
        """Yield samples from the mock dataset.

        Returns:
            Iterator: An iterator that yields the same sample repeatedly.
        """
        for _ in range(self.length):
            yield self.sample


class repeatedly(IterableDataset, PipelineStage):
    """Repeatedly yield samples from a dataset.

    Args:
        source: The source dataset to repeat.
        nepochs (int, optional): Maximum number of epochs to repeat.
        nbatches (int, optional): Maximum number of batches to repeat.
        length (int, optional): Length of the repeated dataset.
    """

    def __init__(self, source, nepochs=None, nbatches=None, length=None):
        self.source = source
        self.length = length
        self.nbatches = nbatches

    def invoke(self, source):
        """Return an iterator that iterates repeatedly over a source.

        Args:
            source: The source dataset to repeat.

        Returns:
            Iterator: An iterator that repeatedly yields samples from the source.
        """
        return utils.repeatedly(
            source,
            nepochs=self.nepochs,
            nbatches=self.nbatches,
        )


class with_epoch(IterableDataset):
    """Change the actual and nominal length of an IterableDataset.

    This will continuously iterate through the original dataset, but
    impose new epoch boundaries at the given length/nominal.
    This exists mainly as a workaround for the odd logic in DataLoader.
    It is also useful for choosing smaller nominal epoch sizes with
    very large datasets.

    Args:
        dataset: The source IterableDataset.
        length (int): Declared length of the dataset.
    """

    def __init__(self, dataset, length):
        super().__init__()
        self.length = length
        self.source = None

    def __getstate__(self):
        """Return the pickled state of the dataset.

        This resets the dataset iterator, since that can't be pickled.

        Returns:
            dict: A dictionary representing the pickled state of the dataset.
        """
        result = dict(self.__dict__)
        result["source"] = None
        return result

    def invoke(self, dataset):
        """Return an iterator over the dataset.

        This iterator returns as many samples as given by the `length` parameter.

        Args:
            dataset: The source dataset to iterate over.

        Yields:
            Sample: The next sample from the dataset.
        """
        if self.source is None:
            self.source = iter(dataset)
        for _ in range(self.length):
            try:
                sample = next(self.source)
            except StopIteration:
                self.source = iter(dataset)
                try:
                    sample = next(self.source)
                except StopIteration:
                    return
            yield sample
        self.source = None


class with_length(IterableDataset, PipelineStage):
    """Repeatedly yield samples from a dataset with a specified length.

    Args:
        dataset: The source dataset.
        length (int): The stated length of the dataset.
    """

    def __init__(self, dataset, length):
        super().__init__()
        self.dataset = dataset
        self.length = length

    def invoke(self, dataset):
        """Return an iterator that iterates over the source dataset.

        Args:
            dataset: The source dataset to iterate over.

        Returns:
            Iterator: An iterator over the source dataset.
        """
        return iter(dataset)

    def __len__(self):
        """Return the user specified length.

        Returns:
            int: The specified length of the dataset.
        """
        return self.length