namanpenguin commited on
Commit
39918b3
·
verified ·
1 Parent(s): e7e8222

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +570 -530
app.py CHANGED
@@ -1,530 +1,570 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
- from fastapi.responses import FileResponse
3
- from pydantic import BaseModel
4
- from typing import Optional, Dict, Any, List
5
- import uvicorn
6
- import torch
7
- from torch.utils.data import DataLoader
8
- import logging
9
- import os
10
- import asyncio
11
- import pandas as pd
12
- from datetime import datetime
13
- import shutil
14
- from pathlib import Path
15
- from sklearn.model_selection import train_test_split
16
- import zipfile
17
- import io
18
- import numpy as np
19
- import sys
20
- import json
21
-
22
-
23
- # Import existing utilities
24
- from dataset_utils import (
25
- ComplianceDataset,
26
- ComplianceDatasetWithMetadata,
27
- load_and_preprocess_data,
28
- get_tokenizer,
29
- save_label_encoders,
30
- get_num_labels,
31
- load_label_encoders
32
- )
33
- from train_utils import (
34
- initialize_criterions,
35
- train_model,
36
- evaluate_model,
37
- save_model,
38
- summarize_metrics,
39
- predict_probabilities
40
- )
41
- from models.deberta_model import DebertaMultiOutputModel
42
- from config import (
43
- TEXT_COLUMN,
44
- LABEL_COLUMNS,
45
- DEVICE,
46
- NUM_EPOCHS,
47
- LEARNING_RATE,
48
- MAX_LEN,
49
- BATCH_SIZE,
50
- METADATA_COLUMNS,
51
- DEBERTA_MODEL_NAME
52
- )
53
-
54
- # Configure logging
55
- logging.basicConfig(level=logging.INFO)
56
- logger = logging.getLogger(__name__)
57
-
58
- app = FastAPI(title="DeBERTa Compliance Predictor API")
59
-
60
- # Create necessary directories
61
- UPLOAD_DIR = Path("uploads")
62
- MODEL_SAVE_DIR = Path("saved_models")
63
- UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
64
- MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
65
-
66
- # Global variables to track training status
67
- training_status = {
68
- "is_training": False,
69
- "current_epoch": 0,
70
- "total_epochs": 0,
71
- "current_loss": 0.0,
72
- "start_time": None,
73
- "end_time": None,
74
- "status": "idle",
75
- "metrics": None
76
- }
77
-
78
- # Load the model and tokenizer for prediction
79
- model_path = "DEBERTA_model.pth"
80
- tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
81
- model = DebertaMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
82
- if os.path.exists(model_path):
83
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
84
- model.eval()
85
-
86
- class TrainingConfig(BaseModel):
87
- model_name: str = DEBERTA_MODEL_NAME
88
- batch_size: int = 8
89
- learning_rate: float = 2e-5
90
- num_epochs: int = 2
91
- max_length: int = 128
92
- random_state: int = 42
93
-
94
- class TrainingResponse(BaseModel):
95
- message: str
96
- training_id: str
97
- status: str
98
- download_url: Optional[str] = None
99
-
100
- class ValidationResponse(BaseModel):
101
- message: str
102
- metrics: Dict[str, Any]
103
- predictions: List[Dict[str, Any]]
104
-
105
- class TransactionData(BaseModel):
106
- Transaction_Id: str
107
- Hit_Seq: int
108
- Hit_Id_List: str
109
- Origin: str
110
- Designation: str
111
- Keywords: str
112
- Name: str
113
- SWIFT_Tag: str
114
- Currency: str
115
- Entity: str
116
- Message: str
117
- City: str
118
- Country: str
119
- State: str
120
- Hit_Type: str
121
- Record_Matching_String: str
122
- WatchList_Match_String: str
123
- Payment_Sender_Name: Optional[str] = ""
124
- Payment_Reciever_Name: Optional[str] = ""
125
- Swift_Message_Type: str
126
- Text_Sanction_Data: str
127
- Matched_Sanctioned_Entity: str
128
- Is_Match: int
129
- Red_Flag_Reason: str
130
- Risk_Level: str
131
- Risk_Score: float
132
- Risk_Score_Description: str
133
- CDD_Level: str
134
- PEP_Status: str
135
- Value_Date: str
136
- Last_Review_Date: str
137
- Next_Review_Date: str
138
- Sanction_Description: str
139
- Checker_Notes: str
140
- Sanction_Context: str
141
- Maker_Action: str
142
- Customer_ID: int
143
- Customer_Type: str
144
- Industry: str
145
- Transaction_Date_Time: str
146
- Transaction_Type: str
147
- Transaction_Channel: str
148
- Originating_Bank: str
149
- Beneficiary_Bank: str
150
- Geographic_Origin: str
151
- Geographic_Destination: str
152
- Match_Score: float
153
- Match_Type: str
154
- Sanctions_List_Version: str
155
- Screening_Date_Time: str
156
- Risk_Category: str
157
- Risk_Drivers: str
158
- Alert_Status: str
159
- Investigation_Outcome: str
160
- Case_Owner_Analyst: str
161
- Escalation_Level: str
162
- Escalation_Date: str
163
- Regulatory_Reporting_Flags: bool
164
- Audit_Trail_Timestamp: str
165
- Source_Of_Funds: str
166
- Purpose_Of_Transaction: str
167
- Beneficial_Owner: str
168
- Sanctions_Exposure_History: bool
169
-
170
- class PredictionRequest(BaseModel):
171
- transaction_data: TransactionData
172
- model_name: str = "DEBERTA_model" # Default to DeBERTa_model if not specified
173
-
174
- class BatchPredictionResponse(BaseModel):
175
- message: str
176
- predictions: List[Dict[str, Any]]
177
- metrics: Optional[Dict[str, Any]] = None
178
-
179
- @app.get("/")
180
- async def root():
181
- return {"message": "DeBERTa Compliance Predictor API"}
182
-
183
- @app.get("/v1/deberta/health")
184
- async def health_check():
185
- return {"status": "healthy"}
186
-
187
- @app.get("/v1/deberta/training-status")
188
- async def get_training_status():
189
- return training_status
190
-
191
- @app.post("/v1/deberta/train", response_model=TrainingResponse)
192
- async def start_training(
193
- config: str = Form(...),
194
- background_tasks: BackgroundTasks = None,
195
- file: UploadFile = File(...)
196
- ):
197
- if training_status["is_training"]:
198
- raise HTTPException(status_code=400, detail="Training is already in progress")
199
-
200
- if not file.filename.endswith('.csv'):
201
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
202
-
203
- try:
204
- # Parse the config JSON string into a TrainingConfig object
205
- config_dict = json.loads(config)
206
- training_config = TrainingConfig(**config_dict)
207
- except json.JSONDecodeError:
208
- raise HTTPException(status_code=400, detail="Invalid config JSON format")
209
- except Exception as e:
210
- raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
211
-
212
- file_path = UPLOAD_DIR / file.filename
213
- with file_path.open("wb") as buffer:
214
- shutil.copyfileobj(file.file, buffer)
215
-
216
- training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
217
-
218
- training_status.update({
219
- "is_training": True,
220
- "current_epoch": 0,
221
- "total_epochs": training_config.num_epochs,
222
- "start_time": datetime.now().isoformat(),
223
- "status": "starting"
224
- })
225
-
226
- background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
227
-
228
- download_url = f"/v1/deberta/download-model/{training_id}"
229
-
230
- return TrainingResponse(
231
- message="Training started successfully",
232
- training_id=training_id,
233
- status="started",
234
- download_url=download_url
235
- )
236
-
237
- @app.post("/v1/deberta/validate")
238
- async def validate_model(
239
- file: UploadFile = File(...),
240
- model_name: str = "DEBERTA_model"
241
- ):
242
- """Validate a DeBERTa model on uploaded data"""
243
- if not file.filename.endswith('.csv'):
244
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
245
-
246
- try:
247
- file_path = UPLOAD_DIR / file.filename
248
- with file_path.open("wb") as buffer:
249
- shutil.copyfileobj(file.file, buffer)
250
-
251
- data_df, label_encoders = load_and_preprocess_data(str(file_path))
252
-
253
- model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
254
- if not model_path.exists():
255
- raise HTTPException(status_code=404, detail="DeBERTa model file not found")
256
-
257
- num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
258
- metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
259
-
260
- if metadata_df is not None:
261
- metadata_dim = metadata_df.shape[1]
262
- model = DebertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
263
- else:
264
- model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
265
-
266
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
267
- model.eval()
268
-
269
- texts = data_df[TEXT_COLUMN]
270
- labels_array = data_df[LABEL_COLUMNS].values
271
- tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
272
-
273
- if metadata_df is not None:
274
- dataset = ComplianceDatasetWithMetadata(
275
- texts.tolist(),
276
- metadata_df.values,
277
- labels_array,
278
- tokenizer,
279
- MAX_LEN
280
- )
281
- else:
282
- dataset = ComplianceDataset(
283
- texts.tolist(),
284
- labels_array,
285
- tokenizer,
286
- MAX_LEN
287
- )
288
-
289
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
290
- metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
291
- summary_metrics = summarize_metrics(metrics).to_dict()
292
-
293
- all_probs = predict_probabilities(model, dataloader)
294
-
295
- predictions = []
296
- for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
297
- field = LABEL_COLUMNS[i]
298
- label_encoder = label_encoders[field]
299
- true_labels_orig = label_encoder.inverse_transform(true_labels)
300
- pred_labels_orig = label_encoder.inverse_transform(pred_labels)
301
-
302
- for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
303
- predictions.append({
304
- "field": field,
305
- "true_label": true,
306
- "predicted_label": pred,
307
- "probabilities": probs.tolist()
308
- })
309
-
310
- return ValidationResponse(
311
- message="Validation completed successfully",
312
- metrics=summary_metrics,
313
- predictions=predictions
314
- )
315
-
316
- except Exception as e:
317
- logger.error(f"Validation failed: {str(e)}")
318
- raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
319
- finally:
320
- if os.path.exists(file_path):
321
- os.remove(file_path)
322
-
323
- @app.post("/v1/deberta/predict")
324
- async def predict(
325
- request: Optional[PredictionRequest] = None,
326
- file: UploadFile = File(None),
327
- model_name: str = "DEBERTA_model"
328
- ):
329
- """
330
- Make predictions on either a single transaction or a batch of transactions from a CSV file.
331
-
332
- You can either:
333
- 1. Send a single transaction in the request body
334
- 2. Upload a CSV file with multiple transactions
335
-
336
- Parameters:
337
- - file: CSV file containing transactions for batch prediction
338
- - model_name: Name of the model to use for prediction (default: "DEBERTA_model")
339
- """
340
- try:
341
- # Load the model
342
- model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
343
- if not model_path.exists():
344
- raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
345
-
346
- num_labels_list = [len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]
347
- model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
348
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
349
- model.eval()
350
-
351
- # Handle batch prediction from CSV
352
- if file and file.filename:
353
- if not file.filename.endswith('.csv'):
354
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
355
-
356
- file_path = UPLOAD_DIR / file.filename
357
- with file_path.open("wb") as buffer:
358
- shutil.copyfileobj(file.file, buffer)
359
-
360
- try:
361
- # Load and preprocess the CSV data
362
- data_df, _ = load_and_preprocess_data(str(file_path))
363
- texts = data_df[TEXT_COLUMN]
364
-
365
- # Create dataset and dataloader
366
- dataset = ComplianceDataset(
367
- texts.tolist(),
368
- [[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction
369
- tokenizer,
370
- MAX_LEN
371
- )
372
- loader = DataLoader(dataset, batch_size=BATCH_SIZE)
373
-
374
- # Get predictions
375
- all_probabilities = predict_probabilities(model, loader)
376
- label_encoders = load_label_encoders()
377
-
378
- # Process predictions
379
- predictions = []
380
- for i, row in data_df.iterrows():
381
- transaction_pred = {}
382
- for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
383
- pred = np.argmax(probs[i])
384
- decoded_pred = label_encoders[col].inverse_transform([pred])[0]
385
-
386
- class_probs = {
387
- label: float(probs[i][j])
388
- for j, label in enumerate(label_encoders[col].classes_)
389
- }
390
-
391
- transaction_pred[col] = {
392
- "prediction": decoded_pred,
393
- "probabilities": class_probs
394
- }
395
-
396
- predictions.append({
397
- "transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
398
- "predictions": transaction_pred
399
- })
400
-
401
- return BatchPredictionResponse(
402
- message="Batch prediction completed successfully",
403
- predictions=predictions
404
- )
405
-
406
- finally:
407
- if os.path.exists(file_path):
408
- os.remove(file_path)
409
-
410
- # Handle single prediction
411
- elif request and request.transaction_data:
412
- input_data = pd.DataFrame([request.transaction_data.dict()])
413
-
414
- text_input = f"<s>Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]}</s>"
415
-
416
- dataset = ComplianceDataset(
417
- texts=[text_input],
418
- labels=[[0] * len(LABEL_COLUMNS)],
419
- tokenizer=tokenizer,
420
- max_len=MAX_LEN
421
- )
422
-
423
- loader = DataLoader(dataset, batch_size=1, shuffle=False)
424
- all_probabilities = predict_probabilities(model, loader)
425
-
426
- label_encoders = load_label_encoders()
427
-
428
- response = {}
429
- for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
430
- pred = np.argmax(probs[0])
431
- decoded_pred = label_encoders[col].inverse_transform([pred])[0]
432
-
433
- class_probs = {
434
- label: float(probs[0][j])
435
- for j, label in enumerate(label_encoders[col].classes_)
436
- }
437
-
438
- response[col] = {
439
- "prediction": decoded_pred,
440
- "probabilities": class_probs
441
- }
442
-
443
- return response
444
-
445
- else:
446
- raise HTTPException(
447
- status_code=400,
448
- detail="Either provide a transaction in the request body or upload a CSV file"
449
- )
450
-
451
- except Exception as e:
452
- raise HTTPException(status_code=500, detail=str(e))
453
-
454
- @app.get("/v1/deberta/download-model/{model_id}")
455
- async def download_model(model_id: str):
456
- """Download a trained model"""
457
- model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
458
- if not model_path.exists():
459
- raise HTTPException(status_code=404, detail="Model not found")
460
-
461
- return FileResponse(
462
- path=model_path,
463
- filename=f"deberta_model_{model_id}.pth",
464
- media_type="application/octet-stream"
465
- )
466
-
467
- async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
468
- try:
469
- data_df_original, label_encoders = load_and_preprocess_data(file_path)
470
- save_label_encoders(label_encoders)
471
-
472
- texts = data_df_original[TEXT_COLUMN]
473
- labels_array = data_df_original[LABEL_COLUMNS].values
474
-
475
- metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
476
-
477
- num_labels_list = get_num_labels(label_encoders)
478
- tokenizer = get_tokenizer(config.model_name)
479
-
480
- if metadata_df is not None:
481
- metadata_dim = metadata_df.shape[1]
482
- dataset = ComplianceDatasetWithMetadata(
483
- texts.tolist(),
484
- metadata_df.values,
485
- labels_array,
486
- tokenizer,
487
- config.max_length
488
- )
489
- model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
490
- else:
491
- dataset = ComplianceDataset(
492
- texts.tolist(),
493
- labels_array,
494
- tokenizer,
495
- config.max_length
496
- )
497
- model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
498
-
499
- train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
500
-
501
- criterions = initialize_criterions(num_labels_list)
502
- optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
503
-
504
- for epoch in range(config.num_epochs):
505
- training_status["current_epoch"] = epoch + 1
506
-
507
- train_loss = train_model(model, train_loader, criterions, optimizer)
508
- training_status["current_loss"] = train_loss
509
-
510
- # Save model after each epoch
511
- save_model(model, training_id)
512
-
513
- training_status.update({
514
- "is_training": False,
515
- "end_time": datetime.now().isoformat(),
516
- "status": "completed"
517
- })
518
-
519
- except Exception as e:
520
- logger.error(f"Training failed: {str(e)}")
521
- training_status.update({
522
- "is_training": False,
523
- "end_time": datetime.now().isoformat(),
524
- "status": "failed",
525
- "error": str(e)
526
- })
527
-
528
- if __name__ == "__main__":
529
- port = int(os.environ.get("PORT", 7860))
530
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ import logging
9
+ import os
10
+ import asyncio
11
+ import pandas as pd
12
+ from datetime import datetime
13
+ import shutil
14
+ from pathlib import Path
15
+ from sklearn.model_selection import train_test_split
16
+ import zipfile
17
+ import io
18
+ import numpy as np
19
+ import sys
20
+ import json
21
+
22
+
23
+ # Import existing utilities
24
+ from dataset_utils import (
25
+ ComplianceDataset,
26
+ ComplianceDatasetWithMetadata,
27
+ load_and_preprocess_data,
28
+ get_tokenizer,
29
+ save_label_encoders,
30
+ get_num_labels,
31
+ load_label_encoders
32
+ )
33
+ from train_utils import (
34
+ initialize_criterions,
35
+ train_model,
36
+ evaluate_model,
37
+ save_model,
38
+ summarize_metrics,
39
+ predict_probabilities
40
+ )
41
+ from models.deberta_model import DebertaMultiOutputModel
42
+ from config import (
43
+ TEXT_COLUMN,
44
+ LABEL_COLUMNS,
45
+ DEVICE,
46
+ NUM_EPOCHS,
47
+ LEARNING_RATE,
48
+ MAX_LEN,
49
+ BATCH_SIZE,
50
+ METADATA_COLUMNS,
51
+ DEBERTA_MODEL_NAME
52
+ )
53
+
54
+ # Configure logging
55
+ logging.basicConfig(level=logging.INFO)
56
+ logger = logging.getLogger(__name__)
57
+
58
+ # Log environment information
59
+ logger.info("Starting DeBERTa Compliance Predictor API")
60
+ logger.info(f"Device: {DEVICE}")
61
+ logger.info(f"Model path: {MODEL_SAVE_DIR}")
62
+ logger.info(f"Label encoders path: {LABEL_ENCODERS_PATH}")
63
+
64
+ app = FastAPI(title="DeBERTa Compliance Predictor API")
65
+
66
+ # Create necessary directories
67
+ UPLOAD_DIR = Path("uploads")
68
+ MODEL_SAVE_DIR = Path("saved_models")
69
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
70
+ MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
71
+
72
+ # Global variables to track training status
73
+ training_status = {
74
+ "is_training": False,
75
+ "current_epoch": 0,
76
+ "total_epochs": 0,
77
+ "current_loss": 0.0,
78
+ "start_time": None,
79
+ "end_time": None,
80
+ "status": "idle",
81
+ "metrics": None
82
+ }
83
+
84
+ # Load the model and tokenizer for prediction
85
+ model_path = MODEL_SAVE_DIR / "DEBERTA_model.pth"
86
+ tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
87
+
88
+ # Load label encoders with error handling
89
+ try:
90
+ label_encoders = load_label_encoders()
91
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
92
+ except Exception as e:
93
+ logger.warning(f"Could not load label encoders: {str(e)}")
94
+ # Use default values if label encoders can't be loaded
95
+ num_labels_list = [2] * len(LABEL_COLUMNS) # Default to binary classification
96
+
97
+ model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
98
+ if os.path.exists(model_path):
99
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
100
+ model.eval()
101
+
102
+ class TrainingConfig(BaseModel):
103
+ model_name: str = DEBERTA_MODEL_NAME
104
+ batch_size: int = 8
105
+ learning_rate: float = 2e-5
106
+ num_epochs: int = 2
107
+ max_length: int = 128
108
+ random_state: int = 42
109
+
110
+ class TrainingResponse(BaseModel):
111
+ message: str
112
+ training_id: str
113
+ status: str
114
+ download_url: Optional[str] = None
115
+
116
+ class ValidationResponse(BaseModel):
117
+ message: str
118
+ metrics: Dict[str, Any]
119
+ predictions: List[Dict[str, Any]]
120
+
121
+ class TransactionData(BaseModel):
122
+ Transaction_Id: str
123
+ Hit_Seq: int
124
+ Hit_Id_List: str
125
+ Origin: str
126
+ Designation: str
127
+ Keywords: str
128
+ Name: str
129
+ SWIFT_Tag: str
130
+ Currency: str
131
+ Entity: str
132
+ Message: str
133
+ City: str
134
+ Country: str
135
+ State: str
136
+ Hit_Type: str
137
+ Record_Matching_String: str
138
+ WatchList_Match_String: str
139
+ Payment_Sender_Name: Optional[str] = ""
140
+ Payment_Reciever_Name: Optional[str] = ""
141
+ Swift_Message_Type: str
142
+ Text_Sanction_Data: str
143
+ Matched_Sanctioned_Entity: str
144
+ Is_Match: int
145
+ Red_Flag_Reason: str
146
+ Risk_Level: str
147
+ Risk_Score: float
148
+ Risk_Score_Description: str
149
+ CDD_Level: str
150
+ PEP_Status: str
151
+ Value_Date: str
152
+ Last_Review_Date: str
153
+ Next_Review_Date: str
154
+ Sanction_Description: str
155
+ Checker_Notes: str
156
+ Sanction_Context: str
157
+ Maker_Action: str
158
+ Customer_ID: int
159
+ Customer_Type: str
160
+ Industry: str
161
+ Transaction_Date_Time: str
162
+ Transaction_Type: str
163
+ Transaction_Channel: str
164
+ Originating_Bank: str
165
+ Beneficiary_Bank: str
166
+ Geographic_Origin: str
167
+ Geographic_Destination: str
168
+ Match_Score: float
169
+ Match_Type: str
170
+ Sanctions_List_Version: str
171
+ Screening_Date_Time: str
172
+ Risk_Category: str
173
+ Risk_Drivers: str
174
+ Alert_Status: str
175
+ Investigation_Outcome: str
176
+ Case_Owner_Analyst: str
177
+ Escalation_Level: str
178
+ Escalation_Date: str
179
+ Regulatory_Reporting_Flags: bool
180
+ Audit_Trail_Timestamp: str
181
+ Source_Of_Funds: str
182
+ Purpose_Of_Transaction: str
183
+ Beneficial_Owner: str
184
+ Sanctions_Exposure_History: bool
185
+
186
+ class PredictionRequest(BaseModel):
187
+ transaction_data: TransactionData
188
+ model_name: str = "DEBERTA_model" # Default to DeBERTa_model if not specified
189
+
190
+ class BatchPredictionResponse(BaseModel):
191
+ message: str
192
+ predictions: List[Dict[str, Any]]
193
+ metrics: Optional[Dict[str, Any]] = None
194
+
195
+ @app.get("/")
196
+ async def root():
197
+ return {"message": "DeBERTa Compliance Predictor API"}
198
+
199
+ @app.get("/v1/deberta/health")
200
+ async def health_check():
201
+ return {"status": "healthy"}
202
+
203
+ @app.get("/v1/deberta/training-status")
204
+ async def get_training_status():
205
+ return training_status
206
+
207
+ @app.post("/v1/deberta/train", response_model=TrainingResponse)
208
+ async def start_training(
209
+ config: str = Form(...),
210
+ background_tasks: BackgroundTasks = None,
211
+ file: UploadFile = File(...)
212
+ ):
213
+ if training_status["is_training"]:
214
+ raise HTTPException(status_code=400, detail="Training is already in progress")
215
+
216
+ if not file.filename.endswith('.csv'):
217
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
218
+
219
+ try:
220
+ # Parse the config JSON string into a TrainingConfig object
221
+ config_dict = json.loads(config)
222
+ training_config = TrainingConfig(**config_dict)
223
+ except json.JSONDecodeError:
224
+ raise HTTPException(status_code=400, detail="Invalid config JSON format")
225
+ except Exception as e:
226
+ raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
227
+
228
+ file_path = UPLOAD_DIR / file.filename
229
+ with file_path.open("wb") as buffer:
230
+ shutil.copyfileobj(file.file, buffer)
231
+
232
+ training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
233
+
234
+ training_status.update({
235
+ "is_training": True,
236
+ "current_epoch": 0,
237
+ "total_epochs": training_config.num_epochs,
238
+ "start_time": datetime.now().isoformat(),
239
+ "status": "starting"
240
+ })
241
+
242
+ background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
243
+
244
+ download_url = f"/v1/deberta/download-model/{training_id}"
245
+
246
+ return TrainingResponse(
247
+ message="Training started successfully",
248
+ training_id=training_id,
249
+ status="started",
250
+ download_url=download_url
251
+ )
252
+
253
+ @app.post("/v1/deberta/validate")
254
+ async def validate_model(
255
+ file: UploadFile = File(...),
256
+ model_name: str = "DEBERTA_model"
257
+ ):
258
+ """Validate a DeBERTa model on uploaded data"""
259
+ if not file.filename.endswith('.csv'):
260
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
261
+
262
+ try:
263
+ file_path = UPLOAD_DIR / file.filename
264
+ with file_path.open("wb") as buffer:
265
+ shutil.copyfileobj(file.file, buffer)
266
+
267
+ data_df, label_encoders = load_and_preprocess_data(str(file_path))
268
+
269
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
270
+ if not model_path.exists():
271
+ raise HTTPException(status_code=404, detail="DeBERTa model file not found")
272
+
273
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
274
+ metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
275
+
276
+ if metadata_df is not None:
277
+ metadata_dim = metadata_df.shape[1]
278
+ model = DebertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
279
+ else:
280
+ model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
281
+
282
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
283
+ model.eval()
284
+
285
+ texts = data_df[TEXT_COLUMN]
286
+ labels_array = data_df[LABEL_COLUMNS].values
287
+ tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
288
+
289
+ if metadata_df is not None:
290
+ dataset = ComplianceDatasetWithMetadata(
291
+ texts.tolist(),
292
+ metadata_df.values,
293
+ labels_array,
294
+ tokenizer,
295
+ MAX_LEN
296
+ )
297
+ else:
298
+ dataset = ComplianceDataset(
299
+ texts.tolist(),
300
+ labels_array,
301
+ tokenizer,
302
+ MAX_LEN
303
+ )
304
+
305
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
306
+ metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
307
+ summary_metrics = summarize_metrics(metrics).to_dict()
308
+
309
+ all_probs = predict_probabilities(model, dataloader)
310
+
311
+ predictions = []
312
+ for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
313
+ field = LABEL_COLUMNS[i]
314
+ label_encoder = label_encoders[field]
315
+ true_labels_orig = label_encoder.inverse_transform(true_labels)
316
+ pred_labels_orig = label_encoder.inverse_transform(pred_labels)
317
+
318
+ for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
319
+ predictions.append({
320
+ "field": field,
321
+ "true_label": true,
322
+ "predicted_label": pred,
323
+ "probabilities": probs.tolist()
324
+ })
325
+
326
+ return ValidationResponse(
327
+ message="Validation completed successfully",
328
+ metrics=summary_metrics,
329
+ predictions=predictions
330
+ )
331
+
332
+ except Exception as e:
333
+ logger.error(f"Validation failed: {str(e)}")
334
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
335
+ finally:
336
+ if os.path.exists(file_path):
337
+ os.remove(file_path)
338
+
339
+ @app.post("/v1/deberta/predict")
340
+ async def predict(
341
+ request: Optional[PredictionRequest] = None,
342
+ file: UploadFile = File(None),
343
+ model_name: str = "DEBERTA_model"
344
+ ):
345
+ """
346
+ Make predictions on either a single transaction or a batch of transactions from a CSV file.
347
+
348
+ You can either:
349
+ 1. Send a single transaction in the request body
350
+ 2. Upload a CSV file with multiple transactions
351
+
352
+ Parameters:
353
+ - file: CSV file containing transactions for batch prediction
354
+ - model_name: Name of the model to use for prediction (default: "DEBERTA_model")
355
+ """
356
+ try:
357
+ # Load the model
358
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
359
+ if not model_path.exists():
360
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
361
+
362
+ # Load label encoders with error handling
363
+ try:
364
+ label_encoders = load_label_encoders()
365
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
366
+ except Exception as e:
367
+ logger.warning(f"Could not load label encoders: {str(e)}")
368
+ # Use default values if label encoders can't be loaded
369
+ num_labels_list = [2] * len(LABEL_COLUMNS) # Default to binary classification
370
+
371
+ model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
372
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
373
+ model.eval()
374
+
375
+ # Handle batch prediction from CSV
376
+ if file and file.filename:
377
+ if not file.filename.endswith('.csv'):
378
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
379
+
380
+ file_path = UPLOAD_DIR / file.filename
381
+ with file_path.open("wb") as buffer:
382
+ shutil.copyfileobj(file.file, buffer)
383
+
384
+ try:
385
+ # Load and preprocess the CSV data
386
+ data_df, _ = load_and_preprocess_data(str(file_path))
387
+ texts = data_df[TEXT_COLUMN]
388
+
389
+ # Create dataset and dataloader
390
+ dataset = ComplianceDataset(
391
+ texts.tolist(),
392
+ [[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction
393
+ tokenizer,
394
+ MAX_LEN
395
+ )
396
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE)
397
+
398
+ # Get predictions
399
+ all_probabilities = predict_probabilities(model, loader)
400
+
401
+ # Process predictions
402
+ predictions = []
403
+ for i, row in data_df.iterrows():
404
+ transaction_pred = {}
405
+ for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
406
+ pred = np.argmax(probs[i])
407
+
408
+ # Try to decode prediction, fallback to raw index if label encoder fails
409
+ try:
410
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
411
+ except Exception:
412
+ decoded_pred = f"class_{pred}"
413
+
414
+ class_probs = {}
415
+ try:
416
+ for k, label in enumerate(label_encoders[col].classes_):
417
+ class_probs[label] = float(probs[i][k])
418
+ except Exception:
419
+ # Fallback: use generic class names
420
+ for k in range(len(probs[i])):
421
+ class_probs[f"class_{k}"] = float(probs[i][k])
422
+
423
+ transaction_pred[col] = {
424
+ "prediction": decoded_pred,
425
+ "probabilities": class_probs
426
+ }
427
+
428
+ predictions.append({
429
+ "transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
430
+ "predictions": transaction_pred
431
+ })
432
+
433
+ return BatchPredictionResponse(
434
+ message="Batch prediction completed successfully",
435
+ predictions=predictions
436
+ )
437
+
438
+ finally:
439
+ if os.path.exists(file_path):
440
+ os.remove(file_path)
441
+
442
+ # Handle single prediction
443
+ elif request and request.transaction_data:
444
+ input_data = pd.DataFrame([request.transaction_data.dict()])
445
+
446
+ text_input = f"<s>Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]}</s>"
447
+
448
+ dataset = ComplianceDataset(
449
+ texts=[text_input],
450
+ labels=[[0] * len(LABEL_COLUMNS)],
451
+ tokenizer=tokenizer,
452
+ max_len=MAX_LEN
453
+ )
454
+
455
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
456
+ all_probabilities = predict_probabilities(model, loader)
457
+
458
+ response = {}
459
+ for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
460
+ pred = np.argmax(probs[0])
461
+
462
+ # Try to decode prediction, fallback to raw index if label encoder fails
463
+ try:
464
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
465
+ except Exception:
466
+ decoded_pred = f"class_{pred}"
467
+
468
+ class_probs = {}
469
+ try:
470
+ for j, label in enumerate(label_encoders[col].classes_):
471
+ class_probs[label] = float(probs[0][j])
472
+ except Exception:
473
+ # Fallback: use generic class names
474
+ for j in range(len(probs[0])):
475
+ class_probs[f"class_{j}"] = float(probs[0][j])
476
+
477
+ response[col] = {
478
+ "prediction": decoded_pred,
479
+ "probabilities": class_probs
480
+ }
481
+
482
+ return response
483
+
484
+ else:
485
+ raise HTTPException(
486
+ status_code=400,
487
+ detail="Either provide a transaction in the request body or upload a CSV file"
488
+ )
489
+
490
+ except Exception as e:
491
+ logger.error(f"Prediction failed: {str(e)}")
492
+ raise HTTPException(status_code=500, detail=str(e))
493
+
494
+ @app.get("/v1/deberta/download-model/{model_id}")
495
+ async def download_model(model_id: str):
496
+ """Download a trained model"""
497
+ model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
498
+ if not model_path.exists():
499
+ raise HTTPException(status_code=404, detail="Model not found")
500
+
501
+ return FileResponse(
502
+ path=model_path,
503
+ filename=f"deberta_model_{model_id}.pth",
504
+ media_type="application/octet-stream"
505
+ )
506
+
507
+ async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
508
+ try:
509
+ data_df_original, label_encoders = load_and_preprocess_data(file_path)
510
+ save_label_encoders(label_encoders)
511
+
512
+ texts = data_df_original[TEXT_COLUMN]
513
+ labels_array = data_df_original[LABEL_COLUMNS].values
514
+
515
+ metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
516
+
517
+ num_labels_list = get_num_labels(label_encoders)
518
+ tokenizer = get_tokenizer(config.model_name)
519
+
520
+ if metadata_df is not None:
521
+ metadata_dim = metadata_df.shape[1]
522
+ dataset = ComplianceDatasetWithMetadata(
523
+ texts.tolist(),
524
+ metadata_df.values,
525
+ labels_array,
526
+ tokenizer,
527
+ config.max_length
528
+ )
529
+ model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
530
+ else:
531
+ dataset = ComplianceDataset(
532
+ texts.tolist(),
533
+ labels_array,
534
+ tokenizer,
535
+ config.max_length
536
+ )
537
+ model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
538
+
539
+ train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
540
+
541
+ criterions = initialize_criterions(num_labels_list)
542
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
543
+
544
+ for epoch in range(config.num_epochs):
545
+ training_status["current_epoch"] = epoch + 1
546
+
547
+ train_loss = train_model(model, train_loader, criterions, optimizer)
548
+ training_status["current_loss"] = train_loss
549
+
550
+ # Save model after each epoch
551
+ save_model(model, training_id)
552
+
553
+ training_status.update({
554
+ "is_training": False,
555
+ "end_time": datetime.now().isoformat(),
556
+ "status": "completed"
557
+ })
558
+
559
+ except Exception as e:
560
+ logger.error(f"Training failed: {str(e)}")
561
+ training_status.update({
562
+ "is_training": False,
563
+ "end_time": datetime.now().isoformat(),
564
+ "status": "failed",
565
+ "error": str(e)
566
+ })
567
+
568
+ if __name__ == "__main__":
569
+ port = int(os.environ.get("PORT", 7860))
570
+ uvicorn.run(app, host="0.0.0.0", port=port)