File size: 5,949 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Manipulation of micro-batches."""
import typing
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast

import torch
from torch import Tensor
import torch.cuda.comm

__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], TensorOrTensors]


class Batch:
    """An abstraction of an atomic tensor or a tuple of tensors. This
    eliminates every boilerplate code to classify an atomic tensor or a tuple
    of tensors.
    ::

        x = generate_tensor_or_tensors()
        x = Batch(x)

        # in-place update
        x[0] = F.apply(x[0])
        x[:] = F.apply(*x)

        # f(x) if x is a tensor.
        # f(*x) if x is a tuple of tensors.
        # y is also a batch.
        y = x.call(f)

    """

    def __init__(self, value: TensorOrTensors, index: int) -> None:
        self.value = value
        self.atomic = torch.is_tensor(value)
        self.__index = index

    @property
    def index(self) -> int:
        return self.__index

    @property
    def tensor(self) -> Tensor:
        """Retrieves the underlying tensor."""
        if not self.atomic:
            raise AttributeError("not atomic batch")
        return cast(Tensor, self.value)

    @property
    def tensors(self) -> Tensors:
        """Retrieves the underlying tensors."""
        if self.atomic:
            raise AttributeError("batch is atomic")
        return cast(Tensors, self.value)

    @property
    def tensor_or_tensors(self) -> TensorOrTensors:
        """Retrieves the underlying tensor or tensors regardless of type."""
        return self.value

    def call(self, function: Function) -> "Batch":
        """Calls a function by the underlying tensor or tensors. It also wraps
        the output with :class:`Batch`.
        """
        return Batch(function(self.value), self.index)

    def __repr__(self) -> str:
        return f"Batch[atomic={self.atomic!r}]({self.value!r})"

    def __iter__(self) -> Iterator[Tensor]:
        if self.atomic:
            yield self.tensor
        else:
            yield from self.tensors

    def __len__(self) -> int:
        return 1 if self.atomic else len(self.tensors)

    def __getitem__(self, index: int) -> Tensor:
        if not self.atomic:
            return self.tensors[index]

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        return self.tensor

    # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
    @typing.overload
    def __setitem__(self, index: int, value: Tensor) -> None:
        ...

    @typing.overload
    def __setitem__(self, index: slice, value: Tensors) -> None:
        ...

    def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None:
        if isinstance(index, int):
            value = cast(Tensor, value)
            self._setitem_by_index(index, value)
        else:
            value = cast(Tensors, value)
            self._setitem_by_slice(index, value)

    def _setitem_by_index(self, index: int, value: Tensor) -> None:
        if not self.atomic:
            i = index
            self.value = self.value[:i] + (value,) + self.value[i + 1 :]
            return

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        self.value = value

    def _setitem_by_slice(self, index: slice, value: Tensors) -> None:
        if not (index.start is index.stop is index.step is None):
            raise NotImplementedError("only slice [:] supported")

        if not self.atomic:
            self.value = value
            return

        if len(value) != 1:
            raise IndexError("atomic batch cannot be replaced with multiple tensors")

        self.value = value[0]


def check(input: TensorOrTensors) -> None:
    """Checks whether the input is a tensor or tensors.

    Raises:
        TypeError: input is not a tensor or tensors.

    """
    if isinstance(input, tuple):
        for x in input:
            check(x)
        return

    if not isinstance(input, Tensor):
        raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")


def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
    """Splits an input mini-batch into multiple micro-batches."""
    inputs: Iterable[TensorOrTensors]

    if isinstance(input, Tensor):
        inputs = input.chunk(chunks)
    else:
        rotated: List[Tensors] = []

        for tensor in input:
            tensors = tensor.chunk(chunks)
            rotated.append(cast(Tensors, tensors))

        inputs = zip(*rotated)

    return [Batch(x, i) for i, x in enumerate(inputs)]


def gather(outputs: List[Batch]) -> TensorOrTensors:
    """Concatenates output micro-batches into a mini-batch."""
    output: TensorOrTensors

    if outputs[0].atomic:
        tensors = tuple(b.tensor for b in outputs)
        output = torch.cat(tensors)
    else:
        rotated = [b.tensors for b in outputs]
        output_buf = []

        for tensors in zip(*rotated):
            output_buf.append(torch.cat(tensors))

        output = tuple(output_buf)

    return output