Spaces:
Sleeping
Sleeping
File size: 2,521 Bytes
ca5c473 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
from torch import nn
from transformers import AutoModel
class BiLSTMAttentionBERT(nn.Module):
def __init__(self,
hidden_dim=256,
num_classes=22, # Based on the label distribution
num_layers=2, # Multiple LSTM layers
dropout=0.1):
super().__init__()
# Load BioBERT instead of BERT
self.bert_model = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
bert_dim = self.bert_model.config.hidden_size # Still 768 for BioBERT basee
# Dropout for BERT outputs
self.dropout_bert = nn.Dropout(dropout)
# Multi-layer BiLSTM
self.lstm = nn.LSTM(
input_size=bert_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
bidirectional=True,
batch_first=True,
dropout=dropout if num_layers > 1 else 0
)
# Multi-head attention
self.attention = nn.MultiheadAttention(
embed_dim=hidden_dim * 2, # *2 for bidirectional
num_heads=1,
dropout=dropout,
batch_first=True
)
# Regularization layers
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout + 0.1)
self.layer_norm = nn.LayerNorm(hidden_dim * 2)
self.batch_norm = nn.BatchNorm1d(hidden_dim * 2)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.BatchNorm1d(hidden_dim),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, input_ids, attention_mask):
# BERT encoding
bert_output = self.bert_model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
sequence_output = self.dropout_bert(bert_output.last_hidden_state)
# BiLSTM processing
lstm_out, _ = self.lstm(sequence_output)
lstm_out = self.layer_norm(lstm_out)
# Self-attention
attn_out, _ = self.attention(
query=lstm_out,
key=lstm_out,
value=lstm_out,
need_weights=False
)
# Pooling and normalization
pooled = torch.mean(attn_out, dim=1)
pooled = self.batch_norm(pooled)
pooled = self.dropout2(pooled)
# Classification
return self.classifier(pooled) |