imnim commited on
Commit
011aa0f
·
verified ·
1 Parent(s): b9b5220

Updated with 16 bit instead of 32 bit params

Browse files
Files changed (1) hide show
  1. app.py +34 -21
app.py CHANGED
@@ -4,10 +4,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from peft import PeftModel, PeftConfig
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torch
 
7
 
8
  app = FastAPI()
9
 
10
- # Allow CORS for all origins (adjust this in production)
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
@@ -16,44 +17,56 @@ app.add_middleware(
16
  allow_headers=["*"],
17
  )
18
 
19
- # Path to your HF Hub repo with full model + adapter
 
 
 
20
  adapter_path = "imnim/multi-label-email-classifier"
21
 
22
  try:
23
- # Load PEFT config to get base model path
24
- peft_config = PeftConfig.from_pretrained(adapter_path, use_auth_token=True)
25
-
26
- # Load base model and tokenizer with HF auth token
27
- base_model = AutoModelForCausalLM.from_pretrained(
28
- peft_config.base_model_name_or_path,
29
- torch_dtype=torch.bfloat16,
30
- device_map={"": "cpu"},
31
- use_auth_token=True
32
- )
 
 
 
 
 
 
 
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  peft_config.base_model_name_or_path,
35
- use_auth_token=True
36
  )
37
 
38
- # Load adapter with HF auth token
39
  model = PeftModel.from_pretrained(
40
- base_model, adapter_path,
41
- device_map={"": "cpu"},
42
- use_auth_token=True
43
  )
44
 
45
- # Setup text-generation pipeline
46
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
47
 
48
  except Exception as e:
49
  raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
50
 
51
- # Request schema
52
  class EmailInput(BaseModel):
53
  subject: str
54
  body: str
55
 
56
- # POST /classify endpoint
57
  @app.post("/classify")
58
  async def classify_email(data: EmailInput):
59
  prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
 
4
  from peft import PeftModel, PeftConfig
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torch
7
+ import os
8
 
9
  app = FastAPI()
10
 
11
+ # Allow CORS (customize in production)
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
 
17
  allow_headers=["*"],
18
  )
19
 
20
+ # Hugging Face access token (from env)
21
+ hf_token = os.getenv("HF_TOKEN")
22
+
23
+ # HF model repo (includes adapter + full model)
24
  adapter_path = "imnim/multi-label-email-classifier"
25
 
26
  try:
27
+ # Load PEFT adapter config
28
+ peft_config = PeftConfig.from_pretrained(adapter_path, token=hf_token)
29
+
30
+ # Try loading in bfloat16, fallback to float32
31
+ try:
32
+ base_model = AutoModelForCausalLM.from_pretrained(
33
+ peft_config.base_model_name_or_path,
34
+ torch_dtype=torch.bfloat16,
35
+ device_map="auto",
36
+ token=hf_token
37
+ )
38
+ except Exception:
39
+ base_model = AutoModelForCausalLM.from_pretrained(
40
+ peft_config.base_model_name_or_path,
41
+ torch_dtype=torch.float32,
42
+ device_map="auto",
43
+ token=hf_token
44
+ )
45
+
46
  tokenizer = AutoTokenizer.from_pretrained(
47
  peft_config.base_model_name_or_path,
48
+ token=hf_token
49
  )
50
 
51
+ # Load the adapter
52
  model = PeftModel.from_pretrained(
53
+ base_model,
54
+ adapter_path,
55
+ token=hf_token
56
  )
57
 
58
+ # Create the pipeline — no device argument (handled by accelerate)
59
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
60
 
61
  except Exception as e:
62
  raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
63
 
64
+ # === Request Schema ===
65
  class EmailInput(BaseModel):
66
  subject: str
67
  body: str
68
 
69
+ # === Endpoint ===
70
  @app.post("/classify")
71
  async def classify_email(data: EmailInput):
72
  prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""