Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import time | |
import asyncio | |
import json | |
import re | |
from typing import Dict, Any, Optional | |
import logging | |
import traceback | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Qwen3 API", description="OpenAI-compatible API for Qwen3 models", version="1.0.0") | |
# Global variables | |
models = {} | |
tokenizers = {} | |
MODEL_CONFIGS = { | |
"qwen3-1.7b": "Qwen/Qwen3-1.7B", | |
"qwen3-4b": "Qwen/Qwen3-4B" | |
} | |
def download_model_safely(model_name: str, max_retries: int = 3): | |
"""Download model với retry logic""" | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Downloading {model_name} (attempt {attempt + 1}/{max_retries})...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
logger.info(f"Successfully loaded {model_name}") | |
return tokenizer, model | |
except Exception as e: | |
logger.error(f"Download failed (attempt {attempt + 1}): {str(e)}") | |
if attempt == max_retries - 1: | |
raise e | |
time.sleep(30) | |
def load_model_on_demand(model_key: str): | |
"""Load model khi cần thiết""" | |
if model_key not in models: | |
if model_key not in MODEL_CONFIGS: | |
raise ValueError(f"Unknown model key: {model_key}") | |
model_name = MODEL_CONFIGS[model_key] | |
logger.info(f"Loading {model_name} on demand...") | |
# Clear memory | |
if len(models) >= 1: | |
for key in list(models.keys()): | |
logger.info(f"Unloading {key} to free memory...") | |
del models[key] | |
del tokenizers[key] | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
import gc | |
gc.collect() | |
tokenizer, model = download_model_safely(model_name) | |
tokenizers[model_key] = tokenizer | |
models[model_key] = model | |
logger.info(f"{model_name} loaded successfully!") | |
def extract_json_from_response(text: str) -> str: | |
"""Extract JSON from response text""" | |
# Remove thinking tags completely | |
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL) | |
text = text.strip() | |
# Try to find JSON object | |
json_match = re.search(r'\{[^{}]*\}', text) | |
if json_match: | |
return json_match.group(0) | |
# If no JSON found, return the cleaned text | |
return text | |
def format_structured_prompt(messages: list, json_schema: dict) -> str: | |
"""Format messages with JSON schema instructions""" | |
# Extract schema properties for clear instructions | |
schema_info = json_schema.get('schema', {}) | |
properties = schema_info.get('properties', {}) | |
required = schema_info.get('required', []) | |
# Create clear JSON format instructions | |
json_instructions = f""" | |
You must respond with a valid JSON object only. No explanations, no markdown, no additional text. | |
Required JSON format: | |
{json.dumps(schema_info, indent=2)} | |
Example response format: {{"type": "examschedule"}} | |
""" | |
# Build the conversation | |
formatted_messages = [] | |
for msg in messages: | |
if msg["role"] == "system": | |
# Append JSON instructions to system message | |
content = msg["content"] + "\n" + json_instructions | |
formatted_messages.append({"role": "system", "content": content}) | |
else: | |
formatted_messages.append(msg) | |
return formatted_messages | |
async def load_models(): | |
"""Load default model""" | |
try: | |
logger.info("Loading default model: Qwen3-1.7B...") | |
tokenizer, model = download_model_safely("Qwen/Qwen3-1.7B") | |
tokenizers["qwen3-1.7b"] = tokenizer | |
models["qwen3-1.7b"] = model | |
logger.info("Default model loaded successfully!") | |
except Exception as e: | |
logger.error(f"Failed to load default model: {str(e)}") | |
logger.info("Server will continue running, models will be loaded on demand") | |
def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "API is running", | |
"available_models": list(MODEL_CONFIGS.keys()), | |
"loaded_models": list(models.keys()), | |
"version": "1.0.0", | |
"message": "Qwen3 API Service - OpenAI Compatible with Structured Output" | |
} | |
def list_models(): | |
"""List available models""" | |
return { | |
"available_models": MODEL_CONFIGS, | |
"loaded_models": list(models.keys()), | |
"total_available": len(MODEL_CONFIGS), | |
"total_loaded": len(models) | |
} | |
async def chat_completions(request: Dict[str, Any]): | |
"""OpenAI-compatible chat completions endpoint với Structured Output support""" | |
try: | |
logger.info("=== CHAT COMPLETIONS REQUEST START ===") | |
logger.info(f"Request payload: {json.dumps(request, ensure_ascii=False, indent=2)}") | |
# Parse request parameters | |
model_name = request.get("model", "qwen3-1.7b") | |
messages = request.get("messages", []) | |
temperature = request.get("temperature", 0.7) | |
max_tokens = request.get("max_tokens", 200) | |
response_format = request.get("response_format", None) | |
logger.info(f"Model: {model_name}, Temperature: {temperature}, Max tokens: {max_tokens}") | |
logger.info(f"Response format: {response_format}") | |
# Validate input | |
if not messages: | |
logger.error("Messages is empty") | |
raise HTTPException(status_code=400, detail="Messages cannot be empty") | |
# Determine model key | |
if "4b" in model_name.lower() or "4" in model_name.lower(): | |
model_key = "qwen3-4b" | |
else: | |
model_key = "qwen3-1.7b" | |
logger.info(f"Using model key: {model_key}") | |
# Load model if needed | |
if model_key not in models: | |
logger.info(f"Model {model_key} not loaded, loading on demand...") | |
load_model_on_demand(model_key) | |
# Get model and tokenizer | |
tokenizer = tokenizers[model_key] | |
model = models[model_key] | |
logger.info(f"Got tokenizer and model for {model_key}") | |
# Handle structured output | |
if response_format and response_format.get("type") == "json_schema": | |
json_schema = response_format.get("json_schema", {}) | |
logger.info("Structured output requested, formatting messages with JSON schema") | |
messages = format_structured_prompt(messages, json_schema) | |
# Format messages - FORCE DISABLE thinking mode | |
logger.info("Formatting messages with apply_chat_template...") | |
try: | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=False # CRITICAL: Force disable thinking | |
) | |
# AGGRESSIVE thinking mode removal | |
if "<think>" in text or "think>" in text: | |
logger.warning("Found thinking tags in formatted text, removing...") | |
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL) | |
text = re.sub(r'<think>\s*</think>', '', text) | |
text = text.replace("<think>", "").replace("</think>", "") | |
logger.info(f"Formatted text (first 300 chars): {text[:300]}...") | |
except Exception as e: | |
logger.error(f"Error in apply_chat_template: {str(e)}") | |
# Fallback to simple format WITHOUT thinking | |
text = "" | |
for msg in messages: | |
if msg["role"] == "system": | |
text += f"<|im_start|>system\n{msg['content']}<|im_end|>\n" | |
elif msg["role"] == "user": | |
text += f"<|im_start|>user\n{msg['content']}<|im_end|>\n" | |
elif msg["role"] == "assistant": | |
text += f"<|im_start|>assistant\n{msg['content']}<|im_end|>\n" | |
text += "<|im_start|>assistant\n" # NO thinking tags | |
logger.info(f"Using fallback formatting") | |
# Tokenize input | |
logger.info("Tokenizing input...") | |
model_inputs = tokenizer([text], return_tensors="pt") | |
logger.info(f"Input tokens shape: {model_inputs.input_ids.shape}") | |
# Move to device | |
if hasattr(model, 'device'): | |
logger.info(f"Moving inputs to device: {model.device}") | |
model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} | |
# Generate response với timeout | |
logger.info("Starting generation...") | |
start_time = time.time() | |
try: | |
# Sử dụng asyncio timeout | |
async def generate_with_timeout(): | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=min(max_tokens, 200), | |
temperature=temperature, | |
do_sample=True if temperature > 0 else False, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.1, | |
top_p=0.9 if temperature > 0 else None, | |
use_cache=True | |
) | |
return generated_ids | |
# 30 second timeout | |
generated_ids = await asyncio.wait_for(generate_with_timeout(), timeout=30.0) | |
generation_time = time.time() - start_time | |
logger.info(f"Generation completed in {generation_time:.2f} seconds") | |
except asyncio.TimeoutError: | |
logger.error("Generation timeout after 30 seconds") | |
return { | |
"choices": [{ | |
"message": { | |
"content": "Generation timeout. Please try a shorter prompt.", | |
"role": "assistant" | |
}, | |
"finish_reason": "timeout", | |
"index": 0 | |
}], | |
"error": "timeout", | |
"model": model_key | |
} | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
return { | |
"choices": [{ | |
"message": { | |
"content": f"Generation error: {str(e)}", | |
"role": "assistant" | |
}, | |
"finish_reason": "error", | |
"index": 0 | |
}], | |
"error": str(e), | |
"model": model_key | |
} | |
# Extract response | |
logger.info("Extracting response...") | |
try: | |
# Get input length correctly | |
if hasattr(model_inputs, 'input_ids'): | |
input_length = model_inputs.input_ids.shape[1] | |
elif isinstance(model_inputs, dict) and 'input_ids' in model_inputs: | |
input_length = model_inputs['input_ids'].shape[1] | |
else: | |
input_length = 0 | |
# Extract output tokens | |
output_ids = generated_ids[0][input_length:].tolist() | |
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip() | |
# Handle structured output | |
if response_format and response_format.get("type") == "json_schema": | |
response = extract_json_from_response(response) | |
logger.info(f"Extracted JSON response: {response}") | |
# Validate JSON | |
try: | |
json.loads(response) | |
except json.JSONDecodeError: | |
logger.warning("Generated response is not valid JSON, attempting to fix...") | |
# Try to extract just the JSON part | |
json_match = re.search(r'\{.*\}', response) | |
if json_match: | |
response = json_match.group(0) | |
else: | |
response = '{"type": "other"}' # Fallback | |
logger.info(f"Final response: {response}") | |
except Exception as e: | |
logger.error(f"Error extracting response: {str(e)}") | |
response = "Error extracting response" | |
# Clean up response | |
if not response: | |
response = "I apologize, but I couldn't generate a proper response. Please try again." | |
# Format response - tương thích với AiService | |
result = { | |
"choices": [{ | |
"message": { | |
"content": response, | |
"role": "assistant" | |
}, | |
"finish_reason": "stop", | |
"index": 0 | |
}], | |
"model": model_key, | |
"usage": { | |
"prompt_tokens": input_length if 'input_length' in locals() else 0, | |
"completion_tokens": len(output_ids) if 'output_ids' in locals() else 0, | |
"total_tokens": (input_length if 'input_length' in locals() else 0) + (len(output_ids) if 'output_ids' in locals() else 0) | |
}, | |
"object": "chat.completion", | |
"created": int(time.time()) | |
} | |
logger.info("=== CHAT COMPLETIONS REQUEST END ===") | |
return result | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error in chat_completions: {str(e)}") | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
return { | |
"choices": [{ | |
"message": { | |
"content": f"Unexpected error: {str(e)}", | |
"role": "assistant" | |
}, | |
"finish_reason": "error", | |
"index": 0 | |
}], | |
"error": str(e), | |
"model": "qwen3-1.7b" | |
} | |
def health(): | |
"""Simple health check""" | |
return { | |
"status": "healthy", | |
"timestamp": int(time.time()), | |
"models_loaded": len(models) | |
} | |
# Error handlers | |
async def not_found_handler(request, exc): | |
return JSONResponse( | |
status_code=404, | |
content={ | |
"error": { | |
"message": "Endpoint not found", | |
"type": "not_found_error", | |
"code": 404 | |
} | |
} | |
) | |
async def internal_error_handler(request, exc): | |
return JSONResponse( | |
status_code=500, | |
content={ | |
"error": { | |
"message": "Internal server error", | |
"type": "internal_server_error", | |
"code": 500 | |
} | |
} | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |