|
import torch |
|
from transformers import pipeline |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from typing import Optional |
|
import uvicorn |
|
import os |
|
|
|
|
|
app = FastAPI( |
|
title="Text Generation API", |
|
description="A simple text generation API using Hugging Face transformers", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
class TextGenerationRequest(BaseModel): |
|
prompt: str |
|
max_length: Optional[int] = 50 |
|
num_return_sequences: Optional[int] = 1 |
|
temperature: Optional[float] = 1.0 |
|
do_sample: Optional[bool] = True |
|
|
|
|
|
class TextGenerationResponse(BaseModel): |
|
generated_text: str |
|
prompt: str |
|
|
|
|
|
generator = None |
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
global generator |
|
|
|
|
|
if torch.cuda.is_available(): |
|
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}") |
|
device = 0 |
|
else: |
|
print("CUDA not available, using CPU.") |
|
device = -1 |
|
|
|
|
|
try: |
|
generator = pipeline( |
|
'text-generation', |
|
model='EleutherAI/gpt-neo-2.7B', |
|
device=device, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
print("Model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise e |
|
|
|
@app.get("/") |
|
async def root(): |
|
return { |
|
"message": "Text Generation API", |
|
"status": "running", |
|
"endpoints": { |
|
"generate": "/generate", |
|
"health": "/health", |
|
"docs": "/docs" |
|
} |
|
} |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return { |
|
"status": "healthy", |
|
"model_loaded": generator is not None, |
|
"cuda_available": torch.cuda.is_available() |
|
} |
|
|
|
@app.post("/generate", response_model=TextGenerationResponse) |
|
async def generate_text(request: TextGenerationRequest): |
|
if generator is None: |
|
raise HTTPException(status_code=503, detail="Model not loaded yet") |
|
|
|
try: |
|
|
|
result = generator( |
|
request.prompt, |
|
max_length=min(request.max_length, 200), |
|
num_return_sequences=request.num_return_sequences, |
|
temperature=request.temperature, |
|
do_sample=request.do_sample, |
|
pad_token_id=generator.tokenizer.eos_token_id |
|
) |
|
|
|
generated_text = result[0]['generated_text'] |
|
|
|
return TextGenerationResponse( |
|
generated_text=generated_text, |
|
prompt=request.prompt |
|
) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
@app.get("/generate") |
|
async def generate_text_get( |
|
prompt: str, |
|
max_length: int = 50, |
|
temperature: float = 1.0 |
|
): |
|
"""GET endpoint for simple text generation""" |
|
if generator is None: |
|
raise HTTPException(status_code=503, detail="Model not loaded yet") |
|
|
|
try: |
|
result = generator( |
|
prompt, |
|
max_length=min(max_length, 200), |
|
num_return_sequences=1, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=generator.tokenizer.eos_token_id |
|
) |
|
|
|
return { |
|
"generated_text": result[0]['generated_text'], |
|
"prompt": prompt |
|
} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
port = int(os.environ.get("PORT", 7860)) |
|
uvicorn.run(app, host="0.0.0.0", port=port) |