Update app.py
Browse files
app.py
CHANGED
@@ -180,37 +180,29 @@ class BatchPredictionResponse(BaseModel):
|
|
180 |
async def root():
|
181 |
return {"message": "BERT Compliance Predictor API"}
|
182 |
|
183 |
-
@app.get("/health")
|
184 |
async def health_check():
|
185 |
return {"status": "healthy"}
|
186 |
|
187 |
-
@app.get("/training-status")
|
188 |
async def get_training_status():
|
189 |
return training_status
|
190 |
|
191 |
-
@app.post("/
|
192 |
-
async def upload_file(file: UploadFile = File(...)):
|
193 |
-
"""Upload a CSV file for training or validation"""
|
194 |
-
if not file.filename.endswith('.csv'):
|
195 |
-
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
196 |
-
|
197 |
-
file_path = UPLOAD_DIR / file.filename
|
198 |
-
with file_path.open("wb") as buffer:
|
199 |
-
shutil.copyfileobj(file.file, buffer)
|
200 |
-
|
201 |
-
return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
|
202 |
-
|
203 |
-
@app.post("/bert/train", response_model=TrainingResponse)
|
204 |
async def start_training(
|
205 |
config: TrainingConfig,
|
206 |
background_tasks: BackgroundTasks,
|
207 |
-
|
208 |
):
|
209 |
if training_status["is_training"]:
|
210 |
raise HTTPException(status_code=400, detail="Training is already in progress")
|
211 |
|
212 |
-
if not
|
213 |
-
raise HTTPException(status_code=
|
|
|
|
|
|
|
|
|
214 |
|
215 |
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
216 |
|
@@ -222,9 +214,9 @@ async def start_training(
|
|
222 |
"status": "starting"
|
223 |
})
|
224 |
|
225 |
-
background_tasks.add_task(train_model_task, config, file_path, training_id)
|
226 |
|
227 |
-
download_url = f"/bert/download-model/{training_id}"
|
228 |
|
229 |
return TrainingResponse(
|
230 |
message="Training started successfully",
|
@@ -233,7 +225,7 @@ async def start_training(
|
|
233 |
download_url=download_url
|
234 |
)
|
235 |
|
236 |
-
@app.post("/bert/validate")
|
237 |
async def validate_model(
|
238 |
file: UploadFile = File(...),
|
239 |
model_name: str = "BERT_model"
|
@@ -319,7 +311,7 @@ async def validate_model(
|
|
319 |
if os.path.exists(file_path):
|
320 |
os.remove(file_path)
|
321 |
|
322 |
-
@app.post("/bert/predict")
|
323 |
async def predict(
|
324 |
request: Optional[PredictionRequest] = None,
|
325 |
file: Optional[UploadFile] = File(None),
|
@@ -492,7 +484,7 @@ async def predict(
|
|
492 |
except Exception as e:
|
493 |
raise HTTPException(status_code=500, detail=str(e))
|
494 |
|
495 |
-
@app.get("/bert/download-model/{model_id}")
|
496 |
async def download_model(model_id: str):
|
497 |
"""Download a trained model"""
|
498 |
model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
|
|
|
180 |
async def root():
|
181 |
return {"message": "BERT Compliance Predictor API"}
|
182 |
|
183 |
+
@app.get("/v1/bert/health")
|
184 |
async def health_check():
|
185 |
return {"status": "healthy"}
|
186 |
|
187 |
+
@app.get("/v1/bert/training-status")
|
188 |
async def get_training_status():
|
189 |
return training_status
|
190 |
|
191 |
+
@app.post("/v1/bert/train", response_model=TrainingResponse)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
async def start_training(
|
193 |
config: TrainingConfig,
|
194 |
background_tasks: BackgroundTasks,
|
195 |
+
file: UploadFile = File(...)
|
196 |
):
|
197 |
if training_status["is_training"]:
|
198 |
raise HTTPException(status_code=400, detail="Training is already in progress")
|
199 |
|
200 |
+
if not file.filename.endswith('.csv'):
|
201 |
+
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
202 |
+
|
203 |
+
file_path = UPLOAD_DIR / file.filename
|
204 |
+
with file_path.open("wb") as buffer:
|
205 |
+
shutil.copyfileobj(file.file, buffer)
|
206 |
|
207 |
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
208 |
|
|
|
214 |
"status": "starting"
|
215 |
})
|
216 |
|
217 |
+
background_tasks.add_task(train_model_task, config, str(file_path), training_id)
|
218 |
|
219 |
+
download_url = f"/v1/bert/download-model/{training_id}"
|
220 |
|
221 |
return TrainingResponse(
|
222 |
message="Training started successfully",
|
|
|
225 |
download_url=download_url
|
226 |
)
|
227 |
|
228 |
+
@app.post("/v1/bert/validate")
|
229 |
async def validate_model(
|
230 |
file: UploadFile = File(...),
|
231 |
model_name: str = "BERT_model"
|
|
|
311 |
if os.path.exists(file_path):
|
312 |
os.remove(file_path)
|
313 |
|
314 |
+
@app.post("/v1/bert/predict")
|
315 |
async def predict(
|
316 |
request: Optional[PredictionRequest] = None,
|
317 |
file: Optional[UploadFile] = File(None),
|
|
|
484 |
except Exception as e:
|
485 |
raise HTTPException(status_code=500, detail=str(e))
|
486 |
|
487 |
+
@app.get("/v1/bert/download-model/{model_id}")
|
488 |
async def download_model(model_id: str):
|
489 |
"""Download a trained model"""
|
490 |
model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
|