joko333's picture
Add BiLSTMConfig and update BiLSTMAttentionBERT for improved model configuration and loading
1b17d16
raw
history blame
1.38 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, PretrainedConfig
class BiLSTMConfig(PretrainedConfig):
model_type = "bilstm_attention"
def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs):
super().__init__(**kwargs)
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.num_layers = num_layers
self.dropout = dropout
class BiLSTMAttentionBERT(PreTrainedModel):
config_class = BiLSTMConfig
base_model_prefix = "bilstm_attention"
def __init__(self, config):
super().__init__(config)
self.config = config
self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
self.lstm = nn.LSTM(
768,
config.hidden_dim,
config.num_layers,
batch_first=True,
bidirectional=True
)
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
bert_output = outputs[0]
lstm_output, _ = self.lstm(bert_output)
dropped = self.dropout(lstm_output[:, -1, :])
logits = self.fc(dropped)
return logits