brendon-ai's picture
Update app.py
e8b97b9 verified
raw
history blame
3.81 kB
import torch
from transformers import pipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
import os
# Initialize FastAPI app
app = FastAPI(
title="Text Generation API",
description="A simple text generation API using Hugging Face transformers",
version="1.0.0"
)
# Request model
class TextGenerationRequest(BaseModel):
prompt: str
max_length: Optional[int] = 50
num_return_sequences: Optional[int] = 1
temperature: Optional[float] = 1.0
do_sample: Optional[bool] = True
# Response model
class TextGenerationResponse(BaseModel):
generated_text: str
prompt: str
# Global variable to store the pipeline
generator = None
@app.on_event("startup")
async def load_model():
global generator
# Check for GPU
if torch.cuda.is_available():
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
device = 0 # Use GPU
else:
print("CUDA not available, using CPU.")
device = -1 # Use CPU
# Load the text generation pipeline
try:
generator = pipeline(
'text-generation',
model='EleutherAI/gpt-neo-2.7B',
device=device,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
@app.get("/")
async def root():
return {
"message": "Text Generation API",
"status": "running",
"endpoints": {
"generate": "/generate",
"health": "/health",
"docs": "/docs"
}
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": generator is not None,
"cuda_available": torch.cuda.is_available()
}
@app.post("/generate", response_model=TextGenerationResponse)
async def generate_text(request: TextGenerationRequest):
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
try:
# Generate text
result = generator(
request.prompt,
max_length=min(request.max_length, 200), # Limit max length for safety
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
do_sample=request.do_sample,
pad_token_id=generator.tokenizer.eos_token_id
)
generated_text = result[0]['generated_text']
return TextGenerationResponse(
generated_text=generated_text,
prompt=request.prompt
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/generate")
async def generate_text_get(
prompt: str,
max_length: int = 50,
temperature: float = 1.0
):
"""GET endpoint for simple text generation"""
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
try:
result = generator(
prompt,
max_length=min(max_length, 200),
num_return_sequences=1,
temperature=temperature,
do_sample=True,
pad_token_id=generator.tokenizer.eos_token_id
)
return {
"generated_text": result[0]['generated_text'],
"prompt": prompt
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port)