|
import modal |
|
MINUTES = 60 |
|
MODEL_REPO_ID = "Qwen/Qwen3-Reranker-4B" |
|
rerank_image = ( |
|
modal.Image.debian_slim(python_version="3.12") |
|
.pip_install( |
|
"transformers==4.51.0", |
|
"huggingface_hub[hf_transfer]", |
|
"fastapi[standard]", |
|
"torch" |
|
) |
|
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) |
|
) |
|
|
|
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol") |
|
|
|
with rerank_image.imports(): |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM |
|
|
|
app = modal.App("qwen3-rerank-service") |
|
|
|
@app.function(image=rerank_image,volumes = { |
|
"/root/.cache/huggingface":hf_cache_vol |
|
}) |
|
def download_model(): |
|
from huggingface_hub import snapshot_download |
|
loc = snapshot_download(repo_id=MODEL_REPO_ID) |
|
print(f"Saved model to {loc}") |
|
|
|
@app.cls(image=rerank_image,gpu="A100-40GB",volumes = { |
|
"/root/.cache/huggingface":hf_cache_vol |
|
}) |
|
class RerankerService: |
|
|
|
@modal.enter() |
|
def load_model(self): |
|
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B", padding_side='left') |
|
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-4B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval() |
|
|
|
@modal.method() |
|
def rank(self,query,documents): |
|
max_length = 8192 |
|
prefix = "<|im_start|>system\nJudge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" |
|
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
|
suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) |
|
token_false_id = self.tokenizer.convert_tokens_to_ids("no") |
|
token_true_id = self.tokenizer.convert_tokens_to_ids("yes") |
|
def format_instruction(instruction, query, doc): |
|
if instruction is None: |
|
instruction = 'Given a web search query, retrieve relevant passages that answer the query' |
|
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}" |
|
|
|
def process_inputs(pairs): |
|
inputs = self.tokenizer( |
|
pairs, padding=False, truncation='longest_first', |
|
return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens) |
|
) |
|
for i, ele in enumerate(inputs['input_ids']): |
|
inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens |
|
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) |
|
for key in inputs: |
|
inputs[key] = inputs[key].to(self.model.device) |
|
return inputs |
|
|
|
@torch.no_grad() |
|
def compute_logits(inputs): |
|
logits = self.model(**inputs).logits[:, -1, :] |
|
true_vector = logits[:, token_true_id] |
|
false_vector = logits[:, token_false_id] |
|
batch_scores = torch.stack([false_vector, true_vector], dim=1) |
|
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) |
|
return batch_scores[:, 1].exp().tolist() |
|
|
|
instruction = "Given a user query find the usefull tables in order to build an sql request" |
|
pairs = [format_instruction(instruction, query, doc) for doc in documents] |
|
|
|
inputs = process_inputs(pairs) |
|
scores = compute_logits(inputs) |
|
|
|
return scores |
|
|
|
@app.function( |
|
image=modal.Image.debian_slim(python_version="3.12") |
|
.pip_install("fastapi[standard]==0.115.4") |
|
) |
|
@modal.asgi_app(label="rerank-endpoint") |
|
def fastapi_app(): |
|
from fastapi import FastAPI, Request, Response |
|
from fastapi.staticfiles import StaticFiles |
|
|
|
web_app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
@web_app.post("/predict") |
|
async def predict(request: Request): |
|
|
|
body = await request.body() |
|
output_data = RerankerService().rank.remote("What is the capital of China?",["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",]) |
|
return Response(content=output_data) |
|
|
|
return web_app |