imnim's picture
Update app.py
66ffe4f verified
raw
history blame
1.75 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
from fastapi.middleware.cors import CORSMiddleware
import torch
from dotenv import load_dotenv
import os
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
adapter_path = "./checkpoint-711"
try:
peft_config = PeftConfig.from_pretrained(adapter_path)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.float32,
device_map={"": "cpu"}
)
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"})
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
# === Request Schema ===
class EmailInput(BaseModel):
subject: str
body: str
# === Endpoint ===
@app.post("/classify")
async def classify_email(data: EmailInput):
prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
try:
result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
full_text = result[0]["generated_text"]
label_section = full_text.split("### Labels:")[1].strip()
return {"label": label_section}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")