Roberta / app.py
namanpenguin's picture
Update app.py
87e48e1 verified
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, Dict, Any, List
import uvicorn
import torch
from torch.utils.data import DataLoader
import logging
import os
import asyncio
import pandas as pd
from datetime import datetime
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split
import zipfile
import io
import numpy as np
import sys
import json
# Import existing utilities
from dataset_utils import (
ComplianceDataset,
ComplianceDatasetWithMetadata,
load_and_preprocess_data,
get_tokenizer,
save_label_encoders,
get_num_labels,
load_label_encoders
)
from train_utils import (
initialize_criterions,
train_model,
evaluate_model,
save_model,
summarize_metrics,
predict_probabilities
)
from models.roberta_model import RobertaMultiOutputModel
from config import (
TEXT_COLUMN,
LABEL_COLUMNS,
DEVICE,
NUM_EPOCHS,
LEARNING_RATE,
MAX_LEN,
BATCH_SIZE,
METADATA_COLUMNS
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="RoBERTa Compliance Predictor API")
# Create necessary directories
UPLOAD_DIR = Path("uploads")
MODEL_SAVE_DIR = Path("saved_models")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
# Global variables to track training status
training_status = {
"is_training": False,
"current_epoch": 0,
"total_epochs": 0,
"current_loss": 0.0,
"start_time": None,
"end_time": None,
"status": "idle",
"metrics": None
}
# Load the model and tokenizer for prediction
model_path = MODEL_SAVE_DIR / "ROBERTA_model.pth"
tokenizer = get_tokenizer('roberta-base')
# Initialize model and label encoders with error handling
try:
label_encoders = load_label_encoders()
model = RobertaMultiOutputModel([len(label_encoders[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
if model_path.exists():
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
else:
print(f"Warning: Model file {model_path} not found. Model will be initialized but not loaded.")
except Exception as e:
print(f"Warning: Could not load label encoders or model: {str(e)}")
print("Model will be initialized when training starts.")
model = None
class TrainingConfig(BaseModel):
model_name: str = "roberta-base"
batch_size: int = 8
learning_rate: float = 2e-5
num_epochs: int = 2
max_length: int = 128
random_state: int = 42
class TrainingResponse(BaseModel):
message: str
training_id: str
status: str
download_url: Optional[str] = None
class ValidationResponse(BaseModel):
message: str
metrics: Dict[str, Any]
predictions: List[Dict[str, Any]]
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
class PredictionRequest(BaseModel):
transaction_data: TransactionData
model_name: str = "ROBERTA_model" # Default to RoBERTa_model if not specified
class BatchPredictionResponse(BaseModel):
message: str
predictions: List[Dict[str, Any]]
metrics: Optional[Dict[str, Any]] = None
@app.get("/")
async def root():
return {"message": "RoBERTa Compliance Predictor API"}
@app.get("/v1/roberta/health")
async def health_check():
return {"status": "healthy"}
@app.get("/v1/roberta/training-status")
async def get_training_status():
return training_status
@app.post("/v1/roberta/train", response_model=TrainingResponse)
async def start_training(
config: str = Form(...),
background_tasks: BackgroundTasks = None,
file: UploadFile = File(...)
):
if training_status["is_training"]:
raise HTTPException(status_code=400, detail="Training is already in progress")
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
try:
# Parse the config JSON string into a TrainingConfig object
config_dict = json.loads(config)
training_config = TrainingConfig(**config_dict)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid config JSON format")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
training_status.update({
"is_training": True,
"current_epoch": 0,
"total_epochs": training_config.num_epochs,
"start_time": datetime.now().isoformat(),
"status": "starting"
})
background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
download_url = f"/v1/roberta/download-model/{training_id}"
return TrainingResponse(
message="Training started successfully",
training_id=training_id,
status="started",
download_url=download_url
)
@app.post("/v1/roberta/validate")
async def validate_model(
file: UploadFile = File(...),
model_name: str = "ROBERTA_model"
):
"""Validate a RoBERTa model on uploaded data"""
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
try:
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
data_df, label_encoders = load_and_preprocess_data(str(file_path))
model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
if not model_path.exists():
raise HTTPException(status_code=404, detail="RoBERTa model file not found")
num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
if metadata_df is not None:
metadata_dim = metadata_df.shape[1]
model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
else:
model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
texts = data_df[TEXT_COLUMN]
labels_array = data_df[LABEL_COLUMNS].values
tokenizer = get_tokenizer("roberta-base")
if metadata_df is not None:
dataset = ComplianceDatasetWithMetadata(
texts.tolist(),
metadata_df.values,
labels_array,
tokenizer,
MAX_LEN
)
else:
dataset = ComplianceDataset(
texts.tolist(),
labels_array,
tokenizer,
MAX_LEN
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
summary_metrics = summarize_metrics(metrics).to_dict()
all_probs = predict_probabilities(model, dataloader)
predictions = []
for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
field = LABEL_COLUMNS[i]
label_encoder = label_encoders[field]
true_labels_orig = label_encoder.inverse_transform(true_labels)
pred_labels_orig = label_encoder.inverse_transform(pred_labels)
for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
predictions.append({
"field": field,
"true_label": true,
"predicted_label": pred,
"probabilities": probs.tolist()
})
return ValidationResponse(
message="Validation completed successfully",
metrics=summary_metrics,
predictions=predictions
)
except Exception as e:
logger.error(f"Validation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
finally:
if os.path.exists(file_path):
os.remove(file_path)
@app.post("/v1/roberta/predict")
async def predict(
request: Optional[PredictionRequest] = None,
file: UploadFile = File(None),
model_name: str = "ROBERTA_model"
):
"""
Make predictions on either a single transaction or a batch of transactions from a CSV file.
You can either:
1. Send a single transaction in the request body
2. Upload a CSV file with multiple transactions
Parameters:
- file: CSV file containing transactions for batch prediction
- model_name: Name of the model to use for prediction (default: "ROBERTA_model")
"""
try:
# Load the model
model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
if not model_path.exists():
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
# Load label encoders
try:
label_encoders = load_label_encoders()
num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Could not load label encoders: {str(e)}")
model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
# Handle batch prediction from CSV
if file and file.filename:
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Load and preprocess the CSV data
data_df, _ = load_and_preprocess_data(str(file_path))
texts = data_df[TEXT_COLUMN]
# Create dataset and dataloader
dataset = ComplianceDataset(
texts.tolist(),
[[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction
tokenizer,
MAX_LEN
)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
# Get predictions
all_probabilities = predict_probabilities(model, loader)
# Process predictions
predictions = []
for i, row in data_df.iterrows():
transaction_pred = {}
for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
pred = np.argmax(probs[i])
decoded_pred = label_encoders[col].inverse_transform([pred])[0]
class_probs = {
label: float(probs[i][j])
for j, label in enumerate(label_encoders[col].classes_)
}
transaction_pred[col] = {
"prediction": decoded_pred,
"probabilities": class_probs
}
predictions.append({
"transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
"predictions": transaction_pred
})
return BatchPredictionResponse(
message="Batch prediction completed successfully",
predictions=predictions
)
finally:
if os.path.exists(file_path):
os.remove(file_path)
# Handle single prediction
elif request and request.transaction_data:
input_data = pd.DataFrame([request.transaction_data.dict()])
text_input = f"<s>Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]}</s>"
dataset = ComplianceDataset(
texts=[text_input],
labels=[[0] * len(LABEL_COLUMNS)],
tokenizer=tokenizer,
max_len=MAX_LEN
)
loader = DataLoader(dataset, batch_size=1, shuffle=False)
all_probabilities = predict_probabilities(model, loader)
response = {}
for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
pred = np.argmax(probs[0])
decoded_pred = label_encoders[col].inverse_transform([pred])[0]
class_probs = {
label: float(probs[0][j])
for j, label in enumerate(label_encoders[col].classes_)
}
response[col] = {
"prediction": decoded_pred,
"probabilities": class_probs
}
return response
else:
raise HTTPException(
status_code=400,
detail="Either provide a transaction in the request body or upload a CSV file"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/roberta/download-model/{model_id}")
async def download_model(model_id: str):
"""Download a trained model"""
model_path = MODEL_SAVE_DIR / f"{model_id}_model.pth"
if not model_path.exists():
raise HTTPException(status_code=404, detail="Model not found")
return FileResponse(
path=model_path,
filename=f"roberta_model_{model_id}.pth",
media_type="application/octet-stream"
)
async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
try:
data_df_original, label_encoders = load_and_preprocess_data(file_path)
save_label_encoders(label_encoders)
texts = data_df_original[TEXT_COLUMN]
labels_array = data_df_original[LABEL_COLUMNS].values
metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
num_labels_list = get_num_labels(label_encoders)
tokenizer = get_tokenizer(config.model_name)
if metadata_df is not None:
metadata_dim = metadata_df.shape[1]
dataset = ComplianceDatasetWithMetadata(
texts.tolist(),
metadata_df.values,
labels_array,
tokenizer,
config.max_length
)
model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
else:
dataset = ComplianceDataset(
texts.tolist(),
labels_array,
tokenizer,
config.max_length
)
model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
criterions = initialize_criterions(num_labels_list)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
for epoch in range(config.num_epochs):
training_status["current_epoch"] = epoch + 1
train_loss = train_model(model, train_loader, criterions, optimizer)
training_status["current_loss"] = train_loss
# Save model after each epoch
save_model(model, training_id, 'pth')
training_status.update({
"is_training": False,
"end_time": datetime.now().isoformat(),
"status": "completed"
})
except Exception as e:
logger.error(f"Training failed: {str(e)}")
training_status.update({
"is_training": False,
"end_time": datetime.now().isoformat(),
"status": "failed",
"error": str(e)
})
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)