Spaces:
Runtime error
Runtime error
# app.py | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional | |
import torch | |
import uvicorn | |
from transformers import pipeline | |
import os | |
from contextlib import asynccontextmanager # Import this! | |
import sys # Import sys for sys.exit() | |
# Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN | |
# from huggingface_hub import login | |
# --- Global variable to store the pipeline --- | |
generator = None | |
# Choose a model appropriate for free tier (e.g., 7B-8B parameters) | |
# For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big. | |
MODEL_NAME = "brendon-ai/gemma3-dolly-finetuned" | |
#"openai-community/gpt2" # Recommended for free tier | |
# --- Lifespan Event Handler --- | |
async def lifespan(app: FastAPI): | |
""" | |
Handles startup and shutdown events for the FastAPI application. | |
Loads the model on startup and can optionally clean up on shutdown. | |
""" | |
global generator | |
try: | |
# --- Optional: Login to Hugging Face Hub for gated models --- | |
# If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct), | |
# uncomment the following lines and ensure HF_TOKEN is set as a Space Secret. | |
# hf_token = os.getenv("HF_TOKEN") | |
# if hf_token: | |
# login(token=hf_token) | |
# print("Logged into Hugging Face Hub.") | |
# else: | |
# print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.") | |
# --- Startup Code: Load the model --- | |
if torch.cuda.is_available(): | |
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}") | |
device = 0 # Use GPU | |
# For larger models, use device_map="auto" and torch_dtype | |
# device_map = "auto" | |
# torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it | |
else: | |
print("CUDA not available, using CPU. Inference will be very slow for this model size.") | |
device = -1 # Use CPU | |
# device_map = None | |
# torch_dtype = torch.float32 # Default for CPU | |
print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}") | |
# The pipeline automatically handles AutoModel and AutoTokenizer. | |
# For better memory management with larger models, directly load with model_kwargs: | |
generator = pipeline( | |
'text-generation', | |
model=MODEL_NAME, | |
device=device, | |
# Pass your HF token to the model loading for gated models | |
# token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model | |
# For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
# For more fine-grained control and auto device mapping for multiple GPUs: | |
# model_kwargs={"device_map": "auto", "torch_dtype": torch.float16} | |
) | |
print("Model loaded successfully!") | |
# 'yield' signifies that the startup code has completed, and the application | |
# can now start processing requests. | |
yield | |
except Exception as e: | |
print(f"CRITICAL ERROR: Failed to load model during startup: {e}") | |
# Exit with a non-zero code to indicate failure if model loading fails | |
sys.exit(1) | |
finally: | |
# --- Shutdown Code (Optional): Clean up resources --- | |
print("Application shutting down. Any cleanup can go here.") | |
# --- Initialize FastAPI application with the lifespan handler --- | |
app = FastAPI(lifespan=lifespan, # Use the lifespan context manager | |
title="Text Generation API", | |
description="A simple text generation API using Hugging Face transformers", | |
version="1.0.0" | |
) | |
# Request model | |
class TextGenerationRequest(BaseModel): | |
prompt: str | |
max_new_tokens: Optional[int] = 250 # Changed from max_length for better control | |
num_return_sequences: Optional[int] = 1 | |
temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output | |
do_sample: Optional[bool] = True | |
top_p: Optional[float] = 0.9 # Added top_p for more control | |
# Response model | |
class TextGenerationResponse(BaseModel): | |
generated_text: str | |
prompt: str | |
model_name: str | |
async def root(): | |
return { | |
"message": "Text Generation API", | |
"status": "running", | |
"endpoints": { | |
"generate_post": "/generate", # Renamed for clarity | |
"generate_get": "/generate_simple", # Renamed for clarity | |
"health": "/health", | |
"docs": "/docs" | |
}, | |
"current_model": MODEL_NAME | |
} | |
async def health_check(): | |
return { | |
"status": "healthy" if generator else "unhealthy", | |
"model_loaded": generator is not None, | |
"cuda_available": torch.cuda.is_available(), | |
"model_name": MODEL_NAME | |
} | |
async def generate_text_post(request: TextGenerationRequest): | |
if generator is None: | |
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.") | |
try: | |
# Generate text | |
result = generator( | |
request.prompt, | |
max_new_tokens=request.max_new_tokens, # Use max_new_tokens | |
num_return_sequences=request.num_return_sequences, | |
temperature=request.temperature, | |
do_sample=request.do_sample, | |
top_p=request.top_p, # Pass top_p | |
pad_token_id=generator.tokenizer.eos_token_id, | |
eos_token_id=generator.tokenizer.eos_token_id, | |
# Add stop sequences relevant to your instruction-tuned model format | |
# stop_sequences=["\nUser:", "\n###", "\n\n"] | |
) | |
generated_text = result[0]['generated_text'] | |
return TextGenerationResponse( | |
generated_text=generated_text, | |
prompt=request.prompt, | |
model_name=MODEL_NAME | |
) | |
except Exception as e: | |
print(f"Generation failed: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.") | |
# Changed endpoint name to avoid conflict with POST | |
async def generate_text_get( | |
prompt: str, | |
max_new_tokens: int = 250, # Changed from max_length | |
temperature: float = 0.7 | |
): | |
"""GET endpoint for simple text generation""" | |
if generator is None: | |
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.") | |
try: | |
result = generator( | |
prompt, | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.9, # Default top_p for simple GET | |
pad_token_id=generator.tokenizer.eos_token_id, | |
eos_token_id=generator.tokenizer.eos_token_id, | |
) | |
return { | |
"generated_text": result[0]['generated_text'], | |
"prompt": prompt, | |
"model_name": MODEL_NAME | |
} | |
except Exception as e: | |
print(f"Generation failed: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.") | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860 | |
uvicorn.run(app, host="0.0.0.0", port=port) |