Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -258,6 +258,8 @@ async def validate_model(
|
|
258 |
if os.path.exists(file_path):
|
259 |
os.remove(file_path)
|
260 |
|
|
|
|
|
261 |
@app.post("/v1/xgb/predict")
|
262 |
async def predict_model(
|
263 |
file: UploadFile = File(...),
|
@@ -267,61 +269,15 @@ async def predict_model(
|
|
267 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
268 |
|
269 |
try:
|
270 |
-
# Save uploaded CSV
|
271 |
file_path = UPLOAD_DIR / file.filename
|
272 |
with file_path.open("wb") as buffer:
|
273 |
shutil.copyfileobj(file.file, buffer)
|
274 |
|
275 |
-
# Load CSV into DataFrame
|
276 |
data_df = pd.read_csv(file_path)
|
277 |
if TEXT_COLUMN not in data_df.columns:
|
278 |
raise HTTPException(status_code=400, detail=f"Missing required text column: {TEXT_COLUMN}")
|
279 |
|
280 |
-
|
281 |
-
tfidf = joblib.load(TFIDF_PATH)
|
282 |
-
|
283 |
-
# Load label encoders
|
284 |
-
label_encoders = joblib.load(ENCODERS_PATH)
|
285 |
-
logger.info(f"Loaded encoders: {list(label_encoders.keys())}")
|
286 |
-
|
287 |
-
# Load model
|
288 |
-
model = TfidfXGBoost(label_encoders)
|
289 |
-
model.load_model(model_name)
|
290 |
-
|
291 |
-
# Vectorize text
|
292 |
-
X_text = data_df[TEXT_COLUMN]
|
293 |
-
X_vec = tfidf.transform(X_text)
|
294 |
-
|
295 |
-
# Make predictions
|
296 |
-
y_pred_list = model.predict(X_vec)
|
297 |
-
all_probs = model.predict_proba(X_vec)
|
298 |
-
|
299 |
-
logger.info(f"y_pred_list: {len(y_pred_list)}, all_probs: {len(all_probs)}")
|
300 |
-
|
301 |
-
# Build prediction response
|
302 |
-
predictions = []
|
303 |
-
for i, col in enumerate(LABEL_COLUMNS):
|
304 |
-
if col not in label_encoders:
|
305 |
-
logger.warning(f"Missing encoder for column: {col}")
|
306 |
-
continue
|
307 |
-
|
308 |
-
label_encoder = label_encoders[col]
|
309 |
-
try:
|
310 |
-
pred_labels = label_encoder.inverse_transform(y_pred_list[i])
|
311 |
-
except Exception as e:
|
312 |
-
logger.error(f"Error in inverse_transform for {col}: {str(e)}")
|
313 |
-
raise HTTPException(status_code=500, detail=f"Label decoding failed for {col}: {str(e)}")
|
314 |
-
|
315 |
-
probs = all_probs[i]
|
316 |
-
|
317 |
-
for pred, prob in zip(pred_labels, probs):
|
318 |
-
predictions.append({
|
319 |
-
"field": col,
|
320 |
-
"predicted_label": pred,
|
321 |
-
"probabilities": {
|
322 |
-
label: float(p) for label, p in zip(label_encoder.classes_, prob)
|
323 |
-
}
|
324 |
-
})
|
325 |
|
326 |
return {
|
327 |
"message": "Prediction completed successfully",
|
|
|
258 |
if os.path.exists(file_path):
|
259 |
os.remove(file_path)
|
260 |
|
261 |
+
from predict_utils import run_prediction
|
262 |
+
|
263 |
@app.post("/v1/xgb/predict")
|
264 |
async def predict_model(
|
265 |
file: UploadFile = File(...),
|
|
|
269 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
270 |
|
271 |
try:
|
|
|
272 |
file_path = UPLOAD_DIR / file.filename
|
273 |
with file_path.open("wb") as buffer:
|
274 |
shutil.copyfileobj(file.file, buffer)
|
275 |
|
|
|
276 |
data_df = pd.read_csv(file_path)
|
277 |
if TEXT_COLUMN not in data_df.columns:
|
278 |
raise HTTPException(status_code=400, detail=f"Missing required text column: {TEXT_COLUMN}")
|
279 |
|
280 |
+
predictions = run_prediction(data_df, model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
return {
|
283 |
"message": "Prediction completed successfully",
|