namanpenguin commited on
Commit
7cee3e9
·
verified ·
1 Parent(s): 28cd8d3

Update train_utils.py

Browse files
Files changed (1) hide show
  1. train_utils.py +40 -20
train_utils.py CHANGED
@@ -287,24 +287,44 @@ def predict_probabilities(model, loader):
287
  all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
288
 
289
  with torch.no_grad():
290
- for batch in tqdm(loader, desc="Predicting Probabilities"):
291
- # Unpack batch, ignoring labels as we only need inputs
292
- if len(batch) == 2:
293
- inputs, _ = batch
294
- input_ids = inputs['input_ids'].to(DEVICE)
295
- attention_mask = inputs['attention_mask'].to(DEVICE)
296
- outputs = model(input_ids, attention_mask)
297
- elif len(batch) == 3:
298
- inputs, metadata, _ = batch
299
- input_ids = inputs['input_ids'].to(DEVICE)
300
- attention_mask = inputs['attention_mask'].to(DEVICE)
301
- metadata = metadata.to(DEVICE)
302
- outputs = model(input_ids, attention_mask, metadata)
303
- else:
304
- raise ValueError("Unsupported batch format.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- for i, out_logits in enumerate(outputs):
307
- # Apply softmax to logits to get probabilities
308
- probs = torch.softmax(out_logits, dim=1).cpu().numpy()
309
- all_probabilities[i].extend(probs)
310
- return all_probabilities
 
287
  all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
288
 
289
  with torch.no_grad():
290
+ # Check if we're in a non-interactive environment (like a server)
291
+ import sys
292
+ is_interactive = hasattr(sys, 'ps1') or sys.stdout.isatty()
293
+
294
+ if is_interactive:
295
+ # Use tqdm with proper error handling
296
+ try:
297
+ for batch in tqdm(loader, desc="Predicting Probabilities", leave=False):
298
+ _process_batch_for_prediction(batch, model, all_probabilities)
299
+ except Exception as e:
300
+ print(f"Warning: Error during prediction with progress bar: {str(e)}")
301
+ # Fallback: process without progress bar
302
+ for batch in loader:
303
+ _process_batch_for_prediction(batch, model, all_probabilities)
304
+ else:
305
+ # Non-interactive environment, process without progress bar
306
+ for batch in loader:
307
+ _process_batch_for_prediction(batch, model, all_probabilities)
308
+
309
+ return all_probabilities
310
+
311
+ def _process_batch_for_prediction(batch, model, all_probabilities):
312
+ """Helper function to process a batch for prediction"""
313
+ if len(batch) == 2:
314
+ inputs, _ = batch
315
+ input_ids = inputs['input_ids'].to(DEVICE)
316
+ attention_mask = inputs['attention_mask'].to(DEVICE)
317
+ outputs = model(input_ids, attention_mask)
318
+ elif len(batch) == 3:
319
+ inputs, metadata, _ = batch
320
+ input_ids = inputs['input_ids'].to(DEVICE)
321
+ attention_mask = inputs['attention_mask'].to(DEVICE)
322
+ metadata = metadata.to(DEVICE)
323
+ outputs = model(input_ids, attention_mask, metadata)
324
+ else:
325
+ raise ValueError("Unsupported batch format.")
326
 
327
+ for i, out_logits in enumerate(outputs):
328
+ # Apply softmax to logits to get probabilities
329
+ probs = torch.softmax(out_logits, dim=1).cpu().numpy()
330
+ all_probabilities[i].extend(probs)