|
|
|
|
|
|
|
|
|
|
|
from threading import Event, Lock, Thread |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast |
|
|
|
import torch |
|
from torch import nn |
|
from torch.distributed import ProcessGroup, rpc |
|
from torch.distributed.distributed_c10d import _get_global_rank |
|
|
|
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group |
|
|
|
from .async_pipe import AsyncPipe |
|
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors |
|
|
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024 |
|
DEFAULT_MAX_TARGET_POSITIONS = 1024 |
|
|
|
PipeModel: AsyncPipe |
|
PipeResult: TensorOrTensors |
|
|
|
|
|
SizeOrSizes = Union[torch.Size, List[torch.Size]] |
|
DtypeOrDtypes = Union[torch.dtype, List[torch.dtype]] |
|
|
|
|
|
def set_device_based_on_group(group: ProcessGroup) -> None: |
|
|
|
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) |
|
|
|
|
|
def get_shapes(tensor: TensorOrTensors) -> SizeOrSizes: |
|
if isinstance(tensor, torch.Tensor): |
|
return tensor.shape |
|
else: |
|
return [t.shape for t in tensor] |
|
|
|
|
|
def get_dtype(tensor: TensorOrTensors) -> DtypeOrDtypes: |
|
if isinstance(tensor, torch.Tensor): |
|
return tensor.dtype |
|
else: |
|
return [t.dtype for t in tensor] |
|
|
|
|
|
def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: |
|
return [_get_global_rank(group, r) for r in range(group.size())] |
|
|
|
|
|
class PipeBackRedirect(torch.autograd.Function): |
|
@staticmethod |
|
|
|
def forward(ctx, inputs, dest, event, message, transport, futures): |
|
ctx.dest = dest |
|
ctx.event = event |
|
ctx.message = message |
|
ctx.transport = transport |
|
ctx.futures = futures |
|
return inputs |
|
|
|
@staticmethod |
|
|
|
def backward(ctx, *grad): |
|
ctx.message.tensors = tuple(grad) |
|
ctx.transport.send_message(ctx.message, sync=False, skip_header=True) |
|
ctx.event.set() |
|
|
|
return (None, None, None, None, None, None) |
|
|
|
|
|
def callback_with_model(callback: Callable[[Any, AsyncPipe], None], ctx: Any) -> None: |
|
try: |
|
group = get_pipeline_parallel_group() |
|
set_device_based_on_group(group) |
|
|
|
with PipeModel.lock: |
|
callback(ctx, PipeModel) |
|
except Exception as e: |
|
print(f"callback_with_model got {e}") |
|
|
|
|
|
class PipeRPCWrapper(nn.Module): |
|
"""A wrapper for Pipe to control the entire pipeline from a single process. |
|
Typical usecase would have rank 0 construct `PipeRPCWrapper` and run the |
|
training loop as normal, and all other ranks would call |
|
`torch.distributed.rpc.shutdown()` |
|
|
|
To run code on each worker, e.g. to run the optimizer, use `foreach_worker` |
|
""" |
|
|
|
def __init__(self, *args: Any, **kwargs: Any): |
|
super().__init__() |
|
self.group = cast(ProcessGroup, kwargs.get("group")) or get_pipeline_parallel_group() |
|
assert self.group.rank() == 0 |
|
self.lock = Lock() |
|
|
|
if True: |
|
assert ( |
|
self.group == get_pipeline_parallel_group() |
|
), "Can't pickle groups, so group must be `get_pipeline_parallel_group()`" |
|
kwargs["group"] = None |
|
else: |
|
kwargs["group"] = self.group |
|
|
|
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) |
|
|
|
self.model = AsyncPipe(*args, **kwargs) |
|
self.worker_map = kwargs["worker_map"] |
|
self._foreach_worker(self._register_remote_model, args=(args, kwargs)) |
|
self.model.cuda() |
|
|
|
def _get_rpc_name(self, rank: int) -> str: |
|
return self.worker_map[_get_global_rank(self.group, rank)] |
|
|
|
def _foreach_worker(self, callback: Callable, args: Any = None) -> None: |
|
futures = [rpc.rpc_async(self._get_rpc_name(rank), callback, args=args) for rank in range(1, self.group.size())] |
|
futures = [f.wait() for f in futures] |
|
|
|
def foreach_worker( |
|
self, callback: Callable[[Any, AsyncPipe], None], ctx: Any = None, *, include_self: bool = False |
|
) -> None: |
|
"""Call `callback` on each worker with the `ctx` and model local to that |
|
worker. e.g. |
|
def register_optimizer(ctx, model): |
|
args, kwargs = ctx |
|
model.optimizer = torch.optim.SGD(model.parameters(), *args, **kwargs) |
|
|
|
pipe_model = PipeRPCWrapper( ... ) |
|
|
|
pipe_model.foreach_worker( |
|
register_optimizer, |
|
([], {"lr" : 0.01, "momentum" : 0.9}) |
|
) |
|
""" |
|
|
|
self._foreach_worker(callback_with_model, args=(callback, ctx)) |
|
|
|
if include_self: |
|
with self.model.lock: |
|
callback(ctx, self.model) |
|
|
|
def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: |
|
shape = get_shapes(tensor) |
|
dtype = get_dtype(tensor) |
|
|
|
if isinstance(tensor, torch.Tensor): |
|
num_tensors = 1 |
|
else: |
|
num_tensors = len(tensor) |
|
|
|
futures = [ |
|
rpc.rpc_async(self._get_rpc_name(rank), self._model_forward, args=(self.model.training, shape, dtype)) |
|
for rank in range(1, self.group.size()) |
|
] |
|
|
|
if self.model.final_stage: |
|
return self.model(tensor) |
|
else: |
|
event = Event() |
|
t = Thread(target=self._model_forward_first_stage, args=(tensor, event)) |
|
t.start() |
|
|
|
shape, dtype = futures.pop().wait() |
|
dest_rank = self.group.size() - 1 |
|
dest = self._get_rpc_name(dest_rank) |
|
dest_global_rank = _get_global_rank(self.group, dest_rank) |
|
src_global_rank = torch.distributed.get_rank() |
|
queue = EVENT_LOOP_QUEUE |
|
|
|
activations = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors) |
|
grads = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) |
|
|
|
back_fut = rpc.rpc_async( |
|
dest, self._send_result_and_do_backwards, args=(self.model.training, activations, grads) |
|
) |
|
futures.append(back_fut) |
|
|
|
result = self._recv_result(self.model, shape, dtype, activations) |
|
if isinstance(result, torch.Tensor): |
|
result.requires_grad_() |
|
else: |
|
for r in result: |
|
r.requires_grad_() |
|
|
|
assert self.model.pipeline |
|
return PipeBackRedirect.apply( |
|
result, dest_global_rank, event, grads, self.model.pipeline.transport, futures |
|
) |
|
|
|
@property |
|
def final_stage(self) -> bool: |
|
return self.model.final_stage |
|
|
|
@staticmethod |
|
def _recv_result( |
|
model: AsyncPipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage |
|
) -> TensorOrTensors: |
|
group = get_pipeline_parallel_group() |
|
set_device_based_on_group(group) |
|
|
|
assert model.pipeline |
|
transport = model.pipeline.transport |
|
|
|
if isinstance(shapes, torch.Size): |
|
message.tensor_shapes = [cast(torch.Size, shapes)] |
|
message.tensor_dtypes = [cast(torch.dtype, dtypes)] |
|
message = transport.recv_message_tensors(message) |
|
return message.tensors[0] |
|
else: |
|
message.tensor_shapes = cast(List[torch.Size], shapes) |
|
message.tensor_dtypes = cast(List[torch.dtype], dtypes) |
|
message = transport.recv_message_tensors(message) |
|
return message.tensors |
|
|
|
@staticmethod |
|
def _send_result_and_do_backwards(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None: |
|
group = get_pipeline_parallel_group() |
|
set_device_based_on_group(group) |
|
result = PipeResult |
|
model = PipeModel |
|
|
|
if isinstance(result, torch.Tensor): |
|
result = tuple([result]) |
|
|
|
message.tensors = tuple(result) |
|
assert model.pipeline |
|
transport = model.pipeline.transport |
|
transport.send_message(message, sync=False, skip_header=True) |
|
|
|
if training: |
|
grads_message.tensor_shapes = [r.shape for r in result] |
|
grads_message.tensor_dtypes = [r.dtype for r in result] |
|
grads_message = transport.recv_message_tensors(grads_message) |
|
|
|
with model.lock: |
|
torch.autograd.backward(result, grads_message.tensors, retain_graph=True) |
|
|
|
@staticmethod |
|
def _register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: |
|
group = get_pipeline_parallel_group() |
|
set_device_based_on_group(group) |
|
kwargs["group"] = group |
|
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) |
|
model = AsyncPipe(*args, **kwargs) |
|
model.cuda() |
|
global PipeModel |
|
PipeModel = model |
|
|
|
@staticmethod |
|
def _model_forward( |
|
training: bool, shape: torch.Size, dtype: torch.dtype |
|
) -> Optional[Tuple[SizeOrSizes, DtypeOrDtypes]]: |
|
try: |
|
if isinstance(shape, torch.Size): |
|
tensor = torch.empty(shape, dtype=dtype) |
|
else: |
|
tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) |
|
|
|
model = PipeModel |
|
assert model.group |
|
set_device_based_on_group(model.group) |
|
|
|
model.train(training) |
|
result = model(tensor) |
|
if model.final_stage: |
|
global PipeResult |
|
PipeResult = result |
|
return (get_shapes(result), get_dtype(result)) |
|
|
|
return None |
|
except Exception as e: |
|
print(f"_model_forward got {e}") |
|
raise e |
|
|
|
def _model_forward_first_stage(self, tensor: TensorOrTensors, event: Event) -> None: |
|
try: |
|
assert self.model.group |
|
set_device_based_on_group(self.model.group) |
|
self.model(tensor, event=event) |
|
except Exception as e: |
|
print(f"_model_forward got {e}") |
|
raise e |
|
|