# app.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import torch import uvicorn from transformers import pipeline import os from contextlib import asynccontextmanager # Import this! import sys # Import sys for sys.exit() # Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN # from huggingface_hub import login # --- Global variable to store the pipeline --- generator = None # Choose a model appropriate for free tier (e.g., 7B-8B parameters) # For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big. MODEL_NAME = "brendon-ai/gemma3-dolly-finetuned" #"openai-community/gpt2" # Recommended for free tier # --- Lifespan Event Handler --- @asynccontextmanager async def lifespan(app: FastAPI): """ Handles startup and shutdown events for the FastAPI application. Loads the model on startup and can optionally clean up on shutdown. """ global generator try: # --- Optional: Login to Hugging Face Hub for gated models --- # If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct), # uncomment the following lines and ensure HF_TOKEN is set as a Space Secret. # hf_token = os.getenv("HF_TOKEN") # if hf_token: # login(token=hf_token) # print("Logged into Hugging Face Hub.") # else: # print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.") # --- Startup Code: Load the model --- if torch.cuda.is_available(): print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}") device = 0 # Use GPU # For larger models, use device_map="auto" and torch_dtype # device_map = "auto" # torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it else: print("CUDA not available, using CPU. Inference will be very slow for this model size.") device = -1 # Use CPU # device_map = None # torch_dtype = torch.float32 # Default for CPU print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}") # The pipeline automatically handles AutoModel and AutoTokenizer. # For better memory management with larger models, directly load with model_kwargs: generator = pipeline( 'text-generation', model=MODEL_NAME, device=device, # Pass your HF token to the model loading for gated models # token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model # For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # For more fine-grained control and auto device mapping for multiple GPUs: # model_kwargs={"device_map": "auto", "torch_dtype": torch.float16} ) print("Model loaded successfully!") # 'yield' signifies that the startup code has completed, and the application # can now start processing requests. yield except Exception as e: print(f"CRITICAL ERROR: Failed to load model during startup: {e}") # Exit with a non-zero code to indicate failure if model loading fails sys.exit(1) finally: # --- Shutdown Code (Optional): Clean up resources --- print("Application shutting down. Any cleanup can go here.") # --- Initialize FastAPI application with the lifespan handler --- app = FastAPI(lifespan=lifespan, # Use the lifespan context manager title="Text Generation API", description="A simple text generation API using Hugging Face transformers", version="1.0.0" ) # Request model class TextGenerationRequest(BaseModel): prompt: str max_new_tokens: Optional[int] = 250 # Changed from max_length for better control num_return_sequences: Optional[int] = 1 temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output do_sample: Optional[bool] = True top_p: Optional[float] = 0.9 # Added top_p for more control # Response model class TextGenerationResponse(BaseModel): generated_text: str prompt: str model_name: str @app.get("/") async def root(): return { "message": "Text Generation API", "status": "running", "endpoints": { "generate_post": "/generate", # Renamed for clarity "generate_get": "/generate_simple", # Renamed for clarity "health": "/health", "docs": "/docs" }, "current_model": MODEL_NAME } @app.get("/health") async def health_check(): return { "status": "healthy" if generator else "unhealthy", "model_loaded": generator is not None, "cuda_available": torch.cuda.is_available(), "model_name": MODEL_NAME } @app.post("/generate", response_model=TextGenerationResponse) async def generate_text_post(request: TextGenerationRequest): if generator is None: raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.") try: # Generate text result = generator( request.prompt, max_new_tokens=request.max_new_tokens, # Use max_new_tokens num_return_sequences=request.num_return_sequences, temperature=request.temperature, do_sample=request.do_sample, top_p=request.top_p, # Pass top_p pad_token_id=generator.tokenizer.eos_token_id, eos_token_id=generator.tokenizer.eos_token_id, # Add stop sequences relevant to your instruction-tuned model format # stop_sequences=["\nUser:", "\n###", "\n\n"] ) generated_text = result[0]['generated_text'] return TextGenerationResponse( generated_text=generated_text, prompt=request.prompt, model_name=MODEL_NAME ) except Exception as e: print(f"Generation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.") @app.get("/generate_simple") # Changed endpoint name to avoid conflict with POST async def generate_text_get( prompt: str, max_new_tokens: int = 250, # Changed from max_length temperature: float = 0.7 ): """GET endpoint for simple text generation""" if generator is None: raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.") try: result = generator( prompt, max_new_tokens=max_new_tokens, num_return_sequences=1, temperature=temperature, do_sample=True, top_p=0.9, # Default top_p for simple GET pad_token_id=generator.tokenizer.eos_token_id, eos_token_id=generator.tokenizer.eos_token_id, ) return { "generated_text": result[0]['generated_text'], "prompt": prompt, "model_name": MODEL_NAME } except Exception as e: print(f"Generation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.") if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860 uvicorn.run(app, host="0.0.0.0", port=port)