|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoModel |
|
from sentence_transformers import SentenceTransformer, models |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger=logging.getLogger(__name__) |
|
|
|
logger.info("Server Starting") |
|
try: |
|
logger.info("Loading model") |
|
MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final" |
|
device="cpu" |
|
|
|
|
|
|
|
|
|
model=SentenceTransformer(MODEL_NAME, |
|
|
|
device=device |
|
) |
|
|
|
logger.info("Model Loaded") |
|
except Exception as e: |
|
logger.error("Failed to load Model %s",e) |
|
model=None |
|
|
|
app=FastAPI() |
|
|
|
|
|
class EmbedRequest(BaseModel): |
|
text : str |
|
|
|
class EmbedResponse(BaseModel): |
|
embedding: list[float] |
|
|
|
@app.get("/") |
|
def root_status(): |
|
return {"status":"ok","model":model is not None} |
|
|
|
@app.post("/embed",response_model=EmbedResponse) |
|
def get_embedding(request: EmbedRequest): |
|
if model is None: |
|
raise HTTPException(status_code=503,detail="/Model could not be loaded") |
|
try: |
|
|
|
embedding = model.encode(request.text).tolist() |
|
return EmbedResponse(embedding=embedding) |
|
except Exception as e: |
|
logger.error("Error during embedding generation %s",e) |
|
raise HTTPException(status_code=500,detail="Error generating embeddings") |