from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form from fastapi.responses import FileResponse from pydantic import BaseModel from typing import Optional, Dict, Any, List import uvicorn import torch from transformers import BertTokenizer, BertForSequenceClassification from torch.utils.data import DataLoader import logging import os import asyncio import pandas as pd from datetime import datetime import shutil from pathlib import Path from sklearn.model_selection import train_test_split import zipfile import io import numpy as np import sys import json # Import existing utilities from dataset_utils import ( ComplianceDataset, ComplianceDatasetWithMetadata, load_and_preprocess_data, get_tokenizer, save_label_encoders, get_num_labels, load_label_encoders ) from train_utils import ( initialize_criterions, train_model, evaluate_model, save_model, summarize_metrics, predict_probabilities ) from models.bert_model import BertMultiOutputModel from config import ( TEXT_COLUMN, LABEL_COLUMNS, DEVICE, NUM_EPOCHS, LEARNING_RATE, MAX_LEN, BATCH_SIZE, METADATA_COLUMNS ) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="BERT Compliance Predictor API") # Create necessary directories UPLOAD_DIR = Path("uploads") MODEL_SAVE_DIR = Path("saved_models") UPLOAD_DIR.mkdir(parents=True, exist_ok=True) MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True) # Global variables to track training status training_status = { "is_training": False, "current_epoch": 0, "total_epochs": 0, "current_loss": 0.0, "start_time": None, "end_time": None, "status": "idle", "metrics": None } # Load the model and tokenizer for prediction model_path = MODEL_SAVE_DIR / "BERT_model.pth" tokenizer = get_tokenizer('bert-base-uncased') # Initialize model and label encoders with error handling try: label_encoders = load_label_encoders() model = BertMultiOutputModel([len(label_encoders[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE) if model_path.exists(): model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() else: print(f"Warning: Model file {model_path} not found. Model will be initialized but not loaded.") except Exception as e: print(f"Warning: Could not load label encoders or model: {str(e)}") print("Model will be initialized when training starts.") model = None class TrainingConfig(BaseModel): model_name: str = "bert-base-uncased" batch_size: int = 8 learning_rate: float = 2e-5 num_epochs: int = 2 max_length: int = 128 random_state: int = 42 class TrainingResponse(BaseModel): message: str training_id: str status: str download_url: Optional[str] = None class ValidationResponse(BaseModel): message: str metrics: Dict[str, Any] predictions: List[Dict[str, Any]] class TransactionData(BaseModel): Transaction_Id: str Hit_Seq: int Hit_Id_List: str Origin: str Designation: str Keywords: str Name: str SWIFT_Tag: str Currency: str Entity: str Message: str City: str Country: str State: str Hit_Type: str Record_Matching_String: str WatchList_Match_String: str Payment_Sender_Name: Optional[str] = "" Payment_Reciever_Name: Optional[str] = "" Swift_Message_Type: str Text_Sanction_Data: str Matched_Sanctioned_Entity: str Is_Match: int Red_Flag_Reason: str Risk_Level: str Risk_Score: float Risk_Score_Description: str CDD_Level: str PEP_Status: str Value_Date: str Last_Review_Date: str Next_Review_Date: str Sanction_Description: str Checker_Notes: str Sanction_Context: str Maker_Action: str Customer_ID: int Customer_Type: str Industry: str Transaction_Date_Time: str Transaction_Type: str Transaction_Channel: str Originating_Bank: str Beneficiary_Bank: str Geographic_Origin: str Geographic_Destination: str Match_Score: float Match_Type: str Sanctions_List_Version: str Screening_Date_Time: str Risk_Category: str Risk_Drivers: str Alert_Status: str Investigation_Outcome: str Case_Owner_Analyst: str Escalation_Level: str Escalation_Date: str Regulatory_Reporting_Flags: bool Audit_Trail_Timestamp: str Source_Of_Funds: str Purpose_Of_Transaction: str Beneficial_Owner: str Sanctions_Exposure_History: bool class PredictionRequest(BaseModel): transaction_data: TransactionData model_name: str = "BERT_model" # Default to BERT_model if not specified class BatchPredictionResponse(BaseModel): message: str predictions: List[Dict[str, Any]] metrics: Optional[Dict[str, Any]] = None @app.get("/") async def root(): return {"message": "BERT Compliance Predictor API"} @app.get("/v1/bert/health") async def health_check(): return {"status": "healthy"} @app.get("/v1/bert/training-status") async def get_training_status(): return training_status @app.post("/v1/bert/train", response_model=TrainingResponse) async def start_training( config: str = Form(...), background_tasks: BackgroundTasks = None, file: UploadFile = File(...) ): if training_status["is_training"]: raise HTTPException(status_code=400, detail="Training is already in progress") if not file.filename.endswith('.csv'): raise HTTPException(status_code=400, detail="Only CSV files are allowed") try: # Parse the config JSON string into a TrainingConfig object config_dict = json.loads(config) training_config = TrainingConfig(**config_dict) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid config JSON format") except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}") file_path = UPLOAD_DIR / file.filename with file_path.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) training_id = datetime.now().strftime("%Y%m%d_%H%M%S") training_status.update({ "is_training": True, "current_epoch": 0, "total_epochs": training_config.num_epochs, "start_time": datetime.now().isoformat(), "status": "starting" }) background_tasks.add_task(train_model_task, training_config, str(file_path), training_id) download_url = f"/v1/bert/download-model/{training_id}" return TrainingResponse( message="Training started successfully", training_id=training_id, status="started", download_url=download_url ) @app.post("/v1/bert/validate") async def validate_model( file: UploadFile = File(...), model_name: str = "BERT_model" ): """Validate a BERT model on uploaded data""" if not file.filename.endswith('.csv'): raise HTTPException(status_code=400, detail="Only CSV files are allowed") try: file_path = UPLOAD_DIR / file.filename with file_path.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) data_df, label_encoders = load_and_preprocess_data(str(file_path)) model_path = MODEL_SAVE_DIR / f"{model_name}.pth" if not model_path.exists(): raise HTTPException(status_code=404, detail="BERT model file not found") num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None if metadata_df is not None: metadata_dim = metadata_df.shape[1] model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE) else: model = BertMultiOutputModel(num_labels_list).to(DEVICE) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() texts = data_df[TEXT_COLUMN] labels_array = data_df[LABEL_COLUMNS].values tokenizer = get_tokenizer("bert-base-uncased") if metadata_df is not None: dataset = ComplianceDatasetWithMetadata( texts.tolist(), metadata_df.values, labels_array, tokenizer, MAX_LEN ) else: dataset = ComplianceDataset( texts.tolist(), labels_array, tokenizer, MAX_LEN ) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE) metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader) summary_metrics = summarize_metrics(metrics).to_dict() all_probs = predict_probabilities(model, dataloader) predictions = [] for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)): field = LABEL_COLUMNS[i] label_encoder = label_encoders[field] true_labels_orig = label_encoder.inverse_transform(true_labels) pred_labels_orig = label_encoder.inverse_transform(pred_labels) for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]): predictions.append({ "field": field, "true_label": true, "predicted_label": pred, "probabilities": probs.tolist() }) return ValidationResponse( message="Validation completed successfully", metrics=summary_metrics, predictions=predictions ) except Exception as e: logger.error(f"Validation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}") finally: if os.path.exists(file_path): os.remove(file_path) @app.post("/v1/bert/predict") async def predict( request: Optional[PredictionRequest] = None, file: UploadFile = File(None), model_name: str = "BERT_model" ): """ Make predictions on either a single transaction or a batch of transactions from a CSV file. You can either: 1. Send a single transaction in the request body 2. Upload a CSV file with multiple transactions Parameters: - file: CSV file containing transactions for batch prediction - model_name: Name of the model to use for prediction (default: "BERT_model") """ try: # Load the model model_path = MODEL_SAVE_DIR / f"{model_name}.pth" if not model_path.exists(): raise HTTPException(status_code=404, detail=f"Model {model_name} not found") num_labels_list = [len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS] model = BertMultiOutputModel(num_labels_list).to(DEVICE) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() # Handle batch prediction from CSV if file and file.filename: if not file.filename.endswith('.csv'): raise HTTPException(status_code=400, detail="Only CSV files are allowed") file_path = UPLOAD_DIR / file.filename with file_path.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) try: # Load and preprocess the CSV data data_df, _ = load_and_preprocess_data(str(file_path)) texts = data_df[TEXT_COLUMN] # Create dataset and dataloader dataset = ComplianceDataset( texts.tolist(), [[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction tokenizer, MAX_LEN ) loader = DataLoader(dataset, batch_size=BATCH_SIZE) # Get predictions all_probabilities = predict_probabilities(model, loader) label_encoders = load_label_encoders() # Process predictions predictions = [] for i, row in data_df.iterrows(): transaction_pred = {} for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)): # Get probabilities for each class class_probs = { label: float(probs[i][j]) for j, label in enumerate(label_encoders[col].classes_) } # Sort probabilities in descending order sorted_probs = sorted(class_probs.items(), key=lambda x: x[1], reverse=True) # Get top prediction and its probability top_pred, top_prob = sorted_probs[0] # Get top 3 predictions with probabilities top_3_predictions = [ {"label": label, "probability": prob} for label, prob in sorted_probs[:3] ] transaction_pred[col] = { "top_prediction": { "label": top_pred, "probability": top_prob }, "alternative_predictions": top_3_predictions[1:], # Exclude the top prediction "all_probabilities": class_probs # Keep all probabilities for reference } predictions.append({ "transaction_id": row.get('Transaction_Id', f"transaction_{i}"), "predictions": transaction_pred }) return BatchPredictionResponse( message="Batch prediction completed successfully", predictions=predictions ) finally: if os.path.exists(file_path): os.remove(file_path) # Handle single prediction elif request and request.transaction_data: input_data = pd.DataFrame([request.transaction_data.dict()]) text_input = f""" Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]} """ dataset = ComplianceDataset( texts=[text_input], labels=[[0] * len(LABEL_COLUMNS)], tokenizer=tokenizer, max_len=MAX_LEN ) loader = DataLoader(dataset, batch_size=1, shuffle=False) all_probabilities = predict_probabilities(model, loader) label_encoders = load_label_encoders() response = {} for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)): # Get probabilities for each class class_probs = { label: float(probs[0][j]) for j, label in enumerate(label_encoders[col].classes_) } # Sort probabilities in descending order sorted_probs = sorted(class_probs.items(), key=lambda x: x[1], reverse=True) # Get top prediction and its probability top_pred, top_prob = sorted_probs[0] # Get top 3 predictions with probabilities top_3_predictions = [ {"label": label, "probability": prob} for label, prob in sorted_probs[:3] ] response[col] = { "top_prediction": { "label": top_pred, "probability": top_prob }, "alternative_predictions": top_3_predictions[1:], # Exclude the top prediction "all_probabilities": class_probs # Keep all probabilities for reference } return response else: raise HTTPException( status_code=400, detail="Either provide a transaction in the request body or upload a CSV file" ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/bert/download-model/{model_id}") async def download_model(model_id: str): """Download a trained model""" model_path = MODEL_SAVE_DIR / f"{model_id}.pth" if not model_path.exists(): raise HTTPException(status_code=404, detail="Model not found") return FileResponse( path=model_path, filename=f"bert_model_{model_id}.pth", media_type="application/octet-stream" ) async def train_model_task(config: TrainingConfig, file_path: str, training_id: str): try: data_df_original, label_encoders = load_and_preprocess_data(file_path) save_label_encoders(label_encoders) texts = data_df_original[TEXT_COLUMN] labels_array = data_df_original[LABEL_COLUMNS].values metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None num_labels_list = get_num_labels(label_encoders) tokenizer = get_tokenizer(config.model_name) if metadata_df is not None: metadata_dim = metadata_df.shape[1] dataset = ComplianceDatasetWithMetadata( texts.tolist(), metadata_df.values, labels_array, tokenizer, config.max_length ) model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE) else: dataset = ComplianceDataset( texts.tolist(), labels_array, tokenizer, config.max_length ) model = BertMultiOutputModel(num_labels_list).to(DEVICE) train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) criterions = initialize_criterions(data_df_original, label_encoders) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) for epoch in range(config.num_epochs): training_status["current_epoch"] = epoch + 1 train_loss = train_model(model, train_loader, optimizer, criterions, epoch) training_status["current_loss"] = train_loss # Save model after each epoch save_model(model, training_id, 'pth') training_status.update({ "is_training": False, "end_time": datetime.now().isoformat(), "status": "completed" }) except Exception as e: logger.error(f"Training failed: {str(e)}") training_status.update({ "is_training": False, "end_time": datetime.now().isoformat(), "status": "failed", "error": str(e) }) if __name__ == "__main__": port = int(os.environ.get("PORT", 7861)) uvicorn.run(app, host="0.0.0.0", port=port)