imnim's picture
Upload 3 files
d4e90c0 verified
raw
history blame
2.17 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
from fastapi.middleware.cors import CORSMiddleware
import torch
from huggingface_hub import login
from dotenv import load_dotenv
import os
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
app = FastAPI()
# Allow frontend communication
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# === Load Base + Adapter ===
adapter_path = "C:/Users/nimes/Desktop/NLP Projects/Multi-label Email Classifier/checkpoint-711"
try:
# Load PEFT config to get base model path
peft_config = PeftConfig.from_pretrained(adapter_path)
# Load base model and tokenizer (CPU-safe)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.float32,
device_map={"": "cpu"}
)
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"})
# Build inference pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
# === Request Schema ===
class EmailInput(BaseModel):
subject: str
body: str
# === Endpoint ===
@app.post("/classify")
async def classify_email(data: EmailInput):
prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
try:
result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
full_text = result[0]["generated_text"]
label_section = full_text.split("### Labels:")[1].strip()
return {"label": label_section}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")