|
import modal |
|
from fastapi import HTTPException |
|
from pydantic import BaseModel, Field |
|
from typing import Optional, Union, List, Dict, Any |
|
|
|
image = ( |
|
modal.Image.debian_slim() |
|
.pip_install([ |
|
"torch", |
|
"transformers>=4.51.0", |
|
"fastapi[standard]", |
|
"accelerate", |
|
"tokenizers" |
|
]) |
|
) |
|
|
|
app = modal.App("qwen-api", image=image) |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
message: str |
|
max_tokens: Optional[int] = 16384 |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 0.9 |
|
strip_thinking: Optional[bool] = False |
|
|
|
class ChatResponse(BaseModel): |
|
response: str |
|
tokens_used: Optional[int] = None |
|
input_tokens: Optional[int] = None |
|
model_name: str = "Qwen/Qwen3-4B" |
|
|
|
|
|
@app.cls( |
|
image=image, |
|
gpu="A10G", |
|
scaledown_window=300, |
|
timeout=3600, |
|
enable_memory_snapshot=True, |
|
) |
|
class QwenModel: |
|
|
|
@modal.enter() |
|
def setup(self): |
|
print("Loading Qwen/Qwen3-4B model...") |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
model_name = "Qwen/Qwen3-4B" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
print("Model loaded successfully!") |
|
|
|
def _strip_thinking_tags(self, text: str) -> str: |
|
"""Strip <think> sections from the response to save tokens""" |
|
import re |
|
|
|
return re.sub(r'<think>.*?(?:</think>|$)', '', text, flags=re.DOTALL) |
|
|
|
@modal.method() |
|
def generate_response(self, message: str, max_tokens: int = 16384, |
|
temperature: float = 0.7, top_p: float = 0.9, |
|
strip_thinking: bool = False): |
|
"""Generate a response using the Qwen model""" |
|
try: |
|
import torch |
|
|
|
|
|
messages = [ |
|
{"role": "user", "content": message} |
|
] |
|
|
|
|
|
text = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) |
|
input_token_count = len(model_inputs.input_ids[0]) |
|
|
|
|
|
generation_kwargs = { |
|
**model_inputs, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"do_sample": True, |
|
"pad_token_id": self.tokenizer.eos_token_id, |
|
"max_new_tokens": max_tokens if max_tokens is not None else 16384, |
|
"repetition_penalty": 1.0, |
|
} |
|
|
|
print(f"Generating with settings: max_new_tokens={generation_kwargs.get('max_new_tokens')}") |
|
print(f"Input token count: {input_token_count}") |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = self.model.generate(**generation_kwargs) |
|
|
|
|
|
input_length = model_inputs.input_ids.shape[1] |
|
response_ids = generated_ids[0][input_length:] |
|
response = self.tokenizer.decode(response_ids, skip_special_tokens=True) |
|
|
|
|
|
if strip_thinking: |
|
response = self._strip_thinking_tags(response) |
|
|
|
output_token_count = len(response_ids) |
|
print(f"Generated response with {output_token_count} tokens") |
|
|
|
return { |
|
"response": response.strip(), |
|
"tokens_used": output_token_count, |
|
"input_tokens": input_token_count, |
|
"model_name": "Qwen/Qwen3-4B" |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error during generation: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
|
|
@app.function(image=image, timeout=180) |
|
@modal.fastapi_endpoint(method="POST") |
|
def chat(request: ChatRequest): |
|
""" |
|
Chat endpoint for Qwen3-4B model |
|
|
|
Example usage: |
|
curl -X POST "https://your-modal-url/" \ |
|
-H "Content-Type: application/json" \ |
|
-d '{"message": "Hello, how are you?"}' |
|
""" |
|
try: |
|
print(f"Received request: message length={len(request.message)}, max_tokens={request.max_tokens}, strip_thinking={request.strip_thinking}") |
|
|
|
|
|
model = QwenModel() |
|
|
|
|
|
result = model.generate_response.remote( |
|
message=request.message, |
|
max_tokens=request.max_tokens, |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
strip_thinking=request.strip_thinking |
|
) |
|
|
|
print(f"Returning response: length={len(result['response'])}, output_tokens={result.get('tokens_used')}, input_tokens={result.get('input_tokens')}") |
|
|
|
return ChatResponse( |
|
response=result["response"], |
|
tokens_used=result["tokens_used"], |
|
input_tokens=result["input_tokens"], |
|
model_name=result["model_name"] |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error in chat endpoint: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("To deploy this app, run:") |
|
print("modal deploy qwen.py") |
|
print("\nTo run in development mode, run:") |
|
print("modal serve qwen.py") |