brendon-ai's picture
Update app.py
ac3f3ed verified
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
import uvicorn
from transformers import pipeline
import os
from contextlib import asynccontextmanager # Import this!
import sys # Import sys for sys.exit()
# Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN
# from huggingface_hub import login
# --- Global variable to store the pipeline ---
generator = None
# Choose a model appropriate for free tier (e.g., 7B-8B parameters)
# For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big.
MODEL_NAME = "brendon-ai/gemma3-dolly-finetuned"
#"openai-community/gpt2" # Recommended for free tier
# --- Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown events for the FastAPI application.
Loads the model on startup and can optionally clean up on shutdown.
"""
global generator
try:
# --- Optional: Login to Hugging Face Hub for gated models ---
# If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct),
# uncomment the following lines and ensure HF_TOKEN is set as a Space Secret.
# hf_token = os.getenv("HF_TOKEN")
# if hf_token:
# login(token=hf_token)
# print("Logged into Hugging Face Hub.")
# else:
# print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.")
# --- Startup Code: Load the model ---
if torch.cuda.is_available():
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
device = 0 # Use GPU
# For larger models, use device_map="auto" and torch_dtype
# device_map = "auto"
# torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it
else:
print("CUDA not available, using CPU. Inference will be very slow for this model size.")
device = -1 # Use CPU
# device_map = None
# torch_dtype = torch.float32 # Default for CPU
print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}")
# The pipeline automatically handles AutoModel and AutoTokenizer.
# For better memory management with larger models, directly load with model_kwargs:
generator = pipeline(
'text-generation',
model=MODEL_NAME,
device=device,
# Pass your HF token to the model loading for gated models
# token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model
# For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
# For more fine-grained control and auto device mapping for multiple GPUs:
# model_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
)
print("Model loaded successfully!")
# 'yield' signifies that the startup code has completed, and the application
# can now start processing requests.
yield
except Exception as e:
print(f"CRITICAL ERROR: Failed to load model during startup: {e}")
# Exit with a non-zero code to indicate failure if model loading fails
sys.exit(1)
finally:
# --- Shutdown Code (Optional): Clean up resources ---
print("Application shutting down. Any cleanup can go here.")
# --- Initialize FastAPI application with the lifespan handler ---
app = FastAPI(lifespan=lifespan, # Use the lifespan context manager
title="Text Generation API",
description="A simple text generation API using Hugging Face transformers",
version="1.0.0"
)
# Request model
class TextGenerationRequest(BaseModel):
prompt: str
max_new_tokens: Optional[int] = 250 # Changed from max_length for better control
num_return_sequences: Optional[int] = 1
temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output
do_sample: Optional[bool] = True
top_p: Optional[float] = 0.9 # Added top_p for more control
# Response model
class TextGenerationResponse(BaseModel):
generated_text: str
prompt: str
model_name: str
@app.get("/")
async def root():
return {
"message": "Text Generation API",
"status": "running",
"endpoints": {
"generate_post": "/generate", # Renamed for clarity
"generate_get": "/generate_simple", # Renamed for clarity
"health": "/health",
"docs": "/docs"
},
"current_model": MODEL_NAME
}
@app.get("/health")
async def health_check():
return {
"status": "healthy" if generator else "unhealthy",
"model_loaded": generator is not None,
"cuda_available": torch.cuda.is_available(),
"model_name": MODEL_NAME
}
@app.post("/generate", response_model=TextGenerationResponse)
async def generate_text_post(request: TextGenerationRequest):
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
try:
# Generate text
result = generator(
request.prompt,
max_new_tokens=request.max_new_tokens, # Use max_new_tokens
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
do_sample=request.do_sample,
top_p=request.top_p, # Pass top_p
pad_token_id=generator.tokenizer.eos_token_id,
eos_token_id=generator.tokenizer.eos_token_id,
# Add stop sequences relevant to your instruction-tuned model format
# stop_sequences=["\nUser:", "\n###", "\n\n"]
)
generated_text = result[0]['generated_text']
return TextGenerationResponse(
generated_text=generated_text,
prompt=request.prompt,
model_name=MODEL_NAME
)
except Exception as e:
print(f"Generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
@app.get("/generate_simple") # Changed endpoint name to avoid conflict with POST
async def generate_text_get(
prompt: str,
max_new_tokens: int = 250, # Changed from max_length
temperature: float = 0.7
):
"""GET endpoint for simple text generation"""
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
try:
result = generator(
prompt,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=temperature,
do_sample=True,
top_p=0.9, # Default top_p for simple GET
pad_token_id=generator.tokenizer.eos_token_id,
eos_token_id=generator.tokenizer.eos_token_id,
)
return {
"generated_text": result[0]['generated_text'],
"prompt": prompt,
"model_name": MODEL_NAME
}
except Exception as e:
print(f"Generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port)