namanpenguin commited on
Commit
6a1e1ff
·
verified ·
1 Parent(s): b886555

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -23
app.py CHANGED
@@ -180,37 +180,29 @@ class BatchPredictionResponse(BaseModel):
180
  async def root():
181
  return {"message": "BERT Compliance Predictor API"}
182
 
183
- @app.get("/health")
184
  async def health_check():
185
  return {"status": "healthy"}
186
 
187
- @app.get("/training-status")
188
  async def get_training_status():
189
  return training_status
190
 
191
- @app.post("/upload")
192
- async def upload_file(file: UploadFile = File(...)):
193
- """Upload a CSV file for training or validation"""
194
- if not file.filename.endswith('.csv'):
195
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
196
-
197
- file_path = UPLOAD_DIR / file.filename
198
- with file_path.open("wb") as buffer:
199
- shutil.copyfileobj(file.file, buffer)
200
-
201
- return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
202
-
203
- @app.post("/bert/train", response_model=TrainingResponse)
204
  async def start_training(
205
  config: TrainingConfig,
206
  background_tasks: BackgroundTasks,
207
- file_path: str
208
  ):
209
  if training_status["is_training"]:
210
  raise HTTPException(status_code=400, detail="Training is already in progress")
211
 
212
- if not os.path.exists(file_path):
213
- raise HTTPException(status_code=404, detail="Training file not found")
 
 
 
 
214
 
215
  training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
216
 
@@ -222,9 +214,9 @@ async def start_training(
222
  "status": "starting"
223
  })
224
 
225
- background_tasks.add_task(train_model_task, config, file_path, training_id)
226
 
227
- download_url = f"/bert/download-model/{training_id}"
228
 
229
  return TrainingResponse(
230
  message="Training started successfully",
@@ -233,7 +225,7 @@ async def start_training(
233
  download_url=download_url
234
  )
235
 
236
- @app.post("/bert/validate")
237
  async def validate_model(
238
  file: UploadFile = File(...),
239
  model_name: str = "BERT_model"
@@ -319,7 +311,7 @@ async def validate_model(
319
  if os.path.exists(file_path):
320
  os.remove(file_path)
321
 
322
- @app.post("/bert/predict")
323
  async def predict(
324
  request: Optional[PredictionRequest] = None,
325
  file: Optional[UploadFile] = File(None),
@@ -492,7 +484,7 @@ async def predict(
492
  except Exception as e:
493
  raise HTTPException(status_code=500, detail=str(e))
494
 
495
- @app.get("/bert/download-model/{model_id}")
496
  async def download_model(model_id: str):
497
  """Download a trained model"""
498
  model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
 
180
  async def root():
181
  return {"message": "BERT Compliance Predictor API"}
182
 
183
+ @app.get("/v1/bert/health")
184
  async def health_check():
185
  return {"status": "healthy"}
186
 
187
+ @app.get("/v1/bert/training-status")
188
  async def get_training_status():
189
  return training_status
190
 
191
+ @app.post("/v1/bert/train", response_model=TrainingResponse)
 
 
 
 
 
 
 
 
 
 
 
 
192
  async def start_training(
193
  config: TrainingConfig,
194
  background_tasks: BackgroundTasks,
195
+ file: UploadFile = File(...)
196
  ):
197
  if training_status["is_training"]:
198
  raise HTTPException(status_code=400, detail="Training is already in progress")
199
 
200
+ if not file.filename.endswith('.csv'):
201
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
202
+
203
+ file_path = UPLOAD_DIR / file.filename
204
+ with file_path.open("wb") as buffer:
205
+ shutil.copyfileobj(file.file, buffer)
206
 
207
  training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
208
 
 
214
  "status": "starting"
215
  })
216
 
217
+ background_tasks.add_task(train_model_task, config, str(file_path), training_id)
218
 
219
+ download_url = f"/v1/bert/download-model/{training_id}"
220
 
221
  return TrainingResponse(
222
  message="Training started successfully",
 
225
  download_url=download_url
226
  )
227
 
228
+ @app.post("/v1/bert/validate")
229
  async def validate_model(
230
  file: UploadFile = File(...),
231
  model_name: str = "BERT_model"
 
311
  if os.path.exists(file_path):
312
  os.remove(file_path)
313
 
314
+ @app.post("/v1/bert/predict")
315
  async def predict(
316
  request: Optional[PredictionRequest] = None,
317
  file: Optional[UploadFile] = File(None),
 
484
  except Exception as e:
485
  raise HTTPException(status_code=500, detail=str(e))
486
 
487
+ @app.get("/v1/bert/download-model/{model_id}")
488
  async def download_model(model_id: str):
489
  """Download a trained model"""
490
  model_path = MODEL_SAVE_DIR / f"{model_id}.pth"