import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, PretrainedConfig class BiLSTMConfig(PretrainedConfig): model_type = "bilstm_attention" def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs): super().__init__(**kwargs) self.hidden_dim = hidden_dim self.num_classes = num_classes self.num_layers = num_layers self.dropout = dropout class BiLSTMAttentionBERT(PreTrainedModel): config_class = BiLSTMConfig base_model_prefix = "bilstm_attention" def __init__(self, config): super().__init__(config) self.config = config self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') self.lstm = nn.LSTM( 768, config.hidden_dim, config.num_layers, batch_first=True, bidirectional=True ) self.dropout = nn.Dropout(config.dropout) self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids, attention_mask=attention_mask) bert_output = outputs[0] lstm_output, _ = self.lstm(bert_output) dropped = self.dropout(lstm_output[:, -1, :]) logits = self.fc(dropped) return logits