joko333 commited on
Commit
67bf242
·
1 Parent(s): 3681591

Enhance model loading for prediction by integrating pre-trained BERT and refining checkpoint handling

Browse files
Files changed (1) hide show
  1. utils/prediction.py +19 -5
utils/prediction.py CHANGED
@@ -1,6 +1,6 @@
1
  from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
2
  import torch
3
- from transformers import AutoTokenizer
4
  from sklearn.preprocessing import LabelEncoder
5
  import numpy as np
6
  import streamlit as st
@@ -12,6 +12,11 @@ from huggingface_hub import hf_hub_download
12
  def load_model_for_prediction():
13
  try:
14
  st.write("Starting model loading...")
 
 
 
 
 
15
  config = BiLSTMConfig(
16
  hidden_dim=128,
17
  num_classes=22,
@@ -19,20 +24,29 @@ def load_model_for_prediction():
19
  dropout=0.5
20
  )
21
 
22
- # Initialize model
23
  model = BiLSTMAttentionBERT(config)
 
24
 
25
- # Load checkpoint
26
  model_path = hf_hub_download(
27
  repo_id="joko333/BiLSTM_v01",
28
  filename="model_epoch8_acc72.53.pt"
29
  )
30
  checkpoint = torch.load(model_path, map_location='cpu')
31
 
32
- # Extract model state dict from checkpoint
 
 
33
  if 'model_state_dict' in checkpoint:
 
 
34
  state_dict = checkpoint['model_state_dict']
35
- model.load_state_dict(state_dict)
 
 
 
 
 
36
  st.write("Model loaded successfully")
37
  else:
38
  st.error("Invalid checkpoint format")
 
1
  from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
  from sklearn.preprocessing import LabelEncoder
5
  import numpy as np
6
  import streamlit as st
 
12
  def load_model_for_prediction():
13
  try:
14
  st.write("Starting model loading...")
15
+
16
+ # Initialize BERT first
17
+ bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
18
+
19
+ # Initialize config and model
20
  config = BiLSTMConfig(
21
  hidden_dim=128,
22
  num_classes=22,
 
24
  dropout=0.5
25
  )
26
 
 
27
  model = BiLSTMAttentionBERT(config)
28
+ model.bert = bert # Set pre-trained BERT
29
 
30
+ # Load custom layers from checkpoint
31
  model_path = hf_hub_download(
32
  repo_id="joko333/BiLSTM_v01",
33
  filename="model_epoch8_acc72.53.pt"
34
  )
35
  checkpoint = torch.load(model_path, map_location='cpu')
36
 
37
+ # Debug checkpoint structure
38
+ st.write("Checkpoint keys:", checkpoint.keys())
39
+
40
  if 'model_state_dict' in checkpoint:
41
+ # Extract only custom layer weights
42
+ custom_state_dict = {}
43
  state_dict = checkpoint['model_state_dict']
44
+ for key, value in state_dict.items():
45
+ if not key.startswith('bert.'):
46
+ custom_state_dict[key] = value
47
+
48
+ # Load custom layers
49
+ model.load_state_dict(custom_state_dict, strict=False)
50
  st.write("Model loaded successfully")
51
  else:
52
  st.error("Invalid checkpoint format")