from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from peft import PeftModel, PeftConfig from fastapi.middleware.cors import CORSMiddleware import torch from huggingface_hub import login from dotenv import load_dotenv import os load_dotenv() hf_token = os.getenv("HF_TOKEN") login(token=hf_token) app = FastAPI() # Allow frontend communication app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # === Load Base + Adapter === adapter_path = "C:/Users/nimes/Desktop/NLP Projects/Multi-label Email Classifier/checkpoint-711" try: # Load PEFT config to get base model path peft_config = PeftConfig.from_pretrained(adapter_path) # Load base model and tokenizer (CPU-safe) base_model = AutoModelForCausalLM.from_pretrained( peft_config.base_model_name_or_path, torch_dtype=torch.float32, device_map={"": "cpu"} ) tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path) # Load LoRA adapter model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"}) # Build inference pipeline pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) except Exception as e: raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}") # === Request Schema === class EmailInput(BaseModel): subject: str body: str # === Endpoint === @app.post("/classify") async def classify_email(data: EmailInput): prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:""" try: result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95) full_text = result[0]["generated_text"] label_section = full_text.split("### Labels:")[1].strip() return {"label": label_section} except Exception as e: raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")