import modal import logging app = modal.App("qwen-reranker-vllm") hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol") vllm_cache_vol = modal.Volume.from_name("vllm-cache") MINUTES = 60 # seconds vllm_image = ( modal.Image.debian_slim(python_version="3.12") .pip_install( "vllm==0.8.5", "transformers", "torch", "fastapi[all]", "pydantic" ) .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) ) with vllm_image.imports(): from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.inputs.data import TokensPrompt import torch import math @app.cls(image=vllm_image, gpu="A100-40GB", scaledown_window=15 * MINUTES, # how long should we stay up with no requests? timeout=10 * MINUTES, volumes = { "/root/.cache/huggingface":hf_cache_vol, "/root/.cache/vllm": vllm_cache_vol, }) class Reranker: @modal.enter() def load_reranker(self): logging.info("in the rank function") self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B") self.tokenizer.padding_side = "left" self.tokenizer.pad_token = self.tokenizer.eos_token self.model = LLM( model="Qwen/Qwen3-Reranker-4B", tensor_parallel_size=torch.cuda.device_count(), max_model_len=10000, enable_prefix_caching=True, gpu_memory_utilization=0.8 ) self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) self.max_length = 8192 self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0] self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0] self.sampling_params = SamplingParams( temperature=0, max_tokens=1, logprobs=20, allowed_token_ids=[self.true_token, self.false_token], ) def format_instruction(self, instruction, query, doc): return [ {"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""}, {"role": "user", "content": f": {instruction}\n\n: {query}\n\n: {doc}"} ] def process_inputs(self,pairs, instruction): messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs] messages = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, enable_thinking=False ) messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages] messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages] return messages def compute_logits(self, messages): outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False) scores = [] for i in range(len(outputs)): final_logits = outputs[i].outputs[0].logprobs[-1] token_count = len(outputs[i].outputs[0].token_ids) if self.true_token not in final_logits: true_logit = -10 else: true_logit = final_logits[self.true_token].logprob if self.false_token not in final_logits: false_logit = -10 else: false_logit = final_logits[self.false_token].logprob true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) scores.append(score) return scores @modal.method() def rerank(self, query, documents,task): #task = 'Given a web search query, retrieve relevant passages that answer the query' pairs = [(query, doc) for doc in documents] inputs = self.process_inputs(pairs, task) scores = self.compute_logits( inputs) return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)] @app.function( image=modal.Image.debian_slim(python_version="3.12") .pip_install("fastapi[standard]==0.115.4","pydantic") ) @modal.asgi_app(label="rerank-endpoint") def fastapi_app(): from pydantic import BaseModel from fastapi import FastAPI, Request, Response from fastapi.responses import JSONResponse from typing import List web_app = FastAPI() reranker = Reranker() class ScoringResult(BaseModel): score: float content: str class RankingRequest(BaseModel): task:str query: str documents: List[str] @web_app.post("/rank",response_model=List[ScoringResult]) async def predict(payload: RankingRequest): logging.info("call the rank function") query = payload.query documents = payload.documents task = payload.task output_data = reranker.rerank.remote(query,documents,task) return JSONResponse(content=output_data) return web_app