jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed import comm as dist
from torch import nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.runtime.zero.utils import is_zero_param
from abc import ABC, abstractmethod
from typing import Iterable, Any, Optional, List, Tuple
from .fusedqkv_utils import shard_value_with_share_qk, shard_chunk_mlp, prepare_tp_fused_qkvw
from deepspeed.runtime.tensor_parallel import AUTOTP_MODE
from copy import deepcopy
from typing import Union
__all__ = [
"TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce",
"Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer"
]
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
DS_IS_REPLACED_MODULE = 'ds_is_replaced_module'
DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel'
def get_auto_tp_mode():
global DEEPSPEED_AUTOTP_MODE
return DEEPSPEED_AUTOTP_MODE
def is_autotp_training_mode():
global DEEPSPEED_AUTOTP_MODE
return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING
def set_autotp_mode(training=False):
"""
Set the DEEPSPEED_AUTOTP_MODE based on the training flag
"""
global DEEPSPEED_AUTOTP_MODE
if training:
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.TRAINING
else:
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
def add_bias(input, bias):
if bias is None:
return input
if is_autotp_training_mode():
# Training mode - avoid inplace to ensure correct autograd
input = input + bias
return input
else:
input += bias
return input
class RowParallel(torch.autograd.Function):
"""
A custom autograd function for performing row-wise parallelism.
"""
@staticmethod
def symbolic(graph, input):
"""Symbolic function for tracing."""
return input
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, is_inference_mode: bool) -> torch.Tensor:
"""
Forward pass.
"""
ctx.group = group
if group == None:
return input
if is_inference_mode:
dist.inference_all_reduce(input, group=group)
else:
dist.all_reduce(input.contiguous(), group=group)
return input
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None]:
"""
Backward pass.
"""
return None, grad_output, None
class AsyncColumnParallel(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor:
"""
Forward pass.
"""
ctx.use_bias = bias is not None
ctx.group = group
output = torch.matmul(input, weight.transpose(-1, -2))
if bias is not None:
output = add_bias(output, bias)
ctx.save_for_backward(input, weight)
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
input, weight = ctx.saved_tensors
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
grad_bias = grad_output.sum(0) if ctx.use_bias else None
handle.wait()
return None, grad_input, grad_weight, grad_bias
class ColumnParallel(torch.autograd.Function):
"""
Custom autograd function for column-wise parallelism.
"""
@staticmethod
def symbolic(graph, input):
"""Symbolic function for tracing."""
return dist.all_reduce(input.contiguous(), dist.get_tensor_model_parallel_group())
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
"""
ctx.group = group
return input
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
"""
Backward pass.
"""
if ctx.group == None:
return None, grad_output
dist.all_reduce(grad_output.contiguous(), group=ctx.group)
return None, grad_output
class TensorParallel_Layer(nn.Module, ABC):
"""
A base class for model layers with tensor parallelism support.
This class is designed to be extended by specific layers that require distributed
operations and parameter gather/partitioning during inference or training.
Attributes:
mode (str): The mode of operation[INFERENCE or TRAINING], default is "INFERENCE".
mp_group (Optional[dist.ProcessGroup]): The process group used for model parallelism.
tp_world_size (int): The world size of tensor parallelism, i.e., the number of parallel workers.
tp_index (int): The rank (ID) of the current worker in tensor parallelism.
support_training (bool): Flag indicating whether the layer supports training (default: False).
name (Optional[str]): The name of the layer, if provided.
"""
##### Initialize Parameter List #####
# keep_module_on_host determines whether to keep the module on the host.
# Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
# so an additional copy is unnecessary.
keep_module_on_host: bool = False
##### Runtime Parameter List #####
tp_overlap_comm: bool = False
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
"""
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
Args:
mp_group (Optional[dist.ProcessGroup]): The process group for model parallelism.
If None, no model parallelism is set.
"""
super().__init__()
self.support_training: bool = False
if mp_group is not None:
self.mp_group = mp_group
self.tp_world_size: int = dist.get_world_size(self.mp_group)
self.tp_index: int = dist.get_rank(mp_group)
# backward compatibility
self.world_size = self.tp_world_size
self.rank = self.tp_index
self.name = getattr(self, 'name', None)
if kwargs.get('name') is not None:
self.name = kwargs.get('name') # Set the layer name if provided.
@classmethod
def set_keep_module_on_host(cls, value: bool):
"""
Set the static variable keep_module_on_host.
Args:
value (bool): The new value for keep_module_on_host.
"""
cls.keep_module_on_host = value
@abstractmethod
def forward(self, input):
"""
Forward pass method. Must be implemented by subclasses to define layer-specific operations.
"""
pass
@abstractmethod
def gather_params(self, params_list):
"""
Gathers parameters across devices for distributed training. Must be implemented by subclasses in "TRAINING" mode.
"""
pass
@abstractmethod
def _tp_partition(self, params_list: List[torch.Tensor]):
"""
Partitions the parameters for tensor parallelism.
It is necessary to ensure that this function only involves the logic of params partitioning.
"""
pass
def config_requires_grad(self, weight):
if weight is not None:
if self.is_training_mode():
if weight.requires_grad is None:
weight.requires_grad = True
else:
weight.requires_grad = False
def config_tp_params(self, weight):
"""
Configures the weight tensor for training with tensor parallelism. This includes enabling gradients
and associating necessary methods for parameter gathering and partitioning.
Args:
weight (Optional[torch.Tensor]): The weight tensor to configure for tensor parallelism.
If None, no action is taken.
"""
# # The RNG states have already been synchronized in init_inference.
if self.is_training_mode():
assert self.support_training, "No implementation of backward."
if weight is not None:
self.config_requires_grad(weight)
weight.gather_params = self.gather_params
weight._tp_partition = self._tp_partition
setattr(weight, DS_TENSOR_MODEL_PARALLEL, True)
setattr(weight, DS_IS_REPLACED_MODULE, True)
def is_training_mode(self):
global DEEPSPEED_AUTOTP_MODE
return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING
def __deepcopy__(self, memo):
# This function is designed for
# 'mp_group' (a 'ProcessGroup') cannot be pickled during deepcopy in some usage.
cls = self.__class__
new_obj = cls.__new__(cls)
for key, value in vars(self).items():
if key == 'mp_group':
new_obj.mp_group = self.mp_group
else:
setattr(new_obj, key, deepcopy(value, memo))
memo[id(self)] = new_obj
return new_obj
def extra_repr(self):
out_features, in_features = None, None
if self.weight is not None:
out_features, in_features = self.weight.ds_shape[-2:] if is_zero_param(
self.weight) else self.weight.shape[-2:]
dtype = self.weight.dtype if self.weight is not None else None
return "in_features={}, out_features={}, bias={}, dtype={}".format(in_features, out_features, self.bias
is not None, dtype)
def move(self, tensor):
# TODO: consider the timing of deletion
# to save host resources when DP > 1。
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
if tensor.is_meta:
# Keep tensor in meta device if tensor is meta.
return tensor
else:
device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name()
return_new_copy = not self.__class__.keep_module_on_host
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
cloned_tensor = tensor.to(device, copy=return_new_copy)
if return_new_copy:
# free the memory of the original tensor to reduce memory peak
# Equivalent to directly deleting the tensor reference outside the function.
# see https://github.com/microsoft/DeepSpeed/pull/4353
tensor.data = torch.empty(0, device=tensor.device)
return cloned_tensor
def configure_tensor_parallel_runtime(config):
runtime_keys = ['tp_overlap_comm']
for key in runtime_keys:
if hasattr(config, key):
setattr(TensorParallel_Layer, key, getattr(config, key))
class GatherReplacedLayerParams:
"""
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
based on the configuration of the model.
"""
def __init__(self,
params: Union[Iterable[torch.Tensor], torch.Tensor],
module: torch.nn.Module,
enabled: bool = True):
"""
Initialize the context manager to handle parameter gathering and partitioning for a replaced layer.
Args:
params (Iterable or torch.Tensor): A collection or single parameter to manage.
module (torch.nn.Module): The module that these parameters belong to.
enabled (bool): Flag indicating whether the parameter management is enabled (default: True).
"""
self.enabled = enabled
self.module = module
if not enabled:
return
# Ensure params is a list, whether it's a single param or iterable (e.g., model.parameters())
if isinstance(params, Iterable) and not isinstance(params, torch.Tensor):
self.params: List[torch.Tensor] = list(params) # Convert generators to a list for multiple iterations
else:
self.params: List[torch.Tensor] = [params] # Wrap single parameter in a list for uniform processing
# Check if the parameters belong to a replaced layer (indicated by a specific attribute)
if not any(self._is_replaced_module_weight(p) for p in params):
self.enabled = False
return
def _is_replaced_module_weight(self, param: torch.Tensor) -> bool:
"""
Helper function to determine if a parameter belongs to a replaced module.
Args:
param (torch.Tensor): The parameter to check.
Returns:
bool: True if the parameter belongs to a replaced module, False otherwise.
"""
return getattr(param, DS_IS_REPLACED_MODULE, False)
def __enter__(self) -> None:
"""
Enter the context manager. If enabled, gather parameters for the replaced module.
"""
if self.enabled:
self.params[0].gather_params(self.params)
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Exit the context manager. If enabled, partition the parameters for the replaced module.
"""
#TODO : Check whether there are any missing attributes.
if self.enabled:
self.params[0]._tp_partition(self.params)
class LinearAllreduce(TensorParallel_Layer):
def __init__(self, module, mp_group, **kwargs):
super(LinearAllreduce, self).__init__(mp_group, **kwargs)
self.weight = module.weight
self.bias = module.bias
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
# bias here is not tp params
self.config_requires_grad(self.bias)
def forward(self, input):
output = torch.matmul(input, self.weight.transpose(-1, -2))
output = RowParallel.apply(self.mp_group, output, not self.is_training_mode())
if self.bias is not None:
output = add_bias(output, self.bias)
return output
@torch.no_grad()
def gather_params(self, params_list):
for idx, param in enumerate(params_list):
if param is None or idx > 0:
# don't gather bias
return
params_list[idx].data_partition = param.data
param = param.transpose(0, 1).contiguous()
output_param = torch.empty(self.tp_world_size * param.shape[0],
param.shape[1],
dtype=param.dtype,
device=param.device)
dist.all_gather_into_tensor(output_param, param, group=self.mp_group)
params_list[idx].data = output_param.transpose(0, 1).contiguous()
return
@torch.no_grad()
def _tp_partition(self, params_list):
if not self.is_training_mode():
self.uneven_partition(params_list)
return
else:
for idx, param in enumerate(params_list):
if param is None:
# don't slipt bias
return
if idx > 0: # move bias to device at initialization
_partition = self.move(param).detach()
params_list[idx].data = _partition
return
_partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index]
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
def uneven_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None or idx > 0:
# don't slipt bias
return
assert self.name is not None, "The module name must be provided in the initialization."
_partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[1], self.tp_world_size,
self.name),
dim=1)[self.tp_index]
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
#remove kwargs from partition.
class LinearLayer(TensorParallel_Layer):
def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
super(LinearLayer, self).__init__(mp_group, **kwargs)
self.weight = module.weight
self.bias = module.bias
if not skip_partition:
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
self.config_tp_params(self.bias)
def forward(self, input):
if not self.__class__.tp_overlap_comm:
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output = add_bias(output, self.bias)
else:
output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)
return output
@torch.no_grad()
def gather_params(self, params_list):
# Does not support uneven shard.
for idx, param in enumerate(params_list):
params_list[idx].data_partition = param.data
output_param = torch.empty((self.tp_world_size * param.shape[0], *param.shape[1:]),
dtype=param.dtype,
device=param.device)
dist.all_gather_into_tensor(output_param, param, group=self.mp_group)
params_list[idx].data = output_param.contiguous()
@torch.no_grad()
def _tp_partition(self, params_list):
if not self.is_training_mode():
self.uneven_partition(params_list)
return
for idx, param in enumerate(params_list):
if param is None:
return
#split bias if provide
_partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index]
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
def uneven_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
#split bias if provide
return
assert self.name is not None, "The module name must be provided in the initialization."
_partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[0], self.tp_world_size,
self.name),
dim=0)[self.tp_index]
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
# for bwc
@classmethod
def from_weights(cls, weight_shape=None, dtype=torch.half, weight=None, bias=None):
if weight is not None:
in_features = weight.shape[1]
out_features = weight.shape[0]
linear = nn.Linear(in_features, out_features, bias=(bias is not None))
linear.weight.data = weight
if bias is not None:
linear.bias.data = bias
else:
in_features = weight_shape[1]
out_features = weight_shape[0]
linear = nn.Linear(in_features, out_features, bias=(bias is not None))
return cls(linear, skip_partition=True)
class FusedModuleWrapper:
def __init__(self, fused_module: nn.Module):
self.fused_module = fused_module
def __getattr__(self, module):
return self.fused_module
class fused_LinearLayer(LinearLayer):
def __init__(self, module, mp_group, skip_partition=False, **kwargs):
assert kwargs.get('fused_module') is not None, "'fused_module' is required but not provided"
# Use the warp class to avoid module circular references.
self.fused_module = FusedModuleWrapper(kwargs.get('fused_module'))
super().__init__(module, mp_group, skip_partition, **kwargs)
@torch.no_grad()
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
_partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index)
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
class conv_LinearLayer(LinearLayer):
@torch.no_grad()
def _tp_partition(self, params_list):
weight = None
bias = None
if len(params_list) == 1:
weight = params_list[0]
elif len(params_list) == 2:
weight, bias = params_list[0], params_list[1]
_partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name),
dim=1)[self.tp_index]
_partition = self.move(_partition).detach()
weight.data = _partition
if bias is not None:
_partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name),
dim=0)[self.tp_index]
_partition = self.move(_partition).detach()
bias.data = _partition
#override the subclasses related to weight splitting.
class Yuan_LinearAllreduce(LinearAllreduce):
#Yuan2
@torch.no_grad()
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, False)
params_list[0].data = weight
if bias is not None:
params_list[1].data = bias
class Yuan_LinearLayer(LinearLayer):
#Yuan2
@torch.no_grad()
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, True)
params_list[0].data = self.move(weight).detach()
if bias is not None:
params_list[1].data = self.move(bias).detach()
class GateUpPack_LinearLayer(LinearLayer):
# chatGLM2, chatGLM2
@torch.no_grad()
def _tp_partition(self, params_list):
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
params_list[0].data = self.move(weight).detach()
if bias is not None:
params_list[1].data = self.move(bias).detach()
class Conv_LinearALlreduce(LinearAllreduce):
@torch.no_grad()
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
param.data = param.data.transpose(-1, -2).contiguous()
_partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name),
dim=1)[self.tp_index]
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
#override the subclasses related to fwd/bwd.
class LmHeadLinearAllreduce(LinearAllreduce):
def __init__(self, module, mp_group, **kwargs):
# set the fixed name before partition
self.name = "lm_head"
# In some tied_embedding cases, only the lm head is sharded, while the word embedding is not.
# Reinitialization is used to decouple them and prevent the word embedding from being sharded.
# This should also be effective for cases where both are sharded in tied_embedding scenarios.
# TODO: Training scenario-related tests, is it necessary to re-implement the vocab parallel module?
module.weight = nn.Parameter(module.weight.clone().detach())
if hasattr(module, 'bias') and module.bias is not None:
module.bias = nn.Parameter(module.bias.clone().detach())
super().__init__(module, mp_group, **kwargs)
def forward(self, input):
input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head")
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index])
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
dist.inference_all_reduce(output, group=self.mp_group)
if self.bias is not None:
output = add_bias(output, self.bias)
return output
class TensorParallelConv2d(nn.Module):
def __init__(self, conv, rank, world_size, shard_by_oc):
super().__init__()
self.rank = rank
self.world_size = world_size
self.shard_by_oc = shard_by_oc
self.shard_weights(conv)
# Split along the input/output channel depending on whether it is the last conv layer.
def shard_weights(self, conv):
if self.shard_by_oc:
total_size = conv.weight.shape[0]
else:
total_size = conv.weight.shape[1]
bias_data = None
cols_per_rank = [0]
for i in range(self.world_size - 1, -1, -1):
cols = total_size // self.world_size
if i < total_size % self.world_size:
cols += 1
cols_per_rank.append(cols_per_rank[-1] + cols)
weight_data = conv.weight.data
if self.shard_by_oc:
# not last conv layer, split output channel
weight_data = weight_data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
else:
# last conv layer, split input channel
weight_data = weight_data[:, cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data / float(self.world_size)
self.conv = nn.Conv2d(weight_data.shape[1], weight_data.shape[0], conv.kernel_size, conv.stride, conv.padding,
conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode)
self.conv.weight = torch.nn.Parameter(weight_data)
if conv.bias is not None:
self.conv.bias = torch.nn.Parameter(bias_data)
del conv
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.conv(input)
class TensorParallelOcShardConv2d(TensorParallelConv2d):
def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, True)
class TensorParallelIcShardConv2d(TensorParallelConv2d):
def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, False)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.conv(input)
if self.world_size > 1:
dist.inference_all_reduce(out)
return out
class Normalize(nn.Module):
def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None):
super(Normalize, self).__init__()
if weight is not None:
self.weight = weight
self.bias = bias
else:
self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
self.weight = self.norm.weight
self.bias = self.norm.bias
self.eps = eps
def forward(self, input):
return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps)
class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(EmbeddingLayer, self).__init__()
if weight is None:
self.weight = Parameter(
torch.empty(weight_shape[0],
weight_shape[1],
dtype=dtype,
device=get_accelerator().current_device_name()))
else:
self.weight = weight
def forward(self, input):
return F.embedding(input, self.weight)
class OPTEmbedding(EmbeddingLayer):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape=None, weight=None, bias=None):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape, weight=weight)
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return super().forward(positions + self.offset)
class RMSNormalize(nn.Module):
def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
super(RMSNormalize, self).__init__()
if weight is not None:
self.weight = weight
else:
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=get_accelerator().current_device_name()))
self.eps = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return hidden_states * self.weight