Commit
·
df5822b
1
Parent(s):
b5db444
loaded model using AutoModel and manual pooling
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
-
from
|
|
|
|
|
4 |
import logging
|
5 |
|
6 |
logging.basicConfig(level=logging.INFO)
|
@@ -9,11 +11,19 @@ logger=logging.getLogger(__name__)
|
|
9 |
logger.info("Server Starting")
|
10 |
try:
|
11 |
logger.info("Loading model")
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
logger.info("Model Loaded")
|
14 |
except:
|
15 |
logger.error("Failed to load Model")
|
16 |
model=None
|
|
|
17 |
|
18 |
app=FastAPI()
|
19 |
|
@@ -24,16 +34,33 @@ class EmbedRequest(BaseModel):
|
|
24 |
class EmbedResponse(BaseModel):
|
25 |
embedding: list[float]
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
@app.get("/")
|
28 |
def root_status():
|
29 |
return {"status":"ok","model":model is not None}
|
30 |
|
31 |
@app.post("/embed",response_model=EmbedResponse)
|
32 |
def get_embedding(request: EmbedRequest):
|
33 |
-
if model
|
34 |
-
HTTPException(status_code=503,detail="
|
35 |
try:
|
|
|
|
|
36 |
embedding=model.encode(request.text).tolist()
|
|
|
|
|
|
|
|
|
|
|
37 |
return EmbedResponse(embedding=embedding)
|
38 |
except Exception as e:
|
39 |
logger.error("Error during embedding generation %s",e)
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
import torch
|
5 |
+
from torch.nn.functional import F
|
6 |
import logging
|
7 |
|
8 |
logging.basicConfig(level=logging.INFO)
|
|
|
11 |
logger.info("Server Starting")
|
12 |
try:
|
13 |
logger.info("Loading model")
|
14 |
+
MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
|
15 |
+
device="cpu"
|
16 |
+
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
|
17 |
+
model=AutoModel.from_pretrained(MODEL_NAME)
|
18 |
+
|
19 |
+
model.to(device)
|
20 |
+
model.eval()
|
21 |
+
|
22 |
logger.info("Model Loaded")
|
23 |
except:
|
24 |
logger.error("Failed to load Model")
|
25 |
model=None
|
26 |
+
tokenizer=None
|
27 |
|
28 |
app=FastAPI()
|
29 |
|
|
|
34 |
class EmbedResponse(BaseModel):
|
35 |
embedding: list[float]
|
36 |
|
37 |
+
def mean_pooling(model_output, attention_mask):
|
38 |
+
"""
|
39 |
+
Performs mean pooling on the last hidden state of the model.
|
40 |
+
This turns token-level embeddings into a single sentence-level embedding.
|
41 |
+
"""
|
42 |
+
token_embeddings = model_output.last_hidden_state
|
43 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
44 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
45 |
+
|
46 |
+
|
47 |
@app.get("/")
|
48 |
def root_status():
|
49 |
return {"status":"ok","model":model is not None}
|
50 |
|
51 |
@app.post("/embed",response_model=EmbedResponse)
|
52 |
def get_embedding(request: EmbedRequest):
|
53 |
+
if not model or not tokenizer:
|
54 |
+
HTTPException(status_code=503,detail="/Tokenizer could not be loaded")
|
55 |
try:
|
56 |
+
encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
|
57 |
+
model_output = model(**encoded_input)
|
58 |
embedding=model.encode(request.text).tolist()
|
59 |
+
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
|
60 |
+
|
61 |
+
normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
|
62 |
+
|
63 |
+
embedding_list = normalized_embedding[0].tolist()
|
64 |
return EmbedResponse(embedding=embedding)
|
65 |
except Exception as e:
|
66 |
logger.error("Error during embedding generation %s",e)
|