imnim commited on
Commit
d4e90c0
·
verified ·
1 Parent(s): 98c21cc

Upload 3 files

Browse files
Files changed (3) hide show
  1. Spacefile +4 -0
  2. app/main.py +67 -0
  3. requirements.txt +11 -0
Spacefile ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ sdk: custom
2
+ python_version: 3.9
3
+ app_file: app/main.py
4
+ command: uvicorn app.main:app --host 0.0.0.0 --port 7860
app/main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from peft import PeftModel, PeftConfig
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ import torch
7
+ from huggingface_hub import login
8
+ from dotenv import load_dotenv
9
+ import os
10
+
11
+ load_dotenv()
12
+
13
+ hf_token = os.getenv("HF_TOKEN")
14
+
15
+ login(token=hf_token)
16
+
17
+ app = FastAPI()
18
+
19
+ # Allow frontend communication
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["http://localhost:3000"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # === Load Base + Adapter ===
29
+ adapter_path = "C:/Users/nimes/Desktop/NLP Projects/Multi-label Email Classifier/checkpoint-711"
30
+
31
+ try:
32
+ # Load PEFT config to get base model path
33
+ peft_config = PeftConfig.from_pretrained(adapter_path)
34
+
35
+ # Load base model and tokenizer (CPU-safe)
36
+ base_model = AutoModelForCausalLM.from_pretrained(
37
+ peft_config.base_model_name_or_path,
38
+ torch_dtype=torch.float32,
39
+ device_map={"": "cpu"}
40
+ )
41
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
42
+
43
+ # Load LoRA adapter
44
+ model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"})
45
+
46
+ # Build inference pipeline
47
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
48
+
49
+ except Exception as e:
50
+ raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
51
+
52
+ # === Request Schema ===
53
+ class EmailInput(BaseModel):
54
+ subject: str
55
+ body: str
56
+
57
+ # === Endpoint ===
58
+ @app.post("/classify")
59
+ async def classify_email(data: EmailInput):
60
+ prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
61
+ try:
62
+ result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
63
+ full_text = result[0]["generated_text"]
64
+ label_section = full_text.split("### Labels:")[1].strip()
65
+ return {"label": label_section}
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai
2
+ tqdm
3
+ pydantic
4
+ dotenv
5
+ requests
6
+ datasets
7
+ fastapi
8
+ checkpoints
9
+ torch
10
+ peft
11
+ auto-gptq