Update app.py
Browse files
app.py
CHANGED
@@ -184,11 +184,11 @@ async def root():
|
|
184 |
async def health_check():
|
185 |
return {"status": "healthy"}
|
186 |
|
187 |
-
@app.get("/
|
188 |
async def get_training_status():
|
189 |
return training_status
|
190 |
|
191 |
-
@app.post("/
|
192 |
async def upload_file(file: UploadFile = File(...)):
|
193 |
"""Upload a CSV file for training or validation"""
|
194 |
if not file.filename.endswith('.csv'):
|
@@ -200,7 +200,7 @@ async def upload_file(file: UploadFile = File(...)):
|
|
200 |
|
201 |
return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
|
202 |
|
203 |
-
@app.post("/
|
204 |
async def start_training(
|
205 |
config: TrainingConfig,
|
206 |
background_tasks: BackgroundTasks,
|
@@ -224,7 +224,7 @@ async def start_training(
|
|
224 |
|
225 |
background_tasks.add_task(train_model_task, config, file_path, training_id)
|
226 |
|
227 |
-
download_url = f"/
|
228 |
|
229 |
return TrainingResponse(
|
230 |
message="Training started successfully",
|
@@ -233,7 +233,7 @@ async def start_training(
|
|
233 |
download_url=download_url
|
234 |
)
|
235 |
|
236 |
-
@app.post("/
|
237 |
async def validate_model(
|
238 |
file: UploadFile = File(...),
|
239 |
model_name: str = "BERT_model"
|
@@ -319,7 +319,7 @@ async def validate_model(
|
|
319 |
if os.path.exists(file_path):
|
320 |
os.remove(file_path)
|
321 |
|
322 |
-
@app.post("/
|
323 |
async def predict(
|
324 |
request: Optional[PredictionRequest] = None,
|
325 |
file: Optional[UploadFile] = File(None),
|
@@ -510,80 +510,51 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
|
|
510 |
data_df_original, label_encoders = load_and_preprocess_data(file_path)
|
511 |
save_label_encoders(label_encoders)
|
512 |
|
513 |
-
|
514 |
-
|
515 |
-
test_size=config.test_size,
|
516 |
-
random_state=config.random_state,
|
517 |
-
stratify=data_df_original[LABEL_COLUMNS[0]]
|
518 |
-
)
|
519 |
-
|
520 |
-
train_texts = train_df[TEXT_COLUMN]
|
521 |
-
val_texts = val_df[TEXT_COLUMN]
|
522 |
-
train_labels_array = train_df[LABEL_COLUMNS].values
|
523 |
-
val_labels_array = val_df[LABEL_COLUMNS].values
|
524 |
|
525 |
-
|
526 |
-
val_metadata_df = val_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in val_df.columns for col in METADATA_COLUMNS) else None
|
527 |
|
528 |
num_labels_list = get_num_labels(label_encoders)
|
529 |
tokenizer = get_tokenizer(config.model_name)
|
530 |
|
531 |
-
if
|
532 |
-
metadata_dim =
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
tokenizer,
|
538 |
-
config.max_length
|
539 |
-
)
|
540 |
-
val_dataset = ComplianceDatasetWithMetadata(
|
541 |
-
val_texts.tolist(),
|
542 |
-
val_metadata_df.values,
|
543 |
-
val_labels_array,
|
544 |
tokenizer,
|
545 |
config.max_length
|
546 |
)
|
547 |
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
|
548 |
else:
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
tokenizer,
|
553 |
-
config.max_length
|
554 |
-
)
|
555 |
-
val_dataset = ComplianceDataset(
|
556 |
-
val_texts.tolist(),
|
557 |
-
val_labels_array,
|
558 |
tokenizer,
|
559 |
config.max_length
|
560 |
)
|
561 |
model = BertMultiOutputModel(num_labels_list).to(DEVICE)
|
562 |
|
563 |
-
train_loader = DataLoader(
|
564 |
-
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
|
565 |
|
566 |
criterions = initialize_criterions(num_labels_list)
|
567 |
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
568 |
|
569 |
-
best_val_loss = float('inf')
|
570 |
for epoch in range(config.num_epochs):
|
571 |
training_status["current_epoch"] = epoch + 1
|
572 |
|
573 |
train_loss = train_model(model, train_loader, criterions, optimizer)
|
574 |
-
val_metrics, _, _ = evaluate_model(model, val_loader)
|
575 |
-
|
576 |
training_status["current_loss"] = train_loss
|
577 |
|
578 |
-
|
579 |
-
|
580 |
-
save_model(model, training_id)
|
581 |
|
582 |
training_status.update({
|
583 |
"is_training": False,
|
584 |
"end_time": datetime.now().isoformat(),
|
585 |
-
"status": "completed"
|
586 |
-
"metrics": summarize_metrics(val_metrics).to_dict()
|
587 |
})
|
588 |
|
589 |
except Exception as e:
|
|
|
184 |
async def health_check():
|
185 |
return {"status": "healthy"}
|
186 |
|
187 |
+
@app.get("/training-status")
|
188 |
async def get_training_status():
|
189 |
return training_status
|
190 |
|
191 |
+
@app.post("/upload")
|
192 |
async def upload_file(file: UploadFile = File(...)):
|
193 |
"""Upload a CSV file for training or validation"""
|
194 |
if not file.filename.endswith('.csv'):
|
|
|
200 |
|
201 |
return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
|
202 |
|
203 |
+
@app.post("/bert/train", response_model=TrainingResponse)
|
204 |
async def start_training(
|
205 |
config: TrainingConfig,
|
206 |
background_tasks: BackgroundTasks,
|
|
|
224 |
|
225 |
background_tasks.add_task(train_model_task, config, file_path, training_id)
|
226 |
|
227 |
+
download_url = f"/bert/download-model/{training_id}"
|
228 |
|
229 |
return TrainingResponse(
|
230 |
message="Training started successfully",
|
|
|
233 |
download_url=download_url
|
234 |
)
|
235 |
|
236 |
+
@app.post("/bert/validate")
|
237 |
async def validate_model(
|
238 |
file: UploadFile = File(...),
|
239 |
model_name: str = "BERT_model"
|
|
|
319 |
if os.path.exists(file_path):
|
320 |
os.remove(file_path)
|
321 |
|
322 |
+
@app.post("/bert/predict")
|
323 |
async def predict(
|
324 |
request: Optional[PredictionRequest] = None,
|
325 |
file: Optional[UploadFile] = File(None),
|
|
|
510 |
data_df_original, label_encoders = load_and_preprocess_data(file_path)
|
511 |
save_label_encoders(label_encoders)
|
512 |
|
513 |
+
texts = data_df_original[TEXT_COLUMN]
|
514 |
+
labels_array = data_df_original[LABEL_COLUMNS].values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
|
516 |
+
metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
|
|
|
517 |
|
518 |
num_labels_list = get_num_labels(label_encoders)
|
519 |
tokenizer = get_tokenizer(config.model_name)
|
520 |
|
521 |
+
if metadata_df is not None:
|
522 |
+
metadata_dim = metadata_df.shape[1]
|
523 |
+
dataset = ComplianceDatasetWithMetadata(
|
524 |
+
texts.tolist(),
|
525 |
+
metadata_df.values,
|
526 |
+
labels_array,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
tokenizer,
|
528 |
config.max_length
|
529 |
)
|
530 |
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
|
531 |
else:
|
532 |
+
dataset = ComplianceDataset(
|
533 |
+
texts.tolist(),
|
534 |
+
labels_array,
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
tokenizer,
|
536 |
config.max_length
|
537 |
)
|
538 |
model = BertMultiOutputModel(num_labels_list).to(DEVICE)
|
539 |
|
540 |
+
train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
|
|
|
541 |
|
542 |
criterions = initialize_criterions(num_labels_list)
|
543 |
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
544 |
|
|
|
545 |
for epoch in range(config.num_epochs):
|
546 |
training_status["current_epoch"] = epoch + 1
|
547 |
|
548 |
train_loss = train_model(model, train_loader, criterions, optimizer)
|
|
|
|
|
549 |
training_status["current_loss"] = train_loss
|
550 |
|
551 |
+
# Save model after each epoch
|
552 |
+
save_model(model, training_id)
|
|
|
553 |
|
554 |
training_status.update({
|
555 |
"is_training": False,
|
556 |
"end_time": datetime.now().isoformat(),
|
557 |
+
"status": "completed"
|
|
|
558 |
})
|
559 |
|
560 |
except Exception as e:
|