jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import deepspeed.comm as dist
class _VocabSequenceParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_seq_parallel_logits, target, sp_group):
# vocab_seq_parallel_logits: [S/P, B, V]
# target: [S/P, B]
# return: [S, B]
# Need softmax for backward
softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1)
ctx.vocab_size = vocab_seq_parallel_logits.size(2)
loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none')
sp_world_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
ctx.sp_world_size = sp_world_size
ctx.sp_rank = sp_rank
ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_world_size
batch_size = vocab_seq_parallel_logits.size(1)
loss_all = torch.empty(ctx.seqlen,
batch_size,
dtype=vocab_seq_parallel_logits.dtype,
device=vocab_seq_parallel_logits.device)
dist.all_gather_into_tensor(loss_all, loss, group=sp_group)
ctx.save_for_backward(softmax, target)
return loss_all
@staticmethod
def backward(ctx, grad_output):
softmax, target = ctx.saved_tensors
step_seqlen = ctx.seqlen // ctx.sp_world_size
sp_rank = ctx.sp_rank
grad_output_part = grad_output[step_seqlen * sp_rank:step_seqlen * (sp_rank + 1), :]
grad_input = softmax
grad_2d = grad_input.view(-1, ctx.vocab_size)
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, target.view(-1)] -= 1
grad_input.mul_(grad_output_part.unsqueeze(dim=-1))
return grad_input, None, None, None
def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group):
return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group)