Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, AutoModel, PretrainedConfig | |
class BiLSTMConfig(PretrainedConfig): | |
model_type = "bilstm_attention" | |
def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs): | |
super().__init__(**kwargs) | |
self.hidden_dim = hidden_dim | |
self.num_classes = num_classes | |
self.num_layers = num_layers | |
self.dropout = dropout | |
class BiLSTMAttentionBERT(PreTrainedModel): | |
config_class = BiLSTMConfig | |
base_model_prefix = "bilstm_attention" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') | |
self.lstm = nn.LSTM( | |
768, | |
config.hidden_dim, | |
config.num_layers, | |
batch_first=True, | |
bidirectional=True | |
) | |
self.dropout = nn.Dropout(config.dropout) | |
self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids, attention_mask=attention_mask) | |
bert_output = outputs[0] | |
lstm_output, _ = self.lstm(bert_output) | |
dropped = self.dropout(lstm_output[:, -1, :]) | |
logits = self.fc(dropped) | |
return logits |