|
|
|
|
|
|
|
|
|
|
|
from abc import ABC |
|
from dataclasses import dataclass |
|
from queue import Empty as QueueEmpty |
|
from queue import Queue |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
|
|
from fairscale.internal.object import pyobject_to_tensor, tensor_to_pyobject |
|
from fairscale.nn.model_parallel import get_pipeline_parallel_group |
|
|
|
from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors |
|
|
|
MESSAGE_TENSOR_SIZE = 1024 |
|
|
|
MessageQueues: List[Queue] = [Queue() for _ in range(MESSAGE_GENERATION_START)] |
|
|
|
|
|
def to_input_device(tensors: Tensors, input_device: InputDevice) -> Tensors: |
|
if input_device is None: |
|
return tensors |
|
else: |
|
return tuple(t.to(input_device) for t in tensors) |
|
|
|
|
|
def rpc_push_queue(message: PipeMessage) -> None: |
|
globals()["MessageQueues"][message.queue_name].put(message) |
|
|
|
|
|
@dataclass(frozen=True) |
|
class Transport(ABC): |
|
worker_map: Optional[Dict[int, str]] |
|
input_device: InputDevice |
|
|
|
def recv_message(self, queue_name: int, *, nowait: bool = False) -> PipeMessage: |
|
message = self.recv_message_header(queue_name, nowait) |
|
return self.recv_message_tensors(message) |
|
|
|
def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: |
|
... |
|
|
|
def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: |
|
... |
|
|
|
def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: |
|
... |
|
|
|
def get_out_of_order(self, queue_name: int, index: int) -> Tensors: |
|
... |
|
|
|
|
|
def MakeTransport(use_rpc: bool, worker_map: Optional[Dict[int, str]], input_device: InputDevice) -> Transport: |
|
if use_rpc: |
|
if worker_map is None: |
|
raise ValueError("'RpcTransport' requires 'worker_map' to be set") |
|
return RpcTransport(worker_map, input_device) |
|
else: |
|
return SendRecvTransport(worker_map, input_device) |
|
|
|
|
|
class RpcTransport(Transport): |
|
def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: |
|
message.tensors = tuple(t.cpu() for t in message.tensors) |
|
assert self.worker_map |
|
name = self.worker_map[message.dest] |
|
if sync: |
|
torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,)) |
|
else: |
|
torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,)) |
|
|
|
def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: |
|
queue = MessageQueues[queue_name] |
|
if nowait: |
|
result = queue.get_nowait() |
|
else: |
|
result = queue.get() |
|
result.tensors = to_input_device(result.tensors, self.input_device) |
|
return result |
|
|
|
def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: |
|
|
|
message.tensors = to_input_device(message.tensors, self.input_device) |
|
return message |
|
|
|
def get_out_of_order(self, queue_name: int, index: int) -> Tensors: |
|
"""Receive a message with a known microbatch index, and handle out-of-order |
|
messages by placing them back on the queue""" |
|
|
|
queue = globals()["MessageQueues"][queue_name] |
|
out_of_order: List[PipeMessage] = [] |
|
while True: |
|
message = self.recv_message(queue_name) |
|
got_index = message.args |
|
value = message.tensors |
|
if got_index == index: |
|
for b in out_of_order: |
|
queue.put(b) |
|
return value |
|
else: |
|
out_of_order.append(message) |
|
|
|
|
|
class SendRecvTransport(Transport): |
|
def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: |
|
tensors = message.tensors |
|
message.tensors = tuple() |
|
torch.cuda.current_stream().synchronize() |
|
if not skip_header: |
|
message.tensor_shapes = [t.size() for t in tensors] |
|
message.tensor_dtypes = [t.dtype for t in tensors] |
|
torch.distributed.send( |
|
pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(), |
|
message.dest, |
|
tag=message.queue_name, |
|
group=get_pipeline_parallel_group(), |
|
) |
|
for index, t in enumerate(tensors): |
|
if t.device.type == "cpu": |
|
t = t.cuda() |
|
torch.distributed.send( |
|
t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group() |
|
) |
|
|
|
def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: |
|
|
|
if nowait: |
|
raise QueueEmpty |
|
tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device) |
|
torch.cuda.current_stream().synchronize() |
|
torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group()) |
|
torch.cuda.current_stream().synchronize() |
|
return tensor_to_pyobject(tensor) |
|
|
|
def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: |
|
torch.cuda.current_stream().synchronize() |
|
|
|
message_tensors = [] |
|
for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): |
|
t = torch.empty(*shape, dtype=dtype, device=self.input_device) |
|
torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) |
|
message_tensors.append(t) |
|
|
|
message.tensors = tuple(message_tensors) |
|
|
|
torch.cuda.current_stream().synchronize() |
|
return message |
|
|
|
def get_out_of_order(self, queue_name: int, index: int) -> Tensors: |
|
"""Receive a message with a known microbatch index, and handle out-of-order |
|
messages by placing them back on the queue""" |
|
|
|
message = self.recv_message(queue_name) |
|
assert message.args == index |
|
return message.tensors |
|
|