mcp-data-analyst / modal /rerank_service.py
Jacqkues's picture
Upload 11 files
d25ee4b verified
import modal
MINUTES = 60 # seconds
MODEL_REPO_ID = "Qwen/Qwen3-Reranker-4B"
rerank_image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"transformers==4.51.0",
"huggingface_hub[hf_transfer]",
"fastapi[standard]",
"torch"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
with rerank_image.imports():
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
app = modal.App("qwen3-rerank-service")
@app.function(image=rerank_image,volumes = {
"/root/.cache/huggingface":hf_cache_vol
})
def download_model():
from huggingface_hub import snapshot_download
loc = snapshot_download(repo_id=MODEL_REPO_ID)
print(f"Saved model to {loc}")
@app.cls(image=rerank_image,gpu="A100-40GB",volumes = {
"/root/.cache/huggingface":hf_cache_vol
})
class RerankerService:
@modal.enter()
def load_model(self):
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B", padding_side='left')
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-4B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()
@modal.method()
def rank(self,query,documents):
max_length = 8192
prefix = "<|im_start|>system\nJudge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
token_false_id = self.tokenizer.convert_tokens_to_ids("no")
token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
def format_instruction(instruction, query, doc):
if instruction is None:
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
def process_inputs(pairs):
inputs = self.tokenizer(
pairs, padding=False, truncation='longest_first',
return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad()
def compute_logits(inputs):
logits = self.model(**inputs).logits[:, -1, :]
true_vector = logits[:, token_true_id]
false_vector = logits[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
return batch_scores[:, 1].exp().tolist()
instruction = "Given a user query find the usefull tables in order to build an sql request"
pairs = [format_instruction(instruction, query, doc) for doc in documents]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)
return scores
@app.function(
image=modal.Image.debian_slim(python_version="3.12")
.pip_install("fastapi[standard]==0.115.4")
)
@modal.asgi_app(label="rerank-endpoint")
def fastapi_app():
from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
web_app = FastAPI()
# The endpoint for the prediction function takes an image as a
# [data URI](https://en.wikipedia.org/wiki/Data_URI_scheme)
# and returns another image, also as a data URI:
@web_app.post("/predict")
async def predict(request: Request):
# Takes a webcam image as a datauri, returns a bounding box image as a datauri
body = await request.body()
output_data = RerankerService().rank.remote("What is the capital of China?",["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",])
return Response(content=output_data)
return web_app