Sid-the-sloth's picture
loaded model using AutoModel and manual pooling
8ebacc3
raw
history blame
2.28 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from torch.nn.functional import F
import logging
logging.basicConfig(level=logging.INFO)
logger=logging.getLogger(__name__)
logger.info("Server Starting")
try:
logger.info("Loading model")
MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
device="cpu"
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
model=AutoModel.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()
logger.info("Model Loaded")
except:
logger.error("Failed to load Model")
model=None
tokenizer=None
app=FastAPI()
#Req and Response Pydantic models
class EmbedRequest(BaseModel):
text : str
class EmbedResponse(BaseModel):
embedding: list[float]
def mean_pooling(model_output, attention_mask):
"""
Performs mean pooling on the last hidden state of the model.
This turns token-level embeddings into a single sentence-level embedding.
"""
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
@app.get("/")
def root_status():
return {"status":"ok","model":model is not None}
@app.post("/embed",response_model=EmbedResponse)
def get_embedding(request: EmbedRequest):
if not model or not tokenizer:
HTTPException(status_code=503,detail="/Tokenizer could not be loaded")
try:
encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
model_output = model(**encoded_input)
# embedding=model.encode(request.text).tolist()
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
embedding_list = normalized_embedding[0].tolist()
return EmbedResponse(embedding=embedding_list)
except Exception as e:
logger.error("Error during embedding generation %s",e)
return HTTPException(status_code=500,detail="Error generating embeddings")