namanpenguin commited on
Commit
b81f538
·
verified ·
1 Parent(s): 9f213b7

Upload 10 files

Browse files
app.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ import logging
9
+ import os
10
+ import asyncio
11
+ import pandas as pd
12
+ from datetime import datetime
13
+ import shutil
14
+ from pathlib import Path
15
+ from sklearn.model_selection import train_test_split
16
+ import zipfile
17
+ import io
18
+ import numpy as np
19
+ import sys
20
+ import json
21
+
22
+
23
+ # Import existing utilities
24
+ from dataset_utils import (
25
+ ComplianceDataset,
26
+ ComplianceDatasetWithMetadata,
27
+ load_and_preprocess_data,
28
+ get_tokenizer,
29
+ save_label_encoders,
30
+ get_num_labels,
31
+ load_label_encoders
32
+ )
33
+ from train_utils import (
34
+ initialize_criterions,
35
+ train_model,
36
+ evaluate_model,
37
+ save_model,
38
+ summarize_metrics,
39
+ predict_probabilities
40
+ )
41
+ from models.roberta_model import RobertaMultiOutputModel
42
+ from config import (
43
+ TEXT_COLUMN,
44
+ LABEL_COLUMNS,
45
+ DEVICE,
46
+ NUM_EPOCHS,
47
+ LEARNING_RATE,
48
+ MAX_LEN,
49
+ BATCH_SIZE,
50
+ METADATA_COLUMNS
51
+ )
52
+
53
+ # Configure logging
54
+ logging.basicConfig(level=logging.INFO)
55
+ logger = logging.getLogger(__name__)
56
+
57
+ app = FastAPI(title="RoBERTa Compliance Predictor API")
58
+
59
+ # Create necessary directories
60
+ UPLOAD_DIR = Path("uploads")
61
+ MODEL_SAVE_DIR = Path("saved_models")
62
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
63
+ MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
64
+
65
+ # Global variables to track training status
66
+ training_status = {
67
+ "is_training": False,
68
+ "current_epoch": 0,
69
+ "total_epochs": 0,
70
+ "current_loss": 0.0,
71
+ "start_time": None,
72
+ "end_time": None,
73
+ "status": "idle",
74
+ "metrics": None
75
+ }
76
+
77
+ # Load the model and tokenizer for prediction
78
+ model_path = "ROBERTA_model.pth"
79
+ tokenizer = get_tokenizer('roberta-base')
80
+ model = RobertaMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
81
+ if os.path.exists(model_path):
82
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
83
+ model.eval()
84
+
85
+ class TrainingConfig(BaseModel):
86
+ model_name: str = "roberta-base"
87
+ batch_size: int = 8
88
+ learning_rate: float = 2e-5
89
+ num_epochs: int = 2
90
+ max_length: int = 128
91
+ random_state: int = 42
92
+
93
+ class TrainingResponse(BaseModel):
94
+ message: str
95
+ training_id: str
96
+ status: str
97
+ download_url: Optional[str] = None
98
+
99
+ class ValidationResponse(BaseModel):
100
+ message: str
101
+ metrics: Dict[str, Any]
102
+ predictions: List[Dict[str, Any]]
103
+
104
+ class TransactionData(BaseModel):
105
+ Transaction_Id: str
106
+ Hit_Seq: int
107
+ Hit_Id_List: str
108
+ Origin: str
109
+ Designation: str
110
+ Keywords: str
111
+ Name: str
112
+ SWIFT_Tag: str
113
+ Currency: str
114
+ Entity: str
115
+ Message: str
116
+ City: str
117
+ Country: str
118
+ State: str
119
+ Hit_Type: str
120
+ Record_Matching_String: str
121
+ WatchList_Match_String: str
122
+ Payment_Sender_Name: Optional[str] = ""
123
+ Payment_Reciever_Name: Optional[str] = ""
124
+ Swift_Message_Type: str
125
+ Text_Sanction_Data: str
126
+ Matched_Sanctioned_Entity: str
127
+ Is_Match: int
128
+ Red_Flag_Reason: str
129
+ Risk_Level: str
130
+ Risk_Score: float
131
+ Risk_Score_Description: str
132
+ CDD_Level: str
133
+ PEP_Status: str
134
+ Value_Date: str
135
+ Last_Review_Date: str
136
+ Next_Review_Date: str
137
+ Sanction_Description: str
138
+ Checker_Notes: str
139
+ Sanction_Context: str
140
+ Maker_Action: str
141
+ Customer_ID: int
142
+ Customer_Type: str
143
+ Industry: str
144
+ Transaction_Date_Time: str
145
+ Transaction_Type: str
146
+ Transaction_Channel: str
147
+ Originating_Bank: str
148
+ Beneficiary_Bank: str
149
+ Geographic_Origin: str
150
+ Geographic_Destination: str
151
+ Match_Score: float
152
+ Match_Type: str
153
+ Sanctions_List_Version: str
154
+ Screening_Date_Time: str
155
+ Risk_Category: str
156
+ Risk_Drivers: str
157
+ Alert_Status: str
158
+ Investigation_Outcome: str
159
+ Case_Owner_Analyst: str
160
+ Escalation_Level: str
161
+ Escalation_Date: str
162
+ Regulatory_Reporting_Flags: bool
163
+ Audit_Trail_Timestamp: str
164
+ Source_Of_Funds: str
165
+ Purpose_Of_Transaction: str
166
+ Beneficial_Owner: str
167
+ Sanctions_Exposure_History: bool
168
+
169
+ class PredictionRequest(BaseModel):
170
+ transaction_data: TransactionData
171
+ model_name: str = "ROBERTA_model" # Default to RoBERTa_model if not specified
172
+
173
+ class BatchPredictionResponse(BaseModel):
174
+ message: str
175
+ predictions: List[Dict[str, Any]]
176
+ metrics: Optional[Dict[str, Any]] = None
177
+
178
+ @app.get("/")
179
+ async def root():
180
+ return {"message": "RoBERTa Compliance Predictor API"}
181
+
182
+ @app.get("/v1/roberta/health")
183
+ async def health_check():
184
+ return {"status": "healthy"}
185
+
186
+ @app.get("/v1/roberta/training-status")
187
+ async def get_training_status():
188
+ return training_status
189
+
190
+ @app.post("/v1/roberta/train", response_model=TrainingResponse)
191
+ async def start_training(
192
+ config: str = Form(...),
193
+ background_tasks: BackgroundTasks = None,
194
+ file: UploadFile = File(...)
195
+ ):
196
+ if training_status["is_training"]:
197
+ raise HTTPException(status_code=400, detail="Training is already in progress")
198
+
199
+ if not file.filename.endswith('.csv'):
200
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
201
+
202
+ try:
203
+ # Parse the config JSON string into a TrainingConfig object
204
+ config_dict = json.loads(config)
205
+ training_config = TrainingConfig(**config_dict)
206
+ except json.JSONDecodeError:
207
+ raise HTTPException(status_code=400, detail="Invalid config JSON format")
208
+ except Exception as e:
209
+ raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
210
+
211
+ file_path = UPLOAD_DIR / file.filename
212
+ with file_path.open("wb") as buffer:
213
+ shutil.copyfileobj(file.file, buffer)
214
+
215
+ training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
216
+
217
+ training_status.update({
218
+ "is_training": True,
219
+ "current_epoch": 0,
220
+ "total_epochs": training_config.num_epochs,
221
+ "start_time": datetime.now().isoformat(),
222
+ "status": "starting"
223
+ })
224
+
225
+ background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
226
+
227
+ download_url = f"/v1/roberta/download-model/{training_id}"
228
+
229
+ return TrainingResponse(
230
+ message="Training started successfully",
231
+ training_id=training_id,
232
+ status="started",
233
+ download_url=download_url
234
+ )
235
+
236
+ @app.post("/v1/roberta/validate")
237
+ async def validate_model(
238
+ file: UploadFile = File(...),
239
+ model_name: str = "ROBERTA_model"
240
+ ):
241
+ """Validate a RoBERTa model on uploaded data"""
242
+ if not file.filename.endswith('.csv'):
243
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
244
+
245
+ try:
246
+ file_path = UPLOAD_DIR / file.filename
247
+ with file_path.open("wb") as buffer:
248
+ shutil.copyfileobj(file.file, buffer)
249
+
250
+ data_df, label_encoders = load_and_preprocess_data(str(file_path))
251
+
252
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
253
+ if not model_path.exists():
254
+ raise HTTPException(status_code=404, detail="RoBERTa model file not found")
255
+
256
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
257
+ metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
258
+
259
+ if metadata_df is not None:
260
+ metadata_dim = metadata_df.shape[1]
261
+ model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
262
+ else:
263
+ model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
264
+
265
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
266
+ model.eval()
267
+
268
+ texts = data_df[TEXT_COLUMN]
269
+ labels_array = data_df[LABEL_COLUMNS].values
270
+ tokenizer = get_tokenizer("roberta-base")
271
+
272
+ if metadata_df is not None:
273
+ dataset = ComplianceDatasetWithMetadata(
274
+ texts.tolist(),
275
+ metadata_df.values,
276
+ labels_array,
277
+ tokenizer,
278
+ MAX_LEN
279
+ )
280
+ else:
281
+ dataset = ComplianceDataset(
282
+ texts.tolist(),
283
+ labels_array,
284
+ tokenizer,
285
+ MAX_LEN
286
+ )
287
+
288
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
289
+ metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
290
+ summary_metrics = summarize_metrics(metrics).to_dict()
291
+
292
+ all_probs = predict_probabilities(model, dataloader)
293
+
294
+ predictions = []
295
+ for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
296
+ field = LABEL_COLUMNS[i]
297
+ label_encoder = label_encoders[field]
298
+ true_labels_orig = label_encoder.inverse_transform(true_labels)
299
+ pred_labels_orig = label_encoder.inverse_transform(pred_labels)
300
+
301
+ for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
302
+ predictions.append({
303
+ "field": field,
304
+ "true_label": true,
305
+ "predicted_label": pred,
306
+ "probabilities": probs.tolist()
307
+ })
308
+
309
+ return ValidationResponse(
310
+ message="Validation completed successfully",
311
+ metrics=summary_metrics,
312
+ predictions=predictions
313
+ )
314
+
315
+ except Exception as e:
316
+ logger.error(f"Validation failed: {str(e)}")
317
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
318
+ finally:
319
+ if os.path.exists(file_path):
320
+ os.remove(file_path)
321
+
322
+ @app.post("/v1/roberta/predict")
323
+ async def predict(
324
+ request: Optional[PredictionRequest] = None,
325
+ file: UploadFile = File(None),
326
+ model_name: str = "ROBERTA_model"
327
+ ):
328
+ """
329
+ Make predictions on either a single transaction or a batch of transactions from a CSV file.
330
+
331
+ You can either:
332
+ 1. Send a single transaction in the request body
333
+ 2. Upload a CSV file with multiple transactions
334
+
335
+ Parameters:
336
+ - file: CSV file containing transactions for batch prediction
337
+ - model_name: Name of the model to use for prediction (default: "ROBERTA_model")
338
+ """
339
+ try:
340
+ # Load the model
341
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
342
+ if not model_path.exists():
343
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
344
+
345
+ num_labels_list = [len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]
346
+ model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
347
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
348
+ model.eval()
349
+
350
+ # Handle batch prediction from CSV
351
+ if file and file.filename:
352
+ if not file.filename.endswith('.csv'):
353
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
354
+
355
+ file_path = UPLOAD_DIR / file.filename
356
+ with file_path.open("wb") as buffer:
357
+ shutil.copyfileobj(file.file, buffer)
358
+
359
+ try:
360
+ # Load and preprocess the CSV data
361
+ data_df, _ = load_and_preprocess_data(str(file_path))
362
+ texts = data_df[TEXT_COLUMN]
363
+
364
+ # Create dataset and dataloader
365
+ dataset = ComplianceDataset(
366
+ texts.tolist(),
367
+ [[0] * len(LABEL_COLUMNS)] * len(texts), # Dummy labels for prediction
368
+ tokenizer,
369
+ MAX_LEN
370
+ )
371
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE)
372
+
373
+ # Get predictions
374
+ all_probabilities = predict_probabilities(model, loader)
375
+ label_encoders = load_label_encoders()
376
+
377
+ # Process predictions
378
+ predictions = []
379
+ for i, row in data_df.iterrows():
380
+ transaction_pred = {}
381
+ for j, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
382
+ pred = np.argmax(probs[i])
383
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
384
+
385
+ class_probs = {
386
+ label: float(probs[i][j])
387
+ for j, label in enumerate(label_encoders[col].classes_)
388
+ }
389
+
390
+ transaction_pred[col] = {
391
+ "prediction": decoded_pred,
392
+ "probabilities": class_probs
393
+ }
394
+
395
+ predictions.append({
396
+ "transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
397
+ "predictions": transaction_pred
398
+ })
399
+
400
+ return BatchPredictionResponse(
401
+ message="Batch prediction completed successfully",
402
+ predictions=predictions
403
+ )
404
+
405
+ finally:
406
+ if os.path.exists(file_path):
407
+ os.remove(file_path)
408
+
409
+ # Handle single prediction
410
+ elif request and request.transaction_data:
411
+ input_data = pd.DataFrame([request.transaction_data.dict()])
412
+
413
+ text_input = f"<s>Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]}</s>"
414
+
415
+ dataset = ComplianceDataset(
416
+ texts=[text_input],
417
+ labels=[[0] * len(LABEL_COLUMNS)],
418
+ tokenizer=tokenizer,
419
+ max_len=MAX_LEN
420
+ )
421
+
422
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
423
+ all_probabilities = predict_probabilities(model, loader)
424
+
425
+ label_encoders = load_label_encoders()
426
+
427
+ response = {}
428
+ for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
429
+ pred = np.argmax(probs[0])
430
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
431
+
432
+ class_probs = {
433
+ label: float(probs[0][j])
434
+ for j, label in enumerate(label_encoders[col].classes_)
435
+ }
436
+
437
+ response[col] = {
438
+ "prediction": decoded_pred,
439
+ "probabilities": class_probs
440
+ }
441
+
442
+ return response
443
+
444
+ else:
445
+ raise HTTPException(
446
+ status_code=400,
447
+ detail="Either provide a transaction in the request body or upload a CSV file"
448
+ )
449
+
450
+ except Exception as e:
451
+ raise HTTPException(status_code=500, detail=str(e))
452
+
453
+ @app.get("/v1/roberta/download-model/{model_id}")
454
+ async def download_model(model_id: str):
455
+ """Download a trained model"""
456
+ model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
457
+ if not model_path.exists():
458
+ raise HTTPException(status_code=404, detail="Model not found")
459
+
460
+ return FileResponse(
461
+ path=model_path,
462
+ filename=f"roberta_model_{model_id}.pth",
463
+ media_type="application/octet-stream"
464
+ )
465
+
466
+ async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
467
+ try:
468
+ data_df_original, label_encoders = load_and_preprocess_data(file_path)
469
+ save_label_encoders(label_encoders)
470
+
471
+ texts = data_df_original[TEXT_COLUMN]
472
+ labels_array = data_df_original[LABEL_COLUMNS].values
473
+
474
+ metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
475
+
476
+ num_labels_list = get_num_labels(label_encoders)
477
+ tokenizer = get_tokenizer(config.model_name)
478
+
479
+ if metadata_df is not None:
480
+ metadata_dim = metadata_df.shape[1]
481
+ dataset = ComplianceDatasetWithMetadata(
482
+ texts.tolist(),
483
+ metadata_df.values,
484
+ labels_array,
485
+ tokenizer,
486
+ config.max_length
487
+ )
488
+ model = RobertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
489
+ else:
490
+ dataset = ComplianceDataset(
491
+ texts.tolist(),
492
+ labels_array,
493
+ tokenizer,
494
+ config.max_length
495
+ )
496
+ model = RobertaMultiOutputModel(num_labels_list).to(DEVICE)
497
+
498
+ train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
499
+
500
+ criterions = initialize_criterions(num_labels_list)
501
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
502
+
503
+ for epoch in range(config.num_epochs):
504
+ training_status["current_epoch"] = epoch + 1
505
+
506
+ train_loss = train_model(model, train_loader, criterions, optimizer)
507
+ training_status["current_loss"] = train_loss
508
+
509
+ # Save model after each epoch
510
+ save_model(model, training_id)
511
+
512
+ training_status.update({
513
+ "is_training": False,
514
+ "end_time": datetime.now().isoformat(),
515
+ "status": "completed"
516
+ })
517
+
518
+ except Exception as e:
519
+ logger.error(f"Training failed: {str(e)}")
520
+ training_status.update({
521
+ "is_training": False,
522
+ "end_time": datetime.now().isoformat(),
523
+ "status": "failed",
524
+ "error": str(e)
525
+ })
526
+
527
+ if __name__ == "__main__":
528
+ port = int(os.environ.get("PORT", 7860))
529
+ uvicorn.run(app, host="0.0.0.0", port=port)
config.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ import torch
4
+ import os
5
+
6
+ # --- Paths ---
7
+ # Adjust DATA_PATH to your actual data location
8
+ DATA_PATH = './data/synthetic_transactions_samples_5000.csv'
9
+ TOKENIZER_PATH = './tokenizer/'
10
+ LABEL_ENCODERS_PATH = './label_encoders.pkl'
11
+ MODEL_SAVE_DIR = './saved_models/'
12
+ PREDICTIONS_SAVE_DIR = './predictions/' # To save predictions for voting ensemble
13
+
14
+ # --- Data Columns ---
15
+ TEXT_COLUMN = "Sanction_Context"
16
+ # Define all your target label columns
17
+ LABEL_COLUMNS = [
18
+ "Red_Flag_Reason",
19
+ "Maker_Action",
20
+ "Escalation_Level",
21
+ "Risk_Category",
22
+ "Risk_Drivers",
23
+ "Investigation_Outcome"
24
+ ]
25
+ # Example metadata columns. Add actual numerical/categorical metadata if available in your CSV.
26
+ # For now, it's an empty list. If you add metadata, ensure these columns exist and are numeric or can be encoded.
27
+ METADATA_COLUMNS = [] # e.g., ["Risk_Score", "Transaction_Amount"]
28
+
29
+ # --- Model Hyperparameters ---
30
+ MAX_LEN = 128 # Maximum sequence length for transformer tokenizers
31
+ BATCH_SIZE = 16 # Batch size for training and evaluation
32
+ LEARNING_RATE = 2e-5 # Learning rate for AdamW optimizer
33
+ NUM_EPOCHS = 3 # Number of training epochs. Adjust based on convergence.
34
+ DROPOUT_RATE = 0.3 # Dropout rate for regularization
35
+
36
+ # --- Device Configuration ---
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # --- Specific Model Configurations ---
40
+ ROBERTA_MODEL_NAME = 'roberta-base'
41
+ BERT_MODEL_NAME = 'bert-base-uncased'
42
+ DEBERTA_MODEL_NAME = 'microsoft/deberta-base'
43
+
44
+ # TF-IDF
45
+ TFIDF_MAX_FEATURES = 5000 # Max features for TF-IDF vectorizer
46
+
47
+ # --- Field-Specific Strategy (Conceptual) ---
48
+ # This dictionary provides conceptual strategies for enhancing specific fields.
49
+ # Actual implementation requires adapting the models (e.g., custom loss functions, metadata integration).
50
+ FIELD_STRATEGIES = {
51
+ "Maker_Action": {
52
+ "loss": "focal_loss", # Requires custom Focal Loss implementation
53
+ "enhancements": ["action_templates", "context_prompt_tuning"] # Advanced NLP concepts
54
+ },
55
+ "Risk_Category": {
56
+ "enhancements": ["numerical_metadata", "transaction_patterns"] # Integrate METADATA_COLUMNS
57
+ },
58
+ "Escalation_Level": {
59
+ "enhancements": ["class_balancing", "policy_keyword_patterns"] # Handled by class weights/metadata
60
+ },
61
+ "Investigation_Outcome": {
62
+ "type": "classification_or_generation" # If generation, T5/BART would be needed.
63
+ }
64
+ }
65
+
66
+ # Ensure model save and predictions directories exist
67
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
68
+ os.makedirs(PREDICTIONS_SAVE_DIR, exist_ok=True)
69
+ os.makedirs(TOKENIZER_PATH, exist_ok=True)
dataset_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_utils.py
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from sklearn.preprocessing import LabelEncoder
7
+ from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
8
+ import pickle
9
+ import os
10
+
11
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
12
+
13
+ class ComplianceDataset(Dataset):
14
+ """
15
+ Custom Dataset class for handling text and multi-output labels for PyTorch models.
16
+ """
17
+ def __init__(self, texts, labels, tokenizer, max_len):
18
+ self.texts = texts
19
+ self.labels = labels
20
+ self.tokenizer = tokenizer
21
+ self.max_len = max_len
22
+
23
+ def __len__(self):
24
+ """Returns the total number of samples in the dataset."""
25
+ return len(self.texts)
26
+
27
+ def __getitem__(self, idx):
28
+ """
29
+ Retrieves a sample from the dataset at the given index.
30
+ Tokenizes the text and converts labels to a PyTorch tensor.
31
+ """
32
+ text = str(self.texts[idx])
33
+ # Tokenize the text, padding to max_length and truncating if longer.
34
+ # return_tensors="pt" ensures PyTorch tensors are returned.
35
+ inputs = self.tokenizer(
36
+ text,
37
+ padding='max_length',
38
+ truncation=True,
39
+ max_length=self.max_len,
40
+ return_tensors="pt"
41
+ )
42
+ # Squeeze removes the batch dimension (which is 1 here because we process one sample at a time)
43
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
44
+ # Convert labels to a PyTorch long tensor
45
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
46
+ return inputs, labels
47
+
48
+ class ComplianceDatasetWithMetadata(Dataset):
49
+ """
50
+ Custom Dataset class for handling text, additional numerical metadata, and multi-output labels.
51
+ Used for hybrid models combining text and tabular features.
52
+ """
53
+ def __init__(self, texts, metadata, labels, tokenizer, max_len):
54
+ self.texts = texts
55
+ self.metadata = metadata # Expects metadata as a NumPy array or list of lists
56
+ self.labels = labels
57
+ self.tokenizer = tokenizer
58
+ self.max_len = max_len
59
+
60
+ def __len__(self):
61
+ """Returns the total number of samples in the dataset."""
62
+ return len(self.texts)
63
+
64
+ def __getitem__(self, idx):
65
+ """
66
+ Retrieves a sample, its metadata, and labels from the dataset at the given index.
67
+ Tokenizes text, converts metadata and labels to PyTorch tensors.
68
+ """
69
+ text = str(self.texts[idx])
70
+ inputs = self.tokenizer(
71
+ text,
72
+ padding='max_length',
73
+ truncation=True,
74
+ max_length=self.max_len,
75
+ return_tensors="pt"
76
+ )
77
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
78
+ # Convert metadata for the current sample to a float tensor
79
+ metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
80
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
81
+ return inputs, metadata, labels
82
+
83
+ def load_and_preprocess_data(data_path):
84
+ """
85
+ Loads data from a CSV, fills missing values, and encodes categorical labels.
86
+ Also handles converting specified METADATA_COLUMNS to numeric.
87
+
88
+ Args:
89
+ data_path (str): Path to the CSV data file.
90
+
91
+ Returns:
92
+ tuple: A tuple containing:
93
+ - data (pd.DataFrame): The preprocessed DataFrame.
94
+ - label_encoders (dict): A dictionary of LabelEncoder objects for each label column.
95
+ """
96
+ data = pd.read_csv(data_path)
97
+ data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown"
98
+
99
+ # Convert metadata columns to numeric, coercing errors and filling NaNs with 0
100
+ # This ensures metadata is suitable for neural networks.
101
+ for col in METADATA_COLUMNS:
102
+ if col in data.columns:
103
+ data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value
104
+
105
+ label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
106
+ for col in LABEL_COLUMNS:
107
+ # Fit and transform each label column using its respective LabelEncoder
108
+ data[col] = label_encoders[col].fit_transform(data[col])
109
+ return data, label_encoders
110
+
111
+ def get_tokenizer(model_name):
112
+ """
113
+ Returns the appropriate Hugging Face tokenizer based on the model name.
114
+
115
+ Args:
116
+ model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased').
117
+
118
+ Returns:
119
+ transformers.PreTrainedTokenizer: The initialized tokenizer.
120
+ """
121
+ if "roberta" in model_name.lower():
122
+ return RobertaTokenizer.from_pretrained(model_name)
123
+ elif "bert" in model_name.lower():
124
+ return BertTokenizer.from_pretrained(model_name)
125
+ elif "deberta" in model_name.lower():
126
+ return DebertaTokenizer.from_pretrained(model_name)
127
+ else:
128
+ raise ValueError(f"Unsupported tokenizer for model: {model_name}")
129
+
130
+ def save_label_encoders(label_encoders):
131
+ """
132
+ Saves a dictionary of label encoders to a pickle file.
133
+ This is crucial for decoding predictions back to original labels.
134
+
135
+ Args:
136
+ label_encoders (dict): Dictionary of LabelEncoder objects.
137
+ """
138
+ with open(LABEL_ENCODERS_PATH, "wb") as f:
139
+ pickle.dump(label_encoders, f)
140
+ print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")
141
+
142
+ def load_label_encoders():
143
+ """
144
+ Loads a dictionary of label encoders from a pickle file.
145
+
146
+ Returns:
147
+ dict: Loaded dictionary of LabelEncoder objects.
148
+ """
149
+ with open(LABEL_ENCODERS_PATH, "rb") as f:
150
+ return pickle.load(f)
151
+ print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}")
152
+
153
+
154
+ def get_num_labels(label_encoders):
155
+ """
156
+ Returns a list containing the number of unique classes for each label column.
157
+ This list is used to define the output dimensions of the model's classification heads.
158
+
159
+ Args:
160
+ label_encoders (dict): Dictionary of LabelEncoder objects.
161
+
162
+ Returns:
163
+ list: A list of integers, where each integer is the number of classes for a label.
164
+ """
165
+ return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
dockerfile ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ software-properties-common \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy requirements file
16
+ COPY requirements.txt .
17
+
18
+ # Install Python dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Create necessary directories with proper permissions
22
+ RUN mkdir -p /app/uploads \
23
+ /app/saved_models/bert \
24
+ /app/predictions \
25
+ /app/tokenizer \
26
+ /app/cache \
27
+ && chmod -R 777 /app/uploads \
28
+ /app/saved_models \
29
+ /app/predictions \
30
+ /app/tokenizer \
31
+ /app/cache
32
+
33
+ # Copy the application code and utilities
34
+ COPY . /app/
35
+ COPY ../dataset_utils.py /app/
36
+ COPY ../train_utils.py /app/
37
+ COPY ../config.py /app/
38
+ COPY ../models/roberta_model.py /app/models/
39
+ COPY ../label_encoders.pkl /app/
40
+
41
+ # Set environment variables
42
+ ENV PYTHONPATH=/app
43
+ ENV PYTHONUNBUFFERED=1
44
+ ENV PORT=7860
45
+ ENV TRANSFORMERS_CACHE=/app/cache
46
+
47
+ # Expose the port the app runs on
48
+ EXPOSE 7860
49
+
50
+ # Command to run the application
51
+ CMD ["python", "app.py"]
label_encoders.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c336fd07858af76d40c7200de1a769099abeec25d4f48b999351318680d4e4d6
3
+ size 2047
models/__pycache__/roberta_model.cpython-311.pyc ADDED
Binary file (3.19 kB). View file
 
