Update train_utils.py
Browse files- train_utils.py +40 -20
train_utils.py
CHANGED
@@ -287,24 +287,44 @@ def predict_probabilities(model, loader):
|
|
287 |
all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
|
288 |
|
289 |
with torch.no_grad():
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
return all_probabilities
|
|
|
287 |
all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
|
288 |
|
289 |
with torch.no_grad():
|
290 |
+
# Check if we're in a non-interactive environment (like a server)
|
291 |
+
import sys
|
292 |
+
is_interactive = hasattr(sys, 'ps1') or sys.stdout.isatty()
|
293 |
+
|
294 |
+
if is_interactive:
|
295 |
+
# Use tqdm with proper error handling
|
296 |
+
try:
|
297 |
+
for batch in tqdm(loader, desc="Predicting Probabilities", leave=False):
|
298 |
+
_process_batch_for_prediction(batch, model, all_probabilities)
|
299 |
+
except Exception as e:
|
300 |
+
print(f"Warning: Error during prediction with progress bar: {str(e)}")
|
301 |
+
# Fallback: process without progress bar
|
302 |
+
for batch in loader:
|
303 |
+
_process_batch_for_prediction(batch, model, all_probabilities)
|
304 |
+
else:
|
305 |
+
# Non-interactive environment, process without progress bar
|
306 |
+
for batch in loader:
|
307 |
+
_process_batch_for_prediction(batch, model, all_probabilities)
|
308 |
+
|
309 |
+
return all_probabilities
|
310 |
+
|
311 |
+
def _process_batch_for_prediction(batch, model, all_probabilities):
|
312 |
+
"""Helper function to process a batch for prediction"""
|
313 |
+
if len(batch) == 2:
|
314 |
+
inputs, _ = batch
|
315 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
316 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
317 |
+
outputs = model(input_ids, attention_mask)
|
318 |
+
elif len(batch) == 3:
|
319 |
+
inputs, metadata, _ = batch
|
320 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
321 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
322 |
+
metadata = metadata.to(DEVICE)
|
323 |
+
outputs = model(input_ids, attention_mask, metadata)
|
324 |
+
else:
|
325 |
+
raise ValueError("Unsupported batch format.")
|
326 |
|
327 |
+
for i, out_logits in enumerate(outputs):
|
328 |
+
# Apply softmax to logits to get probabilities
|
329 |
+
probs = torch.softmax(out_logits, dim=1).cpu().numpy()
|
330 |
+
all_probabilities[i].extend(probs)
|
|