File size: 3,814 Bytes
3557791
5811c7f
51a0302
 
 
 
 
5811c7f
51a0302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8b97b9
51a0302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
from transformers import pipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
import os

# Initialize FastAPI app
app = FastAPI(
    title="Text Generation API",
    description="A simple text generation API using Hugging Face transformers",
    version="1.0.0"
)

# Request model
class TextGenerationRequest(BaseModel):
    prompt: str
    max_length: Optional[int] = 50
    num_return_sequences: Optional[int] = 1
    temperature: Optional[float] = 1.0
    do_sample: Optional[bool] = True

# Response model
class TextGenerationResponse(BaseModel):
    generated_text: str
    prompt: str

# Global variable to store the pipeline
generator = None

@app.on_event("startup")
async def load_model():
    global generator
    
    # Check for GPU
    if torch.cuda.is_available():
        print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
        device = 0  # Use GPU
    else:
        print("CUDA not available, using CPU.")
        device = -1  # Use CPU
    
    # Load the text generation pipeline
    try:
        generator = pipeline(
            'text-generation', 
            model='EleutherAI/gpt-neo-2.7B', 
            device=device,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        print("Model loaded successfully!")
    except Exception as e:
        print(f"Error loading model: {e}")
        raise e

@app.get("/")
async def root():
    return {
        "message": "Text Generation API", 
        "status": "running",
        "endpoints": {
            "generate": "/generate",
            "health": "/health",
            "docs": "/docs"
        }
    }

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": generator is not None,
        "cuda_available": torch.cuda.is_available()
    }

@app.post("/generate", response_model=TextGenerationResponse)
async def generate_text(request: TextGenerationRequest):
    if generator is None:
        raise HTTPException(status_code=503, detail="Model not loaded yet")
    
    try:
        # Generate text
        result = generator(
            request.prompt,
            max_length=min(request.max_length, 200),  # Limit max length for safety
            num_return_sequences=request.num_return_sequences,
            temperature=request.temperature,
            do_sample=request.do_sample,
            pad_token_id=generator.tokenizer.eos_token_id
        )
        
        generated_text = result[0]['generated_text']
        
        return TextGenerationResponse(
            generated_text=generated_text,
            prompt=request.prompt
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

@app.get("/generate")
async def generate_text_get(
    prompt: str,
    max_length: int = 50,
    temperature: float = 1.0
):
    """GET endpoint for simple text generation"""
    if generator is None:
        raise HTTPException(status_code=503, detail="Model not loaded yet")
    
    try:
        result = generator(
            prompt,
            max_length=min(max_length, 200),
            num_return_sequences=1,
            temperature=temperature,
            do_sample=True,
            pad_token_id=generator.tokenizer.eos_token_id
        )
        
        return {
            "generated_text": result[0]['generated_text'],
            "prompt": prompt
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))  # Hugging Face Spaces uses port 7860
    uvicorn.run(app, host="0.0.0.0", port=port)