# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn as nn from fairscale.nn.moe.moe_layer import MOELayer from fairscale.nn.moe.top2gate import Top2Gate # TODO(anj-s): Identify if we need this initialization logic for the below wrapped layers. class EmbeddingLayer(nn.Embedding): """Wrapped nn.Embedding layer to allow for weight initialization.""" def __init__(self, ntoken, ninp, initrange): super().__init__(ntoken, ninp) self.ninp_sqrt = math.sqrt(ninp) self.weight.data.uniform_(-initrange, initrange) def forward(self, src): return super().forward(src) * self.ninp_sqrt class PositionalEncodingLayer(nn.Module): """PositionalEncoding layer for a given Transformer model.""" def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncodingLayer, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[: x.size(0), :] return self.dropout(x) class FeedForwardLayer(nn.Module): """FeedForward layer for a given Transformer model.""" def __init__(self, d_model, dim_feedforward, activation, dropout) -> None: super(FeedForwardLayer, self).__init__() self.linear1 = nn.Linear(d_model, dim_feedforward) self.activation = activation self.dropout1 = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.dropout2 = nn.Dropout(dropout) def forward(self, x): return self.dropout2(self.linear2(self.dropout1(self.activation(self.linear1(x))))) # Forked from https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer. # Parameters is_moe and num_local_experts are added. class TransformerEncoderLayer(nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu layer_norm_eps: the eps value in layer normalization components (default=1e-5). norm_first: if ``True``, layer norm is done prior to attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``False`` (after). is_moe: if ``True``, the feedforward layer will have MOE enabled. num_local_experts: number of local experts for MOE. Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) """ __constants__ = ["norm_first"] def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), layer_norm_eps=1e-5, norm_first=False, is_moe=False, num_local_experts=1, ): super(TransformerEncoderLayer, self).__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm_first = norm_first self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.dropout = nn.Dropout(dropout) self.is_moe = is_moe if is_moe: world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size() num_global_experts = num_local_experts * world_size self.gate = Top2Gate(d_model, num_global_experts) experts = nn.ModuleList( [FeedForwardLayer(d_model, dim_feedforward, activation, dropout) for _ in range(num_local_experts)] ) self.moe_layer = MOELayer(self.gate, experts) else: self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout) def forward(self, src, src_mask=None, src_key_padding_mask=None): r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block(self, x, attn_mask, key_padding_mask): x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] return self.dropout(x) # feed forward block def _ff_block(self, x): if self.is_moe: return self.moe_layer(x) else: return self.ff_block(x) class TransformerDecoderLayer(TransformerEncoderLayer): """TransformerDecoder layer which inherits from TransformerEncoderLayer.""" def __init__(self, ninp, nhead, nhid, dropout, is_moe=False, num_local_experts=1): super().__init__(ninp, nhead, nhid, dropout, is_moe=is_moe, num_local_experts=num_local_experts) self.src_mask = None def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask def forward(self, src): # TODO(anj-s): Fix the data format so that we have [seq_len, batch_size, embedding dim]. # Currently real data has seq_len as the second dimension and batch_size as the first dimension. # We need to mask the sequence length dimension and not the batch size. if self.src_mask is None or self.src_mask.size(0) != len(src): device = src.device mask = self._generate_square_subsequent_mask(len(src)).to(device) self.src_mask = mask return super().forward(src, self.src_mask) class LinearLayer(nn.Linear): """Wrapped nn.Linear layer to allow for weight initialization.""" def __init__(self, ninp, ntoken, initrange): super().__init__(ninp, ntoken) self.bias.data.zero_() self.weight.data.uniform_(-initrange, initrange) class TransformerLM(nn.Sequential): """A GPT-2 based nn.Sequential language model.""" def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder, is_moe=False, num_local_experts=1): layers = [ EmbeddingLayer(ntokens, ninp, initrange), PositionalEncodingLayer(ninp, dropout), ] for _ in range(ndecoder): layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout, is_moe, num_local_experts)) layers.append(LinearLayer(ninp, ntokens, initrange)) super(TransformerLM, self).__init__(*layers)