from typing import Optional import torch import torch.nn as nn from torch.nn import functional as F class MultiLevelDownstreamModel(nn.Module): def __init__( self, model_config, use_conv_output: Optional[bool] = True, ): super().__init__() assert model_config.output_hidden_states == True, "The upstream model must return all hidden states" self.model_config = model_config self.use_conv_output = use_conv_output self.model_seq = nn.Sequential( nn.Conv1d(self.model_config.hidden_size, self.model_config.classifier_proj_size, 1, padding=0), nn.ReLU(), nn.Dropout(p=0.1), nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0), nn.ReLU(), nn.Dropout(p=0.1), nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0) ) if self.use_conv_output: num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings self.weights = nn.Parameter(torch.ones(num_layers)/num_layers) else: num_layers = self.model_config.num_hidden_layers self.weights = nn.Parameter(torch.zeros(num_layers)) self.out_layer = nn.Sequential( nn.Linear(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size), nn.ReLU(), nn.Linear(self.model_config.classifier_proj_size, self.model_config.num_labels), ) def forward(self, encoder_hidden_states, length=None): if self.use_conv_output: stacked_feature = torch.stack(encoder_hidden_states, dim=0) else: stacked_feature = torch.stack(encoder_hidden_states, dim=0)[1:] # exclude the convolution output _, *origin_shape = stacked_feature.shape if self.use_conv_output: stacked_feature = stacked_feature.view(self.model_config.num_hidden_layers + 1, -1) else: stacked_feature = stacked_feature.view(self.model_config.config.num_hidden_layers, -1) norm_weights = F.softmax(self.weights, dim=-1) weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) features = weighted_feature.view(*origin_shape) features = features.transpose(1, 2) features = self.model_seq(features) features = features.transpose(1, 2) if length is not None: length = length.cuda() masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1) masks = masks.float() features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1) else: features = torch.mean(features, dim=1) predicted = self.out_layer(features) return predicted