Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from peft import PeftModel, PeftConfig | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
import os | |
app = FastAPI() | |
# Allow CORS (customize in production) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Hugging Face access token (from env) | |
hf_token = os.getenv("HF_TOKEN") | |
# HF model repo (includes adapter + full model) | |
adapter_path = "imnim/multi-label-email-classifier" | |
try: | |
# Load PEFT adapter config | |
peft_config = PeftConfig.from_pretrained(adapter_path, token=hf_token) | |
# Try loading in bfloat16, fallback to float32 | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
peft_config.base_model_name_or_path, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
token=hf_token | |
) | |
except Exception: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
peft_config.base_model_name_or_path, | |
torch_dtype=torch.float32, | |
device_map="auto", | |
token=hf_token | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
peft_config.base_model_name_or_path, | |
token=hf_token | |
) | |
# Load the adapter | |
model = PeftModel.from_pretrained( | |
base_model, | |
adapter_path, | |
token=hf_token | |
) | |
# Create the pipeline β no device argument (handled by accelerate) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
raise RuntimeError(f"β Failed to load model + adapter: {str(e)}") | |
# === Request Schema === | |
class EmailInput(BaseModel): | |
subject: str | |
body: str | |
# === Endpoint === | |
async def classify_email(data: EmailInput): | |
prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:""" | |
try: | |
result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95) | |
full_text = result[0]["generated_text"] | |
label_section = full_text.split("### Labels:")[1].strip() | |
return {"label": label_section} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}") | |