joko333 commited on
Commit
3681591
·
1 Parent(s): 1b17d16

Refactor model loading function to handle checkpoints and improve error handling

Browse files
Files changed (1) hide show
  1. utils/prediction.py +19 -36
utils/prediction.py CHANGED
@@ -19,57 +19,40 @@ def load_model_for_prediction():
19
  dropout=0.5
20
  )
21
 
 
22
  model = BiLSTMAttentionBERT(config)
 
 
23
  model_path = hf_hub_download(
24
  repo_id="joko333/BiLSTM_v01",
25
  filename="model_epoch8_acc72.53.pt"
26
  )
27
- state_dict = torch.load(model_path, map_location='cpu')
28
- model.load_state_dict(state_dict)
29
 
30
- # Test Hugging Face connectivity
31
- st.write("Testing connection to Hugging Face...")
32
- response = requests.get("https://huggingface.co/joko333/BiLSTM_v01")
33
- if response.status_code != 200:
34
- st.error(f"Cannot connect to Hugging Face. Status code: {response.status_code}")
 
 
35
  return None, None, None
36
 
37
- # Load model with logging
38
- st.write("Loading BiLSTM model...")
39
- model = BiLSTMAttentionBERT.from_pretrained(
40
- "joko333/BiLSTM_v01",
41
- hidden_dim=128,
42
- num_classes=22,
43
- num_layers=2,
44
- dropout=0.5
45
- )
46
- st.write("Model loaded successfully")
47
-
48
- # Initialize label encoder
49
- st.write("Initializing label encoder...")
50
  label_encoder = LabelEncoder()
51
- label_encoder.classes_ = np.array(['Addition', 'Causal', 'Cause and Effect',
52
- 'Clarification', 'Comparison', 'Concession',
53
- 'Conditional', 'Contrast', 'Contrastive Emphasis',
54
- 'Definition', 'Elaboration', 'Emphasis',
55
- 'Enumeration', 'Explanation', 'Generalization',
56
- 'Illustration', 'Inference', 'Problem Solution',
57
- 'Purpose', 'Sequential', 'Summary',
58
- 'Temporal Sequence'])
59
- st.write("Label encoder initialized")
60
-
61
  # Load tokenizer
62
- st.write("Loading tokenizer...")
63
  tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
64
- st.write("Tokenizer loaded successfully")
65
 
66
  return model, label_encoder, tokenizer
67
 
68
  except Exception as e:
69
- st.error(f"Detailed error: {str(e)}")
70
- st.error(f"Error type: {type(e).__name__}")
71
- import traceback
72
- st.error(f"Traceback: {traceback.format_exc()}")
73
  return None, None, None
74
 
75
  def predict_sentence(model, sentence, tokenizer, label_encoder):
 
19
  dropout=0.5
20
  )
21
 
22
+ # Initialize model
23
  model = BiLSTMAttentionBERT(config)
24
+
25
+ # Load checkpoint
26
  model_path = hf_hub_download(
27
  repo_id="joko333/BiLSTM_v01",
28
  filename="model_epoch8_acc72.53.pt"
29
  )
30
+ checkpoint = torch.load(model_path, map_location='cpu')
 
31
 
32
+ # Extract model state dict from checkpoint
33
+ if 'model_state_dict' in checkpoint:
34
+ state_dict = checkpoint['model_state_dict']
35
+ model.load_state_dict(state_dict)
36
+ st.write("Model loaded successfully")
37
+ else:
38
+ st.error("Invalid checkpoint format")
39
  return None, None, None
40
 
41
+ # Initialize label encoder from checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
42
  label_encoder = LabelEncoder()
43
+ if 'label_encoder_classes' in checkpoint:
44
+ label_encoder.classes_ = checkpoint['label_encoder_classes']
45
+ else:
46
+ st.error("Label encoder data not found in checkpoint")
47
+ return None, None, None
48
+
 
 
 
 
49
  # Load tokenizer
 
50
  tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
 
51
 
52
  return model, label_encoder, tokenizer
53
 
54
  except Exception as e:
55
+ st.error(f"Error loading model: {str(e)}")
 
 
 
56
  return None, None, None
57
 
58
  def predict_sentence(model, sentence, tokenizer, label_encoder):