File size: 5,176 Bytes
d25ee4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import modal
import logging
app = modal.App("qwen-reranker-vllm")
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
vllm_cache_vol = modal.Volume.from_name("vllm-cache")
MINUTES = 60  # seconds

vllm_image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install(
        "vllm==0.8.5",
        "transformers",
        "torch",
        "fastapi[all]",
        "pydantic"
    )
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)


with vllm_image.imports():
    from transformers import AutoTokenizer
    from vllm import LLM, SamplingParams
    from vllm.inputs.data import TokensPrompt
    import torch
    import math

@app.cls(image=vllm_image, 
         gpu="A100-40GB", 
         scaledown_window=15 * MINUTES,  # how long should we stay up with no requests?
         timeout=10 * MINUTES,
         volumes = {
    "/root/.cache/huggingface":hf_cache_vol,
    "/root/.cache/vllm": vllm_cache_vol,
})
class Reranker:
    @modal.enter()
    def load_reranker(self):
        logging.info("in the rank function")
        self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B")
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = LLM(
            model="Qwen/Qwen3-Reranker-4B",
            tensor_parallel_size=torch.cuda.device_count(),
            max_model_len=10000,
            enable_prefix_caching=True,
            gpu_memory_utilization=0.8
        )
        self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
        self.max_length = 8192
        self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
        self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0]
        self.sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=20,
            allowed_token_ids=[self.true_token, self.false_token],
        )

    def format_instruction(self, instruction, query, doc):
        return [
            {"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""},
            {"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"}
        ]

    def process_inputs(self,pairs, instruction):
        messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
        messages =  self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=False, enable_thinking=False
        )
        messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages]
        messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages]
        return messages

    def compute_logits(self, messages):
        outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False)
        scores = []
        for i in range(len(outputs)):
            final_logits = outputs[i].outputs[0].logprobs[-1]
            token_count = len(outputs[i].outputs[0].token_ids)
            if self.true_token not in final_logits:
                true_logit = -10
            else:
                true_logit = final_logits[self.true_token].logprob
            if self.false_token not in final_logits:
                false_logit = -10
            else:
                false_logit = final_logits[self.false_token].logprob
            true_score = math.exp(true_logit)
            false_score = math.exp(false_logit)
            score = true_score / (true_score + false_score)
            scores.append(score)
        return scores

    @modal.method()
    def rerank(self, query, documents,task):
        #task = 'Given a web search query, retrieve relevant passages that answer the query'
        pairs = [(query, doc) for doc in documents]
        inputs = self.process_inputs(pairs, task)
        scores = self.compute_logits( inputs)

        return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)]

@app.function(
    image=modal.Image.debian_slim(python_version="3.12")
    .pip_install("fastapi[standard]==0.115.4","pydantic")
)
@modal.asgi_app(label="rerank-endpoint")
def fastapi_app():
    from pydantic import BaseModel
    from fastapi import FastAPI, Request, Response
    from fastapi.responses import JSONResponse
    from typing import List

    web_app = FastAPI()
    reranker = Reranker()
    class ScoringResult(BaseModel):
        score: float
        content: str

    class RankingRequest(BaseModel):
        task:str
        query: str
        documents: List[str]
    @web_app.post("/rank",response_model=List[ScoringResult])
    async def predict(payload: RankingRequest):
        logging.info("call the rank function")
        query = payload.query
        documents = payload.documents
        task = payload.task
        output_data = reranker.rerank.remote(query,documents,task)
        return JSONResponse(content=output_data)

    return web_app