namanpenguin commited on
Commit
ad944b3
·
verified ·
1 Parent(s): 8cc5725

Upload 15 files

Browse files
BERT_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7050d02ac599ef72d7b0410a79a72537fb44d4ac66eb8a1dc719329c8c4b07b
3
+ size 438239057
Dockerfile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ software-properties-common \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Create a non-root user
16
+ RUN useradd -m -u 1000 appuser
17
+
18
+ # Copy requirements file
19
+ COPY requirements.txt .
20
+
21
+ # Install Python dependencies
22
+ RUN pip install --no-cache-dir -r requirements.txt
23
+
24
+ # Create necessary directories with proper permissions
25
+ RUN mkdir -p /app/uploads \
26
+ /app/saved_models/bert \
27
+ /app/predictions \
28
+ /app/tokenizer \
29
+ && chmod -R 777 /app/uploads \
30
+ /app/saved_models \
31
+ /app/predictions \
32
+ /app/tokenizer
33
+
34
+ # Switch to non-root user
35
+ USER appuser
36
+
37
+ # Copy the application code and utilities
38
+ COPY . /app/
39
+ COPY ../dataset_utils.py /app/
40
+ COPY ../train_utils.py /app/
41
+ COPY ../config.py /app/
42
+ COPY ../models/bert_model.py /app/models/
43
+ COPY ../label_encoders.pkl /app/
44
+
45
+ # Set environment variables
46
+ ENV PYTHONPATH=/app
47
+ ENV PYTHONUNBUFFERED=1
48
+ ENV PORT=7860
49
+
50
+ # Expose the port the app runs on
51
+ EXPOSE 7860
52
+
53
+ # Command to run the application
54
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import torch
7
+ from transformers import BertTokenizer, BertForSequenceClassification
8
+ from torch.utils.data import DataLoader
9
+ import logging
10
+ import os
11
+ import asyncio
12
+ import pandas as pd
13
+ from datetime import datetime
14
+ import shutil
15
+ from pathlib import Path
16
+ from sklearn.model_selection import train_test_split
17
+ import zipfile
18
+ import io
19
+ import numpy as np
20
+ import sys
21
+
22
+
23
+ # Import existing utilities
24
+ from dataset_utils import (
25
+ ComplianceDataset,
26
+ ComplianceDatasetWithMetadata,
27
+ load_and_preprocess_data,
28
+ get_tokenizer,
29
+ save_label_encoders,
30
+ get_num_labels,
31
+ load_label_encoders
32
+ )
33
+ from train_utils import (
34
+ initialize_criterions,
35
+ train_model,
36
+ evaluate_model,
37
+ save_model,
38
+ summarize_metrics,
39
+ predict_probabilities
40
+ )
41
+ from models.bert_model import BertMultiOutputModel
42
+ from config import (
43
+ TEXT_COLUMN,
44
+ LABEL_COLUMNS,
45
+ DEVICE,
46
+ NUM_EPOCHS,
47
+ LEARNING_RATE,
48
+ MAX_LEN,
49
+ BATCH_SIZE,
50
+ METADATA_COLUMNS
51
+ )
52
+
53
+ # Configure logging
54
+ logging.basicConfig(level=logging.INFO)
55
+ logger = logging.getLogger(__name__)
56
+
57
+ app = FastAPI(title="BERT Compliance Predictor API")
58
+
59
+ # Create necessary directories
60
+ UPLOAD_DIR = Path("uploads")
61
+ MODEL_SAVE_DIR = Path("saved_models")
62
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
63
+ MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
64
+
65
+ # Global variables to track training status
66
+ training_status = {
67
+ "is_training": False,
68
+ "current_epoch": 0,
69
+ "total_epochs": 0,
70
+ "current_loss": 0.0,
71
+ "start_time": None,
72
+ "end_time": None,
73
+ "status": "idle",
74
+ "metrics": None
75
+ }
76
+
77
+ # Load the model and tokenizer for prediction
78
+ model_path = "BERT_model.pth"
79
+ tokenizer = get_tokenizer('bert-base-uncased')
80
+ model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
81
+ if os.path.exists(model_path):
82
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
83
+ model.eval()
84
+
85
+ class TrainingConfig(BaseModel):
86
+ model_name: str = "bert-base-uncased"
87
+ batch_size: int = 8
88
+ learning_rate: float = 2e-5
89
+ num_epochs: int = 2
90
+ max_length: int = 128
91
+ test_size: float = 0.2
92
+ random_state: int = 42
93
+
94
+ class TrainingResponse(BaseModel):
95
+ message: str
96
+ training_id: str
97
+ status: str
98
+ download_url: Optional[str] = None
99
+
100
+ class ValidationResponse(BaseModel):
101
+ message: str
102
+ metrics: Dict[str, Any]
103
+ predictions: List[Dict[str, Any]]
104
+
105
+ class TransactionData(BaseModel):
106
+ Transaction_Id: str
107
+ Hit_Seq: int
108
+ Hit_Id_List: str
109
+ Origin: str
110
+ Designation: str
111
+ Keywords: str
112
+ Name: str
113
+ SWIFT_Tag: str
114
+ Currency: str
115
+ Entity: str
116
+ Message: str
117
+ City: str
118
+ Country: str
119
+ State: str
120
+ Hit_Type: str
121
+ Record_Matching_String: str
122
+ WatchList_Match_String: str
123
+ Payment_Sender_Name: Optional[str] = ""
124
+ Payment_Reciever_Name: Optional[str] = ""
125
+ Swift_Message_Type: str
126
+ Text_Sanction_Data: str
127
+ Matched_Sanctioned_Entity: str
128
+ Is_Match: int
129
+ Red_Flag_Reason: str
130
+ Risk_Level: str
131
+ Risk_Score: float
132
+ Risk_Score_Description: str
133
+ CDD_Level: str
134
+ PEP_Status: str
135
+ Value_Date: str
136
+ Last_Review_Date: str
137
+ Next_Review_Date: str
138
+ Sanction_Description: str
139
+ Checker_Notes: str
140
+ Sanction_Context: str
141
+ Maker_Action: str
142
+ Customer_ID: int
143
+ Customer_Type: str
144
+ Industry: str
145
+ Transaction_Date_Time: str
146
+ Transaction_Type: str
147
+ Transaction_Channel: str
148
+ Originating_Bank: str
149
+ Beneficiary_Bank: str
150
+ Geographic_Origin: str
151
+ Geographic_Destination: str
152
+ Match_Score: float
153
+ Match_Type: str
154
+ Sanctions_List_Version: str
155
+ Screening_Date_Time: str
156
+ Risk_Category: str
157
+ Risk_Drivers: str
158
+ Alert_Status: str
159
+ Investigation_Outcome: str
160
+ Case_Owner_Analyst: str
161
+ Escalation_Level: str
162
+ Escalation_Date: str
163
+ Regulatory_Reporting_Flags: bool
164
+ Audit_Trail_Timestamp: str
165
+ Source_Of_Funds: str
166
+ Purpose_Of_Transaction: str
167
+ Beneficial_Owner: str
168
+ Sanctions_Exposure_History: bool
169
+
170
+ class PredictionRequest(BaseModel):
171
+ transaction_data: TransactionData
172
+
173
+ @app.get("/")
174
+ async def root():
175
+ return {"message": "BERT Compliance Predictor API"}
176
+
177
+ @app.get("/health")
178
+ async def health_check():
179
+ return {"status": "healthy"}
180
+
181
+ @app.get("/training-status")
182
+ async def get_training_status():
183
+ return training_status
184
+
185
+ @app.post("/upload")
186
+ async def upload_file(file: UploadFile = File(...)):
187
+ """Upload a CSV file for training or validation"""
188
+ if not file.filename.endswith('.csv'):
189
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
190
+
191
+ file_path = UPLOAD_DIR / file.filename
192
+ with file_path.open("wb") as buffer:
193
+ shutil.copyfileobj(file.file, buffer)
194
+
195
+ return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
196
+
197
+ @app.post("/bert/train", response_model=TrainingResponse)
198
+ async def start_training(
199
+ config: TrainingConfig,
200
+ background_tasks: BackgroundTasks,
201
+ file_path: str
202
+ ):
203
+ if training_status["is_training"]:
204
+ raise HTTPException(status_code=400, detail="Training is already in progress")
205
+
206
+ if not os.path.exists(file_path):
207
+ raise HTTPException(status_code=404, detail="Training file not found")
208
+
209
+ training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
210
+
211
+ training_status.update({
212
+ "is_training": True,
213
+ "current_epoch": 0,
214
+ "total_epochs": config.num_epochs,
215
+ "start_time": datetime.now().isoformat(),
216
+ "status": "starting"
217
+ })
218
+
219
+ background_tasks.add_task(train_model_task, config, file_path, training_id)
220
+
221
+ download_url = f"/bert/download-model/{training_id}"
222
+
223
+ return TrainingResponse(
224
+ message="Training started successfully",
225
+ training_id=training_id,
226
+ status="started",
227
+ download_url=download_url
228
+ )
229
+
230
+ @app.post("/bert/validate")
231
+ async def validate_model(
232
+ file: UploadFile = File(...),
233
+ model_name: str = "bert_model_latest"
234
+ ):
235
+ """Validate a BERT model on uploaded data"""
236
+ if not file.filename.endswith('.csv'):
237
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
238
+
239
+ try:
240
+ file_path = UPLOAD_DIR / file.filename
241
+ with file_path.open("wb") as buffer:
242
+ shutil.copyfileobj(file.file, buffer)
243
+
244
+ data_df, label_encoders = load_and_preprocess_data(str(file_path))
245
+
246
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
247
+ if not model_path.exists():
248
+ raise HTTPException(status_code=404, detail="BERT model file not found")
249
+
250
+ num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
251
+ metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
252
+
253
+ if metadata_df is not None:
254
+ metadata_dim = metadata_df.shape[1]
255
+ model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
256
+ else:
257
+ model = BertMultiOutputModel(num_labels_list).to(DEVICE)
258
+
259
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
260
+ model.eval()
261
+
262
+ texts = data_df[TEXT_COLUMN]
263
+ labels_array = data_df[LABEL_COLUMNS].values
264
+ tokenizer = get_tokenizer("bert-base-uncased")
265
+
266
+ if metadata_df is not None:
267
+ dataset = ComplianceDatasetWithMetadata(
268
+ texts.tolist(),
269
+ metadata_df.values,
270
+ labels_array,
271
+ tokenizer,
272
+ MAX_LEN
273
+ )
274
+ else:
275
+ dataset = ComplianceDataset(
276
+ texts.tolist(),
277
+ labels_array,
278
+ tokenizer,
279
+ MAX_LEN
280
+ )
281
+
282
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
283
+ metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
284
+ summary_metrics = summarize_metrics(metrics).to_dict()
285
+
286
+ all_probs = predict_probabilities(model, dataloader)
287
+
288
+ predictions = []
289
+ for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
290
+ field = LABEL_COLUMNS[i]
291
+ label_encoder = label_encoders[field]
292
+ true_labels_orig = label_encoder.inverse_transform(true_labels)
293
+ pred_labels_orig = label_encoder.inverse_transform(pred_labels)
294
+
295
+ for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
296
+ predictions.append({
297
+ "field": field,
298
+ "true_label": true,
299
+ "predicted_label": pred,
300
+ "probabilities": probs.tolist()
301
+ })
302
+
303
+ return ValidationResponse(
304
+ message="Validation completed successfully",
305
+ metrics=summary_metrics,
306
+ predictions=predictions
307
+ )
308
+
309
+ except Exception as e:
310
+ logger.error(f"Validation failed: {str(e)}")
311
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
312
+ finally:
313
+ if os.path.exists(file_path):
314
+ os.remove(file_path)
315
+
316
+ @app.post("/bert/predict")
317
+ async def predict(request: PredictionRequest):
318
+ """Make predictions on a single transaction"""
319
+ try:
320
+ input_data = pd.DataFrame([request.transaction_data.dict()])
321
+
322
+ text_input = f"""
323
+ Transaction ID: {input_data['Transaction_Id'].iloc[0]}
324
+ Origin: {input_data['Origin'].iloc[0]}
325
+ Designation: {input_data['Designation'].iloc[0]}
326
+ Keywords: {input_data['Keywords'].iloc[0]}
327
+ Name: {input_data['Name'].iloc[0]}
328
+ SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]}
329
+ Currency: {input_data['Currency'].iloc[0]}
330
+ Entity: {input_data['Entity'].iloc[0]}
331
+ Message: {input_data['Message'].iloc[0]}
332
+ City: {input_data['City'].iloc[0]}
333
+ Country: {input_data['Country'].iloc[0]}
334
+ State: {input_data['State'].iloc[0]}
335
+ Hit Type: {input_data['Hit_Type'].iloc[0]}
336
+ Record Matching String: {input_data['Record_Matching_String'].iloc[0]}
337
+ WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]}
338
+ Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]}
339
+ Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]}
340
+ Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]}
341
+ Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]}
342
+ Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]}
343
+ Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]}
344
+ Risk Level: {input_data['Risk_Level'].iloc[0]}
345
+ Risk Score: {input_data['Risk_Score'].iloc[0]}
346
+ CDD Level: {input_data['CDD_Level'].iloc[0]}
347
+ PEP Status: {input_data['PEP_Status'].iloc[0]}
348
+ Sanction Description: {input_data['Sanction_Description'].iloc[0]}
349
+ Checker Notes: {input_data['Checker_Notes'].iloc[0]}
350
+ Sanction Context: {input_data['Sanction_Context'].iloc[0]}
351
+ Maker Action: {input_data['Maker_Action'].iloc[0]}
352
+ Customer Type: {input_data['Customer_Type'].iloc[0]}
353
+ Industry: {input_data['Industry'].iloc[0]}
354
+ Transaction Type: {input_data['Transaction_Type'].iloc[0]}
355
+ Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
356
+ Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
357
+ Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
358
+ Risk Category: {input_data['Risk_Category'].iloc[0]}
359
+ Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
360
+ Alert Status: {input_data['Alert_Status'].iloc[0]}
361
+ Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
362
+ Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
363
+ Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
364
+ Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
365
+ """
366
+
367
+ dataset = ComplianceDataset(
368
+ texts=[text_input],
369
+ labels=[[0] * len(LABEL_COLUMNS)],
370
+ tokenizer=tokenizer,
371
+ max_len=MAX_LEN
372
+ )
373
+
374
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
375
+ all_probabilities = predict_probabilities(model, loader)
376
+
377
+ label_encoders = load_label_encoders()
378
+
379
+ response = {}
380
+ for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
381
+ pred = np.argmax(probs[0])
382
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
383
+
384
+ class_probs = {
385
+ label: float(probs[0][j])
386
+ for j, label in enumerate(label_encoders[col].classes_)
387
+ }
388
+
389
+ response[col] = {
390
+ "prediction": decoded_pred,
391
+ "probabilities": class_probs
392
+ }
393
+
394
+ return response
395
+
396
+ except Exception as e:
397
+ raise HTTPException(status_code=500, detail=str(e))
398
+
399
+ @app.get("/bert/download-model/{model_id}")
400
+ async def download_model(model_id: str):
401
+ """Download a trained model"""
402
+ model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
403
+ if not model_path.exists():
404
+ raise HTTPException(status_code=404, detail="Model not found")
405
+
406
+ return FileResponse(
407
+ path=model_path,
408
+ filename=f"bert_model_{model_id}.pth",
409
+ media_type="application/octet-stream"
410
+ )
411
+
412
+ async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
413
+ try:
414
+ data_df_original, label_encoders = load_and_preprocess_data(file_path)
415
+ save_label_encoders(label_encoders)
416
+
417
+ train_df, val_df = train_test_split(
418
+ data_df_original,
419
+ test_size=config.test_size,
420
+ random_state=config.random_state,
421
+ stratify=data_df_original[LABEL_COLUMNS[0]]
422
+ )
423
+
424
+ train_texts = train_df[TEXT_COLUMN]
425
+ val_texts = val_df[TEXT_COLUMN]
426
+ train_labels_array = train_df[LABEL_COLUMNS].values
427
+ val_labels_array = val_df[LABEL_COLUMNS].values
428
+
429
+ train_metadata_df = train_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in train_df.columns for col in METADATA_COLUMNS) else None
430
+ val_metadata_df = val_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in val_df.columns for col in METADATA_COLUMNS) else None
431
+
432
+ num_labels_list = get_num_labels(label_encoders)
433
+ tokenizer = get_tokenizer(config.model_name)
434
+
435
+ if train_metadata_df is not None and val_metadata_df is not None:
436
+ metadata_dim = train_metadata_df.shape[1]
437
+ train_dataset = ComplianceDatasetWithMetadata(
438
+ train_texts.tolist(),
439
+ train_metadata_df.values,
440
+ train_labels_array,
441
+ tokenizer,
442
+ config.max_length
443
+ )
444
+ val_dataset = ComplianceDatasetWithMetadata(
445
+ val_texts.tolist(),
446
+ val_metadata_df.values,
447
+ val_labels_array,
448
+ tokenizer,
449
+ config.max_length
450
+ )
451
+ model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
452
+ else:
453
+ train_dataset = ComplianceDataset(
454
+ train_texts.tolist(),
455
+ train_labels_array,
456
+ tokenizer,
457
+ config.max_length
458
+ )
459
+ val_dataset = ComplianceDataset(
460
+ val_texts.tolist(),
461
+ val_labels_array,
462
+ tokenizer,
463
+ config.max_length
464
+ )
465
+ model = BertMultiOutputModel(num_labels_list).to(DEVICE)
466
+
467
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
468
+ val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
469
+
470
+ criterions = initialize_criterions(num_labels_list)
471
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
472
+
473
+ best_val_loss = float('inf')
474
+ for epoch in range(config.num_epochs):
475
+ training_status["current_epoch"] = epoch + 1
476
+
477
+ train_loss = train_model(model, train_loader, criterions, optimizer)
478
+ val_metrics, _, _ = evaluate_model(model, val_loader)
479
+
480
+ training_status["current_loss"] = train_loss
481
+
482
+ if val_metrics["loss"] < best_val_loss:
483
+ best_val_loss = val_metrics["loss"]
484
+ save_model(model, training_id)
485
+
486
+ training_status.update({
487
+ "is_training": False,
488
+ "end_time": datetime.now().isoformat(),
489
+ "status": "completed",
490
+ "metrics": summarize_metrics(val_metrics).to_dict()
491
+ })
492
+
493
+ except Exception as e:
494
+ logger.error(f"Training failed: {str(e)}")
495
+ training_status.update({
496
+ "is_training": False,
497
+ "end_time": datetime.now().isoformat(),
498
+ "status": "failed",
499
+ "error": str(e)
500
+ })
501
+
502
+ if __name__ == "__main__":
503
+ port = int(os.environ.get("PORT", 7860))
504
+ uvicorn.run(app, host="0.0.0.0", port=port)
config.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ import torch
4
+ import os
5
+
6
+ # --- Paths ---
7
+ # Adjust DATA_PATH to your actual data location
8
+ DATA_PATH = './data/synthetic_transactions_samples_5000.csv'
9
+ TOKENIZER_PATH = './tokenizer/'
10
+ LABEL_ENCODERS_PATH = './label_encoders.pkl'
11
+ MODEL_SAVE_DIR = './saved_models/'
12
+ PREDICTIONS_SAVE_DIR = './predictions/' # To save predictions for voting ensemble
13
+
14
+ # --- Data Columns ---
15
+ TEXT_COLUMN = "Sanction_Context"
16
+ # Define all your target label columns
17
+ LABEL_COLUMNS = [
18
+ "Red_Flag_Reason",
19
+ "Maker_Action",
20
+ "Escalation_Level",
21
+ "Risk_Category",
22
+ "Risk_Drivers",
23
+ "Investigation_Outcome"
24
+ ]
25
+ # Example metadata columns. Add actual numerical/categorical metadata if available in your CSV.
26
+ # For now, it's an empty list. If you add metadata, ensure these columns exist and are numeric or can be encoded.
27
+ METADATA_COLUMNS = [] # e.g., ["Risk_Score", "Transaction_Amount"]
28
+
29
+ # --- Model Hyperparameters ---
30
+ MAX_LEN = 128 # Maximum sequence length for transformer tokenizers
31
+ BATCH_SIZE = 16 # Batch size for training and evaluation
32
+ LEARNING_RATE = 2e-5 # Learning rate for AdamW optimizer
33
+ NUM_EPOCHS = 3 # Number of training epochs. Adjust based on convergence.
34
+ DROPOUT_RATE = 0.3 # Dropout rate for regularization
35
+
36
+ # --- Device Configuration ---
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # --- Specific Model Configurations ---
40
+ BERT_MODEL_NAME = 'bert-base-uncased'
41
+ ROBERTA_MODEL_NAME = 'roberta-base'
42
+ DEBERTA_MODEL_NAME = 'microsoft/deberta-base'
43
+
44
+ # TF-IDF
45
+ TFIDF_MAX_FEATURES = 5000 # Max features for TF-IDF vectorizer
46
+
47
+ # --- Field-Specific Strategy (Conceptual) ---
48
+ # This dictionary provides conceptual strategies for enhancing specific fields.
49
+ # Actual implementation requires adapting the models (e.g., custom loss functions, metadata integration).
50
+ FIELD_STRATEGIES = {
51
+ "Maker_Action": {
52
+ "loss": "focal_loss", # Requires custom Focal Loss implementation
53
+ "enhancements": ["action_templates", "context_prompt_tuning"] # Advanced NLP concepts
54
+ },
55
+ "Risk_Category": {
56
+ "enhancements": ["numerical_metadata", "transaction_patterns"] # Integrate METADATA_COLUMNS
57
+ },
58
+ "Escalation_Level": {
59
+ "enhancements": ["class_balancing", "policy_keyword_patterns"] # Handled by class weights/metadata
60
+ },
61
+ "Investigation_Outcome": {
62
+ "type": "classification_or_generation" # If generation, T5/BART would be needed.
63
+ }
64
+ }
65
+
66
+ # Ensure model save and predictions directories exist
67
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
68
+ os.makedirs(PREDICTIONS_SAVE_DIR, exist_ok=True)
69
+ os.makedirs(TOKENIZER_PATH, exist_ok=True)
dataset_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_utils.py
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from sklearn.preprocessing import LabelEncoder
7
+ from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
8
+ import pickle
9
+ import os
10
+
11
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
12
+
13
+ class ComplianceDataset(Dataset):
14
+ """
15
+ Custom Dataset class for handling text and multi-output labels for PyTorch models.
16
+ """
17
+ def __init__(self, texts, labels, tokenizer, max_len):
18
+ self.texts = texts
19
+ self.labels = labels
20
+ self.tokenizer = tokenizer
21
+ self.max_len = max_len
22
+
23
+ def __len__(self):
24
+ """Returns the total number of samples in the dataset."""
25
+ return len(self.texts)
26
+
27
+ def __getitem__(self, idx):
28
+ """
29
+ Retrieves a sample from the dataset at the given index.
30
+ Tokenizes the text and converts labels to a PyTorch tensor.
31
+ """
32
+ text = str(self.texts[idx])
33
+ # Tokenize the text, padding to max_length and truncating if longer.
34
+ # return_tensors="pt" ensures PyTorch tensors are returned.
35
+ inputs = self.tokenizer(
36
+ text,
37
+ padding='max_length',
38
+ truncation=True,
39
+ max_length=self.max_len,
40
+ return_tensors="pt"
41
+ )
42
+ # Squeeze removes the batch dimension (which is 1 here because we process one sample at a time)
43
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
44
+ # Convert labels to a PyTorch long tensor
45
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
46
+ return inputs, labels
47
+
48
+ class ComplianceDatasetWithMetadata(Dataset):
49
+ """
50
+ Custom Dataset class for handling text, additional numerical metadata, and multi-output labels.
51
+ Used for hybrid models combining text and tabular features.
52
+ """
53
+ def __init__(self, texts, metadata, labels, tokenizer, max_len):
54
+ self.texts = texts
55
+ self.metadata = metadata # Expects metadata as a NumPy array or list of lists
56
+ self.labels = labels
57
+ self.tokenizer = tokenizer
58
+ self.max_len = max_len
59
+
60
+ def __len__(self):
61
+ """Returns the total number of samples in the dataset."""
62
+ return len(self.texts)
63
+
64
+ def __getitem__(self, idx):
65
+ """
66
+ Retrieves a sample, its metadata, and labels from the dataset at the given index.
67
+ Tokenizes text, converts metadata and labels to PyTorch tensors.
68
+ """
69
+ text = str(self.texts[idx])
70
+ inputs = self.tokenizer(
71
+ text,
72
+ padding='max_length',
73
+ truncation=True,
74
+ max_length=self.max_len,
75
+ return_tensors="pt"
76
+ )
77
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
78
+ # Convert metadata for the current sample to a float tensor
79
+ metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
80
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
81
+ return inputs, metadata, labels
82
+
83
+ def load_and_preprocess_data(data_path):
84
+ """
85
+ Loads data from a CSV, fills missing values, and encodes categorical labels.
86
+ Also handles converting specified METADATA_COLUMNS to numeric.
87
+
88
+ Args:
89
+ data_path (str): Path to the CSV data file.
90
+
91
+ Returns:
92
+ tuple: A tuple containing:
93
+ - data (pd.DataFrame): The preprocessed DataFrame.
94
+ - label_encoders (dict): A dictionary of LabelEncoder objects for each label column.
95
+ """
96
+ data = pd.read_csv(data_path)
97
+ data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown"
98
+
99
+ # Convert metadata columns to numeric, coercing errors and filling NaNs with 0
100
+ # This ensures metadata is suitable for neural networks.
101
+ for col in METADATA_COLUMNS:
102
+ if col in data.columns:
103
+ data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value
104
+
105
+ label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
106
+ for col in LABEL_COLUMNS:
107
+ # Fit and transform each label column using its respective LabelEncoder
108
+ data[col] = label_encoders[col].fit_transform(data[col])
109
+ return data, label_encoders
110
+
111
+ def get_tokenizer(model_name):
112
+ """
113
+ Returns the appropriate Hugging Face tokenizer based on the model name.
114
+
115
+ Args:
116
+ model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased').
117
+
118
+ Returns:
119
+ transformers.PreTrainedTokenizer: The initialized tokenizer.
120
+ """
121
+ if "bert" in model_name.lower():
122
+ return BertTokenizer.from_pretrained(model_name)
123
+ elif "roberta" in model_name.lower():
124
+ return RobertaTokenizer.from_pretrained(model_name)
125
+ elif "deberta" in model_name.lower():
126
+ return DebertaTokenizer.from_pretrained(model_name)
127
+ else:
128
+ raise ValueError(f"Unsupported tokenizer for model: {model_name}")
129
+
130
+ def save_label_encoders(label_encoders):
131
+ """
132
+ Saves a dictionary of label encoders to a pickle file.
133
+ This is crucial for decoding predictions back to original labels.
134
+
135
+ Args:
136
+ label_encoders (dict): Dictionary of LabelEncoder objects.
137
+ """
138
+ with open(LABEL_ENCODERS_PATH, "wb") as f:
139
+ pickle.dump(label_encoders, f)
140
+ print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")
141
+
142
+ def load_label_encoders():
143
+ """
144
+ Loads a dictionary of label encoders from a pickle file.
145
+
146
+ Returns:
147
+ dict: Loaded dictionary of LabelEncoder objects.
148
+ """
149
+ with open(LABEL_ENCODERS_PATH, "rb") as f:
150
+ return pickle.load(f)
151
+ print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}")
152
+
153
+
154
+ def get_num_labels(label_encoders):
155
+ """
156
+ Returns a list containing the number of unique classes for each label column.
157
+ This list is used to define the output dimensions of the model's classification heads.
158
+
159
+ Args:
160
+ label_encoders (dict): Dictionary of LabelEncoder objects.
161
+
162
+ Returns:
163
+ list: A list of integers, where each integer is the number of classes for a label.
164
+ """
165
+ return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
docker-compose.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ bert-api:
5
+ build: .
6
+ ports:
7
+ - "7860:7860"
8
+ volumes:
9
+ - ../saved_models:/app/saved_models
10
+ - ../tokenizer:/app/tokenizer
11
+ - ../predictions:/app/predictions
12
+ - ../label_encoders.pkl:/app/label_encoders.pkl
13
+ - ../.cache:/app/.cache
14
+ environment:
15
+ - PYTHONUNBUFFERED=1
16
+ - TRANSFORMERS_CACHE=/app/.cache
17
+ - PORT=7860
18
+ restart: unless-stopped
label_encoders.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c336fd07858af76d40c7200de1a769099abeec25d4f48b999351318680d4e4d6
3
+ size 2047
models/__pycache__/bert_model.cpython-311.pyc ADDED
Binary file (3.29 kB). View file
 
