Upload 15 files
Browse files- BERT_model.pth +3 -0
- Dockerfile +54 -0
- app.py +504 -0
- config.py +69 -0
- dataset_utils.py +165 -0
- docker-compose.yml +18 -0
- label_encoders.pkl +3 -0
- models/__pycache__/bert_model.cpython-311.pyc +0 -0
- models/__pycache__/deberta_model.cpython-311.pyc +0 -0
- models/__pycache__/parallel_bert_deberta.cpython-311.pyc +0 -0
- models/__pycache__/roberta_model.cpython-311.pyc +0 -0
- models/__pycache__/text_and_metadata_model.cpython-311.pyc +0 -0
- models/bert_model.py +59 -0
- requirements.txt +13 -0
- train_utils.py +310 -0
BERT_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7050d02ac599ef72d7b0410a79a72537fb44d4ac66eb8a1dc719329c8c4b07b
|
3 |
+
size 438239057
|
Dockerfile
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Python 3.9 as base image
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# Set working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Install system dependencies
|
8 |
+
RUN apt-get update && apt-get install -y \
|
9 |
+
build-essential \
|
10 |
+
curl \
|
11 |
+
software-properties-common \
|
12 |
+
git \
|
13 |
+
&& rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
# Create a non-root user
|
16 |
+
RUN useradd -m -u 1000 appuser
|
17 |
+
|
18 |
+
# Copy requirements file
|
19 |
+
COPY requirements.txt .
|
20 |
+
|
21 |
+
# Install Python dependencies
|
22 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
23 |
+
|
24 |
+
# Create necessary directories with proper permissions
|
25 |
+
RUN mkdir -p /app/uploads \
|
26 |
+
/app/saved_models/bert \
|
27 |
+
/app/predictions \
|
28 |
+
/app/tokenizer \
|
29 |
+
&& chmod -R 777 /app/uploads \
|
30 |
+
/app/saved_models \
|
31 |
+
/app/predictions \
|
32 |
+
/app/tokenizer
|
33 |
+
|
34 |
+
# Switch to non-root user
|
35 |
+
USER appuser
|
36 |
+
|
37 |
+
# Copy the application code and utilities
|
38 |
+
COPY . /app/
|
39 |
+
COPY ../dataset_utils.py /app/
|
40 |
+
COPY ../train_utils.py /app/
|
41 |
+
COPY ../config.py /app/
|
42 |
+
COPY ../models/bert_model.py /app/models/
|
43 |
+
COPY ../label_encoders.pkl /app/
|
44 |
+
|
45 |
+
# Set environment variables
|
46 |
+
ENV PYTHONPATH=/app
|
47 |
+
ENV PYTHONUNBUFFERED=1
|
48 |
+
ENV PORT=7860
|
49 |
+
|
50 |
+
# Expose the port the app runs on
|
51 |
+
EXPOSE 7860
|
52 |
+
|
53 |
+
# Command to run the application
|
54 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
|
2 |
+
from fastapi.responses import FileResponse
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import Optional, Dict, Any, List
|
5 |
+
import uvicorn
|
6 |
+
import torch
|
7 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import asyncio
|
12 |
+
import pandas as pd
|
13 |
+
from datetime import datetime
|
14 |
+
import shutil
|
15 |
+
from pathlib import Path
|
16 |
+
from sklearn.model_selection import train_test_split
|
17 |
+
import zipfile
|
18 |
+
import io
|
19 |
+
import numpy as np
|
20 |
+
import sys
|
21 |
+
|
22 |
+
|
23 |
+
# Import existing utilities
|
24 |
+
from dataset_utils import (
|
25 |
+
ComplianceDataset,
|
26 |
+
ComplianceDatasetWithMetadata,
|
27 |
+
load_and_preprocess_data,
|
28 |
+
get_tokenizer,
|
29 |
+
save_label_encoders,
|
30 |
+
get_num_labels,
|
31 |
+
load_label_encoders
|
32 |
+
)
|
33 |
+
from train_utils import (
|
34 |
+
initialize_criterions,
|
35 |
+
train_model,
|
36 |
+
evaluate_model,
|
37 |
+
save_model,
|
38 |
+
summarize_metrics,
|
39 |
+
predict_probabilities
|
40 |
+
)
|
41 |
+
from models.bert_model import BertMultiOutputModel
|
42 |
+
from config import (
|
43 |
+
TEXT_COLUMN,
|
44 |
+
LABEL_COLUMNS,
|
45 |
+
DEVICE,
|
46 |
+
NUM_EPOCHS,
|
47 |
+
LEARNING_RATE,
|
48 |
+
MAX_LEN,
|
49 |
+
BATCH_SIZE,
|
50 |
+
METADATA_COLUMNS
|
51 |
+
)
|
52 |
+
|
53 |
+
# Configure logging
|
54 |
+
logging.basicConfig(level=logging.INFO)
|
55 |
+
logger = logging.getLogger(__name__)
|
56 |
+
|
57 |
+
app = FastAPI(title="BERT Compliance Predictor API")
|
58 |
+
|
59 |
+
# Create necessary directories
|
60 |
+
UPLOAD_DIR = Path("uploads")
|
61 |
+
MODEL_SAVE_DIR = Path("saved_models")
|
62 |
+
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
63 |
+
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
|
64 |
+
|
65 |
+
# Global variables to track training status
|
66 |
+
training_status = {
|
67 |
+
"is_training": False,
|
68 |
+
"current_epoch": 0,
|
69 |
+
"total_epochs": 0,
|
70 |
+
"current_loss": 0.0,
|
71 |
+
"start_time": None,
|
72 |
+
"end_time": None,
|
73 |
+
"status": "idle",
|
74 |
+
"metrics": None
|
75 |
+
}
|
76 |
+
|
77 |
+
# Load the model and tokenizer for prediction
|
78 |
+
model_path = "BERT_model.pth"
|
79 |
+
tokenizer = get_tokenizer('bert-base-uncased')
|
80 |
+
model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
|
81 |
+
if os.path.exists(model_path):
|
82 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
83 |
+
model.eval()
|
84 |
+
|
85 |
+
class TrainingConfig(BaseModel):
|
86 |
+
model_name: str = "bert-base-uncased"
|
87 |
+
batch_size: int = 8
|
88 |
+
learning_rate: float = 2e-5
|
89 |
+
num_epochs: int = 2
|
90 |
+
max_length: int = 128
|
91 |
+
test_size: float = 0.2
|
92 |
+
random_state: int = 42
|
93 |
+
|
94 |
+
class TrainingResponse(BaseModel):
|
95 |
+
message: str
|
96 |
+
training_id: str
|
97 |
+
status: str
|
98 |
+
download_url: Optional[str] = None
|
99 |
+
|
100 |
+
class ValidationResponse(BaseModel):
|
101 |
+
message: str
|
102 |
+
metrics: Dict[str, Any]
|
103 |
+
predictions: List[Dict[str, Any]]
|
104 |
+
|
105 |
+
class TransactionData(BaseModel):
|
106 |
+
Transaction_Id: str
|
107 |
+
Hit_Seq: int
|
108 |
+
Hit_Id_List: str
|
109 |
+
Origin: str
|
110 |
+
Designation: str
|
111 |
+
Keywords: str
|
112 |
+
Name: str
|
113 |
+
SWIFT_Tag: str
|
114 |
+
Currency: str
|
115 |
+
Entity: str
|
116 |
+
Message: str
|
117 |
+
City: str
|
118 |
+
Country: str
|
119 |
+
State: str
|
120 |
+
Hit_Type: str
|
121 |
+
Record_Matching_String: str
|
122 |
+
WatchList_Match_String: str
|
123 |
+
Payment_Sender_Name: Optional[str] = ""
|
124 |
+
Payment_Reciever_Name: Optional[str] = ""
|
125 |
+
Swift_Message_Type: str
|
126 |
+
Text_Sanction_Data: str
|
127 |
+
Matched_Sanctioned_Entity: str
|
128 |
+
Is_Match: int
|
129 |
+
Red_Flag_Reason: str
|
130 |
+
Risk_Level: str
|
131 |
+
Risk_Score: float
|
132 |
+
Risk_Score_Description: str
|
133 |
+
CDD_Level: str
|
134 |
+
PEP_Status: str
|
135 |
+
Value_Date: str
|
136 |
+
Last_Review_Date: str
|
137 |
+
Next_Review_Date: str
|
138 |
+
Sanction_Description: str
|
139 |
+
Checker_Notes: str
|
140 |
+
Sanction_Context: str
|
141 |
+
Maker_Action: str
|
142 |
+
Customer_ID: int
|
143 |
+
Customer_Type: str
|
144 |
+
Industry: str
|
145 |
+
Transaction_Date_Time: str
|
146 |
+
Transaction_Type: str
|
147 |
+
Transaction_Channel: str
|
148 |
+
Originating_Bank: str
|
149 |
+
Beneficiary_Bank: str
|
150 |
+
Geographic_Origin: str
|
151 |
+
Geographic_Destination: str
|
152 |
+
Match_Score: float
|
153 |
+
Match_Type: str
|
154 |
+
Sanctions_List_Version: str
|
155 |
+
Screening_Date_Time: str
|
156 |
+
Risk_Category: str
|
157 |
+
Risk_Drivers: str
|
158 |
+
Alert_Status: str
|
159 |
+
Investigation_Outcome: str
|
160 |
+
Case_Owner_Analyst: str
|
161 |
+
Escalation_Level: str
|
162 |
+
Escalation_Date: str
|
163 |
+
Regulatory_Reporting_Flags: bool
|
164 |
+
Audit_Trail_Timestamp: str
|
165 |
+
Source_Of_Funds: str
|
166 |
+
Purpose_Of_Transaction: str
|
167 |
+
Beneficial_Owner: str
|
168 |
+
Sanctions_Exposure_History: bool
|
169 |
+
|
170 |
+
class PredictionRequest(BaseModel):
|
171 |
+
transaction_data: TransactionData
|
172 |
+
|
173 |
+
@app.get("/")
|
174 |
+
async def root():
|
175 |
+
return {"message": "BERT Compliance Predictor API"}
|
176 |
+
|
177 |
+
@app.get("/health")
|
178 |
+
async def health_check():
|
179 |
+
return {"status": "healthy"}
|
180 |
+
|
181 |
+
@app.get("/training-status")
|
182 |
+
async def get_training_status():
|
183 |
+
return training_status
|
184 |
+
|
185 |
+
@app.post("/upload")
|
186 |
+
async def upload_file(file: UploadFile = File(...)):
|
187 |
+
"""Upload a CSV file for training or validation"""
|
188 |
+
if not file.filename.endswith('.csv'):
|
189 |
+
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
190 |
+
|
191 |
+
file_path = UPLOAD_DIR / file.filename
|
192 |
+
with file_path.open("wb") as buffer:
|
193 |
+
shutil.copyfileobj(file.file, buffer)
|
194 |
+
|
195 |
+
return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
|
196 |
+
|
197 |
+
@app.post("/bert/train", response_model=TrainingResponse)
|
198 |
+
async def start_training(
|
199 |
+
config: TrainingConfig,
|
200 |
+
background_tasks: BackgroundTasks,
|
201 |
+
file_path: str
|
202 |
+
):
|
203 |
+
if training_status["is_training"]:
|
204 |
+
raise HTTPException(status_code=400, detail="Training is already in progress")
|
205 |
+
|
206 |
+
if not os.path.exists(file_path):
|
207 |
+
raise HTTPException(status_code=404, detail="Training file not found")
|
208 |
+
|
209 |
+
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
210 |
+
|
211 |
+
training_status.update({
|
212 |
+
"is_training": True,
|
213 |
+
"current_epoch": 0,
|
214 |
+
"total_epochs": config.num_epochs,
|
215 |
+
"start_time": datetime.now().isoformat(),
|
216 |
+
"status": "starting"
|
217 |
+
})
|
218 |
+
|
219 |
+
background_tasks.add_task(train_model_task, config, file_path, training_id)
|
220 |
+
|
221 |
+
download_url = f"/bert/download-model/{training_id}"
|
222 |
+
|
223 |
+
return TrainingResponse(
|
224 |
+
message="Training started successfully",
|
225 |
+
training_id=training_id,
|
226 |
+
status="started",
|
227 |
+
download_url=download_url
|
228 |
+
)
|
229 |
+
|
230 |
+
@app.post("/bert/validate")
|
231 |
+
async def validate_model(
|
232 |
+
file: UploadFile = File(...),
|
233 |
+
model_name: str = "bert_model_latest"
|
234 |
+
):
|
235 |
+
"""Validate a BERT model on uploaded data"""
|
236 |
+
if not file.filename.endswith('.csv'):
|
237 |
+
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
238 |
+
|
239 |
+
try:
|
240 |
+
file_path = UPLOAD_DIR / file.filename
|
241 |
+
with file_path.open("wb") as buffer:
|
242 |
+
shutil.copyfileobj(file.file, buffer)
|
243 |
+
|
244 |
+
data_df, label_encoders = load_and_preprocess_data(str(file_path))
|
245 |
+
|
246 |
+
model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
|
247 |
+
if not model_path.exists():
|
248 |
+
raise HTTPException(status_code=404, detail="BERT model file not found")
|
249 |
+
|
250 |
+
num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
|
251 |
+
metadata_df = data_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df.columns for col in METADATA_COLUMNS) else None
|
252 |
+
|
253 |
+
if metadata_df is not None:
|
254 |
+
metadata_dim = metadata_df.shape[1]
|
255 |
+
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
|
256 |
+
else:
|
257 |
+
model = BertMultiOutputModel(num_labels_list).to(DEVICE)
|
258 |
+
|
259 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
260 |
+
model.eval()
|
261 |
+
|
262 |
+
texts = data_df[TEXT_COLUMN]
|
263 |
+
labels_array = data_df[LABEL_COLUMNS].values
|
264 |
+
tokenizer = get_tokenizer("bert-base-uncased")
|
265 |
+
|
266 |
+
if metadata_df is not None:
|
267 |
+
dataset = ComplianceDatasetWithMetadata(
|
268 |
+
texts.tolist(),
|
269 |
+
metadata_df.values,
|
270 |
+
labels_array,
|
271 |
+
tokenizer,
|
272 |
+
MAX_LEN
|
273 |
+
)
|
274 |
+
else:
|
275 |
+
dataset = ComplianceDataset(
|
276 |
+
texts.tolist(),
|
277 |
+
labels_array,
|
278 |
+
tokenizer,
|
279 |
+
MAX_LEN
|
280 |
+
)
|
281 |
+
|
282 |
+
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
|
283 |
+
metrics, y_true_list, y_pred_list = evaluate_model(model, dataloader)
|
284 |
+
summary_metrics = summarize_metrics(metrics).to_dict()
|
285 |
+
|
286 |
+
all_probs = predict_probabilities(model, dataloader)
|
287 |
+
|
288 |
+
predictions = []
|
289 |
+
for i, (true_labels, pred_labels) in enumerate(zip(y_true_list, y_pred_list)):
|
290 |
+
field = LABEL_COLUMNS[i]
|
291 |
+
label_encoder = label_encoders[field]
|
292 |
+
true_labels_orig = label_encoder.inverse_transform(true_labels)
|
293 |
+
pred_labels_orig = label_encoder.inverse_transform(pred_labels)
|
294 |
+
|
295 |
+
for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
|
296 |
+
predictions.append({
|
297 |
+
"field": field,
|
298 |
+
"true_label": true,
|
299 |
+
"predicted_label": pred,
|
300 |
+
"probabilities": probs.tolist()
|
301 |
+
})
|
302 |
+
|
303 |
+
return ValidationResponse(
|
304 |
+
message="Validation completed successfully",
|
305 |
+
metrics=summary_metrics,
|
306 |
+
predictions=predictions
|
307 |
+
)
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
logger.error(f"Validation failed: {str(e)}")
|
311 |
+
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
|
312 |
+
finally:
|
313 |
+
if os.path.exists(file_path):
|
314 |
+
os.remove(file_path)
|
315 |
+
|
316 |
+
@app.post("/bert/predict")
|
317 |
+
async def predict(request: PredictionRequest):
|
318 |
+
"""Make predictions on a single transaction"""
|
319 |
+
try:
|
320 |
+
input_data = pd.DataFrame([request.transaction_data.dict()])
|
321 |
+
|
322 |
+
text_input = f"""
|
323 |
+
Transaction ID: {input_data['Transaction_Id'].iloc[0]}
|
324 |
+
Origin: {input_data['Origin'].iloc[0]}
|
325 |
+
Designation: {input_data['Designation'].iloc[0]}
|
326 |
+
Keywords: {input_data['Keywords'].iloc[0]}
|
327 |
+
Name: {input_data['Name'].iloc[0]}
|
328 |
+
SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]}
|
329 |
+
Currency: {input_data['Currency'].iloc[0]}
|
330 |
+
Entity: {input_data['Entity'].iloc[0]}
|
331 |
+
Message: {input_data['Message'].iloc[0]}
|
332 |
+
City: {input_data['City'].iloc[0]}
|
333 |
+
Country: {input_data['Country'].iloc[0]}
|
334 |
+
State: {input_data['State'].iloc[0]}
|
335 |
+
Hit Type: {input_data['Hit_Type'].iloc[0]}
|
336 |
+
Record Matching String: {input_data['Record_Matching_String'].iloc[0]}
|
337 |
+
WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]}
|
338 |
+
Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]}
|
339 |
+
Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]}
|
340 |
+
Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]}
|
341 |
+
Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]}
|
342 |
+
Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]}
|
343 |
+
Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]}
|
344 |
+
Risk Level: {input_data['Risk_Level'].iloc[0]}
|
345 |
+
Risk Score: {input_data['Risk_Score'].iloc[0]}
|
346 |
+
CDD Level: {input_data['CDD_Level'].iloc[0]}
|
347 |
+
PEP Status: {input_data['PEP_Status'].iloc[0]}
|
348 |
+
Sanction Description: {input_data['Sanction_Description'].iloc[0]}
|
349 |
+
Checker Notes: {input_data['Checker_Notes'].iloc[0]}
|
350 |
+
Sanction Context: {input_data['Sanction_Context'].iloc[0]}
|
351 |
+
Maker Action: {input_data['Maker_Action'].iloc[0]}
|
352 |
+
Customer Type: {input_data['Customer_Type'].iloc[0]}
|
353 |
+
Industry: {input_data['Industry'].iloc[0]}
|
354 |
+
Transaction Type: {input_data['Transaction_Type'].iloc[0]}
|
355 |
+
Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
|
356 |
+
Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
|
357 |
+
Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
|
358 |
+
Risk Category: {input_data['Risk_Category'].iloc[0]}
|
359 |
+
Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
|
360 |
+
Alert Status: {input_data['Alert_Status'].iloc[0]}
|
361 |
+
Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
|
362 |
+
Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
|
363 |
+
Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
|
364 |
+
Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
|
365 |
+
"""
|
366 |
+
|
367 |
+
dataset = ComplianceDataset(
|
368 |
+
texts=[text_input],
|
369 |
+
labels=[[0] * len(LABEL_COLUMNS)],
|
370 |
+
tokenizer=tokenizer,
|
371 |
+
max_len=MAX_LEN
|
372 |
+
)
|
373 |
+
|
374 |
+
loader = DataLoader(dataset, batch_size=1, shuffle=False)
|
375 |
+
all_probabilities = predict_probabilities(model, loader)
|
376 |
+
|
377 |
+
label_encoders = load_label_encoders()
|
378 |
+
|
379 |
+
response = {}
|
380 |
+
for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
|
381 |
+
pred = np.argmax(probs[0])
|
382 |
+
decoded_pred = label_encoders[col].inverse_transform([pred])[0]
|
383 |
+
|
384 |
+
class_probs = {
|
385 |
+
label: float(probs[0][j])
|
386 |
+
for j, label in enumerate(label_encoders[col].classes_)
|
387 |
+
}
|
388 |
+
|
389 |
+
response[col] = {
|
390 |
+
"prediction": decoded_pred,
|
391 |
+
"probabilities": class_probs
|
392 |
+
}
|
393 |
+
|
394 |
+
return response
|
395 |
+
|
396 |
+
except Exception as e:
|
397 |
+
raise HTTPException(status_code=500, detail=str(e))
|
398 |
+
|
399 |
+
@app.get("/bert/download-model/{model_id}")
|
400 |
+
async def download_model(model_id: str):
|
401 |
+
"""Download a trained model"""
|
402 |
+
model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
|
403 |
+
if not model_path.exists():
|
404 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
405 |
+
|
406 |
+
return FileResponse(
|
407 |
+
path=model_path,
|
408 |
+
filename=f"bert_model_{model_id}.pth",
|
409 |
+
media_type="application/octet-stream"
|
410 |
+
)
|
411 |
+
|
412 |
+
async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
|
413 |
+
try:
|
414 |
+
data_df_original, label_encoders = load_and_preprocess_data(file_path)
|
415 |
+
save_label_encoders(label_encoders)
|
416 |
+
|
417 |
+
train_df, val_df = train_test_split(
|
418 |
+
data_df_original,
|
419 |
+
test_size=config.test_size,
|
420 |
+
random_state=config.random_state,
|
421 |
+
stratify=data_df_original[LABEL_COLUMNS[0]]
|
422 |
+
)
|
423 |
+
|
424 |
+
train_texts = train_df[TEXT_COLUMN]
|
425 |
+
val_texts = val_df[TEXT_COLUMN]
|
426 |
+
train_labels_array = train_df[LABEL_COLUMNS].values
|
427 |
+
val_labels_array = val_df[LABEL_COLUMNS].values
|
428 |
+
|
429 |
+
train_metadata_df = train_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in train_df.columns for col in METADATA_COLUMNS) else None
|
430 |
+
val_metadata_df = val_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in val_df.columns for col in METADATA_COLUMNS) else None
|
431 |
+
|
432 |
+
num_labels_list = get_num_labels(label_encoders)
|
433 |
+
tokenizer = get_tokenizer(config.model_name)
|
434 |
+
|
435 |
+
if train_metadata_df is not None and val_metadata_df is not None:
|
436 |
+
metadata_dim = train_metadata_df.shape[1]
|
437 |
+
train_dataset = ComplianceDatasetWithMetadata(
|
438 |
+
train_texts.tolist(),
|
439 |
+
train_metadata_df.values,
|
440 |
+
train_labels_array,
|
441 |
+
tokenizer,
|
442 |
+
config.max_length
|
443 |
+
)
|
444 |
+
val_dataset = ComplianceDatasetWithMetadata(
|
445 |
+
val_texts.tolist(),
|
446 |
+
val_metadata_df.values,
|
447 |
+
val_labels_array,
|
448 |
+
tokenizer,
|
449 |
+
config.max_length
|
450 |
+
)
|
451 |
+
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
|
452 |
+
else:
|
453 |
+
train_dataset = ComplianceDataset(
|
454 |
+
train_texts.tolist(),
|
455 |
+
train_labels_array,
|
456 |
+
tokenizer,
|
457 |
+
config.max_length
|
458 |
+
)
|
459 |
+
val_dataset = ComplianceDataset(
|
460 |
+
val_texts.tolist(),
|
461 |
+
val_labels_array,
|
462 |
+
tokenizer,
|
463 |
+
config.max_length
|
464 |
+
)
|
465 |
+
model = BertMultiOutputModel(num_labels_list).to(DEVICE)
|
466 |
+
|
467 |
+
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
|
468 |
+
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
|
469 |
+
|
470 |
+
criterions = initialize_criterions(num_labels_list)
|
471 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
472 |
+
|
473 |
+
best_val_loss = float('inf')
|
474 |
+
for epoch in range(config.num_epochs):
|
475 |
+
training_status["current_epoch"] = epoch + 1
|
476 |
+
|
477 |
+
train_loss = train_model(model, train_loader, criterions, optimizer)
|
478 |
+
val_metrics, _, _ = evaluate_model(model, val_loader)
|
479 |
+
|
480 |
+
training_status["current_loss"] = train_loss
|
481 |
+
|
482 |
+
if val_metrics["loss"] < best_val_loss:
|
483 |
+
best_val_loss = val_metrics["loss"]
|
484 |
+
save_model(model, training_id)
|
485 |
+
|
486 |
+
training_status.update({
|
487 |
+
"is_training": False,
|
488 |
+
"end_time": datetime.now().isoformat(),
|
489 |
+
"status": "completed",
|
490 |
+
"metrics": summarize_metrics(val_metrics).to_dict()
|
491 |
+
})
|
492 |
+
|
493 |
+
except Exception as e:
|
494 |
+
logger.error(f"Training failed: {str(e)}")
|
495 |
+
training_status.update({
|
496 |
+
"is_training": False,
|
497 |
+
"end_time": datetime.now().isoformat(),
|
498 |
+
"status": "failed",
|
499 |
+
"error": str(e)
|
500 |
+
})
|
501 |
+
|
502 |
+
if __name__ == "__main__":
|
503 |
+
port = int(os.environ.get("PORT", 7860))
|
504 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
config.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
|
6 |
+
# --- Paths ---
|
7 |
+
# Adjust DATA_PATH to your actual data location
|
8 |
+
DATA_PATH = './data/synthetic_transactions_samples_5000.csv'
|
9 |
+
TOKENIZER_PATH = './tokenizer/'
|
10 |
+
LABEL_ENCODERS_PATH = './label_encoders.pkl'
|
11 |
+
MODEL_SAVE_DIR = './saved_models/'
|
12 |
+
PREDICTIONS_SAVE_DIR = './predictions/' # To save predictions for voting ensemble
|
13 |
+
|
14 |
+
# --- Data Columns ---
|
15 |
+
TEXT_COLUMN = "Sanction_Context"
|
16 |
+
# Define all your target label columns
|
17 |
+
LABEL_COLUMNS = [
|
18 |
+
"Red_Flag_Reason",
|
19 |
+
"Maker_Action",
|
20 |
+
"Escalation_Level",
|
21 |
+
"Risk_Category",
|
22 |
+
"Risk_Drivers",
|
23 |
+
"Investigation_Outcome"
|
24 |
+
]
|
25 |
+
# Example metadata columns. Add actual numerical/categorical metadata if available in your CSV.
|
26 |
+
# For now, it's an empty list. If you add metadata, ensure these columns exist and are numeric or can be encoded.
|
27 |
+
METADATA_COLUMNS = [] # e.g., ["Risk_Score", "Transaction_Amount"]
|
28 |
+
|
29 |
+
# --- Model Hyperparameters ---
|
30 |
+
MAX_LEN = 128 # Maximum sequence length for transformer tokenizers
|
31 |
+
BATCH_SIZE = 16 # Batch size for training and evaluation
|
32 |
+
LEARNING_RATE = 2e-5 # Learning rate for AdamW optimizer
|
33 |
+
NUM_EPOCHS = 3 # Number of training epochs. Adjust based on convergence.
|
34 |
+
DROPOUT_RATE = 0.3 # Dropout rate for regularization
|
35 |
+
|
36 |
+
# --- Device Configuration ---
|
37 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
|
39 |
+
# --- Specific Model Configurations ---
|
40 |
+
BERT_MODEL_NAME = 'bert-base-uncased'
|
41 |
+
ROBERTA_MODEL_NAME = 'roberta-base'
|
42 |
+
DEBERTA_MODEL_NAME = 'microsoft/deberta-base'
|
43 |
+
|
44 |
+
# TF-IDF
|
45 |
+
TFIDF_MAX_FEATURES = 5000 # Max features for TF-IDF vectorizer
|
46 |
+
|
47 |
+
# --- Field-Specific Strategy (Conceptual) ---
|
48 |
+
# This dictionary provides conceptual strategies for enhancing specific fields.
|
49 |
+
# Actual implementation requires adapting the models (e.g., custom loss functions, metadata integration).
|
50 |
+
FIELD_STRATEGIES = {
|
51 |
+
"Maker_Action": {
|
52 |
+
"loss": "focal_loss", # Requires custom Focal Loss implementation
|
53 |
+
"enhancements": ["action_templates", "context_prompt_tuning"] # Advanced NLP concepts
|
54 |
+
},
|
55 |
+
"Risk_Category": {
|
56 |
+
"enhancements": ["numerical_metadata", "transaction_patterns"] # Integrate METADATA_COLUMNS
|
57 |
+
},
|
58 |
+
"Escalation_Level": {
|
59 |
+
"enhancements": ["class_balancing", "policy_keyword_patterns"] # Handled by class weights/metadata
|
60 |
+
},
|
61 |
+
"Investigation_Outcome": {
|
62 |
+
"type": "classification_or_generation" # If generation, T5/BART would be needed.
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
# Ensure model save and predictions directories exist
|
67 |
+
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
68 |
+
os.makedirs(PREDICTIONS_SAVE_DIR, exist_ok=True)
|
69 |
+
os.makedirs(TOKENIZER_PATH, exist_ok=True)
|
dataset_utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dataset_utils.py
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from sklearn.preprocessing import LabelEncoder
|
7 |
+
from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
|
8 |
+
import pickle
|
9 |
+
import os
|
10 |
+
|
11 |
+
from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
|
12 |
+
|
13 |
+
class ComplianceDataset(Dataset):
|
14 |
+
"""
|
15 |
+
Custom Dataset class for handling text and multi-output labels for PyTorch models.
|
16 |
+
"""
|
17 |
+
def __init__(self, texts, labels, tokenizer, max_len):
|
18 |
+
self.texts = texts
|
19 |
+
self.labels = labels
|
20 |
+
self.tokenizer = tokenizer
|
21 |
+
self.max_len = max_len
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
"""Returns the total number of samples in the dataset."""
|
25 |
+
return len(self.texts)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
"""
|
29 |
+
Retrieves a sample from the dataset at the given index.
|
30 |
+
Tokenizes the text and converts labels to a PyTorch tensor.
|
31 |
+
"""
|
32 |
+
text = str(self.texts[idx])
|
33 |
+
# Tokenize the text, padding to max_length and truncating if longer.
|
34 |
+
# return_tensors="pt" ensures PyTorch tensors are returned.
|
35 |
+
inputs = self.tokenizer(
|
36 |
+
text,
|
37 |
+
padding='max_length',
|
38 |
+
truncation=True,
|
39 |
+
max_length=self.max_len,
|
40 |
+
return_tensors="pt"
|
41 |
+
)
|
42 |
+
# Squeeze removes the batch dimension (which is 1 here because we process one sample at a time)
|
43 |
+
inputs = {key: val.squeeze(0) for key, val in inputs.items()}
|
44 |
+
# Convert labels to a PyTorch long tensor
|
45 |
+
labels = torch.tensor(self.labels[idx], dtype=torch.long)
|
46 |
+
return inputs, labels
|
47 |
+
|
48 |
+
class ComplianceDatasetWithMetadata(Dataset):
|
49 |
+
"""
|
50 |
+
Custom Dataset class for handling text, additional numerical metadata, and multi-output labels.
|
51 |
+
Used for hybrid models combining text and tabular features.
|
52 |
+
"""
|
53 |
+
def __init__(self, texts, metadata, labels, tokenizer, max_len):
|
54 |
+
self.texts = texts
|
55 |
+
self.metadata = metadata # Expects metadata as a NumPy array or list of lists
|
56 |
+
self.labels = labels
|
57 |
+
self.tokenizer = tokenizer
|
58 |
+
self.max_len = max_len
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
"""Returns the total number of samples in the dataset."""
|
62 |
+
return len(self.texts)
|
63 |
+
|
64 |
+
def __getitem__(self, idx):
|
65 |
+
"""
|
66 |
+
Retrieves a sample, its metadata, and labels from the dataset at the given index.
|
67 |
+
Tokenizes text, converts metadata and labels to PyTorch tensors.
|
68 |
+
"""
|
69 |
+
text = str(self.texts[idx])
|
70 |
+
inputs = self.tokenizer(
|
71 |
+
text,
|
72 |
+
padding='max_length',
|
73 |
+
truncation=True,
|
74 |
+
max_length=self.max_len,
|
75 |
+
return_tensors="pt"
|
76 |
+
)
|
77 |
+
inputs = {key: val.squeeze(0) for key, val in inputs.items()}
|
78 |
+
# Convert metadata for the current sample to a float tensor
|
79 |
+
metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
|
80 |
+
labels = torch.tensor(self.labels[idx], dtype=torch.long)
|
81 |
+
return inputs, metadata, labels
|
82 |
+
|
83 |
+
def load_and_preprocess_data(data_path):
|
84 |
+
"""
|
85 |
+
Loads data from a CSV, fills missing values, and encodes categorical labels.
|
86 |
+
Also handles converting specified METADATA_COLUMNS to numeric.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
data_path (str): Path to the CSV data file.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
tuple: A tuple containing:
|
93 |
+
- data (pd.DataFrame): The preprocessed DataFrame.
|
94 |
+
- label_encoders (dict): A dictionary of LabelEncoder objects for each label column.
|
95 |
+
"""
|
96 |
+
data = pd.read_csv(data_path)
|
97 |
+
data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown"
|
98 |
+
|
99 |
+
# Convert metadata columns to numeric, coercing errors and filling NaNs with 0
|
100 |
+
# This ensures metadata is suitable for neural networks.
|
101 |
+
for col in METADATA_COLUMNS:
|
102 |
+
if col in data.columns:
|
103 |
+
data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value
|
104 |
+
|
105 |
+
label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
|
106 |
+
for col in LABEL_COLUMNS:
|
107 |
+
# Fit and transform each label column using its respective LabelEncoder
|
108 |
+
data[col] = label_encoders[col].fit_transform(data[col])
|
109 |
+
return data, label_encoders
|
110 |
+
|
111 |
+
def get_tokenizer(model_name):
|
112 |
+
"""
|
113 |
+
Returns the appropriate Hugging Face tokenizer based on the model name.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased').
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
transformers.PreTrainedTokenizer: The initialized tokenizer.
|
120 |
+
"""
|
121 |
+
if "bert" in model_name.lower():
|
122 |
+
return BertTokenizer.from_pretrained(model_name)
|
123 |
+
elif "roberta" in model_name.lower():
|
124 |
+
return RobertaTokenizer.from_pretrained(model_name)
|
125 |
+
elif "deberta" in model_name.lower():
|
126 |
+
return DebertaTokenizer.from_pretrained(model_name)
|
127 |
+
else:
|
128 |
+
raise ValueError(f"Unsupported tokenizer for model: {model_name}")
|
129 |
+
|
130 |
+
def save_label_encoders(label_encoders):
|
131 |
+
"""
|
132 |
+
Saves a dictionary of label encoders to a pickle file.
|
133 |
+
This is crucial for decoding predictions back to original labels.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
label_encoders (dict): Dictionary of LabelEncoder objects.
|
137 |
+
"""
|
138 |
+
with open(LABEL_ENCODERS_PATH, "wb") as f:
|
139 |
+
pickle.dump(label_encoders, f)
|
140 |
+
print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")
|
141 |
+
|
142 |
+
def load_label_encoders():
|
143 |
+
"""
|
144 |
+
Loads a dictionary of label encoders from a pickle file.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
dict: Loaded dictionary of LabelEncoder objects.
|
148 |
+
"""
|
149 |
+
with open(LABEL_ENCODERS_PATH, "rb") as f:
|
150 |
+
return pickle.load(f)
|
151 |
+
print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}")
|
152 |
+
|
153 |
+
|
154 |
+
def get_num_labels(label_encoders):
|
155 |
+
"""
|
156 |
+
Returns a list containing the number of unique classes for each label column.
|
157 |
+
This list is used to define the output dimensions of the model's classification heads.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
label_encoders (dict): Dictionary of LabelEncoder objects.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
list: A list of integers, where each integer is the number of classes for a label.
|
164 |
+
"""
|
165 |
+
return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
|
docker-compose.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
bert-api:
|
5 |
+
build: .
|
6 |
+
ports:
|
7 |
+
- "7860:7860"
|
8 |
+
volumes:
|
9 |
+
- ../saved_models:/app/saved_models
|
10 |
+
- ../tokenizer:/app/tokenizer
|
11 |
+
- ../predictions:/app/predictions
|
12 |
+
- ../label_encoders.pkl:/app/label_encoders.pkl
|
13 |
+
- ../.cache:/app/.cache
|
14 |
+
environment:
|
15 |
+
- PYTHONUNBUFFERED=1
|
16 |
+
- TRANSFORMERS_CACHE=/app/.cache
|
17 |
+
- PORT=7860
|
18 |
+
restart: unless-stopped
|
label_encoders.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c336fd07858af76d40c7200de1a769099abeec25d4f48b999351318680d4e4d6
|
3 |
+
size 2047
|
models/__pycache__/bert_model.cpython-311.pyc
ADDED
Binary file (3.29 kB). View file
|
|
models/__pycache__/deberta_model.cpython-311.pyc
ADDED
Binary file (3.15 kB). View file
|
|
models/__pycache__/parallel_bert_deberta.cpython-311.pyc
ADDED
Binary file (6.45 kB). View file
|
|
models/__pycache__/roberta_model.cpython-311.pyc
ADDED
Binary file (3.18 kB). View file
|
|
models/__pycache__/text_and_metadata_model.cpython-311.pyc
ADDED
Binary file (4.09 kB). View file
|
|
models/bert_model.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/bert_model.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import BertModel
|
6 |
+
from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME from config
|
7 |
+
|
8 |
+
class BertMultiOutputModel(nn.Module):
|
9 |
+
"""
|
10 |
+
BERT-based model for multi-output classification.
|
11 |
+
It uses a pre-trained BERT model as its backbone and adds a dropout layer
|
12 |
+
followed by separate linear classification heads for each target label.
|
13 |
+
"""
|
14 |
+
# Statically set tokenizer name for easy access in main.py
|
15 |
+
tokenizer_name = BERT_MODEL_NAME
|
16 |
+
|
17 |
+
def __init__(self, num_labels):
|
18 |
+
"""
|
19 |
+
Initializes the BertMultiOutputModel.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
num_labels (list): A list where each element is the number of classes
|
23 |
+
for a corresponding label column.
|
24 |
+
"""
|
25 |
+
super(BertMultiOutputModel, self).__init__()
|
26 |
+
# Load the pre-trained BERT model.
|
27 |
+
# BertModel provides contextual embeddings and a pooled output for classification.
|
28 |
+
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
|
29 |
+
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
|
30 |
+
|
31 |
+
# Create a list of classification heads, one for each label column.
|
32 |
+
# Each head is a linear layer mapping BERT's pooled output size to the number of classes for that label.
|
33 |
+
self.classifiers = nn.ModuleList([
|
34 |
+
nn.Linear(self.bert.config.hidden_size, n_classes) for n_classes in num_labels
|
35 |
+
])
|
36 |
+
|
37 |
+
def forward(self, input_ids, attention_mask):
|
38 |
+
"""
|
39 |
+
Performs the forward pass of the model.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
input_ids (torch.Tensor): Tensor of token IDs (from tokenizer).
|
43 |
+
attention_mask (torch.Tensor): Tensor indicating attention (from tokenizer).
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
list: A list of logit tensors, one for each classification head.
|
47 |
+
Each tensor has shape (batch_size, num_classes_for_that_label).
|
48 |
+
"""
|
49 |
+
# Pass input_ids and attention_mask through BERT.
|
50 |
+
# .pooler_output typically represents the hidden state of the [CLS] token,
|
51 |
+
# processed through a linear layer and tanh activation, often used for classification.
|
52 |
+
pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
|
53 |
+
|
54 |
+
# Apply dropout for regularization
|
55 |
+
pooled_output = self.dropout(pooled_output)
|
56 |
+
|
57 |
+
# Pass the pooled output through each classification head.
|
58 |
+
# The result is a list of logits (raw scores before softmax/sigmoid) for each label.
|
59 |
+
return [classifier(pooled_output) for classifier in self.classifiers]
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.104.1
|
2 |
+
uvicorn==0.24.0
|
3 |
+
pydantic==2.4.2
|
4 |
+
torch==2.1.0
|
5 |
+
transformers==4.35.0
|
6 |
+
pandas==2.1.2
|
7 |
+
numpy==1.24.3
|
8 |
+
scikit-learn==1.3.2
|
9 |
+
python-multipart==0.0.6
|
10 |
+
python-jose==3.3.0
|
11 |
+
passlib==1.7.4
|
12 |
+
bcrypt==4.0.1
|
13 |
+
python-dotenv==1.0.0
|
train_utils.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train_utils.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.optim import AdamW
|
6 |
+
from sklearn.metrics import classification_report
|
7 |
+
from sklearn.utils.class_weight import compute_class_weight
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
import pandas as pd
|
11 |
+
import os
|
12 |
+
import joblib
|
13 |
+
|
14 |
+
from config import DEVICE, LABEL_COLUMNS, NUM_EPOCHS, LEARNING_RATE, MODEL_SAVE_DIR
|
15 |
+
|
16 |
+
def get_class_weights(data_df, field, label_encoder):
|
17 |
+
"""
|
18 |
+
Computes balanced class weights for a given target field.
|
19 |
+
These weights can be used in the loss function to mitigate class imbalance.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
data_df (pd.DataFrame): The DataFrame containing the original (unencoded) label data.
|
23 |
+
field (str): The name of the label column for which to compute weights.
|
24 |
+
label_encoder (sklearn.preprocessing.LabelEncoder): The label encoder fitted for this field.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: A tensor of class weights for the specified field.
|
28 |
+
"""
|
29 |
+
# Get the original labels for the specified field
|
30 |
+
y = data_df[field].values
|
31 |
+
# Use label_encoder.transform directly - it will handle unseen labels
|
32 |
+
try:
|
33 |
+
y_encoded = label_encoder.transform(y)
|
34 |
+
except ValueError as e:
|
35 |
+
print(f"Warning: {e}")
|
36 |
+
print(f"Using only seen labels for class weights calculation")
|
37 |
+
# Filter out unseen labels
|
38 |
+
seen_labels = set(label_encoder.classes_)
|
39 |
+
y_filtered = [label for label in y if label in seen_labels]
|
40 |
+
y_encoded = label_encoder.transform(y_filtered)
|
41 |
+
|
42 |
+
# Ensure y_encoded is integer type
|
43 |
+
y_encoded = y_encoded.astype(int)
|
44 |
+
|
45 |
+
# Initialize counts for all possible classes
|
46 |
+
n_classes = len(label_encoder.classes_)
|
47 |
+
class_counts = np.zeros(n_classes, dtype=int)
|
48 |
+
|
49 |
+
# Count occurrences of each class
|
50 |
+
for i in range(n_classes):
|
51 |
+
class_counts[i] = np.sum(y_encoded == i)
|
52 |
+
|
53 |
+
# Calculate weights for all classes
|
54 |
+
total_samples = len(y_encoded)
|
55 |
+
class_weights = np.ones(n_classes) # Default weight of 1 for unseen classes
|
56 |
+
seen_classes = class_counts > 0
|
57 |
+
if np.any(seen_classes):
|
58 |
+
class_weights[seen_classes] = total_samples / (np.sum(seen_classes) * class_counts[seen_classes])
|
59 |
+
|
60 |
+
return torch.tensor(class_weights, dtype=torch.float)
|
61 |
+
|
62 |
+
def initialize_criterions(data_df, label_encoders):
|
63 |
+
"""
|
64 |
+
Initializes CrossEntropyLoss criteria for each label column, applying class weights.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
data_df (pd.DataFrame): The original (unencoded) DataFrame. Used to compute class weights.
|
68 |
+
label_encoders (dict): Dictionary of LabelEncoder objects.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
dict: A dictionary where keys are label column names and values are
|
72 |
+
initialized `torch.nn.CrossEntropyLoss` objects.
|
73 |
+
"""
|
74 |
+
field_criterions = {}
|
75 |
+
for field in LABEL_COLUMNS:
|
76 |
+
# Get class weights for the current field
|
77 |
+
weights = get_class_weights(data_df, field, label_encoders[field])
|
78 |
+
# Initialize CrossEntropyLoss with the computed weights and move to the device
|
79 |
+
field_criterions[field] = torch.nn.CrossEntropyLoss(weight=weights.to(DEVICE))
|
80 |
+
return field_criterions
|
81 |
+
|
82 |
+
def train_model(model, loader, optimizer, field_criterions, epoch):
|
83 |
+
"""
|
84 |
+
Trains the given PyTorch model for one epoch.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
model (torch.nn.Module): The model to train.
|
88 |
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
89 |
+
optimizer (torch.optim.Optimizer): Optimizer for model parameters.
|
90 |
+
field_criterions (dict): Dictionary of loss functions for each label.
|
91 |
+
epoch (int): Current epoch number (for progress bar description).
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
float: Average training loss for the epoch.
|
95 |
+
"""
|
96 |
+
model.train() # Set the model to training mode
|
97 |
+
total_loss = 0
|
98 |
+
# Use tqdm for a progress bar during training
|
99 |
+
tqdm_loader = tqdm(loader, desc=f"Epoch {epoch + 1} Training")
|
100 |
+
|
101 |
+
for batch in tqdm_loader:
|
102 |
+
# Unpack batch based on whether it contains metadata
|
103 |
+
if len(batch) == 2: # Text-only models (inputs, labels)
|
104 |
+
inputs, labels = batch
|
105 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
106 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
107 |
+
labels = labels.to(DEVICE)
|
108 |
+
# Forward pass through the model
|
109 |
+
outputs = model(input_ids, attention_mask)
|
110 |
+
elif len(batch) == 3: # Text + Metadata models (inputs, metadata, labels)
|
111 |
+
inputs, metadata, labels = batch
|
112 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
113 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
114 |
+
metadata = metadata.to(DEVICE)
|
115 |
+
labels = labels.to(DEVICE)
|
116 |
+
# Forward pass through the hybrid model
|
117 |
+
outputs = model(input_ids, attention_mask, metadata)
|
118 |
+
else:
|
119 |
+
raise ValueError("Unsupported batch format. Expected 2 or 3 items in batch.")
|
120 |
+
|
121 |
+
loss = 0
|
122 |
+
# Calculate total loss by summing loss for each label column
|
123 |
+
# `outputs` is a list of logits, one for each label column
|
124 |
+
for i, output_logits in enumerate(outputs):
|
125 |
+
# `labels[:, i]` gets the true labels for the i-th label column
|
126 |
+
# `field_criterions[LABEL_COLUMNS[i]]` selects the appropriate loss function
|
127 |
+
loss += field_criterions[LABEL_COLUMNS[i]](output_logits, labels[:, i])
|
128 |
+
|
129 |
+
optimizer.zero_grad() # Clear previous gradients
|
130 |
+
loss.backward() # Backpropagation
|
131 |
+
optimizer.step() # Update model parameters
|
132 |
+
total_loss += loss.item() # Accumulate loss
|
133 |
+
tqdm_loader.set_postfix(loss=loss.item()) # Update progress bar with current batch loss
|
134 |
+
|
135 |
+
return total_loss / len(loader) # Return average loss for the epoch
|
136 |
+
|
137 |
+
def evaluate_model(model, loader):
|
138 |
+
"""
|
139 |
+
Evaluates the given PyTorch model on a validation/test set.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
model (torch.nn.Module): The model to evaluate.
|
143 |
+
loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
tuple: A tuple containing:
|
147 |
+
- reports (dict): Classification reports (dict format) for each label column.
|
148 |
+
- truths (list): List of true label arrays for each label column.
|
149 |
+
- predictions (list): List of predicted label arrays for each label column.
|
150 |
+
"""
|
151 |
+
model.eval() # Set the model to evaluation mode (disables dropout, batch norm updates, etc.)
|
152 |
+
# Initialize lists to store predictions and true labels for each output head
|
153 |
+
predictions = [[] for _ in range(len(LABEL_COLUMNS))]
|
154 |
+
truths = [[] for _ in range(len(LABEL_COLUMNS))]
|
155 |
+
|
156 |
+
with torch.no_grad(): # Disable gradient calculations during evaluation for efficiency
|
157 |
+
for batch in tqdm(loader, desc="Evaluation"):
|
158 |
+
if len(batch) == 2:
|
159 |
+
inputs, labels = batch
|
160 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
161 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
162 |
+
labels = labels.to(DEVICE)
|
163 |
+
outputs = model(input_ids, attention_mask)
|
164 |
+
elif len(batch) == 3:
|
165 |
+
inputs, metadata, labels = batch
|
166 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
167 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
168 |
+
metadata = metadata.to(DEVICE)
|
169 |
+
labels = labels.to(DEVICE)
|
170 |
+
outputs = model(input_ids, attention_mask, metadata)
|
171 |
+
else:
|
172 |
+
raise ValueError("Unsupported batch format.")
|
173 |
+
|
174 |
+
for i, output_logits in enumerate(outputs):
|
175 |
+
# Get the predicted class by taking the argmax of the logits
|
176 |
+
preds = torch.argmax(output_logits, dim=1).cpu().numpy()
|
177 |
+
predictions[i].extend(preds)
|
178 |
+
# Get the true labels for the current output head
|
179 |
+
truths[i].extend(labels[:, i].cpu().numpy())
|
180 |
+
|
181 |
+
reports = {}
|
182 |
+
# Generate classification report for each label column
|
183 |
+
for i, col in enumerate(LABEL_COLUMNS):
|
184 |
+
try:
|
185 |
+
# `zero_division=0` handles cases where a class might have no true or predicted samples
|
186 |
+
reports[col] = classification_report(truths[i], predictions[i], output_dict=True, zero_division=0)
|
187 |
+
except ValueError:
|
188 |
+
# Handle cases where a label might not appear in the validation set,
|
189 |
+
# which could cause classification_report to fail.
|
190 |
+
print(f"Warning: Could not generate classification report for {col}. Skipping.")
|
191 |
+
reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
|
192 |
+
return reports, truths, predictions
|
193 |
+
|
194 |
+
def summarize_metrics(metrics):
|
195 |
+
"""
|
196 |
+
Summarizes classification reports into a readable Pandas DataFrame.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
metrics (dict): Dictionary of classification reports, as returned by `evaluate_model`.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
pd.DataFrame: A DataFrame summarizing precision, recall, f1-score, accuracy, and support for each field.
|
203 |
+
"""
|
204 |
+
summary = []
|
205 |
+
for field, report in metrics.items():
|
206 |
+
# Safely get metrics, defaulting to 0 if not present (e.g., for empty reports)
|
207 |
+
precision = report['weighted avg']['precision'] if 'weighted avg' in report else 0
|
208 |
+
recall = report['weighted avg']['recall'] if 'weighted avg' in report else 0
|
209 |
+
f1 = report['weighted avg']['f1-score'] if 'weighted avg' in report else 0
|
210 |
+
support = report['weighted avg']['support'] if 'weighted avg' in report else 0
|
211 |
+
accuracy = report['accuracy'] if 'accuracy' in report else 0 # Accuracy is usually top-level
|
212 |
+
summary.append({
|
213 |
+
"Field": field,
|
214 |
+
"Precision": precision,
|
215 |
+
"Recall": recall,
|
216 |
+
"F1-Score": f1,
|
217 |
+
"Accuracy": accuracy,
|
218 |
+
"Support": support
|
219 |
+
})
|
220 |
+
return pd.DataFrame(summary)
|
221 |
+
|
222 |
+
def save_model(model, model_name, save_format='pth'):
|
223 |
+
"""
|
224 |
+
Saves the state dictionary of a PyTorch model.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
model (torch.nn.Module): The trained PyTorch model.
|
228 |
+
model_name (str): A descriptive name for the model (used for filename).
|
229 |
+
save_format (str): Format to save the model in ('pth' for PyTorch models, 'pickle' for traditional ML models).
|
230 |
+
"""
|
231 |
+
# Construct the save path dynamically relative to the project root
|
232 |
+
if save_format == 'pth':
|
233 |
+
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
|
234 |
+
torch.save(model.state_dict(), model_path)
|
235 |
+
elif save_format == 'pickle':
|
236 |
+
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
|
237 |
+
joblib.dump(model, model_path)
|
238 |
+
else:
|
239 |
+
raise ValueError(f"Unsupported save format: {save_format}")
|
240 |
+
|
241 |
+
print(f"Model saved to {model_path}")
|
242 |
+
|
243 |
+
def load_model_state(model, model_name, model_class, num_labels, metadata_dim=0):
|
244 |
+
"""
|
245 |
+
Loads the state dictionary into a PyTorch model.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
model (torch.nn.Module): An initialized model instance (architecture).
|
249 |
+
model_name (str): The name of the model to load.
|
250 |
+
model_class (class): The class of the model (e.g., BertMultiOutputModel).
|
251 |
+
num_labels (list): List of number of classes for each label.
|
252 |
+
metadata_dim (int): Dimensionality of metadata features, if applicable (default 0 for text-only).
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
torch.nn.Module: The model with loaded state_dict, moved to the correct device, and set to eval mode.
|
256 |
+
"""
|
257 |
+
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
|
258 |
+
if not os.path.exists(model_path):
|
259 |
+
print(f"Warning: Model file not found at {model_path}. Returning a newly initialized model instance.")
|
260 |
+
# Re-initialize the model if not found, to ensure it has the correct architecture
|
261 |
+
if metadata_dim > 0:
|
262 |
+
return model_class(num_labels, metadata_dim=metadata_dim).to(DEVICE)
|
263 |
+
else:
|
264 |
+
return model_class(num_labels).to(DEVICE)
|
265 |
+
|
266 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
267 |
+
model.to(DEVICE)
|
268 |
+
model.eval() # Set to evaluation mode after loading
|
269 |
+
print(f"Model loaded from {model_path}")
|
270 |
+
return model
|
271 |
+
|
272 |
+
def predict_probabilities(model, loader):
|
273 |
+
"""
|
274 |
+
Generates prediction probabilities for each label for a given model.
|
275 |
+
This is used for confidence scoring and feeding into a voting ensemble.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
model (torch.nn.Module): The trained PyTorch model.
|
279 |
+
loader (torch.utils.data.DataLoader): DataLoader for the data to predict on.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
list: A list of lists of numpy arrays. Each inner list corresponds to a label column,
|
283 |
+
containing the softmax probabilities for each sample for that label.
|
284 |
+
"""
|
285 |
+
model.eval() # Set to evaluation mode
|
286 |
+
# List to store probabilities for each output head
|
287 |
+
all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
|
288 |
+
|
289 |
+
with torch.no_grad():
|
290 |
+
for batch in tqdm(loader, desc="Predicting Probabilities"):
|
291 |
+
# Unpack batch, ignoring labels as we only need inputs
|
292 |
+
if len(batch) == 2:
|
293 |
+
inputs, _ = batch
|
294 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
295 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
296 |
+
outputs = model(input_ids, attention_mask)
|
297 |
+
elif len(batch) == 3:
|
298 |
+
inputs, metadata, _ = batch
|
299 |
+
input_ids = inputs['input_ids'].to(DEVICE)
|
300 |
+
attention_mask = inputs['attention_mask'].to(DEVICE)
|
301 |
+
metadata = metadata.to(DEVICE)
|
302 |
+
outputs = model(input_ids, attention_mask, metadata)
|
303 |
+
else:
|
304 |
+
raise ValueError("Unsupported batch format.")
|
305 |
+
|
306 |
+
for i, out_logits in enumerate(outputs):
|
307 |
+
# Apply softmax to logits to get probabilities
|
308 |
+
probs = torch.softmax(out_logits, dim=1).cpu().numpy()
|
309 |
+
all_probabilities[i].extend(probs)
|
310 |
+
return all_probabilities
|