namanpenguin commited on
Commit
87e48e1
·
verified ·
1 Parent(s): 5506da7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +542 -529
app.py CHANGED
@@ -1,529 +1,542 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
- from fastapi.responses import FileResponse
3
- from pydantic import BaseModel
4
- from typing import Optional, Dict, Any, List
5
- import uvicorn
6
- import torch
7
- from torch.utils.data import DataLoader
8
- import logging
9
- import os
10
- import asyncio
11
- import pandas as pd
12
- from datetime import datetime
13
- import shutil
14
- from pathlib import Path
15
- from sklearn.model_selection import train_test_split
16
- import zipfile
17
- import io
18
- import numpy as np
19
- import sys
20
- import json
21
-
22
-
23
- # Import existing utilities
24
- from dataset_utils import (
25
- ComplianceDataset,
26
- ComplianceDatasetWithMetadata,
27
- load_and_preprocess_data,
28
- get_tokenizer,
29
- save_label_encoders,
30
- get_num_labels,
31
- load_label_encoders
32
- )
33
- from train_utils import (
34
- initialize_criterions,
35
- train_model,
36
- evaluate_model,
37
- save_model,
38
- summarize_metrics,
39
- predict_probabilities
40
- )
41
- from models.roberta_model import RobertaMultiOutputModel
42
- from config import (
43
- TEXT_COLUMN,
44
- LABEL_COLUMNS,
45
- DEVICE,
46
- NUM_EPOCHS,
47
- LEARNING_RATE,
48
- MAX_LEN,
49
- BATCH_SIZE,
50
- METADATA_COLUMNS
51
- )
52
-
53
- # Configure logging
54
- logging.basicConfig(level=logging.INFO)
55
- logger = logging.getLogger(__name__)
56
-
57
- app = FastAPI(title="RoBERTa Compliance Predictor API")
58
-
59
- # Create necessary directories
60
- UPLOAD_DIR = Path("uploads")
61
- MODEL_SAVE_DIR = Path("saved_models")
62
- UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
63
- MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
64
-
65
- # Global variables to track training status
66
- training_status = {
67
- "is_training": False,
68
- "current_epoch": 0,
69
- "total_epochs": 0,
70
- "current_loss": 0.0,
71
- "start_time": None,
72
- "end_time": None,
73
- "status": "idle",
74
- "metrics": None
75
- }
76
-
77
- # Load the model and tokenizer for prediction
78
- model_path = "ROBERTA_model.pth"
79
- tokenizer = get_tokenizer('roberta-base')
80
- model = RobertaMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
81
- if os.path.exists(model_path):
82
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
83
- model.eval()
84
-
85
- class TrainingConfig(BaseModel):
86
- model_name: str = "roberta-base"
87
- batch_size: int = 8
88
- learning_rate: float = 2e-5
89
- num_epochs: int = 2
90
- max_length: int = 128
91
- random_state: int = 42
92
-
93
- class TrainingResponse(BaseModel):
94
- message: str
95
- training_id: str
96
- status: str
97
- download_url: Optional[str] = None
98
-
99
- class ValidationResponse(BaseModel):
100
- message: str
101
- metrics: Dict[str, Any]
102
- predictions: List[Dict[str, Any]]
103
-
104
- class TransactionData(BaseModel):
105
- Transaction_Id: str
106
- Hit_Seq: int
107
- Hit_Id_List: str
108
- Origin: str
109
- Designation: str
110
- Keywords: str
111
- Name: str
112
- SWIFT_Tag: str
113
- Currency: str
114
- Entity: str
115
- Message: str
116
- City: str
117
- Country: str
118
- State: str
119
- Hit_Type: str
120
- Record_Matching_String: str
121
- WatchList_Match_String: str
122
- Payment_Sender_Name: Optional[str] = ""
123
- Payment_Reciever_Name: Optional[str] = ""
124
- Swift_Message_Type: str
125
- Text_Sanction_Data: str
126
- Matched_Sanctioned_Entity: str
127
- Is_Match: int
128
- Red_Flag_Reason: str
129
- Risk_Level: str
130
- Risk_Score: float
131
- Risk_Score_Description: str
132
- CDD_Level: str
133
- PEP_Status: str
134
- Value_Date: str
135
- Last_Review_Date: str
136
- Next_Review_Date: str
137
- Sanction_Description: str
138
- Checker_Notes: str
139
- Sanction_Context: str
140
- Maker_Action: str
141
- Customer_ID: int
142
- Customer_Type: str
143
- Industry: str
144
- Transaction_Date_Time: str
145
- Transaction_Type: str
146
- Transaction_Channel: str
147
- Originating_Bank: str
148
- Beneficiary_Bank: str
149
- Geographic_Origin: str
150
- Geographic_Destination: str
151
- Match_Score: float
152
- Match_Type: str
153
- Sanctions_List_Version: str
154
- Screening_Date_Time: str
155
- Risk_Category: str
156
- Risk_Drivers: str
157
- Alert_Status: str
158
- Investigation_Outcome: str
159
- Case_Owner_Analyst: str
160
- Escalation_Level: str
161
- Escalation_Date: str
162
- Regulatory_Reporting_Flags: bool
163
- Audit_Trail_Timestamp: str
164
- Source_Of_Funds: str
165
- Purpose_Of_Transaction: str
166
- Beneficial_Owner: str
167
- Sanctions_Exposure_History: bool
168
-
169
- class PredictionRequest(BaseModel):
170
- transaction_data: TransactionData
171
- model_name: str = "ROBERTA_model" # Default to RoBERTa_model if not specified
172
-
173
- class BatchPredictionResponse(BaseModel):
174
- message: str
175
- predictions: List[Dict[str, Any]]
176
- metrics: Optional[Dict[str, Any]] = None
177
-
178
- @app.get("/")
179
- async def root():
180
- return {"message": "RoBERTa Compliance Predictor API"}
181
-
182
- @app.get("/v1/roberta/health")
183
- async def health_check():
184
- return {"status": "healthy"}
185
-
186
- @app.get("/v1/roberta/training-status")
187
- async def get_training_status():
188
- return training_status
189
-
190
- @app.post("/v1/roberta/train", response_model=TrainingResponse)
191
- async def start_training(
192
- config: str = Form(...),
193
- background_tasks: BackgroundTasks = None,
194
- file: UploadFile = File(...)
195
- ):
196
- if training_status["is_training"]:
197
- raise HTTPException(status_code=400, detail="Training is already in progress")
198
-
199
- if not file.filename.endswith('.csv'):
200
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
201
-
202
- try:
203
- # Parse the config JSON string into a TrainingConfig object
204
- config_dict = json.loads(config)
205
- training_config = TrainingConfig(**config_dict)
206
- except json.JSONDecodeError:
207
- raise HTTPException(status_code=400, detail="Invalid config JSON format")
208
- except Exception as e:
209
- raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
210
-
211
- file_path = UPLOAD_DIR / file.filename
212
- with file_path.open("wb") as buffer:
213
- shutil.copyfileobj(file.file, buffer)
214
-
215
- training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
216
-
217
- training_status.update({
218
- "is_training": True,
219
- "current_epoch": 0,
220
- "total_epochs": training_config.num_epochs,
221
- "start_time": datetime.now().isoformat(),
222
- "status": "starting"
223
- })
224
-
225
- background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
226
-
227
- download_url = f"/v1/roberta/download-model/{training_id}"
228
-
229
- return TrainingResponse(
230
- message="Training started successfully",
231
- training_id=training_id,
232
- status="started",
233
- download_url=download_url
234
- )
235
-
236
- @app.post("/v1/roberta/validate")
237
- async def validate_model(
238
- file: UploadFile = File(...),
239
- model_name: str = "ROBERTA_model"
240
- ):
241
- """Validate a RoBERTa model on uploaded data"""
242
- if not file.filename.endswith('.csv'):
243
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
244
-
245
- try:
246
- file_path = UPLOAD_DIR / file.filename
247
- with file_path.open("wb") as buffer:
248
- shutil.copyfileobj(file.file, buffer)
249
-
250
- data_df, label_encoders = load_and_preprocess_data(str(file_path))
251
-
252
- model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
253
- if not model_path.exists():
254
- raise HTTPException(status_code=404, detail="RoBERTa model file not found")
255
-
256
- num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
257
- metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
258
-
259
- if metadata_df is not None:
260
- metadata_dim = metadata_df.shape[1]
261
- model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
262
- else:
263
- model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
264
-
265
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
266
- model.eval()
267
-
268
- texts = data_df[TEXT_COLUMN]
269
- labels_array = data_df[LABEL_COLUMNS].values
270
- tokenizer = get_tokenizer("roberta-base")
271
-
272
- if metadata_df is not None:
273
- dataset = ComplianceDatasetWithMetadata(
274
- texts.tolist(),
275
- metadata_df.values,
276
- labels_array,
277
- tokenizer,
278
- MAX_LEN
279
- )
280
- else:
281
- dataset = ComplianceDataset(
282
- texts.tolist(),
283
- labels_array,
284
- tokenizer,
285
- MAX_LEN
286
- )
287
-
288
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
289
- metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
290
- summary_metrics = summarize_metrics(metrics).to_dict()
291
-
292
- all_probs = predict_probabilities(model, dataloader)
293
-
294
- predictions = []
295
- for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
296
- field = LABEL_COLUMNS[i]
297
- label_encoder = label_encoders[field]
298
- true_labels_orig = label_encoder.inverse_transform(true_labels)
299
- pred_labels_orig = label_encoder.inverse_transform(pred_labels)
300
-
301
- for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
302
- predictions.append({
303
- "field": field,
304
- "true_label": true,
305
- "predicted_label": pred,
306
- "probabilities": probs.tolist()
307
- })
308
-
309
- return ValidationResponse(
310
- message="Validation completed successfully",
311
- metrics=summary_metrics,
312
- predictions=predictions
313
- )
314
-
315
- except Exception as e:
316
- logger.error(f"Validation failed: {str(e)}")
317
- raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
318
- finally:
319
- if os.path.exists(file_path):
320
- os.remove(file_path)
321
-
322
- @app.post("/v1/roberta/predict")
323
- async def predict(
324
- request: Optional[PredictionRequest] = None,
325
- file: UploadFile = File(None),
326
- model_name: str = "ROBERTA_model"
327
- ):
328
- """
329
- Make predictions on either a single transaction or a batch of transactions from a CSV file.
330
-
331
- You can either:
332
- 1. Send a single transaction in the request body
333
- 2. Upload a CSV file with multiple transactions
334
-
335
- Parameters:
336
- - file: CSV file containing transactions for batch prediction
337
- - model_name: Name of the model to use for prediction (default: "ROBERTA_model")
338
- """
339
- try:
340
- # Load the model
341
- model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
342
- if not model_path.exists():
343
- raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
344
-
345
- num_labels_list = [len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]
346
- model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
347
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
348
- model.eval()
349
-
350
- # Handle batch prediction from CSV
351
- if file and file.filename:
352
- if not file.filename.endswith('.csv'):
353
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
354
-
355
- file_path = UPLOAD_DIR / file.filename
356
- with file_path.open("wb") as buffer:
357
- shutil.copyfileobj(file.file, buffer)
358
-
359
- try:
360
- # Load and preprocess the CSV data
361
- data_df, _ = load_and_preprocess_data(str(file_path))
362
- texts = data_df[TEXT_COLUMN]
363
-
364
- # Create dataset and dataloader
365
- dataset = ComplianceDataset(
366
- texts.tolist(),
367
- [[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction
368
- tokenizer,
369
- MAX_LEN
370
- )
371
- loader = DataLoader(dataset, batch_size=BATCH_SIZE)
372
-
373
- # Get predictions
374
- all_probabilities = predict_probabilities(model, loader)
375
- label_encoders = load_label_encoders()
376
-
377
- # Process predictions
378
- predictions = []
379
- for i, row in data_df.iterrows():
380
- transaction_pred = {}
381
- for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
382
- pred = np.argmax(probs[i])
383
- decoded_pred = label_encoders[col].inverse_transform([pred])[0]
384
-
385
- class_probs = {
386
- label: float(probs[i][j])
387
- for j, label in enumerate(label_encoders[col].classes_)
388
- }
389
-
390
- transaction_pred[col] = {
391
- "prediction": decoded_pred,
392
- "probabilities": class_probs
393
- }
394
-
395
- predictions.append({
396
- "transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
397
- "predictions": transaction_pred
398
- })
399
-
400
- return BatchPredictionResponse(
401
- message="Batch prediction completed successfully",
402
- predictions=predictions
403
- )
404
-
405
- finally:
406
- if os.path.exists(file_path):
407
- os.remove(file_path)
408
-
409
- # Handle single prediction
410
- elif request and request.transaction_data:
411
- input_data = pd.DataFrame([request.transaction_data.dict()])
412
-
413
- text_input = f"<s>Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]}</s>"
414
-
415
- dataset = ComplianceDataset(
416
- texts=[text_input],
417
- labels=[[0] * len(LABEL_COLUMNS)],
418
- tokenizer=tokenizer,
419
- max_len=MAX_LEN
420
- )
421
-
422
- loader = DataLoader(dataset, batch_size=1, shuffle=False)
423
- all_probabilities = predict_probabilities(model, loader)
424
-
425
- label_encoders = load_label_encoders()
426
-
427
- response = {}
428
- for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
429
- pred = np.argmax(probs[0])
430
- decoded_pred = label_encoders[col].inverse_transform([pred])[0]
431
-
432
- class_probs = {
433
- label: float(probs[0][j])
434
- for j, label in enumerate(label_encoders[col].classes_)
435
- }
436
-
437
- response[col] = {
438
- "prediction": decoded_pred,
439
- "probabilities": class_probs
440
- }
441
-
442
- return response
443
-
444
- else:
445
- raise HTTPException(
446
- status_code=400,
447
- detail="Either provide a transaction in the request body or upload a CSV file"
448
- )
449
-
450
- except Exception as e:
451
- raise HTTPException(status_code=500, detail=str(e))
452
-
453
- @app.get("/v1/roberta/download-model/{model_id}")
454
- async def download_model(model_id: str):
455
- """Download a trained model"""
456
- model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
457
- if not model_path.exists():
458
- raise HTTPException(status_code=404, detail="Model not found")
459
-
460
- return FileResponse(
461
- path=model_path,
462
- filename=f"roberta_model_{model_id}.pth",
463
- media_type="application/octet-stream"
464
- )
465
-
466
- async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
467
- try:
468
- data_df_original, label_encoders = load_and_preprocess_data(file_path)
469
- save_label_encoders(label_encoders)
470
-
471
- texts = data_df_original[TEXT_COLUMN]
472
- labels_array = data_df_original[LABEL_COLUMNS].values
473
-
474
- metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
475
-
476
- num_labels_list = get_num_labels(label_encoders)
477
- tokenizer = get_tokenizer(config.model_name)
478
-
479
- if metadata_df is not None:
480
- metadata_dim = metadata_df.shape[1]
481
- dataset = ComplianceDatasetWithMetadata(
482
- texts.tolist(),
483
- metadata_df.values,
484
- labels_array,
485
- tokenizer,
486
- config.max_length
487
- )
488
- model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
489
- else:
490
- dataset = ComplianceDataset(
491
- texts.tolist(),
492
- labels_array,
493
- tokenizer,
494
- config.max_length
495
- )
496
- model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
497
-
498
- train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
499
-
500
- criterions = initialize_criterions(num_labels_list)
501
- optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
502
-
503
- for epoch in range(config.num_epochs):
504
- training_status["current_epoch"] = epoch + 1
505
-
506
- train_loss = train_model(model, train_loader, criterions, optimizer)
507
- training_status["current_loss"] = train_loss
508
-
509
- # Save model after each epoch
510
- save_model(model, training_id)
511
-
512
- training_status.update({
513
- "is_training": False,
514
- "end_time": datetime.now().isoformat(),
515
- "status": "completed"
516
- })
517
-
518
- except Exception as e:
519
- logger.error(f"Training failed: {str(e)}")
520
- training_status.update({
521
- "is_training": False,
522
- "end_time": datetime.now().isoformat(),
523
- "status": "failed",
524
- "error": str(e)
525
- })
526
-
527
- if __name__ == "__main__":
528
- port = int(os.environ.get("PORT", 7860))
529
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ import logging
9
+ import os
10
+ import asyncio
11
+ import pandas as pd
12
+ from datetime import datetime
13
+ import shutil
14
+ from pathlib import Path
15
+ from sklearn.model_selection import train_test_split
16
+ import zipfile
17
+ import io
18
+ import numpy as np
19
+ import sys
20
+ import json
21
+
22
+
23
+ # Import existing utilities
24
+ from dataset_utils import (
25
+ ComplianceDataset,
26
+ ComplianceDatasetWithMetadata,
27
+ load_and_preprocess_data,
28
+ get_tokenizer,
29
+ save_label_encoders,
30
+ get_num_labels,
31
+ load_label_encoders
32
+ )
33
+ from train_utils import (
34
+ initialize_criterions,
35
+ train_model,
36
+ evaluate_model,
37
+ save_model,
38
+ summarize_metrics,
39
+ predict_probabilities
40
+ )
41
+ from models.roberta_model import RobertaMultiOutputModel
42
+ from config import (
43
+ TEXT_COLUMN,
44
+ LABEL_COLUMNS,
45
+ DEVICE,
46
+ NUM_EPOCHS,
47
+ LEARNING_RATE,
48
+ MAX_LEN,
49
+ BATCH_SIZE,
50
+ METADATA_COLUMNS
51
+ )
52
+
53
+ # Configure logging
54
+ logging.basicConfig(level=logging.INFO)
55
+ logger = logging.getLogger(__name__)
56
+
57
+ app = FastAPI(title="RoBERTa Compliance Predictor API")
58
+
59
+ # Create necessary directories
60
+ UPLOAD_DIR = Path("uploads")
61
+ MODEL_SAVE_DIR = Path("saved_models")
62
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
63
+ MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
64
+
65
+ # Global variables to track training status
66
+ training_status = {
67
+ "is_training": False,
68
+ "current_epoch": 0,
69
+ "total_epochs": 0,
70
+ "current_loss": 0.0,
71
+ "start_time": None,
72
+ "end_time": None,
73
+ "status": "idle",
74
+ "metrics": None
75
+ }
76
+
77
+ # Load the model and tokenizer for prediction
78
+ model_path = MODEL_SAVE_DIR / "ROBERTA_model.pth"
79
+ tokenizer = get_tokenizer('roberta-base')
80
+
81
+ # Initialize model and label encoders with error handling
82
+ try:
83
+ label_encoders = load_label_encoders()
84
+ model = RobertaMultiOutputModel([len(label_encoders[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
85
+ if model_path.exists():
86
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
87
+ model.eval()
88
+ else:
89
+ print(f"Warning: Model file {model_path} not found. Model will be initialized but not loaded.")
90
+ except Exception as e:
91
+ print(f"Warning: Could not load label encoders or model: {str(e)}")
92
+ print("Model will be initialized when training starts.")
93
+ model = None
94
+
95
+ class TrainingConfig(BaseModel):
96
+ model_name: str = "roberta-base"
97
+ batch_size: int = 8
98
+ learning_rate: float = 2e-5
99
+ num_epochs: int = 2
100
+ max_length: int = 128
101
+ random_state: int = 42
102
+
103
+ class TrainingResponse(BaseModel):
104
+ message: str
105
+ training_id: str
106
+ status: str
107
+ download_url: Optional[str] = None
108
+
109
+ class ValidationResponse(BaseModel):
110
+ message: str
111
+ metrics: Dict[str, Any]
112
+ predictions: List[Dict[str, Any]]
113
+
114
+ class TransactionData(BaseModel):
115
+ Transaction_Id: str
116
+ Hit_Seq: int
117
+ Hit_Id_List: str
118
+ Origin: str
119
+ Designation: str
120
+ Keywords: str
121
+ Name: str
122
+ SWIFT_Tag: str
123
+ Currency: str
124
+ Entity: str
125
+ Message: str
126
+ City: str
127
+ Country: str
128
+ State: str
129
+ Hit_Type: str
130
+ Record_Matching_String: str
131
+ WatchList_Match_String: str
132
+ Payment_Sender_Name: Optional[str] = ""
133
+ Payment_Reciever_Name: Optional[str] = ""
134
+ Swift_Message_Type: str
135
+ Text_Sanction_Data: str
136
+ Matched_Sanctioned_Entity: str
137
+ Is_Match: int
138
+ Red_Flag_Reason: str
139
+ Risk_Level: str
140
+ Risk_Score: float
141
+ Risk_Score_Description: str
142
+ CDD_Level: str
143
+ PEP_Status: str
144
+ Value_Date: str
145
+ Last_Review_Date: str
146
+ Next_Review_Date: str
147
+ Sanction_Description: str
148
+ Checker_Notes: str
149
+ Sanction_Context: str
150
+ Maker_Action: str
151
+ Customer_ID: int
152
+ Customer_Type: str
153
+ Industry: str
154
+ Transaction_Date_Time: str
155
+ Transaction_Type: str
156
+ Transaction_Channel: str
157
+ Originating_Bank: str
158
+ Beneficiary_Bank: str
159
+ Geographic_Origin: str
160
+ Geographic_Destination: str
161
+ Match_Score: float
162
+ Match_Type: str
163
+ Sanctions_List_Version: str
164
+ Screening_Date_Time: str
165
+ Risk_Category: str
166
+ Risk_Drivers: str
167
+ Alert_Status: str
168
+ Investigation_Outcome: str
169
+ Case_Owner_Analyst: str
170
+ Escalation_Level: str
171
+ Escalation_Date: str
172
+ Regulatory_Reporting_Flags: bool
173
+ Audit_Trail_Timestamp: str
174
+ Source_Of_Funds: str
175
+ Purpose_Of_Transaction: str
176
+ Beneficial_Owner: str
177
+ Sanctions_Exposure_History: bool
178
+
179
+ class PredictionRequest(BaseModel):
180
+ transaction_data: TransactionData
181
+ model_name: str = "ROBERTA_model" # Default to RoBERTa_model if not specified
182
+
183
+ class BatchPredictionResponse(BaseModel):
184
+ message: str
185
+ predictions: List[Dict[str, Any]]
186
+ metrics: Optional[Dict[str, Any]] = None
187
+
188
+ @app.get("/")
189
+ async def root():
190
+ return {"message": "RoBERTa Compliance Predictor API"}
191
+
192
+ @app.get("/v1/roberta/health")
193
+ async def health_check():
194
+ return {"status": "healthy"}
195
+
196
+ @app.get("/v1/roberta/training-status")
197
+ async def get_training_status():
198
+ return training_status
199
+
200
+ @app.post("/v1/roberta/train", response_model=TrainingResponse)
201
+ async def start_training(
202
+ config: str = Form(...),
203
+ background_tasks: BackgroundTasks = None,
204
+ file: UploadFile = File(...)
205
+ ):
206
+ if training_status["is_training"]:
207
+ raise HTTPException(status_code=400, detail="Training is already in progress")
208
+
209
+ if not file.filename.endswith('.csv'):
210
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
211
+
212
+ try:
213
+ # Parse the config JSON string into a TrainingConfig object
214
+ config_dict = json.loads(config)
215
+ training_config = TrainingConfig(**config_dict)
216
+ except json.JSONDecodeError:
217
+ raise HTTPException(status_code=400, detail="Invalid config JSON format")
218
+ except Exception as e:
219
+ raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
220
+
221
+ file_path = UPLOAD_DIR / file.filename
222
+ with file_path.open("wb") as buffer:
223
+ shutil.copyfileobj(file.file, buffer)
224
+
225
+ training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
226
+
227
+ training_status.update({
228
+ "is_training": True,
229
+ "current_epoch": 0,
230
+ "total_epochs": training_config.num_epochs,
231
+ "start_time": datetime.now().isoformat(),
232
+ "status": "starting"
233
+ })
234
+
235
+ background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
236
+
237
+ download_url = f"/v1/roberta/download-model/{training_id}"
238
+
239
+ return TrainingResponse(
240
+ message="Training started successfully",
241
+ training_id=training_id,
242
+ status="started",
243
+ download_url=download_url
244
+ )
245
+
246
+ @app.post("/v1/roberta/validate")
247
+ async def validate_model(
248
+ file: UploadFile = File(...),
249
+ model_name: str = "ROBERTA_model"
250
+ ):
251
+ """Validate a RoBERTa model on uploaded data"""
252
+ if not file.filename.endswith('.csv'):
253
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
254
+
255
+ try:
256
+ file_path = UPLOAD_DIR / file.filename
257
+ with file_path.open("wb") as buffer:
258
+ shutil.copyfileobj(file.file, buffer)
259
+
260
+ data_df, label_encoders = load_and_preprocess_data(str(file_path))
261
+
262
+ model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
263
+ if not model_path.exists():
264
+ raise HTTPException(status_code=404, detail="RoBERTa model file not found")
265
+
266
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
267
+ metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
268
+
269
+ if metadata_df is not None:
270
+ metadata_dim = metadata_df.shape[1]
271
+ model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
272
+ else:
273
+ model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
274
+
275
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
276
+ model.eval()
277
+
278
+ texts = data_df[TEXT_COLUMN]
279
+ labels_array = data_df[LABEL_COLUMNS].values
280
+ tokenizer = get_tokenizer("roberta-base")
281
+
282
+ if metadata_df is not None:
283
+ dataset = ComplianceDatasetWithMetadata(
284
+ texts.tolist(),
285
+ metadata_df.values,
286
+ labels_array,
287
+ tokenizer,
288
+ MAX_LEN
289
+ )
290
+ else:
291
+ dataset = ComplianceDataset(
292
+ texts.tolist(),
293
+ labels_array,
294
+ tokenizer,
295
+ MAX_LEN
296
+ )
297
+
298
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
299
+ metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
300
+ summary_metrics = summarize_metrics(metrics).to_dict()
301
+
302
+ all_probs = predict_probabilities(model, dataloader)
303
+
304
+ predictions = []
305
+ for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
306
+ field = LABEL_COLUMNS[i]
307
+ label_encoder = label_encoders[field]
308
+ true_labels_orig = label_encoder.inverse_transform(true_labels)
309
+ pred_labels_orig = label_encoder.inverse_transform(pred_labels)
310
+
311
+ for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
312
+ predictions.append({
313
+ "field": field,
314
+ "true_label": true,
315
+ "predicted_label": pred,
316
+ "probabilities": probs.tolist()
317
+ })
318
+
319
+ return ValidationResponse(
320
+ message="Validation completed successfully",
321
+ metrics=summary_metrics,
322
+ predictions=predictions
323
+ )
324
+
325
+ except Exception as e:
326
+ logger.error(f"Validation failed: {str(e)}")
327
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
328
+ finally:
329
+ if os.path.exists(file_path):
330
+ os.remove(file_path)
331
+
332
+ @app.post("/v1/roberta/predict")
333
+ async def predict(
334
+ request: Optional[PredictionRequest] = None,
335
+ file: UploadFile = File(None),
336
+ model_name: str = "ROBERTA_model"
337
+ ):
338
+ """
339
+ Make predictions on either a single transaction or a batch of transactions from a CSV file.
340
+
341
+ You can either:
342
+ 1. Send a single transaction in the request body
343
+ 2. Upload a CSV file with multiple transactions
344
+
345
+ Parameters:
346
+ - file: CSV file containing transactions for batch prediction
347
+ - model_name: Name of the model to use for prediction (default: "ROBERTA_model")
348
+ """
349
+ try:
350
+ # Load the model
351
+ model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
352
+ if not model_path.exists():
353
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
354
+
355
+ # Load label encoders
356
+ try:
357
+ label_encoders = load_label_encoders()
358
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
359
+ except Exception as e:
360
+ raise HTTPException(status_code=500, detail=f"Could not load label encoders: {str(e)}")
361
+
362
+ model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
363
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
364
+ model.eval()
365
+
366
+ # Handle batch prediction from CSV
367
+ if file and file.filename:
368
+ if not file.filename.endswith('.csv'):
369
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
370
+
371
+ file_path = UPLOAD_DIR / file.filename
372
+ with file_path.open("wb") as buffer:
373
+ shutil.copyfileobj(file.file, buffer)
374
+
375
+ try:
376
+ # Load and preprocess the CSV data
377
+ data_df, _ = load_and_preprocess_data(str(file_path))
378
+ texts = data_df[TEXT_COLUMN]
379
+
380
+ # Create dataset and dataloader
381
+ dataset = ComplianceDataset(
382
+ texts.tolist(),
383
+ [[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction
384
+ tokenizer,
385
+ MAX_LEN
386
+ )
387
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE)
388
+
389
+ # Get predictions
390
+ all_probabilities = predict_probabilities(model, loader)
391
+
392
+ # Process predictions
393
+ predictions = []
394
+ for i, row in data_df.iterrows():
395
+ transaction_pred = {}
396
+ for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
397
+ pred = np.argmax(probs[i])
398
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
399
+
400
+ class_probs = {
401
+ label: float(probs[i][j])
402
+ for j, label in enumerate(label_encoders[col].classes_)
403
+ }
404
+
405
+ transaction_pred[col] = {
406
+ "prediction": decoded_pred,
407
+ "probabilities": class_probs
408
+ }
409
+
410
+ predictions.append({
411
+ "transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
412
+ "predictions": transaction_pred
413
+ })
414
+
415
+ return BatchPredictionResponse(
416
+ message="Batch prediction completed successfully",
417
+ predictions=predictions
418
+ )
419
+
420
+ finally:
421
+ if os.path.exists(file_path):
422
+ os.remove(file_path)
423
+
424
+ # Handle single prediction
425
+ elif request and request.transaction_data:
426
+ input_data = pd.DataFrame([request.transaction_data.dict()])
427
+
428
+ text_input = f"<s>Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]}</s>"
429
+
430
+ dataset = ComplianceDataset(
431
+ texts=[text_input],
432
+ labels=[[0] * len(LABEL_COLUMNS)],
433
+ tokenizer=tokenizer,
434
+ max_len=MAX_LEN
435
+ )
436
+
437
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
438
+ all_probabilities = predict_probabilities(model, loader)
439
+
440
+ response = {}
441
+ for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
442
+ pred = np.argmax(probs[0])
443
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
444
+
445
+ class_probs = {
446
+ label: float(probs[0][j])
447
+ for j, label in enumerate(label_encoders[col].classes_)
448
+ }
449
+
450
+ response[col] = {
451
+ "prediction": decoded_pred,
452
+ "probabilities": class_probs
453
+ }
454
+
455
+ return response
456
+
457
+ else:
458
+ raise HTTPException(
459
+ status_code=400,
460
+ detail="Either provide a transaction in the request body or upload a CSV file"
461
+ )
462
+
463
+ except Exception as e:
464
+ raise HTTPException(status_code=500, detail=str(e))
465
+
466
+ @app.get("/v1/roberta/download-model/{model_id}")
467
+ async def download_model(model_id: str):
468
+ """Download a trained model"""
469
+ model_path = MODEL_SAVE_DIR / f"{model_id}_model.pth"
470
+ if not model_path.exists():
471
+ raise HTTPException(status_code=404, detail="Model not found")
472
+
473
+ return FileResponse(
474
+ path=model_path,
475
+ filename=f"roberta_model_{model_id}.pth",
476
+ media_type="application/octet-stream"
477
+ )
478
+
479
+ async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
480
+ try:
481
+ data_df_original, label_encoders = load_and_preprocess_data(file_path)
482
+ save_label_encoders(label_encoders)
483
+
484
+ texts = data_df_original[TEXT_COLUMN]
485
+ labels_array = data_df_original[LABEL_COLUMNS].values
486
+
487
+ metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
488
+
489
+ num_labels_list = get_num_labels(label_encoders)
490
+ tokenizer = get_tokenizer(config.model_name)
491
+
492
+ if metadata_df is not None:
493
+ metadata_dim = metadata_df.shape[1]
494
+ dataset = ComplianceDatasetWithMetadata(
495
+ texts.tolist(),
496
+ metadata_df.values,
497
+ labels_array,
498
+ tokenizer,
499
+ config.max_length
500
+ )
501
+ model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
502
+ else:
503
+ dataset = ComplianceDataset(
504
+ texts.tolist(),
505
+ labels_array,
506
+ tokenizer,
507
+ config.max_length
508
+ )
509
+ model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
510
+
511
+ train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
512
+
513
+ criterions = initialize_criterions(num_labels_list)
514
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
515
+
516
+ for epoch in range(config.num_epochs):
517
+ training_status["current_epoch"] = epoch + 1
518
+
519
+ train_loss = train_model(model, train_loader, criterions, optimizer)
520
+ training_status["current_loss"] = train_loss
521
+
522
+ # Save model after each epoch
523
+ save_model(model, training_id, 'pth')
524
+
525
+ training_status.update({
526
+ "is_training": False,
527
+ "end_time": datetime.now().isoformat(),
528
+ "status": "completed"
529
+ })
530
+
531
+ except Exception as e:
532
+ logger.error(f"Training failed: {str(e)}")
533
+ training_status.update({
534
+ "is_training": False,
535
+ "end_time": datetime.now().isoformat(),
536
+ "status": "failed",
537
+ "error": str(e)
538
+ })
539
+
540
+ if __name__ == "__main__":
541
+ port = int(os.environ.get("PORT", 7860))
542
+ uvicorn.run(app, host="0.0.0.0", port=port)