File size: 2,892 Bytes
919f56e
 
 
 
 
 
 
 
 
bd77c32
 
 
919f56e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Log cache directory
logger.info(f"TRANSFORMERS_CACHE set to: {os.getenv('TRANSFORMERS_CACHE', '/.cache')}")

app = FastAPI(title="LaMini-LM API",
              description="API for text generation using LaMini-GPT-774M", version="1.0.0")

# Define request model


class TextGenerationRequest(BaseModel):
    prompt: str
    max_length: int = 100
    temperature: float = 1.0
    top_p: float = 0.9


# Load model (cached after first load)
try:
    logger.info("Loading LaMini-GPT-774M model...")
    # device=-1 for CPU
    generator = pipeline(
        'text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1)
    logger.info("Model loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load model: {str(e)}")
    raise Exception(f"Model loading failed: {str(e)}")


@app.post("/generate")
async def generate_text(request: TextGenerationRequest):
    """
    Generate text based on the input prompt using LaMini-GPT-774M.
    """
    try:
        # Validate inputs
        if not request.prompt.strip():
            raise HTTPException(
                status_code=400, detail="Prompt cannot be empty")
        if request.max_length < 10 or request.max_length > 500:
            raise HTTPException(
                status_code=400, detail="max_length must be between 10 and 500")
        if request.temperature <= 0 or request.temperature > 2:
            raise HTTPException(
                status_code=400, detail="temperature must be between 0 and 2")
        if request.top_p <= 0 or request.top_p > 1:
            raise HTTPException(
                status_code=400, detail="top_p must be between 0 and 1")

        # Generate text
        logger.info(f"Generating text for prompt: {request.prompt[:50]}...")
        wrapper = "Instruction: You are a helpful assistant. Please respond to the following prompt.\n\nPrompt: {}\n\nResponse:".format(
            request.prompt)
        outputs = generator(
            wrapper,
            max_length=request.max_length,
            temperature=request.temperature,
            top_p=request.top_p,
            num_return_sequences=1,
            do_sample=True
        )
        generated_text = outputs[0]['generated_text'].replace(
            wrapper, "").strip()

        return {"generated_text": generated_text}
    except Exception as e:
        logger.error(f"Error during text generation: {str(e)}")
        raise HTTPException(
            status_code=500, detail=f"Text generation failed: {str(e)}")


@app.get("/")
async def root():
    """
    Root endpoint with basic info.
    """
    return {"message": "Welcome to the LaMini-LM API. Use POST /generate to generate text."}