File size: 2,521 Bytes
ca5c473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torch import nn
from transformers import AutoModel


class BiLSTMAttentionBERT(nn.Module):
    def __init__(self,
                 hidden_dim=256,
                 num_classes=22,  # Based on the label distribution
                 num_layers=2,    # Multiple LSTM layers
                 dropout=0.1):
        super().__init__()

        # Load BioBERT instead of BERT
        self.bert_model = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
        bert_dim = self.bert_model.config.hidden_size  # Still 768 for BioBERT basee
        # Dropout for BERT outputs
        self.dropout_bert = nn.Dropout(dropout)
        # Multi-layer BiLSTM
        self.lstm = nn.LSTM(
            input_size=bert_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )

        # Multi-head attention
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,  # *2 for bidirectional
            num_heads=1,
            dropout=dropout,
            batch_first=True
        )

        # Regularization layers
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout + 0.1)
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)
        self.batch_norm = nn.BatchNorm1d(hidden_dim * 2)

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_ids, attention_mask):
        # BERT encoding
        bert_output = self.bert_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        sequence_output = self.dropout_bert(bert_output.last_hidden_state)

        # BiLSTM processing
        lstm_out, _ = self.lstm(sequence_output)
        lstm_out = self.layer_norm(lstm_out)

        # Self-attention
        attn_out, _ = self.attention(
            query=lstm_out,
            key=lstm_out,
            value=lstm_out,
            need_weights=False
        )

        # Pooling and normalization
        pooled = torch.mean(attn_out, dim=1)
        pooled = self.batch_norm(pooled)
        pooled = self.dropout2(pooled)

        # Classification
        return self.classifier(pooled)