Update app.py
Browse files
app.py
CHANGED
@@ -76,7 +76,7 @@ training_status = {
|
|
76 |
}
|
77 |
|
78 |
# Load the model and tokenizer for prediction
|
79 |
-
model_path = MODEL_SAVE_DIR / "
|
80 |
tokenizer = get_tokenizer('bert-base-uncased')
|
81 |
|
82 |
# Initialize model and label encoders with error handling
|
@@ -260,7 +260,7 @@ async def validate_model(
|
|
260 |
|
261 |
data_df, label_encoders = load_and_preprocess_data(str(file_path))
|
262 |
|
263 |
-
model_path = MODEL_SAVE_DIR / f"{model_name}
|
264 |
if not model_path.exists():
|
265 |
raise HTTPException(status_code=404, detail="BERT model file not found")
|
266 |
|
@@ -349,7 +349,7 @@ async def predict(
|
|
349 |
"""
|
350 |
try:
|
351 |
# Load the model
|
352 |
-
model_path = MODEL_SAVE_DIR / f"{model_name}
|
353 |
if not model_path.exists():
|
354 |
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
355 |
|
@@ -390,17 +390,31 @@ async def predict(
|
|
390 |
for i, row in data_df.iterrows():
|
391 |
transaction_pred = {}
|
392 |
for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
|
393 |
-
|
394 |
-
decoded_pred = label_encoders[col].inverse_transform([pred])[0]
|
395 |
-
|
396 |
class_probs = {
|
397 |
label: float(probs[i][j])
|
398 |
for j, label in enumerate(label_encoders[col].classes_)
|
399 |
}
|
400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
transaction_pred[col] = {
|
402 |
-
"
|
403 |
-
|
|
|
|
|
|
|
|
|
404 |
}
|
405 |
|
406 |
predictions.append({
|
@@ -466,17 +480,31 @@ async def predict(
|
|
466 |
|
467 |
response = {}
|
468 |
for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
|
469 |
-
|
470 |
-
decoded_pred = label_encoders[col].inverse_transform([pred])[0]
|
471 |
-
|
472 |
class_probs = {
|
473 |
label: float(probs[0][j])
|
474 |
for j, label in enumerate(label_encoders[col].classes_)
|
475 |
}
|
476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
response[col] = {
|
478 |
-
"
|
479 |
-
|
|
|
|
|
|
|
|
|
480 |
}
|
481 |
|
482 |
return response
|
@@ -565,5 +593,5 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
|
|
565 |
})
|
566 |
|
567 |
if __name__ == "__main__":
|
568 |
-
port = int(os.environ.get("PORT",
|
569 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
76 |
}
|
77 |
|
78 |
# Load the model and tokenizer for prediction
|
79 |
+
model_path = MODEL_SAVE_DIR / "BERT_model.pth"
|
80 |
tokenizer = get_tokenizer('bert-base-uncased')
|
81 |
|
82 |
# Initialize model and label encoders with error handling
|
|
|
260 |
|
261 |
data_df, label_encoders = load_and_preprocess_data(str(file_path))
|
262 |
|
263 |
+
model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
|
264 |
if not model_path.exists():
|
265 |
raise HTTPException(status_code=404, detail="BERT model file not found")
|
266 |
|
|
|
349 |
"""
|
350 |
try:
|
351 |
# Load the model
|
352 |
+
model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
|
353 |
if not model_path.exists():
|
354 |
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
355 |
|
|
|
390 |
for i, row in data_df.iterrows():
|
391 |
transaction_pred = {}
|
392 |
for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
|
393 |
+
# Get probabilities for each class
|
|
|
|
|
394 |
class_probs = {
|
395 |
label: float(probs[i][j])
|
396 |
for j, label in enumerate(label_encoders[col].classes_)
|
397 |
}
|
398 |
|
399 |
+
# Sort probabilities in descending order
|
400 |
+
sorted_probs = sorted(class_probs.items(), key=lambda x: x[1], reverse=True)
|
401 |
+
|
402 |
+
# Get top prediction and its probability
|
403 |
+
top_pred, top_prob = sorted_probs[0]
|
404 |
+
|
405 |
+
# Get top 3 predictions with probabilities
|
406 |
+
top_3_predictions = [
|
407 |
+
{"label": label, "probability": prob}
|
408 |
+
for label, prob in sorted_probs[:3]
|
409 |
+
]
|
410 |
+
|
411 |
transaction_pred[col] = {
|
412 |
+
"top_prediction": {
|
413 |
+
"label": top_pred,
|
414 |
+
"probability": top_prob
|
415 |
+
},
|
416 |
+
"alternative_predictions": top_3_predictions[1:], # Exclude the top prediction
|
417 |
+
"all_probabilities": class_probs # Keep all probabilities for reference
|
418 |
}
|
419 |
|
420 |
predictions.append({
|
|
|
480 |
|
481 |
response = {}
|
482 |
for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
|
483 |
+
# Get probabilities for each class
|
|
|
|
|
484 |
class_probs = {
|
485 |
label: float(probs[0][j])
|
486 |
for j, label in enumerate(label_encoders[col].classes_)
|
487 |
}
|
488 |
|
489 |
+
# Sort probabilities in descending order
|
490 |
+
sorted_probs = sorted(class_probs.items(), key=lambda x: x[1], reverse=True)
|
491 |
+
|
492 |
+
# Get top prediction and its probability
|
493 |
+
top_pred, top_prob = sorted_probs[0]
|
494 |
+
|
495 |
+
# Get top 3 predictions with probabilities
|
496 |
+
top_3_predictions = [
|
497 |
+
{"label": label, "probability": prob}
|
498 |
+
for label, prob in sorted_probs[:3]
|
499 |
+
]
|
500 |
+
|
501 |
response[col] = {
|
502 |
+
"top_prediction": {
|
503 |
+
"label": top_pred,
|
504 |
+
"probability": top_prob
|
505 |
+
},
|
506 |
+
"alternative_predictions": top_3_predictions[1:], # Exclude the top prediction
|
507 |
+
"all_probabilities": class_probs # Keep all probabilities for reference
|
508 |
}
|
509 |
|
510 |
return response
|
|
|
593 |
})
|
594 |
|
595 |
if __name__ == "__main__":
|
596 |
+
port = int(os.environ.get("PORT", 7861))
|
597 |
uvicorn.run(app, host="0.0.0.0", port=port)
|