File size: 5,176 Bytes
d25ee4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import modal
import logging
app = modal.App("qwen-reranker-vllm")
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
vllm_cache_vol = modal.Volume.from_name("vllm-cache")
MINUTES = 60 # seconds
vllm_image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"vllm==0.8.5",
"transformers",
"torch",
"fastapi[all]",
"pydantic"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
with vllm_image.imports():
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.inputs.data import TokensPrompt
import torch
import math
@app.cls(image=vllm_image,
gpu="A100-40GB",
scaledown_window=15 * MINUTES, # how long should we stay up with no requests?
timeout=10 * MINUTES,
volumes = {
"/root/.cache/huggingface":hf_cache_vol,
"/root/.cache/vllm": vllm_cache_vol,
})
class Reranker:
@modal.enter()
def load_reranker(self):
logging.info("in the rank function")
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B")
self.tokenizer.padding_side = "left"
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = LLM(
model="Qwen/Qwen3-Reranker-4B",
tensor_parallel_size=torch.cuda.device_count(),
max_model_len=10000,
enable_prefix_caching=True,
gpu_memory_utilization=0.8
)
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
self.max_length = 8192
self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0]
self.sampling_params = SamplingParams(
temperature=0,
max_tokens=1,
logprobs=20,
allowed_token_ids=[self.true_token, self.false_token],
)
def format_instruction(self, instruction, query, doc):
return [
{"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""},
{"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"}
]
def process_inputs(self,pairs, instruction):
messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
messages = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=False, enable_thinking=False
)
messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages]
messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages]
return messages
def compute_logits(self, messages):
outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False)
scores = []
for i in range(len(outputs)):
final_logits = outputs[i].outputs[0].logprobs[-1]
token_count = len(outputs[i].outputs[0].token_ids)
if self.true_token not in final_logits:
true_logit = -10
else:
true_logit = final_logits[self.true_token].logprob
if self.false_token not in final_logits:
false_logit = -10
else:
false_logit = final_logits[self.false_token].logprob
true_score = math.exp(true_logit)
false_score = math.exp(false_logit)
score = true_score / (true_score + false_score)
scores.append(score)
return scores
@modal.method()
def rerank(self, query, documents,task):
#task = 'Given a web search query, retrieve relevant passages that answer the query'
pairs = [(query, doc) for doc in documents]
inputs = self.process_inputs(pairs, task)
scores = self.compute_logits( inputs)
return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)]
@app.function(
image=modal.Image.debian_slim(python_version="3.12")
.pip_install("fastapi[standard]==0.115.4","pydantic")
)
@modal.asgi_app(label="rerank-endpoint")
def fastapi_app():
from pydantic import BaseModel
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
from typing import List
web_app = FastAPI()
reranker = Reranker()
class ScoringResult(BaseModel):
score: float
content: str
class RankingRequest(BaseModel):
task:str
query: str
documents: List[str]
@web_app.post("/rank",response_model=List[ScoringResult])
async def predict(payload: RankingRequest):
logging.info("call the rank function")
query = payload.query
documents = payload.documents
task = payload.task
output_data = reranker.rerank.remote(query,documents,task)
return JSONResponse(content=output_data)
return web_app |