Sid-the-sloth commited on
Commit
df5822b
·
1 Parent(s): b5db444

loaded model using AutoModel and manual pooling

Browse files
Files changed (1) hide show
  1. app.py +31 -4
app.py CHANGED
@@ -1,6 +1,8 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
 
 
4
  import logging
5
 
6
  logging.basicConfig(level=logging.INFO)
@@ -9,11 +11,19 @@ logger=logging.getLogger(__name__)
9
  logger.info("Server Starting")
10
  try:
11
  logger.info("Loading model")
12
- model=SentenceTransformer("Sid-the-sloth/leetcode_unixcoder_final")
 
 
 
 
 
 
 
13
  logger.info("Model Loaded")
14
  except:
15
  logger.error("Failed to load Model")
16
  model=None
 
17
 
18
  app=FastAPI()
19
 
@@ -24,16 +34,33 @@ class EmbedRequest(BaseModel):
24
  class EmbedResponse(BaseModel):
25
  embedding: list[float]
26
 
 
 
 
 
 
 
 
 
 
 
27
  @app.get("/")
28
  def root_status():
29
  return {"status":"ok","model":model is not None}
30
 
31
  @app.post("/embed",response_model=EmbedResponse)
32
  def get_embedding(request: EmbedRequest):
33
- if model is None:
34
- HTTPException(status_code=503,detail="Model could not be loaded")
35
  try:
 
 
36
  embedding=model.encode(request.text).tolist()
 
 
 
 
 
37
  return EmbedResponse(embedding=embedding)
38
  except Exception as e:
39
  logger.error("Error during embedding generation %s",e)
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
+ from torch.nn.functional import F
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
 
11
  logger.info("Server Starting")
12
  try:
13
  logger.info("Loading model")
14
+ MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
15
+ device="cpu"
16
+ tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model=AutoModel.from_pretrained(MODEL_NAME)
18
+
19
+ model.to(device)
20
+ model.eval()
21
+
22
  logger.info("Model Loaded")
23
  except:
24
  logger.error("Failed to load Model")
25
  model=None
26
+ tokenizer=None
27
 
28
  app=FastAPI()
29
 
 
34
  class EmbedResponse(BaseModel):
35
  embedding: list[float]
36
 
37
+ def mean_pooling(model_output, attention_mask):
38
+ """
39
+ Performs mean pooling on the last hidden state of the model.
40
+ This turns token-level embeddings into a single sentence-level embedding.
41
+ """
42
+ token_embeddings = model_output.last_hidden_state
43
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
44
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
45
+
46
+
47
  @app.get("/")
48
  def root_status():
49
  return {"status":"ok","model":model is not None}
50
 
51
  @app.post("/embed",response_model=EmbedResponse)
52
  def get_embedding(request: EmbedRequest):
53
+ if not model or not tokenizer:
54
+ HTTPException(status_code=503,detail="/Tokenizer could not be loaded")
55
  try:
56
+ encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
57
+ model_output = model(**encoded_input)
58
  embedding=model.encode(request.text).tolist()
59
+ sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
60
+
61
+ normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
62
+
63
+ embedding_list = normalized_embedding[0].tolist()
64
  return EmbedResponse(embedding=embedding)
65
  except Exception as e:
66
  logger.error("Error during embedding generation %s",e)