XGB / app.py
subbunanepalli's picture
Update app.py
1574c05 verified
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, Dict, Any, List
import uvicorn
import logging
import os
import pandas as pd
from datetime import datetime
import shutil
from pathlib import Path
import numpy as np
import json
import joblib
from sklearn.metrics import classification_report
from sklearn.multioutput import MultiOutputClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
import xgboost as xgb
import traceback
from xgboost import XGBClassifier
# Import existing utilities
from dataset_utils import (
load_and_preprocess_data,
save_label_encoders,
load_label_encoders
)
from config import (
TEXT_COLUMN,
LABEL_COLUMNS,
BATCH_SIZE,
MODEL_SAVE_DIR
)
from models.tfidf_xgb import TfidfXGBoost
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="XGB Compliance Predictor API")
UPLOAD_DIR = Path("uploads")
MODEL_SAVE_DIR = Path("saved_models")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
# Define paths for vectorizer, model, and encoders
TFIDF_PATH = os.path.join(str(MODEL_SAVE_DIR), "tfidf_vectorizer.pkl")
MODEL_PATH = os.path.join(str(MODEL_SAVE_DIR), "xgb_models.pkl")
ENCODERS_PATH = os.path.join(os.path.dirname(__file__), "label_encoders.pkl")
training_status = {
"is_training": False,
"current_epoch": 0,
"total_epochs": 0,
"current_loss": 0.0,
"start_time": None,
"end_time": None,
"status": "idle",
"metrics": None
}
class TrainingConfig(BaseModel):
batch_size: int = 32
num_epochs: int = 1 # Not used for LGBM, but kept for API compatibility
random_state: int = 42
class TrainingResponse(BaseModel):
message: str
training_id: str
status: str
download_url: Optional[str] = None
class ValidationResponse(BaseModel):
message: str
metrics: Dict[str, Any]
predictions: List[Dict[str, Any]]
class PredictionResponse(BaseModel):
message: str
predictions: List[Dict[str, Any]]
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
class PredictionRequest(BaseModel):
transaction_data: TransactionData
model_name: str = "xgb_models"
class Config:
protected_namespaces = ()
class BatchPredictionResponse(BaseModel):
message: str
predictions: List[Dict[str, Any]]
metrics: Optional[Dict[str, Any]] = None
@app.get("/")
async def root():
return {"message": "XGB Compliance Predictor API"}
@app.get("/v1/xgb/health")
async def health_check():
return {"status": "healthy"}
@app.get("/v1/xgb/training-status")
async def get_training_status():
return training_status
@app.post("/v1/xgb/train", response_model=TrainingResponse)
async def start_training(
config: str = Form(...),
background_tasks: BackgroundTasks = None,
file: UploadFile = File(...)
):
if training_status["is_training"]:
raise HTTPException(status_code=400, detail="Training is already in progress")
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
try:
config_dict = json.loads(config)
training_config = TrainingConfig(**config_dict)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
training_status.update({
"is_training": True,
"current_epoch": 0,
"total_epochs": 1,
"start_time": datetime.now().isoformat(),
"status": "starting"
})
background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
download_url = f"/v1/xgb/download-model/{training_id}"
return TrainingResponse(
message="Training started successfully",
training_id=training_id,
status="started",
download_url=download_url
)
@app.post("/v1/xgb/validate")
async def validate_model(
file: UploadFile = File(...),
model_name: str = "xgb_models"
):
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
try:
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Load and preprocess data
data_df, label_encoders = load_and_preprocess_data(str(file_path))
# Load model and vectorizer
model_path = MODEL_SAVE_DIR / f"{model_name}.pkl"
if not model_path.exists():
raise HTTPException(status_code=404, detail="XGB model file not found")
model = TfidfXGBoost(label_encoders)
model.load_model(model_name)
tfidf = joblib.load(TFIDF_PATH)
# Extract and vectorize text
X_text = data_df[TEXT_COLUMN]
y = data_df[LABEL_COLUMNS]
if not isinstance(X_text, pd.Series) or not pd.api.types.is_string_dtype(X_text):
raise HTTPException(status_code=400, detail=f"TEXT_COLUMN ('{TEXT_COLUMN}') must be a pandas Series of strings. Got type: {type(X_text)}, dtype: {getattr(X_text, 'dtype', None)}")
X_vec = tfidf.transform(X_text)
# Evaluate
reports, y_true_list, y_pred_list = model.evaluate(X_vec, y)
all_probs = model.predict_proba(X_vec)
predictions = []
for i, col in enumerate(LABEL_COLUMNS):
label_encoder = label_encoders[col]
true_labels_orig = label_encoder.inverse_transform(y_true_list[i])
pred_labels_orig = label_encoder.inverse_transform(y_pred_list[i])
for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
class_probs = {label: float(prob) for label, prob in zip(label_encoder.classes_, probs)}
predictions.append({
"field": col,
"true_label": true,
"predicted_label": pred,
"probabilities": class_probs
})
return ValidationResponse(
message="Validation completed successfully",
metrics=reports,
predictions=predictions
)
except Exception as e:
logger.error(f"Validation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
finally:
if os.path.exists(file_path):
os.remove(file_path)
# Pydantic response schema
class PredictionItem(BaseModel):
field: str
predicted_label: str
probabilities: dict
class PredictionResponse(BaseModel):
message: str
predictions: List[PredictionItem]
@app.post("/v1/xgb/predict", response_model=PredictionResponse)
async def predict_model(
file: UploadFile = File(...),
model_name: str = "xgb_models"
):
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
file_path = UPLOAD_DIR / file.filename
try:
# Save uploaded file
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Load and preprocess
data_df, label_encoders = load_and_preprocess_data(str(file_path))
model_path = MODEL_SAVE_DIR / f"{model_name}.pkl"
if not model_path.exists():
raise HTTPException(status_code=404, detail=f"Model file '{model_name}.pkl' not found")
# Load model and vectorizer
model = TfidfXGBoost(label_encoders)
model.load_model(model_name)
tfidf = joblib.load(TFIDF_PATH)
# Extract and validate text
X_text = data_df[TEXT_COLUMN]
if not isinstance(X_text, pd.Series) or not pd.api.types.is_string_dtype(X_text):
raise HTTPException(status_code=400, detail=f"TEXT_COLUMN ('{TEXT_COLUMN}') must be a pandas Series of strings.")
X_vec = tfidf.transform(X_text)
# Predict
y_pred_array = model.predict(X_vec)
all_probs_list = model.predict_proba(X_vec)
predictions = []
for row_idx in range(X_vec.shape[0]):
for label_idx, col in enumerate(LABEL_COLUMNS):
label_encoder = label_encoders.get(col)
if label_encoder is None:
raise HTTPException(status_code=500, detail=f"Label encoder not found for column: {col}")
# Predicted class and decode
pred_class_idx = y_pred_array[row_idx, label_idx]
pred_label = label_encoder.inverse_transform([pred_class_idx])[0]
# Probability distribution
class_prob_dist = all_probs_list[label_idx][row_idx]
class_probs = {
label_encoder.classes_[i]: float(prob)
for i, prob in enumerate(class_prob_dist)
}
predictions.append({
"field": col,
"predicted_label": pred_label,
"probabilities": class_probs
})
return PredictionResponse(
message="Prediction completed successfully",
predictions=predictions
)
except Exception as e:
logger.error(f"Prediction failed: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
finally:
if file_path.exists():
file_path.unlink()
@app.get("/v1/xgb/download-model/{model_id}")
async def download_model(model_id: str):
model_path = MODEL_SAVE_DIR / f"{model_id}.pkl"
if not model_path.exists():
raise HTTPException(status_code=404, detail="Model not found")
return FileResponse(
path=model_path,
filename=f"xgb_model_{model_id}.pkl",
media_type="application/octet-stream"
)
async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
try:
data_df_original, label_encoders = load_and_preprocess_data(file_path)
save_label_encoders(label_encoders)
X = data_df_original[TEXT_COLUMN]
y = data_df_original[LABEL_COLUMNS]
model = TfidfXGB(label_encoders)
model.train(X, y)
model.save_model(training_id)
training_status.update({
"is_training": False,
"end_time": datetime.now().isoformat(),
"status": "completed"
})
except Exception as e:
logger.error(f"Training failed: {str(e)}")
training_status.update({
"is_training": False,
"end_time": datetime.now().isoformat(),
"status": "failed",
"error": str(e)
})
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)