joko333 commited on
Commit
04dc908
·
1 Parent(s): 621f6b2

Enhance prediction function with validation checks and improved error handling

Browse files
Files changed (1) hide show
  1. utils/prediction.py +26 -26
utils/prediction.py CHANGED
@@ -46,43 +46,43 @@ def predict_sentence(model, sentence, tokenizer, label_encoder):
46
  """
47
  Make prediction for a single sentence with label validation.
48
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  model.eval()
50
 
51
  # Tokenize
52
- encoding = tokenizer(
53
- sentence,
54
- add_special_tokens=True,
55
- max_length=512,
56
- padding='max_length',
57
- truncation=True,
58
- return_tensors='pt'
59
- )
60
-
61
  try:
 
 
 
 
 
 
 
 
 
62
  with torch.no_grad():
63
- # Get model outputs
64
  outputs = model(encoding['input_ids'], encoding['attention_mask'])
65
  probabilities = torch.softmax(outputs, dim=1)
66
-
67
- # Get prediction and probability
68
  prob, pred_idx = torch.max(probabilities, dim=1)
 
 
69
 
70
- # Validate prediction index
71
- if pred_idx.item() >= len(label_encoder.classes_):
72
- print(f"Warning: Model predicted invalid label index {pred_idx.item()}")
73
- return "Unknown", 0.0
74
-
75
- # Convert to label
76
- try:
77
- predicted_class = label_encoder.classes_[pred_idx.item()]
78
- return predicted_class, prob.item()
79
- except IndexError:
80
- print(f"Warning: Invalid label index {pred_idx.item()}")
81
- return "Unknown", 0.0
82
-
83
  except Exception as e:
84
  print(f"Prediction error: {str(e)}")
85
- return "Error", 0.0
86
 
87
  def print_labels(label_encoder, show_counts=False):
88
  """Print all labels and their corresponding indices"""
 
46
  """
47
  Make prediction for a single sentence with label validation.
48
  """
49
+ # Validation checks
50
+ if model is None:
51
+ print("Error: Model not loaded")
52
+ return "Error: Model not loaded", 0.0
53
+ if tokenizer is None:
54
+ print("Error: Tokenizer not loaded")
55
+ return "Error: Tokenizer not loaded", 0.0
56
+ if label_encoder is None:
57
+ print("Error: Label encoder not loaded")
58
+ return "Error: Label encoder not loaded", 0.0
59
+
60
+ # Force CPU device
61
+ device = torch.device('cpu')
62
+ model = model.to(device)
63
  model.eval()
64
 
65
  # Tokenize
 
 
 
 
 
 
 
 
 
66
  try:
67
+ encoding = tokenizer(
68
+ sentence,
69
+ add_special_tokens=True,
70
+ max_length=512,
71
+ padding='max_length',
72
+ truncation=True,
73
+ return_tensors='pt'
74
+ ).to(device)
75
+
76
  with torch.no_grad():
 
77
  outputs = model(encoding['input_ids'], encoding['attention_mask'])
78
  probabilities = torch.softmax(outputs, dim=1)
 
 
79
  prob, pred_idx = torch.max(probabilities, dim=1)
80
+ predicted_label = label_encoder.classes_[pred_idx.item()]
81
+ return predicted_label, prob.item()
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  except Exception as e:
84
  print(f"Prediction error: {str(e)}")
85
+ return f"Error: {str(e)}", 0.0
86
 
87
  def print_labels(label_encoder, show_counts=False):
88
  """Print all labels and their corresponding indices"""