import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, PretrainedConfig class BiLSTMAttentionBERT(PreTrainedModel): def __init__(self, hidden_dim, num_classes, num_layers, dropout): super().__init__(PretrainedConfig()) self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') self.lstm = nn.LSTM(768, hidden_dim, num_layers, batch_first=True, bidirectional=True) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_dim * 2, num_classes) @classmethod def from_pretrained(cls, model_path, hidden_dim, num_classes, num_layers, dropout): model = cls(hidden_dim, num_classes, num_layers, dropout) state_dict = torch.load(model_path, map_location='cpu') model.load_state_dict(state_dict) return model def forward(self, input_ids, attention_mask): bert_output = self.bert(input_ids, attention_mask=attention_mask)[0] lstm_output, _ = self.lstm(bert_output) dropped = self.dropout(lstm_output[:, -1, :]) output = self.fc(dropped) return output