imnim commited on
Commit
4a402f7
·
verified ·
1 Parent(s): 31d2076

Delete app

Browse files
Files changed (1) hide show
  1. app/main.py +0 -67
app/main.py DELETED
@@ -1,67 +0,0 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
- from peft import PeftModel, PeftConfig
5
- from fastapi.middleware.cors import CORSMiddleware
6
- import torch
7
- from huggingface_hub import login
8
- from dotenv import load_dotenv
9
- import os
10
-
11
- load_dotenv()
12
-
13
- hf_token = os.getenv("HF_TOKEN")
14
-
15
- login(token=hf_token)
16
-
17
- app = FastAPI()
18
-
19
- # Allow frontend communication
20
- app.add_middleware(
21
- CORSMiddleware,
22
- allow_origins=["http://localhost:3000"],
23
- allow_credentials=True,
24
- allow_methods=["*"],
25
- allow_headers=["*"],
26
- )
27
-
28
- # === Load Base + Adapter ===
29
- adapter_path = "C:/Users/nimes/Desktop/NLP Projects/Multi-label Email Classifier/checkpoint-711"
30
-
31
- try:
32
- # Load PEFT config to get base model path
33
- peft_config = PeftConfig.from_pretrained(adapter_path)
34
-
35
- # Load base model and tokenizer (CPU-safe)
36
- base_model = AutoModelForCausalLM.from_pretrained(
37
- peft_config.base_model_name_or_path,
38
- torch_dtype=torch.float32,
39
- device_map={"": "cpu"}
40
- )
41
- tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
42
-
43
- # Load LoRA adapter
44
- model = PeftModel.from_pretrained(base_model, adapter_path, device_map={"": "cpu"})
45
-
46
- # Build inference pipeline
47
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
48
-
49
- except Exception as e:
50
- raise RuntimeError(f"❌ Failed to load model + adapter: {str(e)}")
51
-
52
- # === Request Schema ===
53
- class EmailInput(BaseModel):
54
- subject: str
55
- body: str
56
-
57
- # === Endpoint ===
58
- @app.post("/classify")
59
- async def classify_email(data: EmailInput):
60
- prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:"""
61
- try:
62
- result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)
63
- full_text = result[0]["generated_text"]
64
- label_section = full_text.split("### Labels:")[1].strip()
65
- return {"label": label_section}
66
- except Exception as e:
67
- raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}")