import torch from transformers import pipeline from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import uvicorn import os # Initialize FastAPI app app = FastAPI( title="Text Generation API", description="A simple text generation API using Hugging Face transformers", version="1.0.0" ) # Request model class TextGenerationRequest(BaseModel): prompt: str max_length: Optional[int] = 50 num_return_sequences: Optional[int] = 1 temperature: Optional[float] = 1.0 do_sample: Optional[bool] = True # Response model class TextGenerationResponse(BaseModel): generated_text: str prompt: str # Global variable to store the pipeline generator = None @app.on_event("startup") async def load_model(): global generator # Check for GPU if torch.cuda.is_available(): print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}") device = 0 # Use GPU else: print("CUDA not available, using CPU.") device = -1 # Use CPU # Load the text generation pipeline try: generator = pipeline( 'text-generation', model='EleutherAI/gpt-neo-2.7B', device=device, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise e @app.get("/") async def root(): return { "message": "Text Generation API", "status": "running", "endpoints": { "generate": "/generate", "health": "/health", "docs": "/docs" } } @app.get("/health") async def health_check(): return { "status": "healthy", "model_loaded": generator is not None, "cuda_available": torch.cuda.is_available() } @app.post("/generate", response_model=TextGenerationResponse) async def generate_text(request: TextGenerationRequest): if generator is None: raise HTTPException(status_code=503, detail="Model not loaded yet") try: # Generate text result = generator( request.prompt, max_length=min(request.max_length, 200), # Limit max length for safety num_return_sequences=request.num_return_sequences, temperature=request.temperature, do_sample=request.do_sample, pad_token_id=generator.tokenizer.eos_token_id ) generated_text = result[0]['generated_text'] return TextGenerationResponse( generated_text=generated_text, prompt=request.prompt ) except Exception as e: raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") @app.get("/generate") async def generate_text_get( prompt: str, max_length: int = 50, temperature: float = 1.0 ): """GET endpoint for simple text generation""" if generator is None: raise HTTPException(status_code=503, detail="Model not loaded yet") try: result = generator( prompt, max_length=min(max_length, 200), num_return_sequences=1, temperature=temperature, do_sample=True, pad_token_id=generator.tokenizer.eos_token_id ) return { "generated_text": result[0]['generated_text'], "prompt": prompt } except Exception as e: raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860 uvicorn.run(app, host="0.0.0.0", port=port)