|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any |
|
|
|
import torch |
|
|
|
from .initialize import get_model_parallel_group |
|
from .utils import split_tensor_along_last_dim |
|
|
|
|
|
def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor: |
|
"""All-reduce the the input tensor across model parallel group.""" |
|
group = get_model_parallel_group() |
|
|
|
if ctx: |
|
ctx.mark_dirty(input_) |
|
|
|
|
|
if torch.distributed.get_world_size(group=group) == 1: |
|
return input_ |
|
|
|
|
|
torch.distributed.all_reduce(input_, group=group) |
|
|
|
return input_ |
|
|
|
|
|
def _split(input_: torch.Tensor) -> torch.Tensor: |
|
"""Split the tensor along its last dimension and keep the |
|
corresponding slice.""" |
|
group = get_model_parallel_group() |
|
|
|
|
|
if torch.distributed.get_world_size(group=group) == 1: |
|
return input_ |
|
|
|
|
|
world_size = torch.distributed.get_world_size(group=group) |
|
input_list = split_tensor_along_last_dim(input_, world_size) |
|
|
|
|
|
rank = torch.distributed.get_rank(group=group) |
|
output = input_list[rank].contiguous() |
|
|
|
return output |
|
|
|
|
|
def _gather(input_: torch.Tensor) -> torch.Tensor: |
|
"""Gather tensors and concatinate along the last dimension.""" |
|
group = get_model_parallel_group() |
|
|
|
|
|
if torch.distributed.get_world_size(group=group) == 1: |
|
return input_ |
|
|
|
|
|
last_dim = input_.dim() - 1 |
|
rank = torch.distributed.get_rank(group=group) |
|
world_size = torch.distributed.get_world_size(group=group) |
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
tensor_list[rank] = input_ |
|
torch.distributed.all_gather(tensor_list, input_, group=group) |
|
|
|
|
|
output = torch.cat(tensor_list, dim=last_dim).contiguous() |
|
|
|
return output |
|
|
|
|
|
class _CopyToModelParallelRegion(torch.autograd.Function): |
|
"""Pass the input to the model parallel region.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return input_ |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _reduce(None, grad_output) |
|
|
|
|
|
class _ReduceFromModelParallelRegion(torch.autograd.Function): |
|
"""All-redcue the input from the model parallel region.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _reduce(ctx, input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return grad_output |
|
|
|
|
|
class _ScatterToModelParallelRegion(torch.autograd.Function): |
|
"""Split the input and keep only the corresponding chuck to the rank.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _split(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _gather(grad_output) |
|
|
|
|
|
class _GatherFromModelParallelRegion(torch.autograd.Function): |
|
"""Gather the input from model parallel region and concatinate.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _gather(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _split(grad_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: |
|
return _CopyToModelParallelRegion.apply(input_) |
|
|
|
|
|
def reduce_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: |
|
return _ReduceFromModelParallelRegion.apply(input_) |
|
|
|
|
|
def scatter_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: |
|
return _ScatterToModelParallelRegion.apply(input_) |
|
|
|
|
|
def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: |
|
return _GatherFromModelParallelRegion.apply(input_) |
|
|