import torch import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2 from torch import nn from transformers import Wav2Vec2Model class Wav2Vec2EncoderLayer(nn.Module): def __init__( self, config, i ): super().__init__() self.attention = w2v2.Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = w2v2.Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.config = config self.i = i def forward(self, hidden_states, attention_mask=None, output_attentions=False): attn_residual = hidden_states hidden_states, attn_weights, _ = self.attention( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class Wav2VecWrapper(nn.Module): def __init__( self, config, ): super(Wav2VecWrapper, self).__init__() self.config = config self.backbone_model = Wav2Vec2Model.from_pretrained( config._name_or_path, output_hidden_states=config.output_hidden_states, ) state_dict = self.backbone_model.state_dict() self.model_config = self.backbone_model.config self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)]) def forward(self, input_features: torch.Tensor, length: torch.Tensor = None, ): with torch.no_grad(): hidden_states = self.backbone_model.feature_extractor(input_features) hidden_states = hidden_states.transpose(1, 2) hidden_states, _ = self.backbone_model.feature_projection(hidden_states) if length is not None: length = self.get_feat_extract_output_lengths(length.detach().cpu()) hidden_states = self.backbone_model.encoder( hidden_states, output_hidden_states=self.config.output_hidden_states ).hidden_states return {'encoder_hidden_states': hidden_states, 'length': length} def get_feat_extract_output_lengths(self, input_length): def _conv_out_length(input_length, kernel_size, stride): return (input_length - kernel_size) // stride + 1 for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): input_length = _conv_out_length(input_length, kernel_size, stride) return input_length def prepare_mask(length, shape, dtype): mask = torch.zeros( shape, dtype=dtype ) mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1 mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool() return mask