|
import modal |
|
import logging |
|
app = modal.App("qwen-reranker-vllm") |
|
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol") |
|
vllm_cache_vol = modal.Volume.from_name("vllm-cache") |
|
MINUTES = 60 |
|
|
|
vllm_image = ( |
|
modal.Image.debian_slim(python_version="3.12") |
|
.pip_install( |
|
"vllm==0.8.5", |
|
"transformers", |
|
"torch", |
|
"fastapi[all]", |
|
"pydantic" |
|
) |
|
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) |
|
) |
|
|
|
|
|
with vllm_image.imports(): |
|
from transformers import AutoTokenizer |
|
from vllm import LLM, SamplingParams |
|
from vllm.inputs.data import TokensPrompt |
|
import torch |
|
import math |
|
|
|
@app.cls(image=vllm_image, |
|
gpu="A100-40GB", |
|
scaledown_window=15 * MINUTES, |
|
timeout=10 * MINUTES, |
|
volumes = { |
|
"/root/.cache/huggingface":hf_cache_vol, |
|
"/root/.cache/vllm": vllm_cache_vol, |
|
}) |
|
class Reranker: |
|
@modal.enter() |
|
def load_reranker(self): |
|
logging.info("in the rank function") |
|
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B") |
|
self.tokenizer.padding_side = "left" |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.model = LLM( |
|
model="Qwen/Qwen3-Reranker-4B", |
|
tensor_parallel_size=torch.cuda.device_count(), |
|
max_model_len=10000, |
|
enable_prefix_caching=True, |
|
gpu_memory_utilization=0.8 |
|
) |
|
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) |
|
self.max_length = 8192 |
|
self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0] |
|
self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0] |
|
self.sampling_params = SamplingParams( |
|
temperature=0, |
|
max_tokens=1, |
|
logprobs=20, |
|
allowed_token_ids=[self.true_token, self.false_token], |
|
) |
|
|
|
def format_instruction(self, instruction, query, doc): |
|
return [ |
|
{"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""}, |
|
{"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"} |
|
] |
|
|
|
def process_inputs(self,pairs, instruction): |
|
messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs] |
|
messages = self.tokenizer.apply_chat_template( |
|
messages, tokenize=True, add_generation_prompt=False, enable_thinking=False |
|
) |
|
messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages] |
|
messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages] |
|
return messages |
|
|
|
def compute_logits(self, messages): |
|
outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False) |
|
scores = [] |
|
for i in range(len(outputs)): |
|
final_logits = outputs[i].outputs[0].logprobs[-1] |
|
token_count = len(outputs[i].outputs[0].token_ids) |
|
if self.true_token not in final_logits: |
|
true_logit = -10 |
|
else: |
|
true_logit = final_logits[self.true_token].logprob |
|
if self.false_token not in final_logits: |
|
false_logit = -10 |
|
else: |
|
false_logit = final_logits[self.false_token].logprob |
|
true_score = math.exp(true_logit) |
|
false_score = math.exp(false_logit) |
|
score = true_score / (true_score + false_score) |
|
scores.append(score) |
|
return scores |
|
|
|
@modal.method() |
|
def rerank(self, query, documents,task): |
|
|
|
pairs = [(query, doc) for doc in documents] |
|
inputs = self.process_inputs(pairs, task) |
|
scores = self.compute_logits( inputs) |
|
|
|
return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)] |
|
|
|
@app.function( |
|
image=modal.Image.debian_slim(python_version="3.12") |
|
.pip_install("fastapi[standard]==0.115.4","pydantic") |
|
) |
|
@modal.asgi_app(label="rerank-endpoint") |
|
def fastapi_app(): |
|
from pydantic import BaseModel |
|
from fastapi import FastAPI, Request, Response |
|
from fastapi.responses import JSONResponse |
|
from typing import List |
|
|
|
web_app = FastAPI() |
|
reranker = Reranker() |
|
class ScoringResult(BaseModel): |
|
score: float |
|
content: str |
|
|
|
class RankingRequest(BaseModel): |
|
task:str |
|
query: str |
|
documents: List[str] |
|
@web_app.post("/rank",response_model=List[ScoringResult]) |
|
async def predict(payload: RankingRequest): |
|
logging.info("call the rank function") |
|
query = payload.query |
|
documents = payload.documents |
|
task = payload.task |
|
output_data = reranker.rerank.remote(query,documents,task) |
|
return JSONResponse(content=output_data) |
|
|
|
return web_app |