Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from peft import PeftModel, PeftConfig | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
hf_token = os.getenv("HF_TOKEN") | |
login(token=hf_token) | |
app = FastAPI() | |
# Allow frontend communication | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:3000"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# === Load Base + Adapter === | |
adapter_path = "C:/Users/nimes/Desktop/NLP Projects/Multi-label Email Classifier/checkpoint-711" | |
try: | |
# Load PEFT config to get base model path | |
peft_config = PeftConfig.from_pretrained(adapter_path) | |
# Load base model and tokenizer (CPU-safe) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
peft_config.base_model_name_or_path, | |
torch_dtype=torch.float32, | |
device_map={"": "cpu"} | |
) | |
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path) | |
# Load LoRA adapter | |
model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"}) | |
# Build inference pipeline | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
raise RuntimeError(f"β Failed to load model + adapter: {str(e)}") | |
# === Request Schema === | |
class EmailInput(BaseModel): | |
subject: str | |
body: str | |
# === Endpoint === | |
async def classify_email(data: EmailInput): | |
prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:""" | |
try: | |
result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95) | |
full_text = result[0]["generated_text"] | |
label_section = full_text.split("### Labels:")[1].strip() | |
return {"label": label_section} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}") | |