joko333 commited on
Commit
621f6b2
·
1 Parent(s): 2705cb4

Refactor prediction functions to remove unnecessary device handling and clean up code

Browse files
Files changed (1) hide show
  1. utils/prediction.py +3 -29
utils/prediction.py CHANGED
@@ -7,9 +7,6 @@ import numpy as np
7
 
8
 
9
  def load_model_for_prediction():
10
- # Force CPU
11
- device = torch.device('cpu')
12
- torch.backends.mps.enabled = False
13
 
14
  try:
15
  # Load model from Hugging Face Hub
@@ -19,7 +16,7 @@ def load_model_for_prediction():
19
  num_classes=22,
20
  num_layers=2,
21
  dropout=0.5
22
- ).to(device)
23
 
24
  model.eval()
25
 
@@ -45,12 +42,10 @@ def load_model_for_prediction():
45
  print(f"Error loading model components: {str(e)}")
46
  return None, None, None
47
 
48
- def predict_sentence(model, sentence, tokenizer, label_encoder, device=None):
49
  """
50
  Make prediction for a single sentence with label validation.
51
  """
52
- device = torch.device('cpu')
53
- model = model.to(device)
54
  model.eval()
55
 
56
  # Tokenize
@@ -61,7 +56,7 @@ def predict_sentence(model, sentence, tokenizer, label_encoder, device=None):
61
  padding='max_length',
62
  truncation=True,
63
  return_tensors='pt'
64
- ).to(device)
65
 
66
  try:
67
  with torch.no_grad():
@@ -98,25 +93,4 @@ def print_labels(label_encoder, show_counts=False):
98
  print("-" * 40)
99
  print(f"Total number of classes: {len(label_encoder.classes_)}\n")
100
 
101
- def predict_sentence2(sentence, model, tokenizer, label_encoder):
102
- # Tokenize the input
103
- inputs = tokenizer(sentence,
104
- padding=True,
105
- truncation=True,
106
- return_tensors='pt',
107
- max_length=512)
108
-
109
- # Move inputs to the same device as model
110
- device = next(model.parameters()).device
111
- inputs = {k: v.to(device) for k, v in inputs.items()}
112
-
113
- # Make prediction
114
- with torch.no_grad():
115
- outputs = model(**inputs)
116
- predictions = torch.argmax(outputs.logits, dim=1)
117
-
118
- # Convert prediction to label
119
- predicted_label = label_encoder.inverse_transform(predictions.cpu().numpy())[0]
120
-
121
- return predicted_label
122
 
 
7
 
8
 
9
  def load_model_for_prediction():
 
 
 
10
 
11
  try:
12
  # Load model from Hugging Face Hub
 
16
  num_classes=22,
17
  num_layers=2,
18
  dropout=0.5
19
+ )
20
 
21
  model.eval()
22
 
 
42
  print(f"Error loading model components: {str(e)}")
43
  return None, None, None
44
 
45
+ def predict_sentence(model, sentence, tokenizer, label_encoder):
46
  """
47
  Make prediction for a single sentence with label validation.
48
  """
 
 
49
  model.eval()
50
 
51
  # Tokenize
 
56
  padding='max_length',
57
  truncation=True,
58
  return_tensors='pt'
59
+ )
60
 
61
  try:
62
  with torch.no_grad():
 
93
  print("-" * 40)
94
  print(f"Total number of classes: {len(label_encoder.classes_)}\n")
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96