joko333 commited on
Commit
1b17d16
·
1 Parent(s): 41047a5

Add BiLSTMConfig and update BiLSTMAttentionBERT for improved model configuration and loading

Browse files
Files changed (2) hide show
  1. utils/model.py +12 -2
  2. utils/prediction.py +16 -1
utils/model.py CHANGED
@@ -3,6 +3,8 @@ import torch.nn as nn
3
  from transformers import PreTrainedModel, AutoModel, PretrainedConfig
4
 
5
  class BiLSTMConfig(PretrainedConfig):
 
 
6
  def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs):
7
  super().__init__(**kwargs)
8
  self.hidden_dim = hidden_dim
@@ -11,12 +13,20 @@ class BiLSTMConfig(PretrainedConfig):
11
  self.dropout = dropout
12
 
13
  class BiLSTMAttentionBERT(PreTrainedModel):
 
 
 
14
  def __init__(self, config):
15
  super().__init__(config)
16
  self.config = config
17
  self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
18
- self.lstm = nn.LSTM(768, config.hidden_dim, config.num_layers,
19
- batch_first=True, bidirectional=True)
 
 
 
 
 
20
  self.dropout = nn.Dropout(config.dropout)
21
  self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes)
22
 
 
3
  from transformers import PreTrainedModel, AutoModel, PretrainedConfig
4
 
5
  class BiLSTMConfig(PretrainedConfig):
6
+ model_type = "bilstm_attention"
7
+
8
  def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs):
9
  super().__init__(**kwargs)
10
  self.hidden_dim = hidden_dim
 
13
  self.dropout = dropout
14
 
15
  class BiLSTMAttentionBERT(PreTrainedModel):
16
+ config_class = BiLSTMConfig
17
+ base_model_prefix = "bilstm_attention"
18
+
19
  def __init__(self, config):
20
  super().__init__(config)
21
  self.config = config
22
  self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
23
+ self.lstm = nn.LSTM(
24
+ 768,
25
+ config.hidden_dim,
26
+ config.num_layers,
27
+ batch_first=True,
28
+ bidirectional=True
29
+ )
30
  self.dropout = nn.Dropout(config.dropout)
31
  self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes)
32
 
utils/prediction.py CHANGED
@@ -1,16 +1,31 @@
1
- from utils.model import BiLSTMAttentionBERT
2
  import torch
3
  from transformers import AutoTokenizer
4
  from sklearn.preprocessing import LabelEncoder
5
  import numpy as np
6
  import streamlit as st
7
  import requests
 
8
 
9
 
10
 
11
  def load_model_for_prediction():
12
  try:
13
  st.write("Starting model loading...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Test Hugging Face connectivity
16
  st.write("Testing connection to Hugging Face...")
 
1
+ from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
2
  import torch
3
  from transformers import AutoTokenizer
4
  from sklearn.preprocessing import LabelEncoder
5
  import numpy as np
6
  import streamlit as st
7
  import requests
8
+ from huggingface_hub import hf_hub_download
9
 
10
 
11
 
12
  def load_model_for_prediction():
13
  try:
14
  st.write("Starting model loading...")
15
+ config = BiLSTMConfig(
16
+ hidden_dim=128,
17
+ num_classes=22,
18
+ num_layers=2,
19
+ dropout=0.5
20
+ )
21
+
22
+ model = BiLSTMAttentionBERT(config)
23
+ model_path = hf_hub_download(
24
+ repo_id="joko333/BiLSTM_v01",
25
+ filename="model_epoch8_acc72.53.pt"
26
+ )
27
+ state_dict = torch.load(model_path, map_location='cpu')
28
+ model.load_state_dict(state_dict)
29
 
30
  # Test Hugging Face connectivity
31
  st.write("Testing connection to Hugging Face...")