File size: 6,974 Bytes
1136fac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import modal
from fastapi import HTTPException
from pydantic import BaseModel, Field
from typing import Optional, Union, List, Dict, Any
# Define the image with all required dependencies
image = (
    modal.Image.debian_slim()
    .pip_install([
        "torch",
        "transformers>=4.51.0",
        "fastapi[standard]",
        "accelerate",
        "tokenizers"
    ])
)

app = modal.App("qwen-api", image=image)

# Request model for the API - Maximizing token output
class ChatRequest(BaseModel):
    message: str
    max_tokens: Optional[int] = 16384  # Greatly increased token limit
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9
    strip_thinking: Optional[bool] = False  # Option to strip <think> tags to save tokens

class ChatResponse(BaseModel):
    response: str
    tokens_used: Optional[int] = None  # Make this optional
    input_tokens: Optional[int] = None  # Track input tokens
    model_name: str = "Qwen/Qwen3-4B"  # Include model info

# Modal class to handle model loading and inference - updated for new Modal syntax
@app.cls(
    image=image,
    gpu="A10G",  # Use a single A10G GPU
    scaledown_window=300,  # Keep container alive for 5 minutes after last use
    timeout=3600,  # 1 hour timeout for long running requests
    enable_memory_snapshot=True,  # Enable memory snapshots for faster cold starts
)
class QwenModel:
    # Fixed: Use modal.enter() instead of __init__ for setup
    @modal.enter()
    def setup(self):
        print("Loading Qwen/Qwen3-4B model...")
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM
        
        model_name = "Qwen/Qwen3-4B"
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        
        # Load model with GPU support - use float16 for more efficient memory usage
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        
        print("Model loaded successfully!")
    
    def _strip_thinking_tags(self, text: str) -> str:
        """Strip <think> sections from the response to save tokens"""
        import re
        # Find and remove content between <think> and </think> or end of string
        return re.sub(r'<think>.*?(?:</think>|$)', '', text, flags=re.DOTALL)
    
    @modal.method()
    def generate_response(self, message: str, max_tokens: int = 16384, 
                         temperature: float = 0.7, top_p: float = 0.9,
                         strip_thinking: bool = False):
        """Generate a response using the Qwen model"""
        try:
            import torch
            
            # Format the message for chat
            messages = [
                {"role": "user", "content": message}
            ]
            
            # Apply chat template
            text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            
            # Tokenize input
            model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
            input_token_count = len(model_inputs.input_ids[0])
            
            # Set parameters with very high token limits for 4B model
            generation_kwargs = {
                **model_inputs,
                "temperature": temperature,
                "top_p": top_p,
                "do_sample": True,
                "pad_token_id": self.tokenizer.eos_token_id,
                "max_new_tokens": max_tokens if max_tokens is not None else 16384,
                "repetition_penalty": 1.0,
            }
                
            print(f"Generating with settings: max_new_tokens={generation_kwargs.get('max_new_tokens')}")
            print(f"Input token count: {input_token_count}")
            
            # Generate response
            with torch.no_grad():
                generated_ids = self.model.generate(**generation_kwargs)
            
            # Decode the response (excluding the input tokens)
            input_length = model_inputs.input_ids.shape[1]
            response_ids = generated_ids[0][input_length:]
            response = self.tokenizer.decode(response_ids, skip_special_tokens=True)
            
            # Optionally strip thinking tags
            if strip_thinking:
                response = self._strip_thinking_tags(response)
            
            output_token_count = len(response_ids)
            print(f"Generated response with {output_token_count} tokens")
            
            return {
                "response": response.strip(),
                "tokens_used": output_token_count,
                "input_tokens": input_token_count,
                "model_name": "Qwen/Qwen3-4B"
            }
            
        except Exception as e:
            print(f"Error during generation: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

# Web endpoint - single LLM interaction endpoint
@app.function(image=image, timeout=180)  # Set the function timeout to 180 seconds
@modal.fastapi_endpoint(method="POST")
def chat(request: ChatRequest):
    """
    Chat endpoint for Qwen3-4B model
    
    Example usage:
    curl -X POST "https://your-modal-url/" \
         -H "Content-Type: application/json" \
         -d '{"message": "Hello, how are you?"}'
    """
    try:
        print(f"Received request: message length={len(request.message)}, max_tokens={request.max_tokens}, strip_thinking={request.strip_thinking}")
        
        # Initialize the model (this will reuse existing instance if available)
        model = QwenModel()
        
        # Generate response - increased function timeout at the app.function level instead
        result = model.generate_response.remote(
            message=request.message,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p,
            strip_thinking=request.strip_thinking
        )
        
        print(f"Returning response: length={len(result['response'])}, output_tokens={result.get('tokens_used')}, input_tokens={result.get('input_tokens')}")
        
        return ChatResponse(
            response=result["response"],
            tokens_used=result["tokens_used"],
            input_tokens=result["input_tokens"],
            model_name=result["model_name"]
        )
        
    except Exception as e:
        print(f"Error in chat endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Local testing function (for Modal serve)
if __name__ == "__main__":
    print("To deploy this app, run:")
    print("modal deploy qwen.py")
    print("\nTo run in development mode, run:")
    print("modal serve qwen.py")