Spaces:
Sleeping
Sleeping
import torch | |
from transformers import pipeline | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional | |
import uvicorn | |
import os | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Text Generation API", | |
description="A simple text generation API using Hugging Face transformers", | |
version="1.0.0" | |
) | |
# Request model | |
class TextGenerationRequest(BaseModel): | |
prompt: str | |
max_length: Optional[int] = 50 | |
num_return_sequences: Optional[int] = 1 | |
temperature: Optional[float] = 1.0 | |
do_sample: Optional[bool] = True | |
# Response model | |
class TextGenerationResponse(BaseModel): | |
generated_text: str | |
prompt: str | |
# Global variable to store the pipeline | |
generator = None | |
async def load_model(): | |
global generator | |
# Check for GPU | |
if torch.cuda.is_available(): | |
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}") | |
device = 0 # Use GPU | |
else: | |
print("CUDA not available, using CPU.") | |
device = -1 # Use CPU | |
# Load the text generation pipeline | |
try: | |
generator = pipeline( | |
'text-generation', | |
model='distilgpt2', | |
device=device, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise e | |
async def root(): | |
return { | |
"message": "Text Generation API", | |
"status": "running", | |
"endpoints": { | |
"generate": "/generate", | |
"health": "/health", | |
"docs": "/docs" | |
} | |
} | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_loaded": generator is not None, | |
"cuda_available": torch.cuda.is_available() | |
} | |
async def generate_text(request: TextGenerationRequest): | |
if generator is None: | |
raise HTTPException(status_code=503, detail="Model not loaded yet") | |
try: | |
# Generate text | |
result = generator( | |
request.prompt, | |
max_length=min(request.max_length, 200), # Limit max length for safety | |
num_return_sequences=request.num_return_sequences, | |
temperature=request.temperature, | |
do_sample=request.do_sample, | |
pad_token_id=generator.tokenizer.eos_token_id | |
) | |
generated_text = result[0]['generated_text'] | |
return TextGenerationResponse( | |
generated_text=generated_text, | |
prompt=request.prompt | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
async def generate_text_get( | |
prompt: str, | |
max_length: int = 50, | |
temperature: float = 1.0 | |
): | |
"""GET endpoint for simple text generation""" | |
if generator is None: | |
raise HTTPException(status_code=503, detail="Model not loaded yet") | |
try: | |
result = generator( | |
prompt, | |
max_length=min(max_length, 200), | |
num_return_sequences=1, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=generator.tokenizer.eos_token_id | |
) | |
return { | |
"generated_text": result[0]['generated_text'], | |
"prompt": prompt | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860 | |
uvicorn.run(app, host="0.0.0.0", port=port) |