|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import deepspeed |
|
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank, |
|
bwc_tensor_model_parallel_group) |
|
|
|
|
|
def _gather_tokens(input_, dim=0): |
|
"""Gather tensors and concatenate them along a dimension""" |
|
mpu = deepspeed.utils.groups.mpu |
|
|
|
input_ = input_.contiguous() |
|
world_size = bwc_tensor_model_parallel_world_size(mpu) |
|
if world_size == 1: |
|
return input_ |
|
|
|
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device) |
|
deepspeed.comm.all_gather_into_tensor(gather_buffer, input_, group=bwc_tensor_model_parallel_group(mpu)) |
|
if dim == 0: |
|
shape = list(input_.size()) |
|
shape[0] = shape[0] * world_size |
|
output = gather_buffer.view(shape) |
|
else: |
|
tensor_list = [ |
|
gather_buffer.narrow(0, |
|
input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) |
|
] |
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
return output |
|
|
|
|
|
def _drop_tokens(input_, dim=0): |
|
"""Divide a tensor among the tensor parallel ranks""" |
|
mpu = deepspeed.utils.groups.mpu |
|
|
|
total_chunks = bwc_tensor_model_parallel_world_size(mpu) |
|
if total_chunks == 1: |
|
return input_ |
|
this_chunk = bwc_tensor_model_parallel_rank(mpu) |
|
assert input_.shape[ |
|
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" |
|
chunk_size = input_.shape[dim] // total_chunks |
|
|
|
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) |
|
|
|
|
|
class _GatherTokens(torch.autograd.Function): |
|
"""All gather tokens among the tensor parallel ranks""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_, dim): |
|
return _gather_tokens(input_, dim) |
|
|
|
@staticmethod |
|
def forward(ctx, input_, dim): |
|
ctx.dim = dim |
|
return _gather_tokens(input_, dim) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _drop_tokens(grad_output, ctx.dim), None |
|
|
|
|
|
class _DropTokens(torch.autograd.Function): |
|
"Divide tokens equally among the tensor parallel ranks" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_, dim): |
|
return _drop_tokens(input_, dim) |
|
|
|
@staticmethod |
|
def forward(ctx, input_, dim): |
|
ctx.dim = dim |
|
return _drop_tokens(input_, dim) |
|
|
|
@staticmethod |
|
def backward(ctx, input_): |
|
return _gather_tokens(input_, ctx.dim), None |
|
|
|
|
|
def gather_tokens(input_, dim=0): |
|
mpu = deepspeed.utils.groups.mpu |
|
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: |
|
|
|
return input_ |
|
return _GatherTokens.apply(input_, dim) |
|
|
|
|
|
def drop_tokens(input_, dim=0): |
|
mpu = deepspeed.utils.groups.mpu |
|
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: |
|
|
|
return input_ |
|
return _DropTokens.apply(input_, dim) |
|
|