Sid-the-sloth commited on
Commit
649b044
·
1 Parent(s): ea47ed4

Attemp using Sentence Transformer

Browse files
Files changed (1) hide show
  1. app.py +14 -29
app.py CHANGED
@@ -1,8 +1,7 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModel
4
- import torch
5
- import torch.nn.functional as F
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
@@ -13,17 +12,19 @@ try:
13
  logger.info("Loading model")
14
  MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
15
  device="cpu"
16
- tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
17
- model=AutoModel.from_pretrained(MODEL_NAME)
 
18
 
19
- model.to(device)
20
- model.eval()
 
 
21
 
22
  logger.info("Model Loaded")
23
  except Exception as e:
24
  logger.error("Failed to load Model %s",e)
25
  model=None
26
- tokenizer=None
27
 
28
  app=FastAPI()
29
 
@@ -34,34 +35,18 @@ class EmbedRequest(BaseModel):
34
  class EmbedResponse(BaseModel):
35
  embedding: list[float]
36
 
37
- def mean_pooling(model_output, attention_mask):
38
- """
39
- Performs mean pooling on the last hidden state of the model.
40
- This turns token-level embeddings into a single sentence-level embedding.
41
- """
42
- token_embeddings = model_output.last_hidden_state
43
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
44
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
45
-
46
-
47
  @app.get("/")
48
  def root_status():
49
  return {"status":"ok","model":model is not None}
50
 
51
  @app.post("/embed",response_model=EmbedResponse)
52
  def get_embedding(request: EmbedRequest):
53
- if not model or not tokenizer:
54
- HTTPException(status_code=503,detail="/Tokenizer could not be loaded")
55
  try:
56
- encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
57
- model_output = model(**encoded_input)
58
- # embedding=model.encode(request.text).tolist()
59
- sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
60
-
61
- normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
62
 
63
- embedding_list = normalized_embedding[0].tolist()
64
- return EmbedResponse(embedding=embedding_list)
65
  except Exception as e:
66
  logger.error("Error during embedding generation %s",e)
67
- return HTTPException(status_code=500,detail="Error generating embeddings")
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoModel
4
+ from sentence_transformers import SentenceTransformer, models
 
5
  import logging
6
 
7
  logging.basicConfig(level=logging.INFO)
 
12
  logger.info("Loading model")
13
  MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
14
  device="cpu"
15
+ # tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
16
+ embedding_model=AutoModel.from_pretrained(MODEL_NAME)
17
+ pooling_model=models.Pooling(embedding_model.get_word_embedding_dimension())
18
 
19
+ model=SentenceTransformer(
20
+ modules=[embedding_model,pooling_model],
21
+ device=device
22
+ )
23
 
24
  logger.info("Model Loaded")
25
  except Exception as e:
26
  logger.error("Failed to load Model %s",e)
27
  model=None
 
28
 
29
  app=FastAPI()
30
 
 
35
  class EmbedResponse(BaseModel):
36
  embedding: list[float]
37
 
 
 
 
 
 
 
 
 
 
 
38
  @app.get("/")
39
  def root_status():
40
  return {"status":"ok","model":model is not None}
41
 
42
  @app.post("/embed",response_model=EmbedResponse)
43
  def get_embedding(request: EmbedRequest):
44
+ if model is None:
45
+ raise HTTPException(status_code=503,detail="/Model could not be loaded")
46
  try:
 
 
 
 
 
 
47
 
48
+ embedding = model.encode(request.text).tolist()
49
+ return EmbedResponse(embedding=embedding)
50
  except Exception as e:
51
  logger.error("Error during embedding generation %s",e)
52
+ raise HTTPException(status_code=500,detail="Error generating embeddings")