joko333 commited on
Commit
dee2852
·
1 Parent(s): 42d8a45

Add BiLSTMAttentionBERT model implementation and update import statements

Browse files
Files changed (2) hide show
  1. utils/model.py +25 -0
  2. utils/prediction.py +1 -1
utils/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoModel
4
+
5
+ class BiLSTMAttentionBERT(PreTrainedModel):
6
+ def __init__(self, hidden_dim, num_classes, num_layers, dropout):
7
+ super().__init__(PretrainedConfig())
8
+ self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
9
+ self.lstm = nn.LSTM(768, hidden_dim, num_layers, batch_first=True, bidirectional=True)
10
+ self.dropout = nn.Dropout(dropout)
11
+ self.fc = nn.Linear(hidden_dim * 2, num_classes)
12
+
13
+ @classmethod
14
+ def from_pretrained(cls, model_path, hidden_dim, num_classes, num_layers, dropout):
15
+ model = cls(hidden_dim, num_classes, num_layers, dropout)
16
+ state_dict = torch.load(model_path, map_location='cpu')
17
+ model.load_state_dict(state_dict)
18
+ return model
19
+
20
+ def forward(self, input_ids, attention_mask):
21
+ bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]
22
+ lstm_output, _ = self.lstm(bert_output)
23
+ dropped = self.dropout(lstm_output[:, -1, :])
24
+ output = self.fc(dropped)
25
+ return output
utils/prediction.py CHANGED
@@ -1,7 +1,7 @@
 
1
  import torch
2
  from transformers import AutoTokenizer
3
  from sklearn.preprocessing import LabelEncoder
4
- from utils.BiLSTM import BiLSTMAttentionBERT
5
  import numpy as np
6
  import streamlit as st
7
  import requests
 
1
+ from utils.model import BiLSTMAttentionBERT
2
  import torch
3
  from transformers import AutoTokenizer
4
  from sklearn.preprocessing import LabelEncoder
 
5
  import numpy as np
6
  import streamlit as st
7
  import requests