LaMini-LM-API / main.py
usmansafdarktk
upgraded the ordering of routes
f52fd48
import os
import logging
import torch
import gc
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="LaMini-LM API", description="API for text generation using LaMini-GPT-774M", version="1.0.0")
# Add CORS middleware to allow UI requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust for production to specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class TextGenerationRequest(BaseModel):
instruction: str
max_length: int = 100
temperature: float = 1.0
top_p: float = 0.9
generator = None
def load_model():
global generator
if generator is None:
try:
logger.info("Loading LaMini-GPT-774M model...")
generator = pipeline('text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1, trust_remote_code=True)
logger.info("Model loaded successfully.")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
generator = None
raise HTTPException(status_code=503, detail=f"Model loading failed: {str(e)}")
@app.get("/api/health")
async def health_check():
return {"status": "healthy"}
@app.get("/api")
async def root():
return {"message": "Welcome to the LaMini-LM API. Use POST /generate to generate text or visit /ui for the web interface."}
@app.post("/api/generate")
async def generate_text(request: TextGenerationRequest):
logger.info(f"Received request: {request.dict()}")
if generator is None:
load_model()
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded. Check server logs.")
try:
if not request.instruction.strip():
raise HTTPException(status_code=400, detail="Instruction cannot be empty")
if request.max_length < 10 or request.max_length > 500:
raise HTTPException(status_code=400, detail="max_length must be between 10 and 500")
if request.temperature <= 0 or request.temperature > 2:
raise HTTPException(status_code=400, detail="temperature must be between 0 and 2")
if request.top_p <= 0 or request.top_p > 1:
raise HTTPException(status_code=400, detail="top_p must be between 0 and 1")
logger.info(f"Generating text for instruction: {request.instruction[:50]}...")
wrapper = "Instruction: You are a helpful assistant. Please respond to the following instruction.\n\nInstruction: {}\n\nResponse:".format(
request.instruction)
outputs = generator(
wrapper,
max_length=request.max_length,
temperature=request.temperature,
top_p=request.top_p,
num_return_sequences=1,
do_sample=True,
truncation=True
)
generated_text = outputs[0]['generated_text'].replace(wrapper, "").strip()
return {"generated_text": generated_text}
except Exception as e:
logger.error(f"Error during text generation: {str(e)}")
raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}")
# Mount static files at root (this must be last)
app.mount("/", StaticFiles(directory="static", html=True), name="static")
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)