Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,87 @@
|
|
1 |
-
|
2 |
-
from transformers import pipeline
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
from typing import Optional
|
6 |
-
import
|
|
|
7 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
11 |
title="Text Generation API",
|
12 |
description="A simple text generation API using Hugging Face transformers",
|
13 |
version="1.0.0"
|
@@ -16,43 +90,17 @@ app = FastAPI(
|
|
16 |
# Request model
|
17 |
class TextGenerationRequest(BaseModel):
|
18 |
prompt: str
|
19 |
-
|
20 |
num_return_sequences: Optional[int] = 1
|
21 |
-
temperature: Optional[float] =
|
22 |
do_sample: Optional[bool] = True
|
|
|
23 |
|
24 |
# Response model
|
25 |
class TextGenerationResponse(BaseModel):
|
26 |
generated_text: str
|
27 |
prompt: str
|
28 |
-
|
29 |
-
# Global variable to store the pipeline
|
30 |
-
generator = None
|
31 |
-
|
32 |
-
@app.on_event("startup")
|
33 |
-
async def load_model():
|
34 |
-
global generator
|
35 |
-
|
36 |
-
# Check for GPU
|
37 |
-
if torch.cuda.is_available():
|
38 |
-
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
|
39 |
-
device = 0 # Use GPU
|
40 |
-
else:
|
41 |
-
print("CUDA not available, using CPU.")
|
42 |
-
device = -1 # Use CPU
|
43 |
-
|
44 |
-
# Load the text generation pipeline
|
45 |
-
try:
|
46 |
-
generator = pipeline(
|
47 |
-
'text-generation',
|
48 |
-
model='microsoft/Phi-3-mini-4k-instruct',
|
49 |
-
device=device,
|
50 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
51 |
-
)
|
52 |
-
print("Model loaded successfully!")
|
53 |
-
except Exception as e:
|
54 |
-
print(f"Error loading model: {e}")
|
55 |
-
raise e
|
56 |
|
57 |
@app.get("/")
|
58 |
async def root():
|
@@ -60,74 +108,87 @@ async def root():
|
|
60 |
"message": "Text Generation API",
|
61 |
"status": "running",
|
62 |
"endpoints": {
|
63 |
-
"
|
|
|
64 |
"health": "/health",
|
65 |
"docs": "/docs"
|
66 |
-
}
|
|
|
67 |
}
|
68 |
|
69 |
@app.get("/health")
|
70 |
async def health_check():
|
71 |
return {
|
72 |
-
"status": "healthy",
|
73 |
"model_loaded": generator is not None,
|
74 |
-
"cuda_available": torch.cuda.is_available()
|
|
|
75 |
}
|
76 |
|
77 |
@app.post("/generate", response_model=TextGenerationResponse)
|
78 |
-
async def
|
79 |
if generator is None:
|
80 |
-
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
81 |
-
|
82 |
try:
|
83 |
# Generate text
|
84 |
result = generator(
|
85 |
request.prompt,
|
86 |
-
|
87 |
num_return_sequences=request.num_return_sequences,
|
88 |
temperature=request.temperature,
|
89 |
do_sample=request.do_sample,
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
)
|
92 |
|
93 |
generated_text = result[0]['generated_text']
|
94 |
|
95 |
return TextGenerationResponse(
|
96 |
generated_text=generated_text,
|
97 |
-
prompt=request.prompt
|
|
|
98 |
)
|
99 |
-
|
100 |
except Exception as e:
|
101 |
-
|
|
|
102 |
|
103 |
-
@app.get("/
|
104 |
async def generate_text_get(
|
105 |
prompt: str,
|
106 |
-
|
107 |
-
temperature: float =
|
108 |
):
|
109 |
"""GET endpoint for simple text generation"""
|
110 |
if generator is None:
|
111 |
-
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
112 |
-
|
113 |
try:
|
114 |
result = generator(
|
115 |
prompt,
|
116 |
-
|
117 |
num_return_sequences=1,
|
118 |
temperature=temperature,
|
119 |
do_sample=True,
|
120 |
-
|
|
|
|
|
121 |
)
|
122 |
|
123 |
return {
|
124 |
"generated_text": result[0]['generated_text'],
|
125 |
-
"prompt": prompt
|
|
|
126 |
}
|
127 |
-
|
128 |
except Exception as e:
|
129 |
-
|
|
|
130 |
|
131 |
if __name__ == "__main__":
|
132 |
-
port = int(os.environ.get("PORT", 7860))
|
133 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
1 |
+
# app.py
|
|
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
from pydantic import BaseModel
|
4 |
from typing import Optional
|
5 |
+
import torch
|
6 |
+
from transformers import pipeline
|
7 |
import os
|
8 |
+
from contextlib import asynccontextmanager # Import this!
|
9 |
+
import sys # Import sys for sys.exit()
|
10 |
+
|
11 |
+
# Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN
|
12 |
+
# from huggingface_hub import login
|
13 |
+
|
14 |
+
# --- Global variable to store the pipeline ---
|
15 |
+
generator = None
|
16 |
+
# Choose a model appropriate for free tier (e.g., 7B-8B parameters)
|
17 |
+
# For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big.
|
18 |
+
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" # Recommended for free tier
|
19 |
+
|
20 |
+
# --- Lifespan Event Handler ---
|
21 |
+
@asynccontextmanager
|
22 |
+
async def lifespan(app: FastAPI):
|
23 |
+
"""
|
24 |
+
Handles startup and shutdown events for the FastAPI application.
|
25 |
+
Loads the model on startup and can optionally clean up on shutdown.
|
26 |
+
"""
|
27 |
+
global generator
|
28 |
+
try:
|
29 |
+
# --- Optional: Login to Hugging Face Hub for gated models ---
|
30 |
+
# If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct),
|
31 |
+
# uncomment the following lines and ensure HF_TOKEN is set as a Space Secret.
|
32 |
+
# hf_token = os.getenv("HF_TOKEN")
|
33 |
+
# if hf_token:
|
34 |
+
# login(token=hf_token)
|
35 |
+
# print("Logged into Hugging Face Hub.")
|
36 |
+
# else:
|
37 |
+
# print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.")
|
38 |
+
|
39 |
+
# --- Startup Code: Load the model ---
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
|
42 |
+
device = 0 # Use GPU
|
43 |
+
# For larger models, use device_map="auto" and torch_dtype
|
44 |
+
# device_map = "auto"
|
45 |
+
# torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it
|
46 |
+
else:
|
47 |
+
print("CUDA not available, using CPU. Inference will be very slow for this model size.")
|
48 |
+
device = -1 # Use CPU
|
49 |
+
# device_map = None
|
50 |
+
# torch_dtype = torch.float32 # Default for CPU
|
51 |
+
|
52 |
+
print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}")
|
53 |
+
|
54 |
+
# The pipeline automatically handles AutoModel and AutoTokenizer.
|
55 |
+
# For better memory management with larger models, directly load with model_kwargs:
|
56 |
+
generator = pipeline(
|
57 |
+
'text-generation',
|
58 |
+
model=MODEL_NAME,
|
59 |
+
device=device,
|
60 |
+
# Pass your HF token to the model loading for gated models
|
61 |
+
# token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model
|
62 |
+
# For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported
|
63 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
64 |
+
# For more fine-grained control and auto device mapping for multiple GPUs:
|
65 |
+
# model_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
|
66 |
+
)
|
67 |
+
print("Model loaded successfully!")
|
68 |
+
|
69 |
+
# 'yield' signifies that the startup code has completed, and the application
|
70 |
+
# can now start processing requests.
|
71 |
+
yield
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
print(f"CRITICAL ERROR: Failed to load model during startup: {e}")
|
75 |
+
# Exit with a non-zero code to indicate failure if model loading fails
|
76 |
+
sys.exit(1)
|
77 |
+
|
78 |
+
finally:
|
79 |
+
# --- Shutdown Code (Optional): Clean up resources ---
|
80 |
+
print("Application shutting down. Any cleanup can go here.")
|
81 |
|
82 |
+
|
83 |
+
# --- Initialize FastAPI application with the lifespan handler ---
|
84 |
+
app = FastAPI(lifespan=lifespan, # Use the lifespan context manager
|
85 |
title="Text Generation API",
|
86 |
description="A simple text generation API using Hugging Face transformers",
|
87 |
version="1.0.0"
|
|
|
90 |
# Request model
|
91 |
class TextGenerationRequest(BaseModel):
|
92 |
prompt: str
|
93 |
+
max_new_tokens: Optional[int] = 250 # Changed from max_length for better control
|
94 |
num_return_sequences: Optional[int] = 1
|
95 |
+
temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output
|
96 |
do_sample: Optional[bool] = True
|
97 |
+
top_p: Optional[float] = 0.9 # Added top_p for more control
|
98 |
|
99 |
# Response model
|
100 |
class TextGenerationResponse(BaseModel):
|
101 |
generated_text: str
|
102 |
prompt: str
|
103 |
+
model_name: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
@app.get("/")
|
106 |
async def root():
|
|
|
108 |
"message": "Text Generation API",
|
109 |
"status": "running",
|
110 |
"endpoints": {
|
111 |
+
"generate_post": "/generate", # Renamed for clarity
|
112 |
+
"generate_get": "/generate_simple", # Renamed for clarity
|
113 |
"health": "/health",
|
114 |
"docs": "/docs"
|
115 |
+
},
|
116 |
+
"current_model": MODEL_NAME
|
117 |
}
|
118 |
|
119 |
@app.get("/health")
|
120 |
async def health_check():
|
121 |
return {
|
122 |
+
"status": "healthy" if generator else "unhealthy",
|
123 |
"model_loaded": generator is not None,
|
124 |
+
"cuda_available": torch.cuda.is_available(),
|
125 |
+
"model_name": MODEL_NAME
|
126 |
}
|
127 |
|
128 |
@app.post("/generate", response_model=TextGenerationResponse)
|
129 |
+
async def generate_text_post(request: TextGenerationRequest):
|
130 |
if generator is None:
|
131 |
+
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
|
132 |
+
|
133 |
try:
|
134 |
# Generate text
|
135 |
result = generator(
|
136 |
request.prompt,
|
137 |
+
max_new_tokens=request.max_new_tokens, # Use max_new_tokens
|
138 |
num_return_sequences=request.num_return_sequences,
|
139 |
temperature=request.temperature,
|
140 |
do_sample=request.do_sample,
|
141 |
+
top_p=request.top_p, # Pass top_p
|
142 |
+
pad_token_id=generator.tokenizer.eos_token_id,
|
143 |
+
eos_token_id=generator.tokenizer.eos_token_id,
|
144 |
+
# Add stop sequences relevant to your instruction-tuned model format
|
145 |
+
# stop_sequences=["\nUser:", "\n###", "\n\n"]
|
146 |
)
|
147 |
|
148 |
generated_text = result[0]['generated_text']
|
149 |
|
150 |
return TextGenerationResponse(
|
151 |
generated_text=generated_text,
|
152 |
+
prompt=request.prompt,
|
153 |
+
model_name=MODEL_NAME
|
154 |
)
|
155 |
+
|
156 |
except Exception as e:
|
157 |
+
print(f"Generation failed: {str(e)}")
|
158 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
|
159 |
|
160 |
+
@app.get("/generate_simple") # Changed endpoint name to avoid conflict with POST
|
161 |
async def generate_text_get(
|
162 |
prompt: str,
|
163 |
+
max_new_tokens: int = 250, # Changed from max_length
|
164 |
+
temperature: float = 0.7
|
165 |
):
|
166 |
"""GET endpoint for simple text generation"""
|
167 |
if generator is None:
|
168 |
+
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
|
169 |
+
|
170 |
try:
|
171 |
result = generator(
|
172 |
prompt,
|
173 |
+
max_new_tokens=max_new_tokens,
|
174 |
num_return_sequences=1,
|
175 |
temperature=temperature,
|
176 |
do_sample=True,
|
177 |
+
top_p=0.9, # Default top_p for simple GET
|
178 |
+
pad_token_id=generator.tokenizer.eos_token_id,
|
179 |
+
eos_token_id=generator.tokenizer.eos_token_id,
|
180 |
)
|
181 |
|
182 |
return {
|
183 |
"generated_text": result[0]['generated_text'],
|
184 |
+
"prompt": prompt,
|
185 |
+
"model_name": MODEL_NAME
|
186 |
}
|
187 |
+
|
188 |
except Exception as e:
|
189 |
+
print(f"Generation failed: {str(e)}")
|
190 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
|
191 |
|
192 |
if __name__ == "__main__":
|
193 |
+
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
|
194 |
uvicorn.run(app, host="0.0.0.0", port=port)
|