imnim commited on
Commit
2bcbc24
·
verified ·
1 Parent(s): 4a402f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -4,51 +4,56 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from peft import PeftModel, PeftConfig
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torch
7
- from dotenv import load_dotenv
8
- import os
9
-
10
-
11
- load_dotenv()
12
- hf_token = os.getenv("HF_TOKEN")
13
-
14
 
15
  app = FastAPI()
16
 
 
17
  app.add_middleware(
18
  CORSMiddleware,
19
- allow_origins=["*"],
20
  allow_credentials=True,
21
  allow_methods=["*"],
22
  allow_headers=["*"],
23
  )
24
 
25
- adapter_path = "./checkpoint-711"
 
26
 
27
  try:
28
- peft_config = PeftConfig.from_pretrained(adapter_path)
29
-
30
-
 
31
  base_model = AutoModelForCausalLM.from_pretrained(
32
  peft_config.base_model_name_or_path,
33
  torch_dtype=torch.float32,
34
- device_map={"": "cpu"}
 
 
 
 
 
35
  )
36
- tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
37
-
38
 
39
- model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"})
 
 
 
 
 
40
 
41
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
42
 
43
  except Exception as e:
44
  raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
45
 
46
- # === Request Schema ===
47
  class EmailInput(BaseModel):
48
  subject: str
49
  body: str
50
 
51
- # === Endpoint ===
52
  @app.post("/classify")
53
  async def classify_email(data: EmailInput):
54
  prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
 
4
  from peft import PeftModel, PeftConfig
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torch
 
 
 
 
 
 
 
7
 
8
  app = FastAPI()
9
 
10
+ # Allow CORS for all origins (adjust this in production)
11
  app.add_middleware(
12
  CORSMiddleware,
13
+ allow_origins=["*"],
14
  allow_credentials=True,
15
  allow_methods=["*"],
16
  allow_headers=["*"],
17
  )
18
 
19
+ # Path to your HF Hub repo with full model + adapter
20
+ adapter_path = "imnim/multi-label-email-classifier"
21
 
22
  try:
23
+ # Load PEFT config to get base model path
24
+ peft_config = PeftConfig.from_pretrained(adapter_path, use_auth_token=True)
25
+
26
+ # Load base model and tokenizer with HF auth token
27
  base_model = AutoModelForCausalLM.from_pretrained(
28
  peft_config.base_model_name_or_path,
29
  torch_dtype=torch.float32,
30
+ device_map={"": "cpu"},
31
+ use_auth_token=True
32
+ )
33
+ tokenizer = AutoTokenizer.from_pretrained(
34
+ peft_config.base_model_name_or_path,
35
+ use_auth_token=True
36
  )
 
 
37
 
38
+ # Load adapter with HF auth token
39
+ model = PeftModel.from_pretrained(
40
+ base_model, adapter_path,
41
+ device_map={"": "cpu"},
42
+ use_auth_token=True
43
+ )
44
 
45
+ # Setup text-generation pipeline
46
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
47
 
48
  except Exception as e:
49
  raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
50
 
51
+ # Request schema
52
  class EmailInput(BaseModel):
53
  subject: str
54
  body: str
55
 
56
+ # POST /classify endpoint
57
  @app.post("/classify")
58
  async def classify_email(data: EmailInput):
59
  prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""