namanpenguin commited on
Commit
ce0b5cd
·
verified ·
1 Parent(s): 6a1e1ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
2
  from fastapi.responses import FileResponse
3
  from pydantic import BaseModel
4
  from typing import Optional, Dict, Any, List
@@ -18,6 +18,7 @@ import zipfile
18
  import io
19
  import numpy as np
20
  import sys
 
21
 
22
 
23
  # Import existing utilities
@@ -88,7 +89,6 @@ class TrainingConfig(BaseModel):
88
  learning_rate: float = 2e-5
89
  num_epochs: int = 2
90
  max_length: int = 128
91
- test_size: float = 0.2
92
  random_state: int = 42
93
 
94
  class TrainingResponse(BaseModel):
@@ -190,8 +190,8 @@ async def get_training_status():
190
 
191
  @app.post("/v1/bert/train", response_model=TrainingResponse)
192
  async def start_training(
193
- config: TrainingConfig,
194
- background_tasks: BackgroundTasks,
195
  file: UploadFile = File(...)
196
  ):
197
  if training_status["is_training"]:
@@ -200,6 +200,15 @@ async def start_training(
200
  if not file.filename.endswith('.csv'):
201
  raise HTTPException(status_code=400, detail="Only CSV files are allowed")
202
 
 
 
 
 
 
 
 
 
 
203
  file_path = UPLOAD_DIR / file.filename
204
  with file_path.open("wb") as buffer:
205
  shutil.copyfileobj(file.file, buffer)
@@ -209,12 +218,12 @@ async def start_training(
209
  training_status.update({
210
  "is_training": True,
211
  "current_epoch": 0,
212
- "total_epochs": config.num_epochs,
213
  "start_time": datetime.now().isoformat(),
214
  "status": "starting"
215
  })
216
 
217
- background_tasks.add_task(train_model_task, config, str(file_path), training_id)
218
 
219
  download_url = f"/v1/bert/download-model/{training_id}"
220
 
@@ -314,7 +323,7 @@ async def validate_model(
314
  @app.post("/v1/bert/predict")
315
  async def predict(
316
  request: Optional[PredictionRequest] = None,
317
- file: Optional[UploadFile] = File(None),
318
  model_name: str = "BERT_model"
319
  ):
320
  """
@@ -325,6 +334,7 @@ async def predict(
325
  2. Upload a CSV file with multiple transactions
326
 
327
  Parameters:
 
328
  - model_name: Name of the model to use for prediction (default: "BERT_model")
329
  """
330
  try:
@@ -339,7 +349,7 @@ async def predict(
339
  model.eval()
340
 
341
  # Handle batch prediction from CSV
342
- if file is not None and file.filename:
343
  if not file.filename.endswith('.csv'):
344
  raise HTTPException(status_code=400, detail="Only CSV files are allowed")
345
 
@@ -398,7 +408,7 @@ async def predict(
398
  os.remove(file_path)
399
 
400
  # Handle single prediction
401
- elif request is not None and request.transaction_data:
402
  input_data = pd.DataFrame([request.transaction_data.dict()])
403
 
404
  text_input = f"""
@@ -430,20 +440,6 @@ async def predict(
430
  Sanction Description: {input_data['Sanction_Description'].iloc[0]}
431
  Checker Notes: {input_data['Checker_Notes'].iloc[0]}
432
  Sanction Context: {input_data['Sanction_Context'].iloc[0]}
433
- Maker Action: {input_data['Maker_Action'].iloc[0]}
434
- Customer Type: {input_data['Customer_Type'].iloc[0]}
435
- Industry: {input_data['Industry'].iloc[0]}
436
- Transaction Type: {input_data['Transaction_Type'].iloc[0]}
437
- Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
438
- Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
439
- Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
440
- Risk Category: {input_data['Risk_Category'].iloc[0]}
441
- Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
442
- Alert Status: {input_data['Alert_Status'].iloc[0]}
443
- Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
444
- Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
445
- Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
446
- Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
447
  """
448
 
449
  dataset = ComplianceDataset(
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
  from fastapi.responses import FileResponse
3
  from pydantic import BaseModel
4
  from typing import Optional, Dict, Any, List
 
18
  import io
19
  import numpy as np
20
  import sys
21
+ import json
22
 
23
 
24
  # Import existing utilities
 
89
  learning_rate: float = 2e-5
90
  num_epochs: int = 2
91
  max_length: int = 128
 
92
  random_state: int = 42
93
 
94
  class TrainingResponse(BaseModel):
 
190
 
191
  @app.post("/v1/bert/train", response_model=TrainingResponse)
192
  async def start_training(
193
+ config: str = Form(...),
194
+ background_tasks: BackgroundTasks = None,
195
  file: UploadFile = File(...)
196
  ):
197
  if training_status["is_training"]:
 
200
  if not file.filename.endswith('.csv'):
201
  raise HTTPException(status_code=400, detail="Only CSV files are allowed")
202
 
203
+ try:
204
+ # Parse the config JSON string into a TrainingConfig object
205
+ config_dict = json.loads(config)
206
+ training_config = TrainingConfig(**config_dict)
207
+ except json.JSONDecodeError:
208
+ raise HTTPException(status_code=400, detail="Invalid config JSON format")
209
+ except Exception as e:
210
+ raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
211
+
212
  file_path = UPLOAD_DIR / file.filename
213
  with file_path.open("wb") as buffer:
214
  shutil.copyfileobj(file.file, buffer)
 
218
  training_status.update({
219
  "is_training": True,
220
  "current_epoch": 0,
221
+ "total_epochs": training_config.num_epochs,
222
  "start_time": datetime.now().isoformat(),
223
  "status": "starting"
224
  })
225
 
226
+ background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
227
 
228
  download_url = f"/v1/bert/download-model/{training_id}"
229
 
 
323
  @app.post("/v1/bert/predict")
324
  async def predict(
325
  request: Optional[PredictionRequest] = None,
326
+ file: UploadFile = File(None),
327
  model_name: str = "BERT_model"
328
  ):
329
  """
 
334
  2. Upload a CSV file with multiple transactions
335
 
336
  Parameters:
337
+ - file: CSV file containing transactions for batch prediction
338
  - model_name: Name of the model to use for prediction (default: "BERT_model")
339
  """
340
  try:
 
349
  model.eval()
350
 
351
  # Handle batch prediction from CSV
352
+ if file and file.filename:
353
  if not file.filename.endswith('.csv'):
354
  raise HTTPException(status_code=400, detail="Only CSV files are allowed")
355
 
 
408
  os.remove(file_path)
409
 
410
  # Handle single prediction
411
+ elif request and request.transaction_data:
412
  input_data = pd.DataFrame([request.transaction_data.dict()])
413
 
414
  text_input = f"""
 
440
  Sanction Description: {input_data['Sanction_Description'].iloc[0]}
441
  Checker Notes: {input_data['Checker_Notes'].iloc[0]}
442
  Sanction Context: {input_data['Sanction_Context'].iloc[0]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  """
444
 
445
  dataset = ComplianceDataset(