mcp-data-analyst / modal /rerank_service_vllm.py
Jacqkues's picture
Upload 11 files
d25ee4b verified
import modal
import logging
app = modal.App("qwen-reranker-vllm")
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
vllm_cache_vol = modal.Volume.from_name("vllm-cache")
MINUTES = 60 # seconds
vllm_image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"vllm==0.8.5",
"transformers",
"torch",
"fastapi[all]",
"pydantic"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
with vllm_image.imports():
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.inputs.data import TokensPrompt
import torch
import math
@app.cls(image=vllm_image,
gpu="A100-40GB",
scaledown_window=15 * MINUTES, # how long should we stay up with no requests?
timeout=10 * MINUTES,
volumes = {
"/root/.cache/huggingface":hf_cache_vol,
"/root/.cache/vllm": vllm_cache_vol,
})
class Reranker:
@modal.enter()
def load_reranker(self):
logging.info("in the rank function")
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B")
self.tokenizer.padding_side = "left"
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = LLM(
model="Qwen/Qwen3-Reranker-4B",
tensor_parallel_size=torch.cuda.device_count(),
max_model_len=10000,
enable_prefix_caching=True,
gpu_memory_utilization=0.8
)
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
self.max_length = 8192
self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0]
self.sampling_params = SamplingParams(
temperature=0,
max_tokens=1,
logprobs=20,
allowed_token_ids=[self.true_token, self.false_token],
)
def format_instruction(self, instruction, query, doc):
return [
{"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""},
{"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"}
]
def process_inputs(self,pairs, instruction):
messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
messages = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=False, enable_thinking=False
)
messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages]
messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages]
return messages
def compute_logits(self, messages):
outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False)
scores = []
for i in range(len(outputs)):
final_logits = outputs[i].outputs[0].logprobs[-1]
token_count = len(outputs[i].outputs[0].token_ids)
if self.true_token not in final_logits:
true_logit = -10
else:
true_logit = final_logits[self.true_token].logprob
if self.false_token not in final_logits:
false_logit = -10
else:
false_logit = final_logits[self.false_token].logprob
true_score = math.exp(true_logit)
false_score = math.exp(false_logit)
score = true_score / (true_score + false_score)
scores.append(score)
return scores
@modal.method()
def rerank(self, query, documents,task):
#task = 'Given a web search query, retrieve relevant passages that answer the query'
pairs = [(query, doc) for doc in documents]
inputs = self.process_inputs(pairs, task)
scores = self.compute_logits( inputs)
return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)]
@app.function(
image=modal.Image.debian_slim(python_version="3.12")
.pip_install("fastapi[standard]==0.115.4","pydantic")
)
@modal.asgi_app(label="rerank-endpoint")
def fastapi_app():
from pydantic import BaseModel
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
from typing import List
web_app = FastAPI()
reranker = Reranker()
class ScoringResult(BaseModel):
score: float
content: str
class RankingRequest(BaseModel):
task:str
query: str
documents: List[str]
@web_app.post("/rank",response_model=List[ScoringResult])
async def predict(payload: RankingRequest):
logging.info("call the rank function")
query = payload.query
documents = payload.documents
task = payload.task
output_data = reranker.rerank.remote(query,documents,task)
return JSONResponse(content=output_data)
return web_app