import modal MINUTES = 60 # seconds MODEL_REPO_ID = "Qwen/Qwen3-Reranker-4B" rerank_image = ( modal.Image.debian_slim(python_version="3.12") .pip_install( "transformers==4.51.0", "huggingface_hub[hf_transfer]", "fastapi[standard]", "torch" ) .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) ) hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol") with rerank_image.imports(): import torch from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM app = modal.App("qwen3-rerank-service") @app.function(image=rerank_image,volumes = { "/root/.cache/huggingface":hf_cache_vol }) def download_model(): from huggingface_hub import snapshot_download loc = snapshot_download(repo_id=MODEL_REPO_ID) print(f"Saved model to {loc}") @app.cls(image=rerank_image,gpu="A100-40GB",volumes = { "/root/.cache/huggingface":hf_cache_vol }) class RerankerService: @modal.enter() def load_model(self): self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B", padding_side='left') self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-4B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval() @modal.method() def rank(self,query,documents): max_length = 8192 prefix = "<|im_start|>system\nJudge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) token_false_id = self.tokenizer.convert_tokens_to_ids("no") token_true_id = self.tokenizer.convert_tokens_to_ids("yes") def format_instruction(instruction, query, doc): if instruction is None: instruction = 'Given a web search query, retrieve relevant passages that answer the query' return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}" def process_inputs(pairs): inputs = self.tokenizer( pairs, padding=False, truncation='longest_first', return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens) ) for i, ele in enumerate(inputs['input_ids']): inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs @torch.no_grad() def compute_logits(inputs): logits = self.model(**inputs).logits[:, -1, :] true_vector = logits[:, token_true_id] false_vector = logits[:, token_false_id] batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) return batch_scores[:, 1].exp().tolist() instruction = "Given a user query find the usefull tables in order to build an sql request" pairs = [format_instruction(instruction, query, doc) for doc in documents] inputs = process_inputs(pairs) scores = compute_logits(inputs) return scores @app.function( image=modal.Image.debian_slim(python_version="3.12") .pip_install("fastapi[standard]==0.115.4") ) @modal.asgi_app(label="rerank-endpoint") def fastapi_app(): from fastapi import FastAPI, Request, Response from fastapi.staticfiles import StaticFiles web_app = FastAPI() # The endpoint for the prediction function takes an image as a # [data URI](https://en.wikipedia.org/wiki/Data_URI_scheme) # and returns another image, also as a data URI: @web_app.post("/predict") async def predict(request: Request): # Takes a webcam image as a datauri, returns a bounding box image as a datauri body = await request.body() output_data = RerankerService().rank.remote("What is the capital of China?",["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",]) return Response(content=output_data) return web_app