File size: 19,047 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum, auto
from threading import Event
from typing import Dict, Iterable, List, Optional, Tuple

import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from torch.distributed import ProcessGroup

from fairscale.nn.model_parallel import get_pipeline_parallel_ranks

from .checkpoint import Checkpointing
from .messages import Transport
from .microbatch import Batch
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors, Tensors
from .worker import Task


def create_task(
    checkpoint_stop: int,
    chunk_id: int,
    part_id: int,
    batch: Batch,
    partition: nn.Sequential,
    skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
    # Determine whether checkpointing or not.
    if chunk_id < checkpoint_stop:

        def function(
            input: TensorOrTensors,
            partition: nn.Sequential = partition,
            skip_tracker: SkipTrackerThroughPotals = skip_trackers[chunk_id],
            chunk_id: int = chunk_id,
            part_id: int = part_id,
        ) -> TensorOrTensors:
            with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                ret = partition(input)
                # We do a check here because the backtrace from the checkpoint backward code path
                # is very hard to make sense. It would be much easier to check earlier at this point.
                assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
                return ret

        chk = Checkpointing(function, batch)
        task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
        del function, chk  # TODO(tom) maybe remove

    else:

        def compute(
            batch: Batch = batch,
            partition: nn.Sequential = partition,
            skip_tracker: SkipTrackerThroughPotals = skip_trackers[chunk_id],
            chunk_id: int = chunk_id,
            part_id: int = part_id,
        ) -> Batch:
            with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                return batch.call(partition)

        task = Task(None, compute=compute, finalize=None)
        del compute  # TODO(tom) maybe remove

    return task


@dataclass(frozen=True)
class Location:
    stage: int
    index: int

    def __repr__(self) -> str:
        return f"{self.stage}@{self.index}"


@dataclass(frozen=True)
class Invocation:
    order: int
    this: Location
    source: Optional[Location]
    dest: Optional[Location]


Activations = Dict[int, Dict[int, Dict[int, Batch]]]
Invocations = Dict[int, Invocation]


@dataclass(frozen=True)
class TailBackwardContext:
    activations: Activations
    invocations: Invocations
    count_per_order: Dict[int, int]
    expected_gradients: int


class ModuleWrapper:
    def __init__(self, module: nn.Sequential, location: Location, invocations: Optional[List[Invocation]] = None):
        self.module: nn.Sequential = module
        self.location: Location = location
        self.invocations: List[Invocation] = invocations or []

    def __repr__(self) -> str:
        return f"{self.location}:\n" + "\n".join(map(str, self.invocations)) + "\n\t" + str(self.module)

    def __len__(self) -> int:
        return len(self.module)

    def __iter__(self) -> Iterable:
        yield from self.module


class AsyncMessageType(Enum):
    Activations = auto()
    Gradients = auto()


@dataclass(frozen=True)
class AsyncMessageBody:
    message_type: AsyncMessageType
    microbatch_index: int
    source: Location
    dest: Location
    order: int


class AutogradWithoutActivations(torch.autograd.Function):
    """A helper class to add another edge in the autograd graph which allows us
    to delete the potentially large activations and still perform a backward
    pass. Returns return a phony tensor which is connected to the graph."""

    @staticmethod
    # type: ignore
    def forward(ctx, *x):
        return torch.tensor(1.0)

    @staticmethod
    # type: ignore
    def backward(ctx, grad):
        assert ctx.grad_from_pipeline is not None
        return ctx.grad_from_pipeline


class AsyncRecvOperator(torch.autograd.Function):
    """Receive activations to the previous pipeline stage"""

    @staticmethod
    # type: ignore
    def forward(ctx, phony: Tensor, transport: Transport, message: PipeMessage, queue_name: int) -> Tensors:
        ctx.transport = transport
        ctx.index = message.args.microbatch_index
        ctx.queue_name = queue_name
        result = transport.recv_message_tensors(message)

        ctx.args = result.args

        def maybe_requires_grad(t: Tensor) -> Tensor:
            if t.dtype.is_floating_point:
                return t.requires_grad_()
            return t

        return tuple(maybe_requires_grad(r) for r in result.tensors)

    @staticmethod
    # type: ignore
    def backward(
        ctx,
        *grad: Tensor,
    ) -> Tuple[Optional[Tensor], ...]:
        ranks = get_pipeline_parallel_ranks()
        this_rank = torch.distributed.get_rank()
        body = AsyncMessageBody(
            AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1
        )
        ctx.transport.send_message(
            PipeMessage(
                this_rank,
                ranks[ctx.args.source.stage],
                queue_name=ctx.queue_name,
                args=body,
                tensors=tuple(grad),
            ),
            sync=True,
        )

        tail_ctx = getattr(ctx, "tail_ctx", None)
        if tail_ctx:
            expected_gradients = tail_ctx.expected_gradients
            while expected_gradients > 0:
                message = ctx.transport.recv_message_header(ctx.queue_name)

                args: AsyncMessageBody = message.args
                assert args.message_type is AsyncMessageType.Gradients

                invocation = tail_ctx.invocations[args.order]
                expected_gradients -= tail_ctx.count_per_order[invocation.order]
                AsyncEventLoop.perform_backward_for_invocation(ctx.transport, message, tail_ctx.activations, invocation)

        return (None, None, None, None, None)


class AsyncEventLoop:
    def __init__(
        self,
        partitions: List[ModuleWrapper],
        group: ProcessGroup,
        transport: Transport,
        training: bool,
        checkpoint_stop: int,
    ):
        self.training = training
        self.checkpoint_stop = checkpoint_stop
        self.transport = transport
        self.group = group
        self.partitions: List[ModuleWrapper] = partitions

    def send_async_message(self, dst_rank: int, result: Batch, invocation: Invocation) -> Batch:
        """Send batch to dst_rank, and use AutogradWithoutActivations to delete
        the activations since we no longer need them"""

        assert invocation.dest
        src_rank = torch.distributed.get_rank()

        body = AsyncMessageBody(
            AsyncMessageType.Activations, result.index, invocation.this, invocation.dest, invocation.order + 1
        )
        self.transport.send_message(
            PipeMessage(src_rank, dst_rank, queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple([*result])),
            sync=True,
        )

        phony = AutogradWithoutActivations.apply(*result)
        return Batch(phony, result.index)

    def run_invocation(
        self,
        batch: Batch,
        partition: ModuleWrapper,
        skip_trackers: List[SkipTrackerThroughPotals],
        invocation: Invocation,
    ) -> Batch:
        """Actually run the forward pass for a given module, and send the result
        to the next stage in the pipeline if needed."""

        task = create_task(
            self.checkpoint_stop,
            batch.index,
            self.group.rank(),
            batch,
            partition.module,
            skip_trackers,
        )
        result = task.compute()
        task.finalize(result)

        if invocation.dest and invocation.dest.stage != invocation.this.stage:
            ranks = get_pipeline_parallel_ranks()
            dst_rank = ranks[invocation.dest.stage]
            result = self.send_async_message(dst_rank, result, invocation)
        return result

    @staticmethod
    def perform_backward_for_invocation(
        transport: Transport, message: PipeMessage, activations: Activations, invocation: Invocation
    ) -> None:
        """Perform the backward pass by looking up the appropriate `Batch` and
        then calling `backward` on the tensor"""

        recvd_grads = transport.recv_message_tensors(message)

        batch: Batch = activations[invocation.this.index][invocation.order][message.args.microbatch_index]

        # All batches saved in `activations` are generated by AutogradWithoutActivations,
        # so we store the gradients in `grad_from_pipeline` so it will be used
        # during the backward pass
        batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors)
        batch.tensor.backward(retain_graph=True)

    def run_invocations_on_batch(
        self,
        batch: Batch,
        invocations: Invocations,
        order: int,
        skip_trackers: List[SkipTrackerThroughPotals],
        activations: Activations,
    ) -> Tuple[int, int]:
        """Run invocations on the batch until we hit one that receives its input
        from a different stage (i.e. another process)"""

        invocations_handled = 0
        last_order = 0
        for invocation in invocations.values():
            if invocation.order < order:
                continue
            pi = invocation.this.index
            partition = self.partitions[pi]

            if invocation.order == order:
                invocations_handled += 1
                last_order = invocation.order
                activations[pi][invocation.order][batch.index] = self.run_invocation(
                    batch, partition, skip_trackers, invocation
                )
            elif invocation.source and invocation.source.stage == self.group.rank():
                invocations_handled += 1
                last_order = invocation.order
                batch = activations[invocation.source.index][invocation.order - 1][batch.index]
                activations[pi][invocation.order][batch.index] = self.run_invocation(
                    batch, partition, skip_trackers, invocation
                )
                del activations[invocation.source.index][invocation.order - 1][batch.index]

            elif invocation.source and invocation.source.stage != self.group.rank():
                break

        return (invocations_handled, last_order)

    def event_loop_head(
        self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals], event: Optional[Event]
    ) -> None:
        """The event loop for the "head", which first performs the forward pass
        on any applicable layers for this stage, and then enters the common
        `event_loop_inner`"""

        invocations, activations = self.get_invocations_and_activations()

        expected_invocations = len(invocations) * len(batches)
        actual_invocations = 0

        count_per_order = dict()

        for batch in batches:
            inv_count, last_order = self.run_invocations_on_batch(batch, invocations, 0, skip_trackers, activations)
            actual_invocations += inv_count
            count_per_order[last_order] = inv_count

        if actual_invocations < expected_invocations or self.training:
            self.event_loop_inner(
                expected_invocations,
                skip_trackers,
                activations,
                invocations,
                count_per_order,
                already_received=actual_invocations,
                event=event,
            )

    def get_batch_from_message(self, message: PipeMessage) -> Batch:
        """Get the tensor(s) wrapped in a `Batch` from a `PipeMessage`, applying
        AsyncRecvOperator so we can intercept the backward pass"""

        microbatch_index = message.args.microbatch_index
        phony = torch.empty(0, device=self.transport.input_device, requires_grad=True)
        result = AsyncRecvOperator.apply(phony, self.transport, message, EVENT_LOOP_QUEUE)
        if len(result) == 1:
            batch = Batch(result[0], microbatch_index)
        else:
            batch = Batch(result, microbatch_index)
        return batch

    def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals]) -> None:
        """The event loop for the "tail", or final stage which only processes
        activations and then returns to the caller so that the loss can be
        calculated. This also handles the first/only stage for the special
        case of a 1-stage pipeline."""

        invocations, activations = self.get_invocations_and_activations()
        expected_invocations = len(invocations) * len(batches)
        actual_invocations = 0

        rank = self.group.rank()
        count_per_order = dict()

        for batch in batches:
            if rank == 0:
                order = 0
            else:
                message = self.transport.recv_message_header(EVENT_LOOP_QUEUE)
                args: AsyncMessageBody = message.args

                batch = self.get_batch_from_message(message)
                order = args.order

            inv_count, last_order = self.run_invocations_on_batch(batch, invocations, order, skip_trackers, activations)
            actual_invocations += inv_count
            count_per_order[last_order] = inv_count

            if invocations[last_order].dest is None:
                self.prepare_tail_backward(
                    batch, activations, invocations, count_per_order, len(invocations) - inv_count
                )

        if actual_invocations < expected_invocations:
            expected_gradients = 0  # (len(invocations) - 1) * len(batches)

            self.event_loop_inner(
                expected_invocations,
                skip_trackers,
                activations,
                invocations,
                count_per_order,
                already_received=actual_invocations,
                ignore_gradients=True,
                tail=True,
            )

        _, last_invocation = invocations.popitem()

        for index, batch in activations[len(self.partitions) - 1][last_invocation.order].items():
            batches[index] = batch

    def get_invocations_and_activations(self) -> Tuple[Invocations, Activations]:
        activations: Activations = dict()
        invocations: Invocations = OrderedDict()

        for pi, partition in enumerate(self.partitions):
            activations[pi] = dict()
            for invocation in partition.invocations:
                activations[pi][invocation.order] = dict()
                invocations[invocation.order] = invocation

        invocations = OrderedDict(sorted(invocations.items(), key=lambda entry: entry[0]))

        return (invocations, activations)

    def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None:
        """The event loop for the "middle", i.e. neither the head nor the tail"""

        invocations, activations = self.get_invocations_and_activations()

        expected_invocations = len(invocations) * num_microbatch

        self.event_loop_inner(expected_invocations, skip_trackers, activations, invocations, dict())

    def event_loop_inner(
        self,
        expected_invocations: int,
        skip_trackers: List[SkipTrackerThroughPotals],
        activations: Activations,
        invocations: Invocations,
        count_per_order: Dict[int, int],
        *,
        already_received: int = 0,
        ignore_gradients: bool = False,
        event: Optional[Event] = None,
        tail: bool = False,
    ) -> None:
        """The common event loop shared by all stages. This processses
        activations for the forward pass, and if `self.training` is true,
        processes gradients for the backward pass."""

        num_activations = already_received
        if self.training and not ignore_gradients:
            num_gradients = 0
        else:
            num_gradients = expected_invocations

        while num_activations < expected_invocations or num_gradients < expected_invocations:
            if num_activations == expected_invocations and num_gradients == 0 and event is not None:
                # We are ready to do the backward pass, but must wait for
                # PipeRPCWrapper to signal that it is safe to proceed, otherwise
                # deadlock
                event.wait()

            message = self.transport.recv_message_header(EVENT_LOOP_QUEUE)
            args: AsyncMessageBody = message.args

            invocation = invocations[args.order]

            # FIXME(tom) for combining pipeline with megatron, I currently don't
            # control the order of received activations or gradients, so it is
            # possible for a reused ColumnParallelLinear for example to receive
            # a different order of activations w.r.t. the sending stage, which
            # would result in incorrect values being used for the all_gather
            if args.message_type is AsyncMessageType.Activations:
                batch = self.get_batch_from_message(message)

                inv_count, last_order = self.run_invocations_on_batch(
                    batch, invocations, args.order, skip_trackers, activations
                )
                count_per_order[last_order] = inv_count
                num_activations += inv_count
                if tail and invocations[last_order].dest is None:
                    self.prepare_tail_backward(
                        batch, activations, invocations, count_per_order, len(invocations) - inv_count
                    )

                assert num_activations <= expected_invocations

            elif args.message_type is AsyncMessageType.Gradients:
                num_gradients += count_per_order[invocation.order]
                self.perform_backward_for_invocation(self.transport, message, activations, invocation)

    @staticmethod
    def prepare_tail_backward(
        batch: Batch,
        activations: Activations,
        invocations: Invocations,
        count_per_order: Dict[int, int],
        expected_gradients: int,
    ) -> None:
        if expected_gradients > 0:
            grad_fn = next(b.grad_fn for b in batch if b.requires_grad)
            assert grad_fn
            grad_fn.tail_ctx = TailBackwardContext(activations, invocations, count_per_order, expected_gradients)