namanpenguin commited on
Commit
da6cfe2
·
verified ·
1 Parent(s): 066dbd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +600 -504
app.py CHANGED
@@ -1,504 +1,600 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
2
- from fastapi.responses import FileResponse
3
- from pydantic import BaseModel
4
- from typing import Optional, Dict, Any, List
5
- import uvicorn
6
- import torch
7
- from transformers import BertTokenizer, BertForSequenceClassification
8
- from torch.utils.data import DataLoader
9
- import logging
10
- import os
11
- import asyncio
12
- import pandas as pd
13
- from datetime import datetime
14
- import shutil
15
- from pathlib import Path
16
- from sklearn.model_selection import train_test_split
17
- import zipfile
18
- import io
19
- import numpy as np
20
- import sys
21
-
22
-
23
- # Import existing utilities
24
- from dataset_utils import (
25
- ComplianceDataset,
26
- ComplianceDatasetWithMetadata,
27
- load_and_preprocess_data,
28
- get_tokenizer,
29
- save_label_encoders,
30
- get_num_labels,
31
- load_label_encoders
32
- )
33
- from train_utils import (
34
- initialize_criterions,
35
- train_model,
36
- evaluate_model,
37
- save_model,
38
- summarize_metrics,
39
- predict_probabilities
40
- )
41
- from models.bert_model import BertMultiOutputModel
42
- from config import (
43
- TEXT_COLUMN,
44
- LABEL_COLUMNS,
45
- DEVICE,
46
- NUM_EPOCHS,
47
- LEARNING_RATE,
48
- MAX_LEN,
49
- BATCH_SIZE,
50
- METADATA_COLUMNS
51
- )
52
-
53
- # Configure logging
54
- logging.basicConfig(level=logging.INFO)
55
- logger = logging.getLogger(__name__)
56
-
57
- app = FastAPI(title="BERT Compliance Predictor API")
58
-
59
- # Create necessary directories
60
- UPLOAD_DIR = Path("uploads")
61
- MODEL_SAVE_DIR = Path("saved_models")
62
- UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
63
- MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
64
-
65
- # Global variables to track training status
66
- training_status = {
67
- "is_training": False,
68
- "current_epoch": 0,
69
- "total_epochs": 0,
70
- "current_loss": 0.0,
71
- "start_time": None,
72
- "end_time": None,
73
- "status": "idle",
74
- "metrics": None
75
- }
76
-
77
- # Load the model and tokenizer for prediction
78
- model_path = "BERT_model.pth"
79
- tokenizer = get_tokenizer('bert-base-uncased')
80
- model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
81
- if os.path.exists(model_path):
82
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
83
- model.eval()
84
-
85
- class TrainingConfig(BaseModel):
86
- model_name: str = "bert-base-uncased"
87
- batch_size: int = 8
88
- learning_rate: float = 2e-5
89
- num_epochs: int = 2
90
- max_length: int = 128
91
- test_size: float = 0.2
92
- random_state: int = 42
93
-
94
- class TrainingResponse(BaseModel):
95
- message: str
96
- training_id: str
97
- status: str
98
- download_url: Optional[str] = None
99
-
100
- class ValidationResponse(BaseModel):
101
- message: str
102
- metrics: Dict[str, Any]
103
- predictions: List[Dict[str, Any]]
104
-
105
- class TransactionData(BaseModel):
106
- Transaction_Id: str
107
- Hit_Seq: int
108
- Hit_Id_List: str
109
- Origin: str
110
- Designation: str
111
- Keywords: str
112
- Name: str
113
- SWIFT_Tag: str
114
- Currency: str
115
- Entity: str
116
- Message: str
117
- City: str
118
- Country: str
119
- State: str
120
- Hit_Type: str
121
- Record_Matching_String: str
122
- WatchList_Match_String: str
123
- Payment_Sender_Name: Optional[str] = ""
124
- Payment_Reciever_Name: Optional[str] = ""
125
- Swift_Message_Type: str
126
- Text_Sanction_Data: str
127
- Matched_Sanctioned_Entity: str
128
- Is_Match: int
129
- Red_Flag_Reason: str
130
- Risk_Level: str
131
- Risk_Score: float
132
- Risk_Score_Description: str
133
- CDD_Level: str
134
- PEP_Status: str
135
- Value_Date: str
136
- Last_Review_Date: str
137
- Next_Review_Date: str
138
- Sanction_Description: str
139
- Checker_Notes: str
140
- Sanction_Context: str
141
- Maker_Action: str
142
- Customer_ID: int
143
- Customer_Type: str
144
- Industry: str
145
- Transaction_Date_Time: str
146
- Transaction_Type: str
147
- Transaction_Channel: str
148
- Originating_Bank: str
149
- Beneficiary_Bank: str
150
- Geographic_Origin: str
151
- Geographic_Destination: str
152
- Match_Score: float
153
- Match_Type: str
154
- Sanctions_List_Version: str
155
- Screening_Date_Time: str
156
- Risk_Category: str
157
- Risk_Drivers: str
158
- Alert_Status: str
159
- Investigation_Outcome: str
160
- Case_Owner_Analyst: str
161
- Escalation_Level: str
162
- Escalation_Date: str
163
- Regulatory_Reporting_Flags: bool
164
- Audit_Trail_Timestamp: str
165
- Source_Of_Funds: str
166
- Purpose_Of_Transaction: str
167
- Beneficial_Owner: str
168
- Sanctions_Exposure_History: bool
169
-
170
- class PredictionRequest(BaseModel):
171
- transaction_data: TransactionData
172
-
173
- @app.get("/")
174
- async def root():
175
- return {"message": "BERT Compliance Predictor API"}
176
-
177
- @app.get("/health")
178
- async def health_check():
179
- return {"status": "healthy"}
180
-
181
- @app.get("/training-status")
182
- async def get_training_status():
183
- return training_status
184
-
185
- @app.post("/upload")
186
- async def upload_file(file: UploadFile = File(...)):
187
- """Upload a CSV file for training or validation"""
188
- if not file.filename.endswith('.csv'):
189
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
190
-
191
- file_path = UPLOAD_DIR / file.filename
192
- with file_path.open("wb") as buffer:
193
- shutil.copyfileobj(file.file, buffer)
194
-
195
- return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
196
-
197
- @app.post("/bert/train", response_model=TrainingResponse)
198
- async def start_training(
199
- config: TrainingConfig,
200
- background_tasks: BackgroundTasks,
201
- file_path: str
202
- ):
203
- if training_status["is_training"]:
204
- raise HTTPException(status_code=400, detail="Training is already in progress")
205
-
206
- if not os.path.exists(file_path):
207
- raise HTTPException(status_code=404, detail="Training file not found")
208
-
209
- training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
210
-
211
- training_status.update({
212
- "is_training": True,
213
- "current_epoch": 0,
214
- "total_epochs": config.num_epochs,
215
- "start_time": datetime.now().isoformat(),
216
- "status": "starting"
217
- })
218
-
219
- background_tasks.add_task(train_model_task, config, file_path, training_id)
220
-
221
- download_url = f"/bert/download-model/{training_id}"
222
-
223
- return TrainingResponse(
224
- message="Training started successfully",
225
- training_id=training_id,
226
- status="started",
227
- download_url=download_url
228
- )
229
-
230
- @app.post("/bert/validate")
231
- async def validate_model(
232
- file: UploadFile = File(...),
233
- model_name: str = "bert_model_latest"
234
- ):
235
- """Validate a BERT model on uploaded data"""
236
- if not file.filename.endswith('.csv'):
237
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
238
-
239
- try:
240
- file_path = UPLOAD_DIR / file.filename
241
- with file_path.open("wb") as buffer:
242
- shutil.copyfileobj(file.file, buffer)
243
-
244
- data_df, label_encoders = load_and_preprocess_data(str(file_path))
245
-
246
- model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
247
- if not model_path.exists():
248
- raise HTTPException(status_code=404, detail="BERT model file not found")
249
-
250
- num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
251
- metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
252
-
253
- if metadata_df is not None:
254
- metadata_dim = metadata_df.shape[1]
255
- model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
256
- else:
257
- model = BertMultiOutputModel(num_labels_list).to(DEVICE)
258
-
259
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
260
- model.eval()
261
-
262
- texts = data_df[TEXT_COLUMN]
263
- labels_array = data_df[LABEL_COLUMNS].values
264
- tokenizer = get_tokenizer("bert-base-uncased")
265
-
266
- if metadata_df is not None:
267
- dataset = ComplianceDatasetWithMetadata(
268
- texts.tolist(),
269
- metadata_df.values,
270
- labels_array,
271
- tokenizer,
272
- MAX_LEN
273
- )
274
- else:
275
- dataset = ComplianceDataset(
276
- texts.tolist(),
277
- labels_array,
278
- tokenizer,
279
- MAX_LEN
280
- )
281
-
282
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
283
- metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
284
- summary_metrics = summarize_metrics(metrics).to_dict()
285
-
286
- all_probs = predict_probabilities(model, dataloader)
287
-
288
- predictions = []
289
- for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
290
- field = LABEL_COLUMNS[i]
291
- label_encoder = label_encoders[field]
292
- true_labels_orig = label_encoder.inverse_transform(true_labels)
293
- pred_labels_orig = label_encoder.inverse_transform(pred_labels)
294
-
295
- for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
296
- predictions.append({
297
- "field": field,
298
- "true_label": true,
299
- "predicted_label": pred,
300
- "probabilities": probs.tolist()
301
- })
302
-
303
- return ValidationResponse(
304
- message="Validation completed successfully",
305
- metrics=summary_metrics,
306
- predictions=predictions
307
- )
308
-
309
- except Exception as e:
310
- logger.error(f"Validation failed: {str(e)}")
311
- raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
312
- finally:
313
- if os.path.exists(file_path):
314
- os.remove(file_path)
315
-
316
- @app.post("/bert/predict")
317
- async def predict(request: PredictionRequest):
318
- """Make predictions on a single transaction"""
319
- try:
320
- input_data = pd.DataFrame([request.transaction_data.dict()])
321
-
322
- text_input = f"""
323
- Transaction ID: {input_data['Transaction_Id'].iloc[0]}
324
- Origin: {input_data['Origin'].iloc[0]}
325
- Designation: {input_data['Designation'].iloc[0]}
326
- Keywords: {input_data['Keywords'].iloc[0]}
327
- Name: {input_data['Name'].iloc[0]}
328
- SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]}
329
- Currency: {input_data['Currency'].iloc[0]}
330
- Entity: {input_data['Entity'].iloc[0]}
331
- Message: {input_data['Message'].iloc[0]}
332
- City: {input_data['City'].iloc[0]}
333
- Country: {input_data['Country'].iloc[0]}
334
- State: {input_data['State'].iloc[0]}
335
- Hit Type: {input_data['Hit_Type'].iloc[0]}
336
- Record Matching String: {input_data['Record_Matching_String'].iloc[0]}
337
- WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]}
338
- Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]}
339
- Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]}
340
- Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]}
341
- Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]}
342
- Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]}
343
- Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]}
344
- Risk Level: {input_data['Risk_Level'].iloc[0]}
345
- Risk Score: {input_data['Risk_Score'].iloc[0]}
346
- CDD Level: {input_data['CDD_Level'].iloc[0]}
347
- PEP Status: {input_data['PEP_Status'].iloc[0]}
348
- Sanction Description: {input_data['Sanction_Description'].iloc[0]}
349
- Checker Notes: {input_data['Checker_Notes'].iloc[0]}
350
- Sanction Context: {input_data['Sanction_Context'].iloc[0]}
351
- Maker Action: {input_data['Maker_Action'].iloc[0]}
352
- Customer Type: {input_data['Customer_Type'].iloc[0]}
353
- Industry: {input_data['Industry'].iloc[0]}
354
- Transaction Type: {input_data['Transaction_Type'].iloc[0]}
355
- Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
356
- Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
357
- Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
358
- Risk Category: {input_data['Risk_Category'].iloc[0]}
359
- Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
360
- Alert Status: {input_data['Alert_Status'].iloc[0]}
361
- Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
362
- Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
363
- Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
364
- Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
365
- """
366
-
367
- dataset = ComplianceDataset(
368
- texts=[text_input],
369
- labels=[[0] * len(LABEL_COLUMNS)],
370
- tokenizer=tokenizer,
371
- max_len=MAX_LEN
372
- )
373
-
374
- loader = DataLoader(dataset, batch_size=1, shuffle=False)
375
- all_probabilities = predict_probabilities(model, loader)
376
-
377
- label_encoders = load_label_encoders()
378
-
379
- response = {}
380
- for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
381
- pred = np.argmax(probs[0])
382
- decoded_pred = label_encoders[col].inverse_transform([pred])[0]
383
-
384
- class_probs = {
385
- label: float(probs[0][j])
386
- for j, label in enumerate(label_encoders[col].classes_)
387
- }
388
-
389
- response[col] = {
390
- "prediction": decoded_pred,
391
- "probabilities": class_probs
392
- }
393
-
394
- return response
395
-
396
- except Exception as e:
397
- raise HTTPException(status_code=500, detail=str(e))
398
-
399
- @app.get("/bert/download-model/{model_id}")
400
- async def download_model(model_id: str):
401
- """Download a trained model"""
402
- model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
403
- if not model_path.exists():
404
- raise HTTPException(status_code=404, detail="Model not found")
405
-
406
- return FileResponse(
407
- path=model_path,
408
- filename=f"bert_model_{model_id}.pth",
409
- media_type="application/octet-stream"
410
- )
411
-
412
- async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
413
- try:
414
- data_df_original, label_encoders = load_and_preprocess_data(file_path)
415
- save_label_encoders(label_encoders)
416
-
417
- train_df, val_df = train_test_split(
418
- data_df_original,
419
- test_size=config.test_size,
420
- random_state=config.random_state,
421
- stratify=data_df_original[LABEL_COLUMNS[0]]
422
- )
423
-
424
- train_texts = train_df[TEXT_COLUMN]
425
- val_texts = val_df[TEXT_COLUMN]
426
- train_labels_array = train_df[LABEL_COLUMNS].values
427
- val_labels_array = val_df[LABEL_COLUMNS].values
428
-
429
- train_metadata_df = train_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in train_df.columns for col in METADATA_COLUMNS) else None
430
- val_metadata_df = val_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in val_df.columns for col in METADATA_COLUMNS) else None
431
-
432
- num_labels_list = get_num_labels(label_encoders)
433
- tokenizer = get_tokenizer(config.model_name)
434
-
435
- if train_metadata_df is not None and val_metadata_df is not None:
436
- metadata_dim = train_metadata_df.shape[1]
437
- train_dataset = ComplianceDatasetWithMetadata(
438
- train_texts.tolist(),
439
- train_metadata_df.values,
440
- train_labels_array,
441
- tokenizer,
442
- config.max_length
443
- )
444
- val_dataset = ComplianceDatasetWithMetadata(
445
- val_texts.tolist(),
446
- val_metadata_df.values,
447
- val_labels_array,
448
- tokenizer,
449
- config.max_length
450
- )
451
- model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
452
- else:
453
- train_dataset = ComplianceDataset(
454
- train_texts.tolist(),
455
- train_labels_array,
456
- tokenizer,
457
- config.max_length
458
- )
459
- val_dataset = ComplianceDataset(
460
- val_texts.tolist(),
461
- val_labels_array,
462
- tokenizer,
463
- config.max_length
464
- )
465
- model = BertMultiOutputModel(num_labels_list).to(DEVICE)
466
-
467
- train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
468
- val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
469
-
470
- criterions = initialize_criterions(num_labels_list)
471
- optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
472
-
473
- best_val_loss = float('inf')
474
- for epoch in range(config.num_epochs):
475
- training_status["current_epoch"] = epoch + 1
476
-
477
- train_loss = train_model(model, train_loader, criterions, optimizer)
478
- val_metrics, _, _ = evaluate_model(model, val_loader)
479
-
480
- training_status["current_loss"] = train_loss
481
-
482
- if val_metrics["loss"] < best_val_loss:
483
- best_val_loss = val_metrics["loss"]
484
- save_model(model, training_id)
485
-
486
- training_status.update({
487
- "is_training": False,
488
- "end_time": datetime.now().isoformat(),
489
- "status": "completed",
490
- "metrics": summarize_metrics(val_metrics).to_dict()
491
- })
492
-
493
- except Exception as e:
494
- logger.error(f"Training failed: {str(e)}")
495
- training_status.update({
496
- "is_training": False,
497
- "end_time": datetime.now().isoformat(),
498
- "status": "failed",
499
- "error": str(e)
500
- })
501
-
502
- if __name__ == "__main__":
503
- port = int(os.environ.get("PORT", 7860))
504
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import torch
7
+ from transformers import BertTokenizer, BertForSequenceClassification
8
+ from torch.utils.data import DataLoader
9
+ import logging
10
+ import os
11
+ import asyncio
12
+ import pandas as pd
13
+ from datetime import datetime
14
+ import shutil
15
+ from pathlib import Path
16
+ from sklearn.model_selection import train_test_split
17
+ import zipfile
18
+ import io
19
+ import numpy as np
20
+ import sys
21
+
22
+
23
+ # Import existing utilities
24
+ from dataset_utils import (
25
+ ComplianceDataset,
26
+ ComplianceDatasetWithMetadata,
27
+ load_and_preprocess_data,
28
+ get_tokenizer,
29
+ save_label_encoders,
30
+ get_num_labels,
31
+ load_label_encoders
32
+ )
33
+ from train_utils import (
34
+ initialize_criterions,
35
+ train_model,
36
+ evaluate_model,
37
+ save_model,
38
+ summarize_metrics,
39
+ predict_probabilities
40
+ )
41
+ from models.bert_model import BertMultiOutputModel
42
+ from config import (
43
+ TEXT_COLUMN,
44
+ LABEL_COLUMNS,
45
+ DEVICE,
46
+ NUM_EPOCHS,
47
+ LEARNING_RATE,
48
+ MAX_LEN,
49
+ BATCH_SIZE,
50
+ METADATA_COLUMNS
51
+ )
52
+
53
+ # Configure logging
54
+ logging.basicConfig(level=logging.INFO)
55
+ logger = logging.getLogger(__name__)
56
+
57
+ app = FastAPI(title="BERT Compliance Predictor API")
58
+
59
+ # Create necessary directories
60
+ UPLOAD_DIR = Path("uploads")
61
+ MODEL_SAVE_DIR = Path("saved_models")
62
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
63
+ MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
64
+
65
+ # Global variables to track training status
66
+ training_status = {
67
+ "is_training": False,
68
+ "current_epoch": 0,
69
+ "total_epochs": 0,
70
+ "current_loss": 0.0,
71
+ "start_time": None,
72
+ "end_time": None,
73
+ "status": "idle",
74
+ "metrics": None
75
+ }
76
+
77
+ # Load the model and tokenizer for prediction
78
+ model_path = "BERT_model.pth"
79
+ tokenizer = get_tokenizer('bert-base-uncased')
80
+ model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
81
+ if os.path.exists(model_path):
82
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
83
+ model.eval()
84
+
85
+ class TrainingConfig(BaseModel):
86
+ model_name: str = "bert-base-uncased"
87
+ batch_size: int = 8
88
+ learning_rate: float = 2e-5
89
+ num_epochs: int = 2
90
+ max_length: int = 128
91
+ test_size: float = 0.2
92
+ random_state: int = 42
93
+
94
+ class TrainingResponse(BaseModel):
95
+ message: str
96
+ training_id: str
97
+ status: str
98
+ download_url: Optional[str] = None
99
+
100
+ class ValidationResponse(BaseModel):
101
+ message: str
102
+ metrics: Dict[str, Any]
103
+ predictions: List[Dict[str, Any]]
104
+
105
+ class TransactionData(BaseModel):
106
+ Transaction_Id: str
107
+ Hit_Seq: int
108
+ Hit_Id_List: str
109
+ Origin: str
110
+ Designation: str
111
+ Keywords: str
112
+ Name: str
113
+ SWIFT_Tag: str
114
+ Currency: str
115
+ Entity: str
116
+ Message: str
117
+ City: str
118
+ Country: str
119
+ State: str
120
+ Hit_Type: str
121
+ Record_Matching_String: str
122
+ WatchList_Match_String: str
123
+ Payment_Sender_Name: Optional[str] = ""
124
+ Payment_Reciever_Name: Optional[str] = ""
125
+ Swift_Message_Type: str
126
+ Text_Sanction_Data: str
127
+ Matched_Sanctioned_Entity: str
128
+ Is_Match: int
129
+ Red_Flag_Reason: str
130
+ Risk_Level: str
131
+ Risk_Score: float
132
+ Risk_Score_Description: str
133
+ CDD_Level: str
134
+ PEP_Status: str
135
+ Value_Date: str
136
+ Last_Review_Date: str
137
+ Next_Review_Date: str
138
+ Sanction_Description: str
139
+ Checker_Notes: str
140
+ Sanction_Context: str
141
+ Maker_Action: str
142
+ Customer_ID: int
143
+ Customer_Type: str
144
+ Industry: str
145
+ Transaction_Date_Time: str
146
+ Transaction_Type: str
147
+ Transaction_Channel: str
148
+ Originating_Bank: str
149
+ Beneficiary_Bank: str
150
+ Geographic_Origin: str
151
+ Geographic_Destination: str
152
+ Match_Score: float
153
+ Match_Type: str
154
+ Sanctions_List_Version: str
155
+ Screening_Date_Time: str
156
+ Risk_Category: str
157
+ Risk_Drivers: str
158
+ Alert_Status: str
159
+ Investigation_Outcome: str
160
+ Case_Owner_Analyst: str
161
+ Escalation_Level: str
162
+ Escalation_Date: str
163
+ Regulatory_Reporting_Flags: bool
164
+ Audit_Trail_Timestamp: str
165
+ Source_Of_Funds: str
166
+ Purpose_Of_Transaction: str
167
+ Beneficial_Owner: str
168
+ Sanctions_Exposure_History: bool
169
+
170
+ class PredictionRequest(BaseModel):
171
+ transaction_data: TransactionData
172
+ model_name: str = "BERT_model" # Default to BERT_model if not specified
173
+
174
+ class BatchPredictionResponse(BaseModel):
175
+ message: str
176
+ predictions: List[Dict[str, Any]]
177
+ metrics: Optional[Dict[str, Any]] = None
178
+
179
+ @app.get("/")
180
+ async def root():
181
+ return {"message": "BERT Compliance Predictor API"}
182
+
183
+ @app.get("/health")
184
+ async def health_check():
185
+ return {"status": "healthy"}
186
+
187
+ @app.get("/training-status")
188
+ async def get_training_status():
189
+ return training_status
190
+
191
+ @app.post("/upload")
192
+ async def upload_file(file: UploadFile = File(...)):
193
+ """Upload a CSV file for training or validation"""
194
+ if not file.filename.endswith('.csv'):
195
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
196
+
197
+ file_path = UPLOAD_DIR / file.filename
198
+ with file_path.open("wb") as buffer:
199
+ shutil.copyfileobj(file.file, buffer)
200
+
201
+ return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
202
+
203
+ @app.post("/bert/train", response_model=TrainingResponse)
204
+ async def start_training(
205
+ config: TrainingConfig,
206
+ background_tasks: BackgroundTasks,
207
+ file_path: str
208
+ ):
209
+ if training_status["is_training"]:
210
+ raise HTTPException(status_code=400, detail="Training is already in progress")
211
+
212
+ if not os.path.exists(file_path):
213
+ raise HTTPException(status_code=404, detail="Training file not found")
214
+
215
+ training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
216
+
217
+ training_status.update({
218
+ "is_training": True,
219
+ "current_epoch": 0,
220
+ "total_epochs": config.num_epochs,
221
+ "start_time": datetime.now().isoformat(),
222
+ "status": "starting"
223
+ })
224
+
225
+ background_tasks.add_task(train_model_task, config, file_path, training_id)
226
+
227
+ download_url = f"/bert/download-model/{training_id}"
228
+
229
+ return TrainingResponse(
230
+ message="Training started successfully",
231
+ training_id=training_id,
232
+ status="started",
233
+ download_url=download_url
234
+ )
235
+
236
+ @app.post("/bert/validate")
237
+ async def validate_model(
238
+ file: UploadFile = File(...),
239
+ model_name: str = "BERT_model"
240
+ ):
241
+ """Validate a BERT model on uploaded data"""
242
+ if not file.filename.endswith('.csv'):
243
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
244
+
245
+ try:
246
+ file_path = UPLOAD_DIR / file.filename
247
+ with file_path.open("wb") as buffer:
248
+ shutil.copyfileobj(file.file, buffer)
249
+
250
+ data_df, label_encoders = load_and_preprocess_data(str(file_path))
251
+
252
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
253
+ if not model_path.exists():
254
+ raise HTTPException(status_code=404, detail="BERT model file not found")
255
+
256
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
257
+ metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
258
+
259
+ if metadata_df is not None:
260
+ metadata_dim = metadata_df.shape[1]
261
+ model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
262
+ else:
263
+ model = BertMultiOutputModel(num_labels_list).to(DEVICE)
264
+
265
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
266
+ model.eval()
267
+
268
+ texts = data_df[TEXT_COLUMN]
269
+ labels_array = data_df[LABEL_COLUMNS].values
270
+ tokenizer = get_tokenizer("bert-base-uncased")
271
+
272
+ if metadata_df is not None:
273
+ dataset = ComplianceDatasetWithMetadata(
274
+ texts.tolist(),
275
+ metadata_df.values,
276
+ labels_array,
277
+ tokenizer,
278
+ MAX_LEN
279
+ )
280
+ else:
281
+ dataset = ComplianceDataset(
282
+ texts.tolist(),
283
+ labels_array,
284
+ tokenizer,
285
+ MAX_LEN
286
+ )
287
+
288
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
289
+ metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
290
+ summary_metrics = summarize_metrics(metrics).to_dict()
291
+
292
+ all_probs = predict_probabilities(model, dataloader)
293
+
294
+ predictions = []
295
+ for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
296
+ field = LABEL_COLUMNS[i]
297
+ label_encoder = label_encoders[field]
298
+ true_labels_orig = label_encoder.inverse_transform(true_labels)
299
+ pred_labels_orig = label_encoder.inverse_transform(pred_labels)
300
+
301
+ for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
302
+ predictions.append({
303
+ "field": field,
304
+ "true_label": true,
305
+ "predicted_label": pred,
306
+ "probabilities": probs.tolist()
307
+ })
308
+
309
+ return ValidationResponse(
310
+ message="Validation completed successfully",
311
+ metrics=summary_metrics,
312
+ predictions=predictions
313
+ )
314
+
315
+ except Exception as e:
316
+ logger.error(f"Validation failed: {str(e)}")
317
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
318
+ finally:
319
+ if os.path.exists(file_path):
320
+ os.remove(file_path)
321
+
322
+ @app.post("/bert/predict")
323
+ async def predict(
324
+ request: Optional[PredictionRequest] = None,
325
+ file: Optional[UploadFile] = File(None),
326
+ model_name: str = "BERT_model"
327
+ ):
328
+ """
329
+ Make predictions on either a single transaction or a batch of transactions from a CSV file.
330
+
331
+ You can either:
332
+ 1. Send a single transaction in the request body
333
+ 2. Upload a CSV file with multiple transactions
334
+
335
+ Parameters:
336
+ - model_name: Name of the model to use for prediction (default: "BERT_model")
337
+ """
338
+ try:
339
+ # Load the model
340
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
341
+ if not model_path.exists():
342
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
343
+
344
+ num_labels_list = [len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]
345
+ model = BertMultiOutputModel(num_labels_list).to(DEVICE)
346
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
347
+ model.eval()
348
+
349
+ # Handle batch prediction from CSV
350
+ if file is not None and file.filename:
351
+ if not file.filename.endswith('.csv'):
352
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
353
+
354
+ file_path = UPLOAD_DIR / file.filename
355
+ with file_path.open("wb") as buffer:
356
+ shutil.copyfileobj(file.file, buffer)
357
+
358
+ try:
359
+ # Load and preprocess the CSV data
360
+ data_df, _ = load_and_preprocess_data(str(file_path))
361
+ texts = data_df[TEXT_COLUMN]
362
+
363
+ # Create dataset and dataloader
364
+ dataset = ComplianceDataset(
365
+ texts.tolist(),
366
+ [[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction
367
+ tokenizer,
368
+ MAX_LEN
369
+ )
370
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE)
371
+
372
+ # Get predictions
373
+ all_probabilities = predict_probabilities(model, loader)
374
+ label_encoders = load_label_encoders()
375
+
376
+ # Process predictions
377
+ predictions = []
378
+ for i, row in data_df.iterrows():
379
+ transaction_pred = {}
380
+ for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
381
+ pred = np.argmax(probs[i])
382
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
383
+
384
+ class_probs = {
385
+ label: float(probs[i][j])
386
+ for j, label in enumerate(label_encoders[col].classes_)
387
+ }
388
+
389
+ transaction_pred[col] = {
390
+ "prediction": decoded_pred,
391
+ "probabilities": class_probs
392
+ }
393
+
394
+ predictions.append({
395
+ "transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
396
+ "predictions": transaction_pred
397
+ })
398
+
399
+ return BatchPredictionResponse(
400
+ message="Batch prediction completed successfully",
401
+ predictions=predictions
402
+ )
403
+
404
+ finally:
405
+ if os.path.exists(file_path):
406
+ os.remove(file_path)
407
+
408
+ # Handle single prediction
409
+ elif request is not None and request.transaction_data:
410
+ input_data = pd.DataFrame([request.transaction_data.dict()])
411
+
412
+ text_input = f"""
413
+ Transaction ID: {input_data['Transaction_Id'].iloc[0]}
414
+ Origin: {input_data['Origin'].iloc[0]}
415
+ Designation: {input_data['Designation'].iloc[0]}
416
+ Keywords: {input_data['Keywords'].iloc[0]}
417
+ Name: {input_data['Name'].iloc[0]}
418
+ SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]}
419
+ Currency: {input_data['Currency'].iloc[0]}
420
+ Entity: {input_data['Entity'].iloc[0]}
421
+ Message: {input_data['Message'].iloc[0]}
422
+ City: {input_data['City'].iloc[0]}
423
+ Country: {input_data['Country'].iloc[0]}
424
+ State: {input_data['State'].iloc[0]}
425
+ Hit Type: {input_data['Hit_Type'].iloc[0]}
426
+ Record Matching String: {input_data['Record_Matching_String'].iloc[0]}
427
+ WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]}
428
+ Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]}
429
+ Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]}
430
+ Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]}
431
+ Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]}
432
+ Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]}
433
+ Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]}
434
+ Risk Level: {input_data['Risk_Level'].iloc[0]}
435
+ Risk Score: {input_data['Risk_Score'].iloc[0]}
436
+ CDD Level: {input_data['CDD_Level'].iloc[0]}
437
+ PEP Status: {input_data['PEP_Status'].iloc[0]}
438
+ Sanction Description: {input_data['Sanction_Description'].iloc[0]}
439
+ Checker Notes: {input_data['Checker_Notes'].iloc[0]}
440
+ Sanction Context: {input_data['Sanction_Context'].iloc[0]}
441
+ Maker Action: {input_data['Maker_Action'].iloc[0]}
442
+ Customer Type: {input_data['Customer_Type'].iloc[0]}
443
+ Industry: {input_data['Industry'].iloc[0]}
444
+ Transaction Type: {input_data['Transaction_Type'].iloc[0]}
445
+ Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
446
+ Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
447
+ Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
448
+ Risk Category: {input_data['Risk_Category'].iloc[0]}
449
+ Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
450
+ Alert Status: {input_data['Alert_Status'].iloc[0]}
451
+ Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
452
+ Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
453
+ Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
454
+ Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
455
+ """
456
+
457
+ dataset = ComplianceDataset(
458
+ texts=[text_input],
459
+ labels=[[0] * len(LABEL_COLUMNS)],
460
+ tokenizer=tokenizer,
461
+ max_len=MAX_LEN
462
+ )
463
+
464
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
465
+ all_probabilities = predict_probabilities(model, loader)
466
+
467
+ label_encoders = load_label_encoders()
468
+
469
+ response = {}
470
+ for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
471
+ pred = np.argmax(probs[0])
472
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
473
+
474
+ class_probs = {
475
+ label: float(probs[0][j])
476
+ for j, label in enumerate(label_encoders[col].classes_)
477
+ }
478
+
479
+ response[col] = {
480
+ "prediction": decoded_pred,
481
+ "probabilities": class_probs
482
+ }
483
+
484
+ return response
485
+
486
+ else:
487
+ raise HTTPException(
488
+ status_code=400,
489
+ detail="Either provide a transaction in the request body or upload a CSV file"
490
+ )
491
+
492
+ except Exception as e:
493
+ raise HTTPException(status_code=500, detail=str(e))
494
+
495
+ @app.get("/bert/download-model/{model_id}")
496
+ async def download_model(model_id: str):
497
+ """Download a trained model"""
498
+ model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
499
+ if not model_path.exists():
500
+ raise HTTPException(status_code=404, detail="Model not found")
501
+
502
+ return FileResponse(
503
+ path=model_path,
504
+ filename=f"bert_model_{model_id}.pth",
505
+ media_type="application/octet-stream"
506
+ )
507
+
508
+ async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
509
+ try:
510
+ data_df_original, label_encoders = load_and_preprocess_data(file_path)
511
+ save_label_encoders(label_encoders)
512
+
513
+ train_df, val_df = train_test_split(
514
+ data_df_original,
515
+ test_size=config.test_size,
516
+ random_state=config.random_state,
517
+ stratify=data_df_original[LABEL_COLUMNS[0]]
518
+ )
519
+
520
+ train_texts = train_df[TEXT_COLUMN]
521
+ val_texts = val_df[TEXT_COLUMN]
522
+ train_labels_array = train_df[LABEL_COLUMNS].values
523
+ val_labels_array = val_df[LABEL_COLUMNS].values
524
+
525
+ train_metadata_df = train_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in train_df.columns for col in METADATA_COLUMNS) else None
526
+ val_metadata_df = val_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in val_df.columns for col in METADATA_COLUMNS) else None
527
+
528
+ num_labels_list = get_num_labels(label_encoders)
529
+ tokenizer = get_tokenizer(config.model_name)
530
+
531
+ if train_metadata_df is not None and val_metadata_df is not None:
532
+ metadata_dim = train_metadata_df.shape[1]
533
+ train_dataset = ComplianceDatasetWithMetadata(
534
+ train_texts.tolist(),
535
+ train_metadata_df.values,
536
+ train_labels_array,
537
+ tokenizer,
538
+ config.max_length
539
+ )
540
+ val_dataset = ComplianceDatasetWithMetadata(
541
+ val_texts.tolist(),
542
+ val_metadata_df.values,
543
+ val_labels_array,
544
+ tokenizer,
545
+ config.max_length
546
+ )
547
+ model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
548
+ else:
549
+ train_dataset = ComplianceDataset(
550
+ train_texts.tolist(),
551
+ train_labels_array,
552
+ tokenizer,
553
+ config.max_length
554
+ )
555
+ val_dataset = ComplianceDataset(
556
+ val_texts.tolist(),
557
+ val_labels_array,
558
+ tokenizer,
559
+ config.max_length
560
+ )
561
+ model = BertMultiOutputModel(num_labels_list).to(DEVICE)
562
+
563
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
564
+ val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
565
+
566
+ criterions = initialize_criterions(num_labels_list)
567
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
568
+
569
+ best_val_loss = float('inf')
570
+ for epoch in range(config.num_epochs):
571
+ training_status["current_epoch"] = epoch + 1
572
+
573
+ train_loss = train_model(model, train_loader, criterions, optimizer)
574
+ val_metrics, _, _ = evaluate_model(model, val_loader)
575
+
576
+ training_status["current_loss"] = train_loss
577
+
578
+ if val_metrics["loss"] < best_val_loss:
579
+ best_val_loss = val_metrics["loss"]
580
+ save_model(model, training_id)
581
+
582
+ training_status.update({
583
+ "is_training": False,
584
+ "end_time": datetime.now().isoformat(),
585
+ "status": "completed",
586
+ "metrics": summarize_metrics(val_metrics).to_dict()
587
+ })
588
+
589
+ except Exception as e:
590
+ logger.error(f"Training failed: {str(e)}")
591
+ training_status.update({
592
+ "is_training": False,
593
+ "end_time": datetime.now().isoformat(),
594
+ "status": "failed",
595
+ "error": str(e)
596
+ })
597
+
598
+ if __name__ == "__main__":
599
+ port = int(os.environ.get("PORT", 7860))
600
+ uvicorn.run(app, host="0.0.0.0", port=port)