models/roberta_model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/roberta_model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import RobertaModel
6
+ from config import DROPOUT_RATE, ROBERTA_MODEL_NAME # Import ROBERTA_MODEL_NAME
7
+
8
+ class RobertaMultiOutputModel(nn.Module):
9
+ """
10
+ RoBERTa-based model for multi-output classification.
11
+ Uses a pre-trained RoBERTa model as its backbone. RoBERTa is an optimized
12
+ version of BERT, often performing better.
13
+ """
14
+ # Statically set tokenizer name for easy access in main.py
15
+ tokenizer_name = ROBERTA_MODEL_NAME
16
+
17
+ def __init__(self, num_labels):
18
+ """
19
+ Initializes the RobertaMultiOutputModel.
20
+
21
+ Args:
22
+ num_labels (list): A list where each element is the number of classes
23
+ for a corresponding label column.
24
+ """
25
+ super(RobertaMultiOutputModel, self).__init__()
26
+ # Load the pre-trained RoBERTa model.
27
+ # RoBERTa's pooler_output typically corresponds to the hidden state of the
28
+ # first token (<s>), which is often used for sequence classification.
29
+ self.roberta = RobertaModel.from_pretrained(ROBERTA_MODEL_NAME)
30
+ self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer
31
+
32
+ # Create classification heads for each label column.
33
+ self.classifiers = nn.ModuleList([
34
+ nn.Linear(self.roberta.config.hidden_size, n_classes) for n_classes in num_labels
35
+ ])
36
+
37
+ def forward(self, input_ids, attention_mask):
38
+ """
39
+ Performs the forward pass of the model.
40
+
41
+ Args:
42
+ input_ids (torch.Tensor): Tensor of token IDs.
43
+ attention_mask (torch.Tensor): Tensor indicating attention.
44
+
45
+ Returns:
46
+ list: A list of logit tensors, one for each classification head.
47
+ """
48
+ # Pass input_ids and attention_mask through RoBERTa.
49
+ # .pooler_output is used for classification.
50
+ pooled_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask).pooler_output
51
+
52
+ # Apply dropout
53
+ pooled_output = self.dropout(pooled_output)
54
+
55
+ # Pass the pooled output through each classification head.
56
+ return [classifier(pooled_output) for classifier in self.classifiers]
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn==0.24.0
3
+ pydantic==2.4.2
4
+ torch==2.1.0
5
+ transformers==4.35.0
6
+ pandas==2.1.2
7
+ numpy==1.24.3
8
+ scikit-learn==1.3.2
9
+ python-multipart==0.0.6
10
+ python-jose==3.3.0
11
+ passlib==1.7.4
12
+ bcrypt==4.0.1
13
+ python-dotenv==1.0.0
saved_models/ROBERTA_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9850fc49de688e8971e6950fe06656da078f65c92eba58dd60a569172b0a089c
3
+ size 498897494
train_utils.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_utils.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.optim import AdamW
6
+ from sklearn.metrics import classification_report
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import pandas as pd
11
+ import os
12
+ import joblib
13
+
14
+ from config import DEVICE, LABEL_COLUMNS, NUM_EPOCHS, LEARNING_RATE, MODEL_SAVE_DIR
15
+
16
+ def get_class_weights(data_df, field, label_encoder):
17
+ """
18
+ Computes balanced class weights for a given target field.
19
+ These weights can be used in the loss function to mitigate class imbalance.
20
+
21
+ Args:
22
+ data_df (pd.DataFrame): The DataFrame containing the original (unencoded) label data.
23
+ field (str): The name of the label column for which to compute weights.
24
+ label_encoder (sklearn.preprocessing.LabelEncoder): The label encoder fitted for this field.
25
+
26
+ Returns:
27
+ torch.Tensor: A tensor of class weights for the specified field.
28
+ """
29
+ # Get the original labels for the specified field
30
+ y = data_df[field].values
31
+ # Use label_encoder.transform directly - it will handle unseen labels
32
+ try:
33
+ y_encoded = label_encoder.transform(y)
34
+ except ValueError as e:
35
+ print(f"Warning: {e}")
36
+ print(f"Using only seen labels for class weights calculation")
37
+ # Filter out unseen labels
38
+ seen_labels = set(label_encoder.classes_)
39
+ y_filtered = [label for label in y if label in seen_labels]
40
+ y_encoded = label_encoder.transform(y_filtered)
41
+
42
+ # Ensure y_encoded is integer type
43
+ y_encoded = y_encoded.astype(int)
44
+
45
+ # Initialize counts for all possible classes
46
+ n_classes = len(label_encoder.classes_)
47
+ class_counts = np.zeros(n_classes, dtype=int)
48
+
49
+ # Count occurrences of each class
50
+ for i in range(n_classes):
51
+ class_counts[i] = np.sum(y_encoded == i)
52
+
53
+ # Calculate weights for all classes
54
+ total_samples = len(y_encoded)
55
+ class_weights = np.ones(n_classes) # Default weight of 1 for unseen classes
56
+ seen_classes = class_counts > 0
57
+ if np.any(seen_classes):
58
+ class_weights[seen_classes] = total_samples / (np.sum(seen_classes) * class_counts[seen_classes])
59
+
60
+ return torch.tensor(class_weights, dtype=torch.float)
61
+
62
+ def initialize_criterions(data_df, label_encoders):
63
+ """
64
+ Initializes CrossEntropyLoss criteria for each label column, applying class weights.
65
+
66
+ Args:
67
+ data_df (pd.DataFrame): The original (unencoded) DataFrame. Used to compute class weights.
68
+ label_encoders (dict): Dictionary of LabelEncoder objects.
69
+
70
+ Returns:
71
+ dict: A dictionary where keys are label column names and values are
72
+ initialized `torch.nn.CrossEntropyLoss` objects.
73
+ """
74
+ field_criterions = {}
75
+ for field in LABEL_COLUMNS:
76
+ # Get class weights for the current field
77
+ weights = get_class_weights(data_df, field, label_encoders[field])
78
+ # Initialize CrossEntropyLoss with the computed weights and move to the device
79
+ field_criterions[field] = torch.nn.CrossEntropyLoss(weight=weights.to(DEVICE))
80
+ return field_criterions
81
+
82
+ def train_model(model, loader, optimizer, field_criterions, epoch):
83
+ """
84
+ Trains the given PyTorch model for one epoch.
85
+
86
+ Args:
87
+ model (torch.nn.Module): The model to train.
88
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
89
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
90
+ field_criterions (dict): Dictionary of loss functions for each label.
91
+ epoch (int): Current epoch number (for progress bar description).
92
+
93
+ Returns:
94
+ float: Average training loss for the epoch.
95
+ """
96
+ model.train() # Set the model to training mode
97
+ total_loss = 0
98
+ # Use tqdm for a progress bar during training
99
+ tqdm_loader = tqdm(loader, desc=f"Epoch {epoch + 1} Training")
100
+
101
+ for batch in tqdm_loader:
102
+ # Unpack batch based on whether it contains metadata
103
+ if len(batch) == 2: # Text-only models (inputs, labels)
104
+ inputs, labels = batch
105
+ input_ids = inputs['input_ids'].to(DEVICE)
106
+ attention_mask = inputs['attention_mask'].to(DEVICE)
107
+ labels = labels.to(DEVICE)
108
+ # Forward pass through the model
109
+ outputs = model(input_ids, attention_mask)
110
+ elif len(batch) == 3: # Text + Metadata models (inputs, metadata, labels)
111
+ inputs, metadata, labels = batch
112
+ input_ids = inputs['input_ids'].to(DEVICE)
113
+ attention_mask = inputs['attention_mask'].to(DEVICE)
114
+ metadata = metadata.to(DEVICE)
115
+ labels = labels.to(DEVICE)
116
+ # Forward pass through the hybrid model
117
+ outputs = model(input_ids, attention_mask, metadata)
118
+ else:
119
+ raise ValueError("Unsupported batch format. Expected 2 or 3 items in batch.")
120
+
121
+ loss = 0
122
+ # Calculate total loss by summing loss for each label column
123
+ # `outputs` is a list of logits, one for each label column
124
+ for i, output_logits in enumerate(outputs):
125
+ # `labels[:, i]` gets the true labels for the i-th label column
126
+ # `field_criterions[LABEL_COLUMNS[i]]` selects the appropriate loss function
127
+ loss += field_criterions[LABEL_COLUMNS[i]](output_logits, labels[:, i])
128
+
129
+ optimizer.zero_grad() # Clear previous gradients
130
+ loss.backward() # Backpropagation
131
+ optimizer.step() # Update model parameters
132
+ total_loss += loss.item() # Accumulate loss
133
+ tqdm_loader.set_postfix(loss=loss.item()) # Update progress bar with current batch loss
134
+
135
+ return total_loss / len(loader) # Return average loss for the epoch
136
+
137
+ def evaluate_model(model, loader):
138
+ """
139
+ Evaluates the given PyTorch model on a validation/test set.
140
+
141
+ Args:
142
+ model (torch.nn.Module): The model to evaluate.
143
+ loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
144
+
145
+ Returns:
146
+ tuple: A tuple containing:
147
+ - reports (dict): Classification reports (dict format) for each label column.
148
+ - truths (list): List of true label arrays for each label column.
149
+ - predictions (list): List of predicted label arrays for each label column.
150
+ """
151
+ model.eval() # Set the model to evaluation mode (disables dropout, batch norm updates, etc.)
152
+ # Initialize lists to store predictions and true labels for each output head
153
+ predictions = [[] for _ in range(len(LABEL_COLUMNS))]
154
+ truths = [[] for _ in range(len(LABEL_COLUMNS))]
155
+
156
+ with torch.no_grad(): # Disable gradient calculations during evaluation for efficiency
157
+ for batch in tqdm(loader, desc="Evaluation"):
158
+ if len(batch) == 2:
159
+ inputs, labels = batch
160
+ input_ids = inputs['input_ids'].to(DEVICE)
161
+ attention_mask = inputs['attention_mask'].to(DEVICE)
162
+ labels = labels.to(DEVICE)
163
+ outputs = model(input_ids, attention_mask)
164
+ elif len(batch) == 3:
165
+ inputs, metadata, labels = batch
166
+ input_ids = inputs['input_ids'].to(DEVICE)
167
+ attention_mask = inputs['attention_mask'].to(DEVICE)
168
+ metadata = metadata.to(DEVICE)
169
+ labels = labels.to(DEVICE)
170
+ outputs = model(input_ids, attention_mask, metadata)
171
+ else:
172
+ raise ValueError("Unsupported batch format.")
173
+
174
+ for i, output_logits in enumerate(outputs):
175
+ # Get the predicted class by taking the argmax of the logits
176
+ preds = torch.argmax(output_logits, dim=1).cpu().numpy()
177
+ predictions[i].extend(preds)
178
+ # Get the true labels for the current output head
179
+ truths[i].extend(labels[:, i].cpu().numpy())
180
+
181
+ reports = {}
182
+ # Generate classification report for each label column
183
+ for i, col in enumerate(LABEL_COLUMNS):
184
+ try:
185
+ # `zero_division=0` handles cases where a class might have no true or predicted samples
186
+ reports[col] = classification_report(truths[i], predictions[i], output_dict=True, zero_division=0)
187
+ except ValueError:
188
+ # Handle cases where a label might not appear in the validation set,
189
+ # which could cause classification_report to fail.
190
+ print(f"Warning: Could not generate classification report for {col}. Skipping.")
191
+ reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
192
+ return reports, truths, predictions
193
+
194
+ def summarize_metrics(metrics):
195
+ """
196
+ Summarizes classification reports into a readable Pandas DataFrame.
197
+
198
+ Args:
199
+ metrics (dict): Dictionary of classification reports, as returned by `evaluate_model`.
200
+
201
+ Returns:
202
+ pd.DataFrame: A DataFrame summarizing precision, recall, f1-score, accuracy, and support for each field.
203
+ """
204
+ summary = []
205
+ for field, report in metrics.items():
206
+ # Safely get metrics, defaulting to 0 if not present (e.g., for empty reports)
207
+ precision = report['weighted avg']['precision'] if 'weighted avg' in report else 0
208
+ recall = report['weighted avg']['recall'] if 'weighted avg' in report else 0
209
+ f1 = report['weighted avg']['f1-score'] if 'weighted avg' in report else 0
210
+ support = report['weighted avg']['support'] if 'weighted avg' in report else 0
211
+ accuracy = report['accuracy'] if 'accuracy' in report else 0 # Accuracy is usually top-level
212
+ summary.append({
213
+ "Field": field,
214
+ "Precision": precision,
215
+ "Recall": recall,
216
+ "F1-Score": f1,
217
+ "Accuracy": accuracy,
218
+ "Support": support
219
+ })
220
+ return pd.DataFrame(summary)
221
+
222
+ def save_model(model, model_name, save_format='pth'):
223
+ """
224
+ Saves the state dictionary of a PyTorch model.
225
+
226
+ Args:
227
+ model (torch.nn.Module): The trained PyTorch model.
228
+ model_name (str): A descriptive name for the model (used for filename).
229
+ save_format (str): Format to save the model in ('pth' for PyTorch models, 'pickle' for traditional ML models).
230
+ """
231
+ # Construct the save path dynamically relative to the project root
232
+ if save_format == 'pth':
233
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
234
+ torch.save(model.state_dict(), model_path)
235
+ elif save_format == 'pickle':
236
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
237
+ joblib.dump(model, model_path)
238
+ else:
239
+ raise ValueError(f"Unsupported save format: {save_format}")
240
+
241
+ print(f"Model saved to {model_path}")
242
+
243
+ def load_model_state(model, model_name, model_class, num_labels, metadata_dim=0):
244
+ """
245
+ Loads the state dictionary into a PyTorch model.
246
+
247
+ Args:
248
+ model (torch.nn.Module): An initialized model instance (architecture).
249
+ model_name (str): The name of the model to load.
250
+ model_class (class): The class of the model (e.g., RobertaMultiOutputModel).
251
+ num_labels (list): List of number of classes for each label.
252
+ metadata_dim (int): Dimensionality of metadata features, if applicable (default 0 for text-only).
253
+
254
+ Returns:
255
+ torch.nn.Module: The model with loaded state_dict, moved to the correct device, and set to eval mode.
256
+ """
257
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
258
+ if not os.path.exists(model_path):
259
+ print(f"Warning: Model file not found at {model_path}. Returning a newly initialized model instance.")
260
+ # Re-initialize the model if not found, to ensure it has the correct architecture
261
+ if metadata_dim > 0:
262
+ return model_class(num_labels, metadata_dim=metadata_dim).to(DEVICE)
263
+ else:
264
+ return model_class(num_labels).to(DEVICE)
265
+
266
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
267
+ model.to(DEVICE)
268
+ model.eval() # Set to evaluation mode after loading
269
+ print(f"Model loaded from {model_path}")
270
+ return model
271
+
272
+ def predict_probabilities(model, loader):
273
+ """
274
+ Generates prediction probabilities for each label for a given model.
275
+ This is used for confidence scoring and feeding into a voting ensemble.
276
+
277
+ Args:
278
+ model (torch.nn.Module): The trained PyTorch model.
279
+ loader (torch.utils.data.DataLoader): DataLoader for the data to predict on.
280
+
281
+ Returns:
282
+ list: A list of lists of numpy arrays. Each inner list corresponds to a label column,
283
+ containing the softmax probabilities for each sample for that label.
284
+ """
285
+ model.eval() # Set to evaluation mode
286
+ # List to store probabilities for each output head
287
+ all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
288
+
289
+ with torch.no_grad():
290
+ for batch in tqdm(loader, desc="Predicting Probabilities"):
291
+ # Unpack batch, ignoring labels as we only need inputs
292
+ if len(batch) == 2:
293
+ inputs, _ = batch
294
+ input_ids = inputs['input_ids'].to(DEVICE)
295
+ attention_mask = inputs['attention_mask'].to(DEVICE)
296
+ outputs = model(input_ids, attention_mask)
297
+ elif len(batch) == 3:
298
+ inputs, metadata, _ = batch
299
+ input_ids = inputs['input_ids'].to(DEVICE)
300
+ attention_mask = inputs['attention_mask'].to(DEVICE)
301
+ metadata = metadata.to(DEVICE)
302
+ outputs = model(input_ids, attention_mask, metadata)
303
+ else:
304
+ raise ValueError("Unsupported batch format.")
305
+
306
+ for i, out_logits in enumerate(outputs):
307
+ # Apply softmax to logits to get probabilities
308
+ probs = torch.softmax(out_logits, dim=1).cpu().numpy()
309
+ all_probabilities[i].extend(probs)
310
+ return all_probabilities