from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from peft import PeftModel, PeftConfig from fastapi.middleware.cors import CORSMiddleware import torch import os app = FastAPI() # Allow CORS (customize in production) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Hugging Face access token (from env) hf_token = os.getenv("HF_TOKEN") # HF model repo (includes adapter + full model) adapter_path = "imnim/multi-label-email-classifier" try: # Load PEFT adapter config peft_config = PeftConfig.from_pretrained(adapter_path, token=hf_token) # Try loading in bfloat16, fallback to float32 try: base_model = AutoModelForCausalLM.from_pretrained( peft_config.base_model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto", token=hf_token ) except Exception: base_model = AutoModelForCausalLM.from_pretrained( peft_config.base_model_name_or_path, torch_dtype=torch.float32, device_map="auto", token=hf_token ) tokenizer = AutoTokenizer.from_pretrained( peft_config.base_model_name_or_path, token=hf_token ) # Load the adapter model = PeftModel.from_pretrained( base_model, adapter_path, token=hf_token ) # Create the pipeline — no device argument (handled by accelerate) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) except Exception as e: raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}") # === Request Schema === class EmailInput(BaseModel): subject: str body: str # === Endpoint === @app.post("/classify") async def classify_email(data: EmailInput): prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:""" try: result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95) full_text = result[0]["generated_text"] label_section = full_text.split("### Labels:")[1].strip() return {"label": label_section} except Exception as e: raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")