Sid-the-sloth's picture
loaded model using AutoModel and manual pooling
ea47ed4
raw
history blame
2.3 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import logging
logging.basicConfig(level=logging.INFO)
logger=logging.getLogger(__name__)
logger.info("Server Starting")
try:
logger.info("Loading model")
MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
device="cpu"
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
model=AutoModel.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()
logger.info("Model Loaded")
except Exception as e:
logger.error("Failed to load Model %s",e)
model=None
tokenizer=None
app=FastAPI()
#Req and Response Pydantic models
class EmbedRequest(BaseModel):
text : str
class EmbedResponse(BaseModel):
embedding: list[float]
def mean_pooling(model_output, attention_mask):
"""
Performs mean pooling on the last hidden state of the model.
This turns token-level embeddings into a single sentence-level embedding.
"""
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
@app.get("/")
def root_status():
return {"status":"ok","model":model is not None}
@app.post("/embed",response_model=EmbedResponse)
def get_embedding(request: EmbedRequest):
if not model or not tokenizer:
HTTPException(status_code=503,detail="/Tokenizer could not be loaded")
try:
encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
model_output = model(**encoded_input)
# embedding=model.encode(request.text).tolist()
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
embedding_list = normalized_embedding[0].tolist()
return EmbedResponse(embedding=embedding_list)
except Exception as e:
logger.error("Error during embedding generation %s",e)
return HTTPException(status_code=500,detail="Error generating embeddings")