File size: 3,740 Bytes
b854771
1574d49
 
 
919f56e
5b2b9b0
f418bc1
919f56e
 
 
 
 
 
1574d49
919f56e
f418bc1
 
 
 
 
 
 
 
 
919f56e
1574d49
919f56e
 
 
 
1574d49
 
 
 
 
 
 
f418bc1
1574d49
 
 
 
 
 
 
f418bc1
919f56e
0bd46d0
14ad9da
 
 
0bd46d0
f418bc1
 
 
0bd46d0
919f56e
f418bc1
1574d49
 
14ad9da
 
919f56e
14ad9da
 
919f56e
14ad9da
919f56e
14ad9da
919f56e
14ad9da
919f56e
14ad9da
 
 
919f56e
 
 
 
 
 
f418bc1
 
919f56e
14ad9da
919f56e
 
 
14ad9da
919f56e
f52fd48
 
 
14ad9da
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import logging
import torch
import gc
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="LaMini-LM API", description="API for text generation using LaMini-GPT-774M", version="1.0.0")

# Add CORS middleware to allow UI requests
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust for production to specific origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class TextGenerationRequest(BaseModel):
    instruction: str
    max_length: int = 100
    temperature: float = 1.0
    top_p: float = 0.9

generator = None

def load_model():
    global generator
    if generator is None:
        try:
            logger.info("Loading LaMini-GPT-774M model...")
            generator = pipeline('text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1, trust_remote_code=True)
            logger.info("Model loaded successfully.")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            generator = None
            raise HTTPException(status_code=503, detail=f"Model loading failed: {str(e)}")

@app.get("/api/health")
async def health_check():
    return {"status": "healthy"}

@app.get("/api")
async def root():
    return {"message": "Welcome to the LaMini-LM API. Use POST /generate to generate text or visit /ui for the web interface."}

@app.post("/api/generate")
async def generate_text(request: TextGenerationRequest):
    logger.info(f"Received request: {request.dict()}")
    if generator is None:
        load_model()
    if generator is None:
        raise HTTPException(status_code=503, detail="Model not loaded. Check server logs.")
    try:
        if not request.instruction.strip():
            raise HTTPException(status_code=400, detail="Instruction cannot be empty")
        if request.max_length < 10 or request.max_length > 500:
            raise HTTPException(status_code=400, detail="max_length must be between 10 and 500")
        if request.temperature <= 0 or request.temperature > 2:
            raise HTTPException(status_code=400, detail="temperature must be between 0 and 2")
        if request.top_p <= 0 or request.top_p > 1:
            raise HTTPException(status_code=400, detail="top_p must be between 0 and 1")

        logger.info(f"Generating text for instruction: {request.instruction[:50]}...")
        wrapper = "Instruction: You are a helpful assistant. Please respond to the following instruction.\n\nInstruction: {}\n\nResponse:".format(
            request.instruction)
        outputs = generator(
            wrapper,
            max_length=request.max_length,
            temperature=request.temperature,
            top_p=request.top_p,
            num_return_sequences=1,
            do_sample=True,
            truncation=True
        )
        generated_text = outputs[0]['generated_text'].replace(wrapper, "").strip()
        return {"generated_text": generated_text}
    except Exception as e:
        logger.error(f"Error during text generation: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}")

# Mount static files at root (this must be last)
app.mount("/", StaticFiles(directory="static", html=True), name="static")

if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)