|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, AutoModel, PretrainedConfig |
|
|
|
class BiLSTMAttentionBERT(PreTrainedModel): |
|
def __init__(self, hidden_dim, num_classes, num_layers, dropout): |
|
super().__init__(PretrainedConfig()) |
|
self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') |
|
self.lstm = nn.LSTM(768, hidden_dim, num_layers, batch_first=True, bidirectional=True) |
|
self.dropout = nn.Dropout(dropout) |
|
self.fc = nn.Linear(hidden_dim * 2, num_classes) |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_path, hidden_dim, num_classes, num_layers, dropout): |
|
model = cls(hidden_dim, num_classes, num_layers, dropout) |
|
state_dict = torch.load(model_path, map_location='cpu') |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
def forward(self, input_ids, attention_mask): |
|
bert_output = self.bert(input_ids, attention_mask=attention_mask)[0] |
|
lstm_output, _ = self.lstm(bert_output) |
|
dropped = self.dropout(lstm_output[:, -1, :]) |
|
output = self.fc(dropped) |
|
return output |