File size: 4,613 Bytes
d25ee4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import modal
MINUTES = 60  # seconds
MODEL_REPO_ID = "Qwen/Qwen3-Reranker-4B"
rerank_image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install(
        "transformers==4.51.0",
        "huggingface_hub[hf_transfer]",
        "fastapi[standard]",
        "torch"
    )
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 
)

hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")

with rerank_image.imports():
    import torch
    from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM

app = modal.App("qwen3-rerank-service")

@app.function(image=rerank_image,volumes = {
    "/root/.cache/huggingface":hf_cache_vol
})
def download_model():
    from huggingface_hub import snapshot_download
    loc = snapshot_download(repo_id=MODEL_REPO_ID)
    print(f"Saved model to {loc}")

@app.cls(image=rerank_image,gpu="A100-40GB",volumes = {
    "/root/.cache/huggingface":hf_cache_vol
})
class RerankerService:
    
    @modal.enter()
    def load_model(self):
        self.tokenizer =  AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B", padding_side='left')
        self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-4B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()

    @modal.method()
    def rank(self,query,documents):
        max_length = 8192
        prefix = "<|im_start|>system\nJudge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
        suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
        suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
        token_false_id = self.tokenizer.convert_tokens_to_ids("no")
        token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
        def format_instruction(instruction, query, doc):
            if instruction is None:
                instruction = 'Given a web search query, retrieve relevant passages that answer the query'
            return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"

        def process_inputs(pairs):
            inputs = self.tokenizer(
                pairs, padding=False, truncation='longest_first',
                return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
            )
            for i, ele in enumerate(inputs['input_ids']):
                inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
            inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
            for key in inputs:
                inputs[key] = inputs[key].to(self.model.device)
            return inputs

        @torch.no_grad()
        def compute_logits(inputs):
            logits = self.model(**inputs).logits[:, -1, :]
            true_vector = logits[:, token_true_id]
            false_vector = logits[:, token_false_id]
            batch_scores = torch.stack([false_vector, true_vector], dim=1)
            batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
            return batch_scores[:, 1].exp().tolist()
        
        instruction = "Given a user query find the usefull tables in order to build an sql request"
        pairs = [format_instruction(instruction, query, doc) for doc in documents]

        inputs = process_inputs(pairs)
        scores = compute_logits(inputs)

        return scores

@app.function(
    image=modal.Image.debian_slim(python_version="3.12")
    .pip_install("fastapi[standard]==0.115.4")
)
@modal.asgi_app(label="rerank-endpoint")
def fastapi_app():
    from fastapi import FastAPI, Request, Response
    from fastapi.staticfiles import StaticFiles

    web_app = FastAPI()

    # The endpoint for the prediction function takes an image as a
    # [data URI](https://en.wikipedia.org/wiki/Data_URI_scheme)
    # and returns another image, also as a data URI:

    @web_app.post("/predict")
    async def predict(request: Request):
        # Takes a webcam image as a datauri, returns a bounding box image as a datauri
        body = await request.body()
        output_data = RerankerService().rank.remote("What is the capital of China?",["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",])
        return Response(content=output_data)

    return web_app