File size: 2,278 Bytes
b5db444
 
df5822b
 
 
b5db444
 
 
 
 
 
 
 
df5822b
 
 
 
 
 
 
 
b5db444
 
 
 
df5822b
b5db444
 
 
 
 
 
 
 
 
 
df5822b
 
 
 
 
 
 
 
 
 
b5db444
 
 
 
 
 
df5822b
 
b5db444
df5822b
 
8ebacc3
df5822b
 
 
 
 
8ebacc3
b5db444
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from torch.nn.functional import F
import logging

logging.basicConfig(level=logging.INFO)
logger=logging.getLogger(__name__)

logger.info("Server Starting")
try:
    logger.info("Loading model")
    MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
    device="cpu"
    tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
    model=AutoModel.from_pretrained(MODEL_NAME)

    model.to(device)
    model.eval()

    logger.info("Model Loaded")
except:
    logger.error("Failed to load Model")
    model=None
    tokenizer=None

app=FastAPI()

#Req and Response Pydantic models
class EmbedRequest(BaseModel):
    text : str

class EmbedResponse(BaseModel):
    embedding: list[float]

def mean_pooling(model_output, attention_mask):
    """
    Performs mean pooling on the last hidden state of the model.
    This turns token-level embeddings into a single sentence-level embedding.
    """
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


@app.get("/")
def root_status():
    return {"status":"ok","model":model is not None}

@app.post("/embed",response_model=EmbedResponse)
def get_embedding(request: EmbedRequest):
    if not model or not tokenizer:
        HTTPException(status_code=503,detail="/Tokenizer could not be loaded")
    try:
        encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
        model_output = model(**encoded_input)
        # embedding=model.encode(request.text).tolist()
        sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
        
        normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
        
        embedding_list = normalized_embedding[0].tolist()
        return EmbedResponse(embedding=embedding_list)
    except Exception as e:
        logger.error("Error during embedding generation %s",e)
        return HTTPException(status_code=500,detail="Error generating embeddings")