imnim commited on
Commit
881b63b
·
verified ·
1 Parent(s): d4e90c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from peft import PeftModel, PeftConfig
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ import torch
7
+ from huggingface_hub import login
8
+ from dotenv import load_dotenv
9
+ import os
10
+
11
+ load_dotenv()
12
+
13
+ hf_token = os.getenv("HF_TOKEN")
14
+
15
+ login(token=hf_token)
16
+
17
+ app = FastAPI()
18
+
19
+ # Allow frontend communication
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["http://localhost:3000"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # === Load Base + Adapter ===
29
+ adapter_path = "C:/Users/nimes/Desktop/NLP Projects/Multi-label Email Classifier/checkpoint-711"
30
+
31
+ try:
32
+ # Load PEFT config to get base model path
33
+ peft_config = PeftConfig.from_pretrained(adapter_path)
34
+
35
+ # Load base model and tokenizer (CPU-safe)
36
+ base_model = AutoModelForCausalLM.from_pretrained(
37
+ peft_config.base_model_name_or_path,
38
+ torch_dtype=torch.float32,
39
+ device_map={"": "cpu"}
40
+ )
41
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
42
+
43
+ # Load LoRA adapter
44
+ model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"})
45
+
46
+ # Build inference pipeline
47
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
48
+
49
+ except Exception as e:
50
+ raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
51
+
52
+ # === Request Schema ===
53
+ class EmailInput(BaseModel):
54
+ subject: str
55
+ body: str
56
+
57
+ # === Endpoint ===
58
+ @app.post("/classify")
59
+ async def classify_email(data: EmailInput):
60
+ prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
61
+ try:
62
+ result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
63
+ full_text = result[0]["generated_text"]
64
+ label_section = full_text.split("### Labels:")[1].strip()
65
+ return {"label": label_section}
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")