File size: 3,814 Bytes
3557791 5811c7f 51a0302 5811c7f 51a0302 e8b97b9 51a0302 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
from transformers import pipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
import os
# Initialize FastAPI app
app = FastAPI(
title="Text Generation API",
description="A simple text generation API using Hugging Face transformers",
version="1.0.0"
)
# Request model
class TextGenerationRequest(BaseModel):
prompt: str
max_length: Optional[int] = 50
num_return_sequences: Optional[int] = 1
temperature: Optional[float] = 1.0
do_sample: Optional[bool] = True
# Response model
class TextGenerationResponse(BaseModel):
generated_text: str
prompt: str
# Global variable to store the pipeline
generator = None
@app.on_event("startup")
async def load_model():
global generator
# Check for GPU
if torch.cuda.is_available():
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
device = 0 # Use GPU
else:
print("CUDA not available, using CPU.")
device = -1 # Use CPU
# Load the text generation pipeline
try:
generator = pipeline(
'text-generation',
model='EleutherAI/gpt-neo-2.7B',
device=device,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
@app.get("/")
async def root():
return {
"message": "Text Generation API",
"status": "running",
"endpoints": {
"generate": "/generate",
"health": "/health",
"docs": "/docs"
}
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": generator is not None,
"cuda_available": torch.cuda.is_available()
}
@app.post("/generate", response_model=TextGenerationResponse)
async def generate_text(request: TextGenerationRequest):
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
try:
# Generate text
result = generator(
request.prompt,
max_length=min(request.max_length, 200), # Limit max length for safety
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
do_sample=request.do_sample,
pad_token_id=generator.tokenizer.eos_token_id
)
generated_text = result[0]['generated_text']
return TextGenerationResponse(
generated_text=generated_text,
prompt=request.prompt
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/generate")
async def generate_text_get(
prompt: str,
max_length: int = 50,
temperature: float = 1.0
):
"""GET endpoint for simple text generation"""
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
try:
result = generator(
prompt,
max_length=min(max_length, 200),
num_return_sequences=1,
temperature=temperature,
do_sample=True,
pad_token_id=generator.tokenizer.eos_token_id
)
return {
"generated_text": result[0]['generated_text'],
"prompt": prompt
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port) |