imnim commited on
Commit
040c190
·
verified ·
1 Parent(s): a20582d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -67
app.py CHANGED
@@ -1,88 +1,54 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from peft import PeftModel, PeftConfig
5
- from fastapi.middleware.cors import CORSMiddleware
6
  import torch
7
  import os
8
 
9
- app = FastAPI()
10
-
11
- # Allow CORS (customize in production)
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"],
15
- allow_credentials=True,
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
- # Hugging Face access token (from env)
21
  hf_token = os.getenv("HF_TOKEN")
22
 
23
- # HF model repo (includes adapter + full model)
24
  adapter_path = "imnim/multi-label-email-classifier"
25
 
26
- try:
27
- # Load PEFT adapter config
28
- peft_config = PeftConfig.from_pretrained(adapter_path, token=hf_token)
29
-
30
- # Try loading in bfloat16, fallback to float32
31
- try:
32
- base_model = AutoModelForCausalLM.from_pretrained(
33
- peft_config.base_model_name_or_path,
34
- torch_dtype=torch.bfloat16,
35
- device_map="auto",
36
- token=hf_token
37
- )
38
- except Exception:
39
- base_model = AutoModelForCausalLM.from_pretrained(
40
- peft_config.base_model_name_or_path,
41
- torch_dtype=torch.float32,
42
- device_map="auto",
43
- token=hf_token
44
- )
45
 
46
- tokenizer = AutoTokenizer.from_pretrained(
 
 
47
  peft_config.base_model_name_or_path,
 
 
48
  token=hf_token
49
  )
50
-
51
- # Load the adapter
52
- model = PeftModel.from_pretrained(
53
- base_model,
54
- adapter_path,
55
  token=hf_token
56
  )
57
 
58
- # Create the pipeline — no device argument (handled by accelerate)
59
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
60
-
61
- except Exception as e:
62
- raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
63
 
64
- # === Request Schema ===
65
- class EmailInput(BaseModel):
66
- subject: str
67
- body: str
68
 
69
- # === Endpoint ===
70
- @app.post("/classify")
71
- async def classify_email(data: EmailInput):
72
- prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
73
- try:
74
- result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
75
- full_text = result[0]["generated_text"]
76
- label_section = full_text.split("### Labels:")[1].strip()
77
- return {"label": label_section}
78
- except Exception as e:
79
- raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")
80
 
 
 
 
 
 
 
 
81
 
 
 
 
 
 
 
 
 
82
 
83
-
84
- import uvicorn
85
-
86
- if __name__ == "__main__":
87
- uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
88
-
 
1
+ import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  from peft import PeftModel, PeftConfig
 
4
  import torch
5
  import os
6
 
7
+ # Hugging Face access token (stored in HF Spaces secrets)
 
 
 
 
 
 
 
 
 
 
 
8
  hf_token = os.getenv("HF_TOKEN")
9
 
 
10
  adapter_path = "imnim/multi-label-email-classifier"
11
 
12
+ # Load PEFT config
13
+ peft_config = PeftConfig.from_pretrained(adapter_path, token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Load base model (fallback to float32 if bfloat16 fails)
16
+ try:
17
+ base_model = AutoModelForCausalLM.from_pretrained(
18
  peft_config.base_model_name_or_path,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map="auto",
21
  token=hf_token
22
  )
23
+ except:
24
+ base_model = AutoModelForCausalLM.from_pretrained(
25
+ peft_config.base_model_name_or_path,
26
+ torch_dtype=torch.float32,
27
+ device_map="auto",
28
  token=hf_token
29
  )
30
 
31
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path, token=hf_token)
 
 
 
 
32
 
33
+ model = PeftModel.from_pretrained(base_model, adapter_path, token=hf_token)
 
 
 
34
 
35
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Define classification function
38
+ def classify_email(subject, body):
39
+ prompt = f"""### Subject:\n{subject}\n\n### Body:\n{body}\n\n### Labels:"""
40
+ result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
41
+ full_text = result[0]["generated_text"]
42
+ label_section = full_text.split("### Labels:")[1].strip()
43
+ return label_section
44
 
45
+ # Gradio UI
46
+ demo = gr.Interface(
47
+ fn=classify_email,
48
+ inputs=["text", "text"],
49
+ outputs="text",
50
+ title="Multi-label Email Classifier",
51
+ description="Enter subject and body to get label prediction"
52
+ )
53
 
54
+ demo.launch()