imnim's picture
Updated with 16 bit instead of 32 bit params
011aa0f verified
raw
history blame
2.39 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
from fastapi.middleware.cors import CORSMiddleware
import torch
import os
app = FastAPI()
# Allow CORS (customize in production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Hugging Face access token (from env)
hf_token = os.getenv("HF_TOKEN")
# HF model repo (includes adapter + full model)
adapter_path = "imnim/multi-label-email-classifier"
try:
# Load PEFT adapter config
peft_config = PeftConfig.from_pretrained(adapter_path, token=hf_token)
# Try loading in bfloat16, fallback to float32
try:
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
token=hf_token
)
except Exception:
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.float32,
device_map="auto",
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(
peft_config.base_model_name_or_path,
token=hf_token
)
# Load the adapter
model = PeftModel.from_pretrained(
base_model,
adapter_path,
token=hf_token
)
# Create the pipeline β€” no device argument (handled by accelerate)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
# === Request Schema ===
class EmailInput(BaseModel):
subject: str
body: str
# === Endpoint ===
@app.post("/classify")
async def classify_email(data: EmailInput):
prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
try:
result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
full_text = result[0]["generated_text"]
label_section = full_text.split("### Labels:")[1].strip()
return {"label": label_section}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")