File size: 6,974 Bytes
1136fac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import modal
from fastapi import HTTPException
from pydantic import BaseModel, Field
from typing import Optional, Union, List, Dict, Any
# Define the image with all required dependencies
image = (
modal.Image.debian_slim()
.pip_install([
"torch",
"transformers>=4.51.0",
"fastapi[standard]",
"accelerate",
"tokenizers"
])
)
app = modal.App("qwen-api", image=image)
# Request model for the API - Maximizing token output
class ChatRequest(BaseModel):
message: str
max_tokens: Optional[int] = 16384 # Greatly increased token limit
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
strip_thinking: Optional[bool] = False # Option to strip <think> tags to save tokens
class ChatResponse(BaseModel):
response: str
tokens_used: Optional[int] = None # Make this optional
input_tokens: Optional[int] = None # Track input tokens
model_name: str = "Qwen/Qwen3-4B" # Include model info
# Modal class to handle model loading and inference - updated for new Modal syntax
@app.cls(
image=image,
gpu="A10G", # Use a single A10G GPU
scaledown_window=300, # Keep container alive for 5 minutes after last use
timeout=3600, # 1 hour timeout for long running requests
enable_memory_snapshot=True, # Enable memory snapshots for faster cold starts
)
class QwenModel:
# Fixed: Use modal.enter() instead of __init__ for setup
@modal.enter()
def setup(self):
print("Loading Qwen/Qwen3-4B model...")
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "Qwen/Qwen3-4B"
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Load model with GPU support - use float16 for more efficient memory usage
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully!")
def _strip_thinking_tags(self, text: str) -> str:
"""Strip <think> sections from the response to save tokens"""
import re
# Find and remove content between <think> and </think> or end of string
return re.sub(r'<think>.*?(?:</think>|$)', '', text, flags=re.DOTALL)
@modal.method()
def generate_response(self, message: str, max_tokens: int = 16384,
temperature: float = 0.7, top_p: float = 0.9,
strip_thinking: bool = False):
"""Generate a response using the Qwen model"""
try:
import torch
# Format the message for chat
messages = [
{"role": "user", "content": message}
]
# Apply chat template
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
input_token_count = len(model_inputs.input_ids[0])
# Set parameters with very high token limits for 4B model
generation_kwargs = {
**model_inputs,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": max_tokens if max_tokens is not None else 16384,
"repetition_penalty": 1.0,
}
print(f"Generating with settings: max_new_tokens={generation_kwargs.get('max_new_tokens')}")
print(f"Input token count: {input_token_count}")
# Generate response
with torch.no_grad():
generated_ids = self.model.generate(**generation_kwargs)
# Decode the response (excluding the input tokens)
input_length = model_inputs.input_ids.shape[1]
response_ids = generated_ids[0][input_length:]
response = self.tokenizer.decode(response_ids, skip_special_tokens=True)
# Optionally strip thinking tags
if strip_thinking:
response = self._strip_thinking_tags(response)
output_token_count = len(response_ids)
print(f"Generated response with {output_token_count} tokens")
return {
"response": response.strip(),
"tokens_used": output_token_count,
"input_tokens": input_token_count,
"model_name": "Qwen/Qwen3-4B"
}
except Exception as e:
print(f"Error during generation: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
# Web endpoint - single LLM interaction endpoint
@app.function(image=image, timeout=180) # Set the function timeout to 180 seconds
@modal.fastapi_endpoint(method="POST")
def chat(request: ChatRequest):
"""
Chat endpoint for Qwen3-4B model
Example usage:
curl -X POST "https://your-modal-url/" \
-H "Content-Type: application/json" \
-d '{"message": "Hello, how are you?"}'
"""
try:
print(f"Received request: message length={len(request.message)}, max_tokens={request.max_tokens}, strip_thinking={request.strip_thinking}")
# Initialize the model (this will reuse existing instance if available)
model = QwenModel()
# Generate response - increased function timeout at the app.function level instead
result = model.generate_response.remote(
message=request.message,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
strip_thinking=request.strip_thinking
)
print(f"Returning response: length={len(result['response'])}, output_tokens={result.get('tokens_used')}, input_tokens={result.get('input_tokens')}")
return ChatResponse(
response=result["response"],
tokens_used=result["tokens_used"],
input_tokens=result["input_tokens"],
model_name=result["model_name"]
)
except Exception as e:
print(f"Error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Local testing function (for Modal serve)
if __name__ == "__main__":
print("To deploy this app, run:")
print("modal deploy qwen.py")
print("\nTo run in development mode, run:")
print("modal serve qwen.py") |