Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Log cache directory | |
logger.info(f"TRANSFORMERS_CACHE set to: {os.getenv('TRANSFORMERS_CACHE', '/.cache')}") | |
app = FastAPI(title="LaMini-LM API", | |
description="API for text generation using LaMini-GPT-774M", version="1.0.0") | |
# Define request model | |
class TextGenerationRequest(BaseModel): | |
prompt: str | |
max_length: int = 100 | |
temperature: float = 1.0 | |
top_p: float = 0.9 | |
# Load model (cached after first load) | |
try: | |
logger.info("Loading LaMini-GPT-774M model...") | |
# device=-1 for CPU | |
generator = pipeline( | |
'text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1) | |
logger.info("Model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
raise Exception(f"Model loading failed: {str(e)}") | |
async def generate_text(request: TextGenerationRequest): | |
""" | |
Generate text based on the input prompt using LaMini-GPT-774M. | |
""" | |
try: | |
# Validate inputs | |
if not request.prompt.strip(): | |
raise HTTPException( | |
status_code=400, detail="Prompt cannot be empty") | |
if request.max_length < 10 or request.max_length > 500: | |
raise HTTPException( | |
status_code=400, detail="max_length must be between 10 and 500") | |
if request.temperature <= 0 or request.temperature > 2: | |
raise HTTPException( | |
status_code=400, detail="temperature must be between 0 and 2") | |
if request.top_p <= 0 or request.top_p > 1: | |
raise HTTPException( | |
status_code=400, detail="top_p must be between 0 and 1") | |
# Generate text | |
logger.info(f"Generating text for prompt: {request.prompt[:50]}...") | |
wrapper = "Instruction: You are a helpful assistant. Please respond to the following prompt.\n\nPrompt: {}\n\nResponse:".format( | |
request.prompt) | |
outputs = generator( | |
wrapper, | |
max_length=request.max_length, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
num_return_sequences=1, | |
do_sample=True | |
) | |
generated_text = outputs[0]['generated_text'].replace( | |
wrapper, "").strip() | |
return {"generated_text": generated_text} | |
except Exception as e: | |
logger.error(f"Error during text generation: {str(e)}") | |
raise HTTPException( | |
status_code=500, detail=f"Text generation failed: {str(e)}") | |
async def root(): | |
""" | |
Root endpoint with basic info. | |
""" | |
return {"message": "Welcome to the LaMini-LM API. Use POST /generate to generate text."} | |