|
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form |
|
from fastapi.responses import FileResponse |
|
from pydantic import BaseModel |
|
from typing import Optional, Dict, Any, List |
|
import uvicorn |
|
import torch |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
from torch.utils.data import DataLoader |
|
import logging |
|
import os |
|
import asyncio |
|
import pandas as pd |
|
from datetime import datetime |
|
import shutil |
|
from pathlib import Path |
|
from sklearn.model_selection import train_test_split |
|
import zipfile |
|
import io |
|
import numpy as np |
|
import sys |
|
import json |
|
|
|
|
|
|
|
from dataset_utils import ( |
|
ComplianceDataset, |
|
ComplianceDatasetWithMetadata, |
|
load_and_preprocess_data, |
|
get_tokenizer, |
|
save_label_encoders, |
|
get_num_labels, |
|
load_label_encoders |
|
) |
|
from train_utils import ( |
|
initialize_criterions, |
|
train_model, |
|
evaluate_model, |
|
save_model, |
|
summarize_metrics, |
|
predict_probabilities |
|
) |
|
from models.bert_model import BertMultiOutputModel |
|
from config import ( |
|
TEXT_COLUMN, |
|
LABEL_COLUMNS, |
|
DEVICE, |
|
NUM_EPOCHS, |
|
LEARNING_RATE, |
|
MAX_LEN, |
|
BATCH_SIZE, |
|
METADATA_COLUMNS |
|
) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="BERT Compliance Predictor API") |
|
|
|
|
|
UPLOAD_DIR = Path("uploads") |
|
MODEL_SAVE_DIR = Path("saved_models") |
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) |
|
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
training_status = { |
|
"is_training": False, |
|
"current_epoch": 0, |
|
"total_epochs": 0, |
|
"current_loss": 0.0, |
|
"start_time": None, |
|
"end_time": None, |
|
"status": "idle", |
|
"metrics": None |
|
} |
|
|
|
|
|
model_path = MODEL_SAVE_DIR / "BERT_model.pth" |
|
tokenizer = get_tokenizer('bert-base-uncased') |
|
|
|
|
|
try: |
|
label_encoders = load_label_encoders() |
|
model = BertMultiOutputModel([len(label_encoders[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE) |
|
if model_path.exists(): |
|
model.load_state_dict(torch.load(model_path, map_location=DEVICE)) |
|
model.eval() |
|
else: |
|
print(f"Warning: Model file {model_path} not found. Model will be initialized but not loaded.") |
|
except Exception as e: |
|
print(f"Warning: Could not load label encoders or model: {str(e)}") |
|
print("Model will be initialized when training starts.") |
|
model = None |
|
|
|
class TrainingConfig(BaseModel): |
|
model_name: str = "bert-base-uncased" |
|
batch_size: int = 8 |
|
learning_rate: float = 2e-5 |
|
num_epochs: int = 2 |
|
max_length: int = 128 |
|
random_state: int = 42 |
|
|
|
class TrainingResponse(BaseModel): |
|
message: str |
|
training_id: str |
|
status: str |
|
download_url: Optional[str] = None |
|
|
|
class ValidationResponse(BaseModel): |
|
message: str |
|
metrics: Dict[str, Any] |
|
predictions: List[Dict[str, Any]] |
|
|
|
class TransactionData(BaseModel): |
|
Transaction_Id: str |
|
Hit_Seq: int |
|
Hit_Id_List: str |
|
Origin: str |
|
Designation: str |
|
Keywords: str |
|
Name: str |
|
SWIFT_Tag: str |
|
Currency: str |
|
Entity: str |
|
Message: str |
|
City: str |
|
Country: str |
|
State: str |
|
Hit_Type: str |
|
Record_Matching_String: str |
|
WatchList_Match_String: str |
|
Payment_Sender_Name: Optional[str] = "" |
|
Payment_Reciever_Name: Optional[str] = "" |
|
Swift_Message_Type: str |
|
Text_Sanction_Data: str |
|
Matched_Sanctioned_Entity: str |
|
Is_Match: int |
|
Red_Flag_Reason: str |
|
Risk_Level: str |
|
Risk_Score: float |
|
Risk_Score_Description: str |
|
CDD_Level: str |
|
PEP_Status: str |
|
Value_Date: str |
|
Last_Review_Date: str |
|
Next_Review_Date: str |
|
Sanction_Description: str |
|
Checker_Notes: str |
|
Sanction_Context: str |
|
Maker_Action: str |
|
Customer_ID: int |
|
Customer_Type: str |
|
Industry: str |
|
Transaction_Date_Time: str |
|
Transaction_Type: str |
|
Transaction_Channel: str |
|
Originating_Bank: str |
|
Beneficiary_Bank: str |
|
Geographic_Origin: str |
|
Geographic_Destination: str |
|
Match_Score: float |
|
Match_Type: str |
|
Sanctions_List_Version: str |
|
Screening_Date_Time: str |
|
Risk_Category: str |
|
Risk_Drivers: str |
|
Alert_Status: str |
|
Investigation_Outcome: str |
|
Case_Owner_Analyst: str |
|
Escalation_Level: str |
|
Escalation_Date: str |
|
Regulatory_Reporting_Flags: bool |
|
Audit_Trail_Timestamp: str |
|
Source_Of_Funds: str |
|
Purpose_Of_Transaction: str |
|
Beneficial_Owner: str |
|
Sanctions_Exposure_History: bool |
|
|
|
class PredictionRequest(BaseModel): |
|
transaction_data: TransactionData |
|
model_name: str = "BERT_model" |
|
|
|
class BatchPredictionResponse(BaseModel): |
|
message: str |
|
predictions: List[Dict[str, Any]] |
|
metrics: Optional[Dict[str, Any]] = None |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "BERT Compliance Predictor API"} |
|
|
|
@app.get("/v1/bert/health") |
|
async def health_check(): |
|
return {"status": "healthy"} |
|
|
|
@app.get("/v1/bert/training-status") |
|
async def get_training_status(): |
|
return training_status |
|
|
|
@app.post("/v1/bert/train", response_model=TrainingResponse) |
|
async def start_training( |
|
config: str = Form(...), |
|
background_tasks: BackgroundTasks = None, |
|
file: UploadFile = File(...) |
|
): |
|
if training_status["is_training"]: |
|
raise HTTPException(status_code=400, detail="Training is already in progress") |
|
|
|
if not file.filename.endswith('.csv'): |
|
raise HTTPException(status_code=400, detail="Only CSV files are allowed") |
|
|
|
try: |
|
|
|
config_dict = json.loads(config) |
|
training_config = TrainingConfig(**config_dict) |
|
except json.JSONDecodeError: |
|
raise HTTPException(status_code=400, detail="Invalid config JSON format") |
|
except Exception as e: |
|
raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}") |
|
|
|
file_path = UPLOAD_DIR / file.filename |
|
with file_path.open("wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
training_id = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
training_status.update({ |
|
"is_training": True, |
|
"current_epoch": 0, |
|
"total_epochs": training_config.num_epochs, |
|
"start_time": datetime.now().isoformat(), |
|
"status": "starting" |
|
}) |
|
|
|
background_tasks.add_task(train_model_task, training_config, str(file_path), training_id) |
|
|
|
download_url = f"/v1/bert/download-model/{training_id}" |
|
|
|
return TrainingResponse( |
|
message="Training started successfully", |
|
training_id=training_id, |
|
status="started", |
|
download_url=download_url |
|
) |
|
|
|
@app.post("/v1/bert/validate") |
|
async def validate_model( |
|
file: UploadFile = File(...), |
|
model_name: str = "BERT_model" |
|
): |
|
"""Validate a BERT model on uploaded data""" |
|
if not file.filename.endswith('.csv'): |
|
raise HTTPException(status_code=400, detail="Only CSV files are allowed") |
|
|
|
try: |
|
file_path = UPLOAD_DIR / file.filename |
|
with file_path.open("wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
data_df, label_encoders = load_and_preprocess_data(str(file_path)) |
|
|
|
model_path = MODEL_SAVE_DIR / f"{model_name}.pth" |
|
if not model_path.exists(): |
|
raise HTTPException(status_code=404, detail="BERT model file not found") |
|
|
|
num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] |
|
metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None |
|
|
|
if metadata_df is not None: |
|
metadata_dim = metadata_df.shape[1] |
|
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE) |
|
else: |
|
model = BertMultiOutputModel(num_labels_list).to(DEVICE) |
|
|
|
model.load_state_dict(torch.load(model_path, map_location=DEVICE)) |
|
model.eval() |
|
|
|
texts = data_df[TEXT_COLUMN] |
|
labels_array = data_df[LABEL_COLUMNS].values |
|
tokenizer = get_tokenizer("bert-base-uncased") |
|
|
|
if metadata_df is not None: |
|
dataset = ComplianceDatasetWithMetadata( |
|
texts.tolist(), |
|
metadata_df.values, |
|
labels_array, |
|
tokenizer, |
|
MAX_LEN |
|
) |
|
else: |
|
dataset = ComplianceDataset( |
|
texts.tolist(), |
|
labels_array, |
|
tokenizer, |
|
MAX_LEN |
|
) |
|
|
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE) |
|
metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader) |
|
summary_metrics = summarize_metrics(metrics).to_dict() |
|
|
|
all_probs = predict_probabilities(model, dataloader) |
|
|
|
predictions = [] |
|
for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)): |
|
field = LABEL_COLUMNS[i] |
|
label_encoder = label_encoders[field] |
|
true_labels_orig = label_encoder.inverse_transform(true_labels) |
|
pred_labels_orig = label_encoder.inverse_transform(pred_labels) |
|
|
|
for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]): |
|
predictions.append({ |
|
"field": field, |
|
"true_label": true, |
|
"predicted_label": pred, |
|
"probabilities": probs.tolist() |
|
}) |
|
|
|
return ValidationResponse( |
|
message="Validation completed successfully", |
|
metrics=summary_metrics, |
|
predictions=predictions |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Validation failed: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}") |
|
finally: |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
@app.post("/v1/bert/predict") |
|
async def predict( |
|
request: Optional[PredictionRequest] = None, |
|
file: UploadFile = File(None), |
|
model_name: str = "BERT_model" |
|
): |
|
""" |
|
Make predictions on either a single transaction or a batch of transactions from a CSV file. |
|
|
|
You can either: |
|
1. Send a single transaction in the request body |
|
2. Upload a CSV file with multiple transactions |
|
|
|
Parameters: |
|
- file: CSV file containing transactions for batch prediction |
|
- model_name: Name of the model to use for prediction (default: "BERT_model") |
|
""" |
|
try: |
|
|
|
model_path = MODEL_SAVE_DIR / f"{model_name}.pth" |
|
if not model_path.exists(): |
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found") |
|
|
|
num_labels_list = [len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS] |
|
model = BertMultiOutputModel(num_labels_list).to(DEVICE) |
|
model.load_state_dict(torch.load(model_path, map_location=DEVICE)) |
|
model.eval() |
|
|
|
|
|
if file and file.filename: |
|
if not file.filename.endswith('.csv'): |
|
raise HTTPException(status_code=400, detail="Only CSV files are allowed") |
|
|
|
file_path = UPLOAD_DIR / file.filename |
|
with file_path.open("wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
try: |
|
|
|
data_df, _ = load_and_preprocess_data(str(file_path)) |
|
texts = data_df[TEXT_COLUMN] |
|
|
|
|
|
dataset = ComplianceDataset( |
|
texts.tolist(), |
|
[[0] * len(LABEL_COLUMNS)] * len(texts), |
|
tokenizer, |
|
MAX_LEN |
|
) |
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE) |
|
|
|
|
|
all_probabilities = predict_probabilities(model, loader) |
|
label_encoders = load_label_encoders() |
|
|
|
|
|
predictions = [] |
|
for i, row in data_df.iterrows(): |
|
transaction_pred = {} |
|
for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)): |
|
|
|
class_probs = { |
|
label: float(probs[i][j]) |
|
for j, label in enumerate(label_encoders[col].classes_) |
|
} |
|
|
|
|
|
sorted_probs = sorted(class_probs.items(), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
top_pred, top_prob = sorted_probs[0] |
|
|
|
|
|
top_3_predictions = [ |
|
{"label": label, "probability": prob} |
|
for label, prob in sorted_probs[:3] |
|
] |
|
|
|
transaction_pred[col] = { |
|
"top_prediction": { |
|
"label": top_pred, |
|
"probability": top_prob |
|
}, |
|
"alternative_predictions": top_3_predictions[1:], |
|
"all_probabilities": class_probs |
|
} |
|
|
|
predictions.append({ |
|
"transaction_id": row.get('Transaction_Id', f"transaction_{i}"), |
|
"predictions": transaction_pred |
|
}) |
|
|
|
return BatchPredictionResponse( |
|
message="Batch prediction completed successfully", |
|
predictions=predictions |
|
) |
|
|
|
finally: |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
|
|
elif request and request.transaction_data: |
|
input_data = pd.DataFrame([request.transaction_data.dict()]) |
|
|
|
text_input = f""" |
|
Transaction ID: {input_data['Transaction_Id'].iloc[0]} |
|
Origin: {input_data['Origin'].iloc[0]} |
|
Designation: {input_data['Designation'].iloc[0]} |
|
Keywords: {input_data['Keywords'].iloc[0]} |
|
Name: {input_data['Name'].iloc[0]} |
|
SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} |
|
Currency: {input_data['Currency'].iloc[0]} |
|
Entity: {input_data['Entity'].iloc[0]} |
|
Message: {input_data['Message'].iloc[0]} |
|
City: {input_data['City'].iloc[0]} |
|
Country: {input_data['Country'].iloc[0]} |
|
State: {input_data['State'].iloc[0]} |
|
Hit Type: {input_data['Hit_Type'].iloc[0]} |
|
Record Matching String: {input_data['Record_Matching_String'].iloc[0]} |
|
WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} |
|
Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} |
|
Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} |
|
Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} |
|
Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} |
|
Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} |
|
Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} |
|
Risk Level: {input_data['Risk_Level'].iloc[0]} |
|
Risk Score: {input_data['Risk_Score'].iloc[0]} |
|
CDD Level: {input_data['CDD_Level'].iloc[0]} |
|
PEP Status: {input_data['PEP_Status'].iloc[0]} |
|
Sanction Description: {input_data['Sanction_Description'].iloc[0]} |
|
Checker Notes: {input_data['Checker_Notes'].iloc[0]} |
|
Sanction Context: {input_data['Sanction_Context'].iloc[0]} |
|
""" |
|
|
|
dataset = ComplianceDataset( |
|
texts=[text_input], |
|
labels=[[0] * len(LABEL_COLUMNS)], |
|
tokenizer=tokenizer, |
|
max_len=MAX_LEN |
|
) |
|
|
|
loader = DataLoader(dataset, batch_size=1, shuffle=False) |
|
all_probabilities = predict_probabilities(model, loader) |
|
|
|
label_encoders = load_label_encoders() |
|
|
|
response = {} |
|
for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)): |
|
|
|
class_probs = { |
|
label: float(probs[0][j]) |
|
for j, label in enumerate(label_encoders[col].classes_) |
|
} |
|
|
|
|
|
sorted_probs = sorted(class_probs.items(), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
top_pred, top_prob = sorted_probs[0] |
|
|
|
|
|
top_3_predictions = [ |
|
{"label": label, "probability": prob} |
|
for label, prob in sorted_probs[:3] |
|
] |
|
|
|
response[col] = { |
|
"top_prediction": { |
|
"label": top_pred, |
|
"probability": top_prob |
|
}, |
|
"alternative_predictions": top_3_predictions[1:], |
|
"all_probabilities": class_probs |
|
} |
|
|
|
return response |
|
|
|
else: |
|
raise HTTPException( |
|
status_code=400, |
|
detail="Either provide a transaction in the request body or upload a CSV file" |
|
) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/v1/bert/download-model/{model_id}") |
|
async def download_model(model_id: str): |
|
"""Download a trained model""" |
|
model_path = MODEL_SAVE_DIR / f"{model_id}.pth" |
|
if not model_path.exists(): |
|
raise HTTPException(status_code=404, detail="Model not found") |
|
|
|
return FileResponse( |
|
path=model_path, |
|
filename=f"bert_model_{model_id}.pth", |
|
media_type="application/octet-stream" |
|
) |
|
|
|
async def train_model_task(config: TrainingConfig, file_path: str, training_id: str): |
|
try: |
|
data_df_original, label_encoders = load_and_preprocess_data(file_path) |
|
save_label_encoders(label_encoders) |
|
|
|
texts = data_df_original[TEXT_COLUMN] |
|
labels_array = data_df_original[LABEL_COLUMNS].values |
|
|
|
metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None |
|
|
|
num_labels_list = get_num_labels(label_encoders) |
|
tokenizer = get_tokenizer(config.model_name) |
|
|
|
if metadata_df is not None: |
|
metadata_dim = metadata_df.shape[1] |
|
dataset = ComplianceDatasetWithMetadata( |
|
texts.tolist(), |
|
metadata_df.values, |
|
labels_array, |
|
tokenizer, |
|
config.max_length |
|
) |
|
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE) |
|
else: |
|
dataset = ComplianceDataset( |
|
texts.tolist(), |
|
labels_array, |
|
tokenizer, |
|
config.max_length |
|
) |
|
model = BertMultiOutputModel(num_labels_list).to(DEVICE) |
|
|
|
train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) |
|
|
|
criterions = initialize_criterions(data_df_original, label_encoders) |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) |
|
|
|
for epoch in range(config.num_epochs): |
|
training_status["current_epoch"] = epoch + 1 |
|
|
|
train_loss = train_model(model, train_loader, optimizer, criterions, epoch) |
|
training_status["current_loss"] = train_loss |
|
|
|
|
|
save_model(model, training_id, 'pth') |
|
|
|
training_status.update({ |
|
"is_training": False, |
|
"end_time": datetime.now().isoformat(), |
|
"status": "completed" |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Training failed: {str(e)}") |
|
training_status.update({ |
|
"is_training": False, |
|
"end_time": datetime.now().isoformat(), |
|
"status": "failed", |
|
"error": str(e) |
|
}) |
|
|
|
if __name__ == "__main__": |
|
port = int(os.environ.get("PORT", 7861)) |
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|