File size: 6,155 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC
from dataclasses import dataclass
from queue import Empty as QueueEmpty
from queue import Queue
from typing import Dict, List, Optional

import torch

from fairscale.internal.object import pyobject_to_tensor, tensor_to_pyobject
from fairscale.nn.model_parallel import get_pipeline_parallel_group

from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors

MESSAGE_TENSOR_SIZE = 1024

MessageQueues: List[Queue] = [Queue() for _ in range(MESSAGE_GENERATION_START)]


def to_input_device(tensors: Tensors, input_device: InputDevice) -> Tensors:
    if input_device is None:
        return tensors
    else:
        return tuple(t.to(input_device) for t in tensors)


def rpc_push_queue(message: PipeMessage) -> None:
    globals()["MessageQueues"][message.queue_name].put(message)


@dataclass(frozen=True)
class Transport(ABC):
    worker_map: Optional[Dict[int, str]]
    input_device: InputDevice

    def recv_message(self, queue_name: int, *, nowait: bool = False) -> PipeMessage:
        message = self.recv_message_header(queue_name, nowait)
        return self.recv_message_tensors(message)

    def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
        ...

    def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
        ...

    def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
        ...

    def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
        ...


def MakeTransport(use_rpc: bool, worker_map: Optional[Dict[int, str]], input_device: InputDevice) -> Transport:
    if use_rpc:
        if worker_map is None:
            raise ValueError("'RpcTransport' requires 'worker_map' to be set")
        return RpcTransport(worker_map, input_device)
    else:
        return SendRecvTransport(worker_map, input_device)


class RpcTransport(Transport):
    def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
        message.tensors = tuple(t.cpu() for t in message.tensors)
        assert self.worker_map
        name = self.worker_map[message.dest]
        if sync:
            torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,))
        else:
            torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,))

    def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
        queue = MessageQueues[queue_name]
        if nowait:
            result = queue.get_nowait()
        else:
            result = queue.get()
        result.tensors = to_input_device(result.tensors, self.input_device)
        return result

    def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
        # Tensors already contained within message
        message.tensors = to_input_device(message.tensors, self.input_device)
        return message

    def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
        """Receive a message with a known microbatch index, and handle out-of-order
        messages by placing them back on the queue"""

        queue = globals()["MessageQueues"][queue_name]
        out_of_order: List[PipeMessage] = []
        while True:
            message = self.recv_message(queue_name)
            got_index = message.args
            value = message.tensors
            if got_index == index:
                for b in out_of_order:
                    queue.put(b)
                return value
            else:
                out_of_order.append(message)


class SendRecvTransport(Transport):
    def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
        tensors = message.tensors
        message.tensors = tuple()
        torch.cuda.current_stream().synchronize()
        if not skip_header:
            message.tensor_shapes = [t.size() for t in tensors]
            message.tensor_dtypes = [t.dtype for t in tensors]
            torch.distributed.send(
                pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(),
                message.dest,
                tag=message.queue_name,
                group=get_pipeline_parallel_group(),
            )
        for index, t in enumerate(tensors):
            if t.device.type == "cpu":
                t = t.cuda()
            torch.distributed.send(
                t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group()
            )

    def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
        # FIXME(handle nowait)
        if nowait:
            raise QueueEmpty
        tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device)
        torch.cuda.current_stream().synchronize()
        torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group())
        torch.cuda.current_stream().synchronize()
        return tensor_to_pyobject(tensor)

    def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
        torch.cuda.current_stream().synchronize()

        message_tensors = []
        for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)):
            t = torch.empty(*shape, dtype=dtype, device=self.input_device)
            torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group())
            message_tensors.append(t)

        message.tensors = tuple(message_tensors)

        torch.cuda.current_stream().synchronize()
        return message

    def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
        """Receive a message with a known microbatch index, and handle out-of-order
        messages by placing them back on the queue"""

        message = self.recv_message(queue_name)
        assert message.args == index
        return message.tensors