Commit
·
649b044
1
Parent(s):
ea47ed4
Attemp using Sentence Transformer
Browse files
app.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
-
from transformers import
|
4 |
-
import
|
5 |
-
import torch.nn.functional as F
|
6 |
import logging
|
7 |
|
8 |
logging.basicConfig(level=logging.INFO)
|
@@ -13,17 +12,19 @@ try:
|
|
13 |
logger.info("Loading model")
|
14 |
MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
|
15 |
device="cpu"
|
16 |
-
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
|
17 |
-
|
|
|
18 |
|
19 |
-
model
|
20 |
-
|
|
|
|
|
21 |
|
22 |
logger.info("Model Loaded")
|
23 |
except Exception as e:
|
24 |
logger.error("Failed to load Model %s",e)
|
25 |
model=None
|
26 |
-
tokenizer=None
|
27 |
|
28 |
app=FastAPI()
|
29 |
|
@@ -34,34 +35,18 @@ class EmbedRequest(BaseModel):
|
|
34 |
class EmbedResponse(BaseModel):
|
35 |
embedding: list[float]
|
36 |
|
37 |
-
def mean_pooling(model_output, attention_mask):
|
38 |
-
"""
|
39 |
-
Performs mean pooling on the last hidden state of the model.
|
40 |
-
This turns token-level embeddings into a single sentence-level embedding.
|
41 |
-
"""
|
42 |
-
token_embeddings = model_output.last_hidden_state
|
43 |
-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
44 |
-
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
45 |
-
|
46 |
-
|
47 |
@app.get("/")
|
48 |
def root_status():
|
49 |
return {"status":"ok","model":model is not None}
|
50 |
|
51 |
@app.post("/embed",response_model=EmbedResponse)
|
52 |
def get_embedding(request: EmbedRequest):
|
53 |
-
if
|
54 |
-
HTTPException(status_code=503,detail="/
|
55 |
try:
|
56 |
-
encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
|
57 |
-
model_output = model(**encoded_input)
|
58 |
-
# embedding=model.encode(request.text).tolist()
|
59 |
-
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
|
60 |
-
|
61 |
-
normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
|
62 |
|
63 |
-
|
64 |
-
return EmbedResponse(embedding=
|
65 |
except Exception as e:
|
66 |
logger.error("Error during embedding generation %s",e)
|
67 |
-
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
+
from transformers import AutoModel
|
4 |
+
from sentence_transformers import SentenceTransformer, models
|
|
|
5 |
import logging
|
6 |
|
7 |
logging.basicConfig(level=logging.INFO)
|
|
|
12 |
logger.info("Loading model")
|
13 |
MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
|
14 |
device="cpu"
|
15 |
+
# tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
|
16 |
+
embedding_model=AutoModel.from_pretrained(MODEL_NAME)
|
17 |
+
pooling_model=models.Pooling(embedding_model.get_word_embedding_dimension())
|
18 |
|
19 |
+
model=SentenceTransformer(
|
20 |
+
modules=[embedding_model,pooling_model],
|
21 |
+
device=device
|
22 |
+
)
|
23 |
|
24 |
logger.info("Model Loaded")
|
25 |
except Exception as e:
|
26 |
logger.error("Failed to load Model %s",e)
|
27 |
model=None
|
|
|
28 |
|
29 |
app=FastAPI()
|
30 |
|
|
|
35 |
class EmbedResponse(BaseModel):
|
36 |
embedding: list[float]
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
@app.get("/")
|
39 |
def root_status():
|
40 |
return {"status":"ok","model":model is not None}
|
41 |
|
42 |
@app.post("/embed",response_model=EmbedResponse)
|
43 |
def get_embedding(request: EmbedRequest):
|
44 |
+
if model is None:
|
45 |
+
raise HTTPException(status_code=503,detail="/Model could not be loaded")
|
46 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
embedding = model.encode(request.text).tolist()
|
49 |
+
return EmbedResponse(embedding=embedding)
|
50 |
except Exception as e:
|
51 |
logger.error("Error during embedding generation %s",e)
|
52 |
+
raise HTTPException(status_code=500,detail="Error generating embeddings")
|