File size: 3,590 Bytes
8f96165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
from torch import nn
from transformers import Wav2Vec2Model

class Wav2Vec2EncoderLayer(nn.Module):
    def __init__(
        self,
        config,
        i
    ):
        super().__init__()
        self.attention = w2v2.Wav2Vec2Attention(
            embed_dim=config.hidden_size,
            num_heads=config.num_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=False,
        )
        self.dropout = nn.Dropout(config.hidden_dropout)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.feed_forward = w2v2.Wav2Vec2FeedForward(config)
        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.config = config
        self.i = i

    def forward(self, hidden_states, attention_mask=None, output_attentions=False):
        attn_residual = hidden_states
        
        hidden_states, attn_weights, _ = self.attention(
            hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
        )
        hidden_states = self.dropout(hidden_states)
        hidden_states = attn_residual + hidden_states
        hidden_states = self.layer_norm(hidden_states)
        hidden_states = hidden_states + self.feed_forward(hidden_states)            
        hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)
        return outputs

class Wav2VecWrapper(nn.Module):
    def __init__(
        self,
        config,
    ):
        super(Wav2VecWrapper, self).__init__()
        self.config = config

        self.backbone_model = Wav2Vec2Model.from_pretrained(
            config._name_or_path,
            output_hidden_states=config.output_hidden_states,
        )
        state_dict = self.backbone_model.state_dict()

        self.model_config = self.backbone_model.config
        self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)])
        
    def forward(self,
                input_features: torch.Tensor,
                length: torch.Tensor = None,
            ):
        with torch.no_grad():
            hidden_states = self.backbone_model.feature_extractor(input_features)
            hidden_states = hidden_states.transpose(1, 2)
            hidden_states, _ = self.backbone_model.feature_projection(hidden_states) 
        
        if length is not None:
            length = self.get_feat_extract_output_lengths(length.detach().cpu())

        hidden_states = self.backbone_model.encoder(
            hidden_states,
            output_hidden_states=self.config.output_hidden_states
        ).hidden_states
        
        return {'encoder_hidden_states': hidden_states, 'length': length}
        
    def get_feat_extract_output_lengths(self, input_length):
        def _conv_out_length(input_length, kernel_size, stride):
            return (input_length - kernel_size) // stride + 1
        for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
            input_length = _conv_out_length(input_length, kernel_size, stride)
        return input_length

def prepare_mask(length, shape, dtype):
    mask = torch.zeros(
        shape, dtype=dtype
    )
    mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
    mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
    return mask