Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
|
2 |
from fastapi.responses import FileResponse
|
3 |
from pydantic import BaseModel
|
4 |
from typing import Optional, Dict, Any, List
|
@@ -18,6 +18,7 @@ import zipfile
|
|
18 |
import io
|
19 |
import numpy as np
|
20 |
import sys
|
|
|
21 |
|
22 |
|
23 |
# Import existing utilities
|
@@ -88,7 +89,6 @@ class TrainingConfig(BaseModel):
|
|
88 |
learning_rate: float = 2e-5
|
89 |
num_epochs: int = 2
|
90 |
max_length: int = 128
|
91 |
-
test_size: float = 0.2
|
92 |
random_state: int = 42
|
93 |
|
94 |
class TrainingResponse(BaseModel):
|
@@ -190,8 +190,8 @@ async def get_training_status():
|
|
190 |
|
191 |
@app.post("/v1/bert/train", response_model=TrainingResponse)
|
192 |
async def start_training(
|
193 |
-
config:
|
194 |
-
background_tasks: BackgroundTasks,
|
195 |
file: UploadFile = File(...)
|
196 |
):
|
197 |
if training_status["is_training"]:
|
@@ -200,6 +200,15 @@ async def start_training(
|
|
200 |
if not file.filename.endswith('.csv'):
|
201 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
file_path = UPLOAD_DIR / file.filename
|
204 |
with file_path.open("wb") as buffer:
|
205 |
shutil.copyfileobj(file.file, buffer)
|
@@ -209,12 +218,12 @@ async def start_training(
|
|
209 |
training_status.update({
|
210 |
"is_training": True,
|
211 |
"current_epoch": 0,
|
212 |
-
"total_epochs":
|
213 |
"start_time": datetime.now().isoformat(),
|
214 |
"status": "starting"
|
215 |
})
|
216 |
|
217 |
-
background_tasks.add_task(train_model_task,
|
218 |
|
219 |
download_url = f"/v1/bert/download-model/{training_id}"
|
220 |
|
@@ -314,7 +323,7 @@ async def validate_model(
|
|
314 |
@app.post("/v1/bert/predict")
|
315 |
async def predict(
|
316 |
request: Optional[PredictionRequest] = None,
|
317 |
-
file:
|
318 |
model_name: str = "BERT_model"
|
319 |
):
|
320 |
"""
|
@@ -325,6 +334,7 @@ async def predict(
|
|
325 |
2. Upload a CSV file with multiple transactions
|
326 |
|
327 |
Parameters:
|
|
|
328 |
- model_name: Name of the model to use for prediction (default: "BERT_model")
|
329 |
"""
|
330 |
try:
|
@@ -339,7 +349,7 @@ async def predict(
|
|
339 |
model.eval()
|
340 |
|
341 |
# Handle batch prediction from CSV
|
342 |
-
if file
|
343 |
if not file.filename.endswith('.csv'):
|
344 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
345 |
|
@@ -398,7 +408,7 @@ async def predict(
|
|
398 |
os.remove(file_path)
|
399 |
|
400 |
# Handle single prediction
|
401 |
-
elif request
|
402 |
input_data = pd.DataFrame([request.transaction_data.dict()])
|
403 |
|
404 |
text_input = f"""
|
@@ -430,20 +440,6 @@ async def predict(
|
|
430 |
Sanction Description: {input_data['Sanction_Description'].iloc[0]}
|
431 |
Checker Notes: {input_data['Checker_Notes'].iloc[0]}
|
432 |
Sanction Context: {input_data['Sanction_Context'].iloc[0]}
|
433 |
-
Maker Action: {input_data['Maker_Action'].iloc[0]}
|
434 |
-
Customer Type: {input_data['Customer_Type'].iloc[0]}
|
435 |
-
Industry: {input_data['Industry'].iloc[0]}
|
436 |
-
Transaction Type: {input_data['Transaction_Type'].iloc[0]}
|
437 |
-
Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
|
438 |
-
Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
|
439 |
-
Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
|
440 |
-
Risk Category: {input_data['Risk_Category'].iloc[0]}
|
441 |
-
Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
|
442 |
-
Alert Status: {input_data['Alert_Status'].iloc[0]}
|
443 |
-
Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
|
444 |
-
Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
|
445 |
-
Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
|
446 |
-
Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
|
447 |
"""
|
448 |
|
449 |
dataset = ComplianceDataset(
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
|
2 |
from fastapi.responses import FileResponse
|
3 |
from pydantic import BaseModel
|
4 |
from typing import Optional, Dict, Any, List
|
|
|
18 |
import io
|
19 |
import numpy as np
|
20 |
import sys
|
21 |
+
import json
|
22 |
|
23 |
|
24 |
# Import existing utilities
|
|
|
89 |
learning_rate: float = 2e-5
|
90 |
num_epochs: int = 2
|
91 |
max_length: int = 128
|
|
|
92 |
random_state: int = 42
|
93 |
|
94 |
class TrainingResponse(BaseModel):
|
|
|
190 |
|
191 |
@app.post("/v1/bert/train", response_model=TrainingResponse)
|
192 |
async def start_training(
|
193 |
+
config: str = Form(...),
|
194 |
+
background_tasks: BackgroundTasks = None,
|
195 |
file: UploadFile = File(...)
|
196 |
):
|
197 |
if training_status["is_training"]:
|
|
|
200 |
if not file.filename.endswith('.csv'):
|
201 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
202 |
|
203 |
+
try:
|
204 |
+
# Parse the config JSON string into a TrainingConfig object
|
205 |
+
config_dict = json.loads(config)
|
206 |
+
training_config = TrainingConfig(**config_dict)
|
207 |
+
except json.JSONDecodeError:
|
208 |
+
raise HTTPException(status_code=400, detail="Invalid config JSON format")
|
209 |
+
except Exception as e:
|
210 |
+
raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
|
211 |
+
|
212 |
file_path = UPLOAD_DIR / file.filename
|
213 |
with file_path.open("wb") as buffer:
|
214 |
shutil.copyfileobj(file.file, buffer)
|
|
|
218 |
training_status.update({
|
219 |
"is_training": True,
|
220 |
"current_epoch": 0,
|
221 |
+
"total_epochs": training_config.num_epochs,
|
222 |
"start_time": datetime.now().isoformat(),
|
223 |
"status": "starting"
|
224 |
})
|
225 |
|
226 |
+
background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
|
227 |
|
228 |
download_url = f"/v1/bert/download-model/{training_id}"
|
229 |
|
|
|
323 |
@app.post("/v1/bert/predict")
|
324 |
async def predict(
|
325 |
request: Optional[PredictionRequest] = None,
|
326 |
+
file: UploadFile = File(None),
|
327 |
model_name: str = "BERT_model"
|
328 |
):
|
329 |
"""
|
|
|
334 |
2. Upload a CSV file with multiple transactions
|
335 |
|
336 |
Parameters:
|
337 |
+
- file: CSV file containing transactions for batch prediction
|
338 |
- model_name: Name of the model to use for prediction (default: "BERT_model")
|
339 |
"""
|
340 |
try:
|
|
|
349 |
model.eval()
|
350 |
|
351 |
# Handle batch prediction from CSV
|
352 |
+
if file and file.filename:
|
353 |
if not file.filename.endswith('.csv'):
|
354 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
355 |
|
|
|
408 |
os.remove(file_path)
|
409 |
|
410 |
# Handle single prediction
|
411 |
+
elif request and request.transaction_data:
|
412 |
input_data = pd.DataFrame([request.transaction_data.dict()])
|
413 |
|
414 |
text_input = f"""
|
|
|
440 |
Sanction Description: {input_data['Sanction_Description'].iloc[0]}
|
441 |
Checker Notes: {input_data['Checker_Notes'].iloc[0]}
|
442 |
Sanction Context: {input_data['Sanction_Context'].iloc[0]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
"""
|
444 |
|
445 |
dataset = ComplianceDataset(
|