namanpenguin commited on
Commit
b886555
·
verified ·
1 Parent(s): 8fe168f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -51
app.py CHANGED
@@ -184,11 +184,11 @@ async def root():
184
  async def health_check():
185
  return {"status": "healthy"}
186
 
187
- @app.get("/v1/bert/training-status")
188
  async def get_training_status():
189
  return training_status
190
 
191
- @app.post("/v1/bert/upload")
192
  async def upload_file(file: UploadFile = File(...)):
193
  """Upload a CSV file for training or validation"""
194
  if not file.filename.endswith('.csv'):
@@ -200,7 +200,7 @@ async def upload_file(file: UploadFile = File(...)):
200
 
201
  return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
202
 
203
- @app.post("/v1/bert/train", response_model=TrainingResponse)
204
  async def start_training(
205
  config: TrainingConfig,
206
  background_tasks: BackgroundTasks,
@@ -224,7 +224,7 @@ async def start_training(
224
 
225
  background_tasks.add_task(train_model_task, config, file_path, training_id)
226
 
227
- download_url = f"/v1/bert/download-model/{training_id}"
228
 
229
  return TrainingResponse(
230
  message="Training started successfully",
@@ -233,7 +233,7 @@ async def start_training(
233
  download_url=download_url
234
  )
235
 
236
- @app.post("/v1/bert/validate")
237
  async def validate_model(
238
  file: UploadFile = File(...),
239
  model_name: str = "BERT_model"
@@ -319,7 +319,7 @@ async def validate_model(
319
  if os.path.exists(file_path):
320
  os.remove(file_path)
321
 
322
- @app.post("/v1/bert/predict")
323
  async def predict(
324
  request: Optional[PredictionRequest] = None,
325
  file: Optional[UploadFile] = File(None),
@@ -510,80 +510,51 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
510
  data_df_original, label_encoders = load_and_preprocess_data(file_path)
511
  save_label_encoders(label_encoders)
512
 
513
- train_df, val_df = train_test_split(
514
- data_df_original,
515
- test_size=config.test_size,
516
- random_state=config.random_state,
517
- stratify=data_df_original[LABEL_COLUMNS[0]]
518
- )
519
-
520
- train_texts = train_df[TEXT_COLUMN]
521
- val_texts = val_df[TEXT_COLUMN]
522
- train_labels_array = train_df[LABEL_COLUMNS].values
523
- val_labels_array = val_df[LABEL_COLUMNS].values
524
 
525
- train_metadata_df = train_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in train_df.columns for col in METADATA_COLUMNS) else None
526
- val_metadata_df = val_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in val_df.columns for col in METADATA_COLUMNS) else None
527
 
528
  num_labels_list = get_num_labels(label_encoders)
529
  tokenizer = get_tokenizer(config.model_name)
530
 
531
- if train_metadata_df is not None and val_metadata_df is not None:
532
- metadata_dim = train_metadata_df.shape[1]
533
- train_dataset = ComplianceDatasetWithMetadata(
534
- train_texts.tolist(),
535
- train_metadata_df.values,
536
- train_labels_array,
537
- tokenizer,
538
- config.max_length
539
- )
540
- val_dataset = ComplianceDatasetWithMetadata(
541
- val_texts.tolist(),
542
- val_metadata_df.values,
543
- val_labels_array,
544
  tokenizer,
545
  config.max_length
546
  )
547
  model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
548
  else:
549
- train_dataset = ComplianceDataset(
550
- train_texts.tolist(),
551
- train_labels_array,
552
- tokenizer,
553
- config.max_length
554
- )
555
- val_dataset = ComplianceDataset(
556
- val_texts.tolist(),
557
- val_labels_array,
558
  tokenizer,
559
  config.max_length
560
  )
561
  model = BertMultiOutputModel(num_labels_list).to(DEVICE)
562
 
563
- train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
564
- val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
565
 
566
  criterions = initialize_criterions(num_labels_list)
567
  optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
568
 
569
- best_val_loss = float('inf')
570
  for epoch in range(config.num_epochs):
571
  training_status["current_epoch"] = epoch + 1
572
 
573
  train_loss = train_model(model, train_loader, criterions, optimizer)
574
- val_metrics, _, _ = evaluate_model(model, val_loader)
575
-
576
  training_status["current_loss"] = train_loss
577
 
578
- if val_metrics["loss"] < best_val_loss:
579
- best_val_loss = val_metrics["loss"]
580
- save_model(model, training_id)
581
 
582
  training_status.update({
583
  "is_training": False,
584
  "end_time": datetime.now().isoformat(),
585
- "status": "completed",
586
- "metrics": summarize_metrics(val_metrics).to_dict()
587
  })
588
 
589
  except Exception as e:
 
184
  async def health_check():
185
  return {"status": "healthy"}
186
 
187
+ @app.get("/training-status")
188
  async def get_training_status():
189
  return training_status
190
 
191
+ @app.post("/upload")
192
  async def upload_file(file: UploadFile = File(...)):
193
  """Upload a CSV file for training or validation"""
194
  if not file.filename.endswith('.csv'):
 
200
 
201
  return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
202
 
203
+ @app.post("/bert/train", response_model=TrainingResponse)
204
  async def start_training(
205
  config: TrainingConfig,
206
  background_tasks: BackgroundTasks,
 
224
 
225
  background_tasks.add_task(train_model_task, config, file_path, training_id)
226
 
227
+ download_url = f"/bert/download-model/{training_id}"
228
 
229
  return TrainingResponse(
230
  message="Training started successfully",
 
233
  download_url=download_url
234
  )
235
 
236
+ @app.post("/bert/validate")
237
  async def validate_model(
238
  file: UploadFile = File(...),
239
  model_name: str = "BERT_model"
 
319
  if os.path.exists(file_path):
320
  os.remove(file_path)
321
 
322
+ @app.post("/bert/predict")
323
  async def predict(
324
  request: Optional[PredictionRequest] = None,
325
  file: Optional[UploadFile] = File(None),
 
510
  data_df_original, label_encoders = load_and_preprocess_data(file_path)
511
  save_label_encoders(label_encoders)
512
 
513
+ texts = data_df_original[TEXT_COLUMN]
514
+ labels_array = data_df_original[LABEL_COLUMNS].values
 
 
 
 
 
 
 
 
 
515
 
516
+ metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
 
517
 
518
  num_labels_list = get_num_labels(label_encoders)
519
  tokenizer = get_tokenizer(config.model_name)
520
 
521
+ if metadata_df is not None:
522
+ metadata_dim = metadata_df.shape[1]
523
+ dataset = ComplianceDatasetWithMetadata(
524
+ texts.tolist(),
525
+ metadata_df.values,
526
+ labels_array,
 
 
 
 
 
 
 
527
  tokenizer,
528
  config.max_length
529
  )
530
  model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
531
  else:
532
+ dataset = ComplianceDataset(
533
+ texts.tolist(),
534
+ labels_array,
 
 
 
 
 
 
535
  tokenizer,
536
  config.max_length
537
  )
538
  model = BertMultiOutputModel(num_labels_list).to(DEVICE)
539
 
540
+ train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
 
541
 
542
  criterions = initialize_criterions(num_labels_list)
543
  optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
544
 
 
545
  for epoch in range(config.num_epochs):
546
  training_status["current_epoch"] = epoch + 1
547
 
548
  train_loss = train_model(model, train_loader, criterions, optimizer)
 
 
549
  training_status["current_loss"] = train_loss
550
 
551
+ # Save model after each epoch
552
+ save_model(model, training_id)
 
553
 
554
  training_status.update({
555
  "is_training": False,
556
  "end_time": datetime.now().isoformat(),
557
+ "status": "completed"
 
558
  })
559
 
560
  except Exception as e: