Gmail-AI-Agent / llm.py
Kayani9178's picture
New Files added
1136fac verified
import modal
from fastapi import HTTPException
from pydantic import BaseModel, Field
from typing import Optional, Union, List, Dict, Any
# Define the image with all required dependencies
image = (
modal.Image.debian_slim()
.pip_install([
"torch",
"transformers>=4.51.0",
"fastapi[standard]",
"accelerate",
"tokenizers"
])
)
app = modal.App("qwen-api", image=image)
# Request model for the API - Maximizing token output
class ChatRequest(BaseModel):
message: str
max_tokens: Optional[int] = 16384 # Greatly increased token limit
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
strip_thinking: Optional[bool] = False # Option to strip <think> tags to save tokens
class ChatResponse(BaseModel):
response: str
tokens_used: Optional[int] = None # Make this optional
input_tokens: Optional[int] = None # Track input tokens
model_name: str = "Qwen/Qwen3-4B" # Include model info
# Modal class to handle model loading and inference - updated for new Modal syntax
@app.cls(
image=image,
gpu="A10G", # Use a single A10G GPU
scaledown_window=300, # Keep container alive for 5 minutes after last use
timeout=3600, # 1 hour timeout for long running requests
enable_memory_snapshot=True, # Enable memory snapshots for faster cold starts
)
class QwenModel:
# Fixed: Use modal.enter() instead of __init__ for setup
@modal.enter()
def setup(self):
print("Loading Qwen/Qwen3-4B model...")
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "Qwen/Qwen3-4B"
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Load model with GPU support - use float16 for more efficient memory usage
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully!")
def _strip_thinking_tags(self, text: str) -> str:
"""Strip <think> sections from the response to save tokens"""
import re
# Find and remove content between <think> and </think> or end of string
return re.sub(r'<think>.*?(?:</think>|$)', '', text, flags=re.DOTALL)
@modal.method()
def generate_response(self, message: str, max_tokens: int = 16384,
temperature: float = 0.7, top_p: float = 0.9,
strip_thinking: bool = False):
"""Generate a response using the Qwen model"""
try:
import torch
# Format the message for chat
messages = [
{"role": "user", "content": message}
]
# Apply chat template
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
input_token_count = len(model_inputs.input_ids[0])
# Set parameters with very high token limits for 4B model
generation_kwargs = {
**model_inputs,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": max_tokens if max_tokens is not None else 16384,
"repetition_penalty": 1.0,
}
print(f"Generating with settings: max_new_tokens={generation_kwargs.get('max_new_tokens')}")
print(f"Input token count: {input_token_count}")
# Generate response
with torch.no_grad():
generated_ids = self.model.generate(**generation_kwargs)
# Decode the response (excluding the input tokens)
input_length = model_inputs.input_ids.shape[1]
response_ids = generated_ids[0][input_length:]
response = self.tokenizer.decode(response_ids, skip_special_tokens=True)
# Optionally strip thinking tags
if strip_thinking:
response = self._strip_thinking_tags(response)
output_token_count = len(response_ids)
print(f"Generated response with {output_token_count} tokens")
return {
"response": response.strip(),
"tokens_used": output_token_count,
"input_tokens": input_token_count,
"model_name": "Qwen/Qwen3-4B"
}
except Exception as e:
print(f"Error during generation: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
# Web endpoint - single LLM interaction endpoint
@app.function(image=image, timeout=180) # Set the function timeout to 180 seconds
@modal.fastapi_endpoint(method="POST")
def chat(request: ChatRequest):
"""
Chat endpoint for Qwen3-4B model
Example usage:
curl -X POST "https://your-modal-url/" \
-H "Content-Type: application/json" \
-d '{"message": "Hello, how are you?"}'
"""
try:
print(f"Received request: message length={len(request.message)}, max_tokens={request.max_tokens}, strip_thinking={request.strip_thinking}")
# Initialize the model (this will reuse existing instance if available)
model = QwenModel()
# Generate response - increased function timeout at the app.function level instead
result = model.generate_response.remote(
message=request.message,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
strip_thinking=request.strip_thinking
)
print(f"Returning response: length={len(result['response'])}, output_tokens={result.get('tokens_used')}, input_tokens={result.get('input_tokens')}")
return ChatResponse(
response=result["response"],
tokens_used=result["tokens_used"],
input_tokens=result["input_tokens"],
model_name=result["model_name"]
)
except Exception as e:
print(f"Error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Local testing function (for Modal serve)
if __name__ == "__main__":
print("To deploy this app, run:")
print("modal deploy qwen.py")
print("\nTo run in development mode, run:")
print("modal serve qwen.py")