models/__pycache__/deberta_model.cpython-311.pyc ADDED
Binary file (3.15 kB). View file
 
models/__pycache__/parallel_bert_deberta.cpython-311.pyc ADDED
Binary file (6.45 kB). View file
 
models/__pycache__/roberta_model.cpython-311.pyc ADDED
Binary file (3.18 kB). View file
 
models/__pycache__/text_and_metadata_model.cpython-311.pyc ADDED
Binary file (4.09 kB). View file
 
models/bert_model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/bert_model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import BertModel
6
+ from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME from config
7
+
8
+ class BertMultiOutputModel(nn.Module):
9
+ """
10
+ BERT-based model for multi-output classification.
11
+ It uses a pre-trained BERT model as its backbone and adds a dropout layer
12
+ followed by separate linear classification heads for each target label.
13
+ """
14
+ # Statically set tokenizer name for easy access in main.py
15
+ tokenizer_name = BERT_MODEL_NAME
16
+
17
+ def __init__(self, num_labels):
18
+ """
19
+ Initializes the BertMultiOutputModel.
20
+
21
+ Args:
22
+ num_labels (list): A list where each element is the number of classes
23
+ for a corresponding label column.
24
+ """
25
+ super(BertMultiOutputModel, self).__init__()
26
+ # Load the pre-trained BERT model.
27
+ # BertModel provides contextual embeddings and a pooled output for classification.
28
+ self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
29
+ self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
30
+
31
+ # Create a list of classification heads, one for each label column.
32
+ # Each head is a linear layer mapping BERT's pooled output size to the number of classes for that label.
33
+ self.classifiers = nn.ModuleList([
34
+ nn.Linear(self.bert.config.hidden_size, n_classes) for n_classes in num_labels
35
+ ])
36
+
37
+ def forward(self, input_ids, attention_mask):
38
+ """
39
+ Performs the forward pass of the model.
40
+
41
+ Args:
42
+ input_ids (torch.Tensor): Tensor of token IDs (from tokenizer).
43
+ attention_mask (torch.Tensor): Tensor indicating attention (from tokenizer).
44
+
45
+ Returns:
46
+ list: A list of logit tensors, one for each classification head.
47
+ Each tensor has shape (batch_size, num_classes_for_that_label).
48
+ """
49
+ # Pass input_ids and attention_mask through BERT.
50
+ # .pooler_output typically represents the hidden state of the [CLS] token,
51
+ # processed through a linear layer and tanh activation, often used for classification.
52
+ pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
53
+
54
+ # Apply dropout for regularization
55
+ pooled_output = self.dropout(pooled_output)
56
+
57
+ # Pass the pooled output through each classification head.
58
+ # The result is a list of logits (raw scores before softmax/sigmoid) for each label.
59
+ return [classifier(pooled_output) for classifier in self.classifiers]
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn==0.24.0
3
+ pydantic==2.4.2
4
+ torch==2.1.0
5
+ transformers==4.35.0
6
+ pandas==2.1.2
7
+ numpy==1.24.3
8
+ scikit-learn==1.3.2
9
+ python-multipart==0.0.6
10
+ python-jose==3.3.0
11
+ passlib==1.7.4
12
+ bcrypt==4.0.1
13
+ python-dotenv==1.0.0
train_utils.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_utils.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.optim import AdamW
6
+ from sklearn.metrics import classification_report
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import pandas as pd
11
+ import os
12
+ import joblib
13
+
14
+ from config import DEVICE, LABEL_COLUMNS, NUM_EPOCHS, LEARNING_RATE, MODEL_SAVE_DIR
15
+
16
+ def get_class_weights(data_df, field, label_encoder):
17
+ """
18
+ Computes balanced class weights for a given target field.
19
+ These weights can be used in the loss function to mitigate class imbalance.
20
+
21
+ Args:
22
+ data_df (pd.DataFrame): The DataFrame containing the original (unencoded) label data.
23
+ field (str): The name of the label column for which to compute weights.
24
+ label_encoder (sklearn.preprocessing.LabelEncoder): The label encoder fitted for this field.
25
+
26
+ Returns:
27
+ torch.Tensor: A tensor of class weights for the specified field.
28
+ """
29
+ # Get the original labels for the specified field
30
+ y = data_df[field].values
31
+ # Use label_encoder.transform directly - it will handle unseen labels
32
+ try:
33
+ y_encoded = label_encoder.transform(y)
34
+ except ValueError as e:
35
+ print(f"Warning: {e}")
36
+ print(f"Using only seen labels for class weights calculation")
37
+ # Filter out unseen labels
38
+ seen_labels = set(label_encoder.classes_)
39
+ y_filtered = [label for label in y if label in seen_labels]
40
+ y_encoded = label_encoder.transform(y_filtered)
41
+
42
+ # Ensure y_encoded is integer type
43
+ y_encoded = y_encoded.astype(int)
44
+
45
+ # Initialize counts for all possible classes
46
+ n_classes = len(label_encoder.classes_)
47
+ class_counts = np.zeros(n_classes, dtype=int)
48
+
49
+ # Count occurrences of each class
50
+ for i in range(n_classes):
51
+ class_counts[i] = np.sum(y_encoded == i)
52
+
53
+ # Calculate weights for all classes
54
+ total_samples = len(y_encoded)
55
+ class_weights = np.ones(n_classes) # Default weight of 1 for unseen classes
56
+ seen_classes = class_counts > 0
57
+ if np.any(seen_classes):
58
+ class_weights[seen_classes] = total_samples / (np.sum(seen_classes) * class_counts[seen_classes])
59
+
60
+ return torch.tensor(class_weights, dtype=torch.float)
61
+
62
+ def initialize_criterions(data_df, label_encoders):
63
+ """
64
+ Initializes CrossEntropyLoss criteria for each label column, applying class weights.
65
+
66
+ Args:
67
+ data_df (pd.DataFrame): The original (unencoded) DataFrame. Used to compute class weights.
68
+ label_encoders (dict): Dictionary of LabelEncoder objects.
69
+
70
+ Returns:
71
+ dict: A dictionary where keys are label column names and values are
72
+ initialized `torch.nn.CrossEntropyLoss` objects.
73
+ """
74
+ field_criterions = {}
75
+ for field in LABEL_COLUMNS:
76
+ # Get class weights for the current field
77
+ weights = get_class_weights(data_df, field, label_encoders[field])
78
+ # Initialize CrossEntropyLoss with the computed weights and move to the device
79
+ field_criterions[field] = torch.nn.CrossEntropyLoss(weight=weights.to(DEVICE))
80
+ return field_criterions
81
+
82
+ def train_model(model, loader, optimizer, field_criterions, epoch):
83
+ """
84
+ Trains the given PyTorch model for one epoch.
85
+
86
+ Args:
87
+ model (torch.nn.Module): The model to train.
88
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
89
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
90
+ field_criterions (dict): Dictionary of loss functions for each label.
91
+ epoch (int): Current epoch number (for progress bar description).
92
+
93
+ Returns:
94
+ float: Average training loss for the epoch.
95
+ """
96
+ model.train() # Set the model to training mode
97
+ total_loss = 0
98
+ # Use tqdm for a progress bar during training
99
+ tqdm_loader = tqdm(loader, desc=f"Epoch {epoch + 1} Training")
100
+
101
+ for batch in tqdm_loader:
102
+ # Unpack batch based on whether it contains metadata
103
+ if len(batch) == 2: # Text-only models (inputs, labels)
104
+ inputs, labels = batch
105
+ input_ids = inputs['input_ids'].to(DEVICE)
106
+ attention_mask = inputs['attention_mask'].to(DEVICE)
107
+ labels = labels.to(DEVICE)
108
+ # Forward pass through the model
109
+ outputs = model(input_ids, attention_mask)
110
+ elif len(batch) == 3: # Text + Metadata models (inputs, metadata, labels)
111
+ inputs, metadata, labels = batch
112
+ input_ids = inputs['input_ids'].to(DEVICE)
113
+ attention_mask = inputs['attention_mask'].to(DEVICE)
114
+ metadata = metadata.to(DEVICE)
115
+ labels = labels.to(DEVICE)
116
+ # Forward pass through the hybrid model
117
+ outputs = model(input_ids, attention_mask, metadata)
118
+ else:
119
+ raise ValueError("Unsupported batch format. Expected 2 or 3 items in batch.")
120
+
121
+ loss = 0
122
+ # Calculate total loss by summing loss for each label column
123
+ # `outputs` is a list of logits, one for each label column
124
+ for i, output_logits in enumerate(outputs):
125
+ # `labels[:, i]` gets the true labels for the i-th label column
126
+ # `field_criterions[LABEL_COLUMNS[i]]` selects the appropriate loss function
127
+ loss += field_criterions[LABEL_COLUMNS[i]](output_logits, labels[:, i])
128
+
129
+ optimizer.zero_grad() # Clear previous gradients
130
+ loss.backward() # Backpropagation
131
+ optimizer.step() # Update model parameters
132
+ total_loss += loss.item() # Accumulate loss
133
+ tqdm_loader.set_postfix(loss=loss.item()) # Update progress bar with current batch loss
134
+
135
+ return total_loss / len(loader) # Return average loss for the epoch
136
+
137
+ def evaluate_model(model, loader):
138
+ """
139
+ Evaluates the given PyTorch model on a validation/test set.
140
+
141
+ Args:
142
+ model (torch.nn.Module): The model to evaluate.
143
+ loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
144
+
145
+ Returns:
146
+ tuple: A tuple containing:
147
+ - reports (dict): Classification reports (dict format) for each label column.
148
+ - truths (list): List of true label arrays for each label column.
149
+ - predictions (list): List of predicted label arrays for each label column.
150
+ """
151
+ model.eval() # Set the model to evaluation mode (disables dropout, batch norm updates, etc.)
152
+ # Initialize lists to store predictions and true labels for each output head
153
+ predictions = [[] for _ in range(len(LABEL_COLUMNS))]
154
+ truths = [[] for _ in range(len(LABEL_COLUMNS))]
155
+
156
+ with torch.no_grad(): # Disable gradient calculations during evaluation for efficiency
157
+ for batch in tqdm(loader, desc="Evaluation"):
158
+ if len(batch) == 2:
159
+ inputs, labels = batch
160
+ input_ids = inputs['input_ids'].to(DEVICE)
161
+ attention_mask = inputs['attention_mask'].to(DEVICE)
162
+ labels = labels.to(DEVICE)
163
+ outputs = model(input_ids, attention_mask)
164
+ elif len(batch) == 3:
165
+ inputs, metadata, labels = batch
166
+ input_ids = inputs['input_ids'].to(DEVICE)
167
+ attention_mask = inputs['attention_mask'].to(DEVICE)
168
+ metadata = metadata.to(DEVICE)
169
+ labels = labels.to(DEVICE)
170
+ outputs = model(input_ids, attention_mask, metadata)
171
+ else:
172
+ raise ValueError("Unsupported batch format.")
173
+
174
+ for i, output_logits in enumerate(outputs):
175
+ # Get the predicted class by taking the argmax of the logits
176
+ preds = torch.argmax(output_logits, dim=1).cpu().numpy()
177
+ predictions[i].extend(preds)
178
+ # Get the true labels for the current output head
179
+ truths[i].extend(labels[:, i].cpu().numpy())
180
+
181
+ reports = {}
182
+ # Generate classification report for each label column
183
+ for i, col in enumerate(LABEL_COLUMNS):
184
+ try:
185
+ # `zero_division=0` handles cases where a class might have no true or predicted samples
186
+ reports[col] = classification_report(truths[i], predictions[i], output_dict=True, zero_division=0)
187
+ except ValueError:
188
+ # Handle cases where a label might not appear in the validation set,
189
+ # which could cause classification_report to fail.
190
+ print(f"Warning: Could not generate classification report for {col}. Skipping.")
191
+ reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
192
+ return reports, truths, predictions
193
+
194
+ def summarize_metrics(metrics):
195
+ """
196
+ Summarizes classification reports into a readable Pandas DataFrame.
197
+
198
+ Args:
199
+ metrics (dict): Dictionary of classification reports, as returned by `evaluate_model`.
200
+
201
+ Returns:
202
+ pd.DataFrame: A DataFrame summarizing precision, recall, f1-score, accuracy, and support for each field.
203
+ """
204
+ summary = []
205
+ for field, report in metrics.items():
206
+ # Safely get metrics, defaulting to 0 if not present (e.g., for empty reports)
207
+ precision = report['weighted avg']['precision'] if 'weighted avg' in report else 0
208
+ recall = report['weighted avg']['recall'] if 'weighted avg' in report else 0
209
+ f1 = report['weighted avg']['f1-score'] if 'weighted avg' in report else 0
210
+ support = report['weighted avg']['support'] if 'weighted avg' in report else 0
211
+ accuracy = report['accuracy'] if 'accuracy' in report else 0 # Accuracy is usually top-level
212
+ summary.append({
213
+ "Field": field,
214
+ "Precision": precision,
215
+ "Recall": recall,
216
+ "F1-Score": f1,
217
+ "Accuracy": accuracy,
218
+ "Support": support
219
+ })
220
+ return pd.DataFrame(summary)
221
+
222
+ def save_model(model, model_name, save_format='pth'):
223
+ """
224
+ Saves the state dictionary of a PyTorch model.
225
+
226
+ Args:
227
+ model (torch.nn.Module): The trained PyTorch model.
228
+ model_name (str): A descriptive name for the model (used for filename).
229
+ save_format (str): Format to save the model in ('pth' for PyTorch models, 'pickle' for traditional ML models).
230
+ """
231
+ # Construct the save path dynamically relative to the project root
232
+ if save_format == 'pth':
233
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
234
+ torch.save(model.state_dict(), model_path)
235
+ elif save_format == 'pickle':
236
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
237
+ joblib.dump(model, model_path)
238
+ else:
239
+ raise ValueError(f"Unsupported save format: {save_format}")
240
+
241
+ print(f"Model saved to {model_path}")
242
+
243
+ def load_model_state(model, model_name, model_class, num_labels, metadata_dim=0):
244
+ """
245
+ Loads the state dictionary into a PyTorch model.
246
+
247
+ Args:
248
+ model (torch.nn.Module): An initialized model instance (architecture).
249
+ model_name (str): The name of the model to load.
250
+ model_class (class): The class of the model (e.g., BertMultiOutputModel).
251
+ num_labels (list): List of number of classes for each label.
252
+ metadata_dim (int): Dimensionality of metadata features, if applicable (default 0 for text-only).
253
+
254
+ Returns:
255
+ torch.nn.Module: The model with loaded state_dict, moved to the correct device, and set to eval mode.
256
+ """
257
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
258
+ if not os.path.exists(model_path):
259
+ print(f"Warning: Model file not found at {model_path}. Returning a newly initialized model instance.")
260
+ # Re-initialize the model if not found, to ensure it has the correct architecture
261
+ if metadata_dim > 0:
262
+ return model_class(num_labels, metadata_dim=metadata_dim).to(DEVICE)
263
+ else:
264
+ return model_class(num_labels).to(DEVICE)
265
+
266
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
267
+ model.to(DEVICE)
268
+ model.eval() # Set to evaluation mode after loading
269
+ print(f"Model loaded from {model_path}")
270
+ return model
271
+
272
+ def predict_probabilities(model, loader):
273
+ """
274
+ Generates prediction probabilities for each label for a given model.
275
+ This is used for confidence scoring and feeding into a voting ensemble.
276
+
277
+ Args:
278
+ model (torch.nn.Module): The trained PyTorch model.
279
+ loader (torch.utils.data.DataLoader): DataLoader for the data to predict on.
280
+
281
+ Returns:
282
+ list: A list of lists of numpy arrays. Each inner list corresponds to a label column,
283
+ containing the softmax probabilities for each sample for that label.
284
+ """
285
+ model.eval() # Set to evaluation mode
286
+ # List to store probabilities for each output head
287
+ all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
288
+
289
+ with torch.no_grad():
290
+ for batch in tqdm(loader, desc="Predicting Probabilities"):
291
+ # Unpack batch, ignoring labels as we only need inputs
292
+ if len(batch) == 2:
293
+ inputs, _ = batch
294
+ input_ids = inputs['input_ids'].to(DEVICE)
295
+ attention_mask = inputs['attention_mask'].to(DEVICE)
296
+ outputs = model(input_ids, attention_mask)
297
+ elif len(batch) == 3:
298
+ inputs, metadata, _ = batch
299
+ input_ids = inputs['input_ids'].to(DEVICE)
300
+ attention_mask = inputs['attention_mask'].to(DEVICE)
301
+ metadata = metadata.to(DEVICE)
302
+ outputs = model(input_ids, attention_mask, metadata)
303
+ else:
304
+ raise ValueError("Unsupported batch format.")
305
+
306
+ for i, out_logits in enumerate(outputs):
307
+ # Apply softmax to logits to get probabilities
308
+ probs = torch.softmax(out_logits, dim=1).cpu().numpy()
309
+ all_probabilities[i].extend(probs)
310
+ return all_probabilities