File size: 10,307 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Checkpointing with preceding recomputation.

PyTorch already provides the official checkpointing utilities in
:mod:`torch.utils.checkpoint`. The official checkpointing combines
recomputation and recursive backpropagation into one autograd function named
``CheckpointFunction``. Hence, the recomputation can be started only when the
gradients arrive to the function. In Pipe, the recomputation needs to precede
the gradient arrival to minimize the GPU idle time.

We solve this problem by introducing separate autograd functions named
:class:`Recompute` and :class:`Checkpoint`. Each function represents
recomputation and recursive backpropagation, respectively. We can manipulate
the control flow in aspect of both the autograd engine and CUDA with a pair of
the functions.

Specifically, we place CUDA stream synchronization between :class:`Recompute`
and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is
copied entirely.

"""
from collections import deque
from contextlib import contextmanager
import threading
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union

import torch
from torch import ByteTensor, Tensor
import torch.autograd

from .dependency import fork, join
from .microbatch import Batch
from .phony import get_phony

__all__ = ["is_checkpointing", "is_recomputing"]


Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]

# Types for shared memory between Checkpoint and Recompute.
Recomputed = Tuple[TensorOrTensors, Tensors]  # (output, input_leaf)
RNGStates = Tuple[ByteTensor, Optional[ByteTensor]]  # (cpu_rng_state, gpu_rng_state)


if TYPE_CHECKING:
    from typing_extensions import Protocol
else:
    Protocol = object


# Protocol with __call__ instead of Callable can be used as an attribute type.
# See: https://github.com/python/mypy/issues/708#issuecomment-561735949
class Function(Protocol):
    def __call__(self, input: TensorOrTensors) -> TensorOrTensors:
        ...


class Checkpointing:
    """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""

    def __init__(self, function: Function, batch: Batch) -> None:
        self.function = function
        self.batch = batch

        # Shared memory between Checkpoint and Recompute. 1-length deque is
        # used for mutability and length limitation.
        self.recomputed: Deque[Recomputed] = deque(maxlen=1)
        self.rng_states: Deque[RNGStates] = deque(maxlen=1)

    def checkpoint(self) -> Batch:
        """Returns a batch applied by :class:`Checkpoint`."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # Use a phony which requires grad to ensure that Checkpoint can be
        # tracked by the autograd engine even when none of the input tensors
        # require grad.
        phony = get_phony(self.batch[0].device, requires_grad=True)

        output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)

        # Gradients are only supported for float Tensors.
        if isinstance(output, tuple):
            output = tuple([x if x.is_floating_point() else x.detach() for x in output])

        return Batch(output, self.batch.index)

    def recompute(self, batch: Batch) -> None:
        """Applies :class:`Recompute` to the batch in place."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # batch[0] is always requiring grad, because it has been passed
        # checkpoint with a phony requiring grad.
        batch[0], phony = fork(batch[0])
        phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)
        batch[0] = join(batch[0], phony)


class ThreadLocal(threading.local):
    def __init__(self) -> None:
        self.is_checkpointing = False
        self.is_recomputing = False


thread_local = ThreadLocal()


@contextmanager
def enable_checkpointing() -> Generator[None, None, None]:
    """Makes :func:`is_checkpointing` return :data:`True` within a context."""
    orig = thread_local.is_checkpointing
    thread_local.is_checkpointing = True
    try:
        yield
    finally:
        thread_local.is_checkpointing = orig


@contextmanager
def enable_recomputing() -> Generator[None, None, None]:
    """Makes :func:`is_recomputing` return :data:`True` within a context."""
    orig = thread_local.is_recomputing
    thread_local.is_recomputing = True
    try:
        yield
    finally:
        thread_local.is_recomputing = orig


def is_checkpointing() -> bool:
    """Whether the current forward propagation is under checkpointing.

    Returns:
        bool: :data:`True` if it's under checkpointing.

    """
    return thread_local.is_checkpointing


def is_recomputing() -> bool:
    """Whether the current forward propagation is under checkpoint
    recomputation. Use this to prevent duplicated side-effects at forward
    propagation::

        class Counter(nn.Module):
            def __init__(self):
                super().__init__()
                self.counter = 0

            def forward(self, input):
                if not is_recomputing():
                    self.counter += 1
                return input

    Returns:
        bool: :data:`True` if it's under checkpoint recomputation.

    .. seealso:: :ref:`Detecting Recomputation`

    """
    return thread_local.is_recomputing


class Context:
    """The common interface between the :class:`Checkpoint` and
    :class:`Recompute` context.
    """

    recomputed: Deque[Recomputed]
    rng_states: Deque[RNGStates]
    function: Function
    input_atomic: bool

    saved_tensors: Tuple[Tensor, ...]

    def save_for_backward(self, *tensors: Tensor) -> None:  # pragma: no cover
        pass


def save_rng_states(
    device: torch.device,
    rng_states: Deque[RNGStates],
) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == "cuda":
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.clear()
    rng_states.append((cpu_rng_state, gpu_rng_state))


@contextmanager
def restore_rng_states(
    device: torch.device,
    rng_states: Deque[RNGStates],
) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states[0]

    gpu_devices: List[torch.device] = []
    if device.type == "cuda":
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield


class Checkpoint(torch.autograd.Function):
    @staticmethod
    # type: ignore
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> TensorOrTensors:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        save_rng_states(input[0].device, ctx.rng_states)

        ctx.function = function
        ctx.input_atomic = input_atomic
        ctx.save_for_backward(*input)

        with torch.no_grad(), enable_checkpointing():
            output = function(input[0] if input_atomic else input)

        return output

    @staticmethod
    def backward(
        ctx: Context,
        *grad_output: Tensor,
    ) -> Tuple[Optional[Tensor], ...]:  # pragma: no cover
        output, input_leaf = ctx.recomputed.pop()

        if isinstance(output, tuple):
            tensors = output
        else:
            tensors = (output,)
        if any(y.requires_grad for y in tensors):
            tensors = tuple([x for x in tensors if x.requires_grad])
            torch.autograd.backward(tensors, grad_output)

        grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
        grad_input.extend(x.grad for x in input_leaf)
        return tuple(grad_input)


class Recompute(torch.autograd.Function):
    @staticmethod
    # type: ignore
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> Tensor:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        ctx.function = function
        ctx.input_atomic = input_atomic
        ctx.save_for_backward(*input)

        return phony

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:  # pragma: no cover
        input = ctx.saved_tensors
        input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)

        with restore_rng_states(input[0].device, ctx.rng_states):
            with torch.enable_grad(), enable_recomputing():
                output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)

        ctx.recomputed.append((output, input_leaf))

        grad_input: List[None] = [None, None, None, None, None]
        grad_input.extend(None for _ in ctx.saved_tensors)
        return tuple(grad_input)