brendon-ai commited on
Commit
88208e2
·
verified ·
1 Parent(s): af21bfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -58
app.py CHANGED
@@ -1,13 +1,87 @@
1
- import torch
2
- from transformers import pipeline
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from typing import Optional
6
- import uvicorn
 
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Initialize FastAPI app
10
- app = FastAPI(
 
11
  title="Text Generation API",
12
  description="A simple text generation API using Hugging Face transformers",
13
  version="1.0.0"
@@ -16,43 +90,17 @@ app = FastAPI(
16
  # Request model
17
  class TextGenerationRequest(BaseModel):
18
  prompt: str
19
- max_length: Optional[int] = 50
20
  num_return_sequences: Optional[int] = 1
21
- temperature: Optional[float] = 1.0
22
  do_sample: Optional[bool] = True
 
23
 
24
  # Response model
25
  class TextGenerationResponse(BaseModel):
26
  generated_text: str
27
  prompt: str
28
-
29
- # Global variable to store the pipeline
30
- generator = None
31
-
32
- @app.on_event("startup")
33
- async def load_model():
34
- global generator
35
-
36
- # Check for GPU
37
- if torch.cuda.is_available():
38
- print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
39
- device = 0 # Use GPU
40
- else:
41
- print("CUDA not available, using CPU.")
42
- device = -1 # Use CPU
43
-
44
- # Load the text generation pipeline
45
- try:
46
- generator = pipeline(
47
- 'text-generation',
48
- model='microsoft/Phi-3-mini-4k-instruct',
49
- device=device,
50
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
51
- )
52
- print("Model loaded successfully!")
53
- except Exception as e:
54
- print(f"Error loading model: {e}")
55
- raise e
56
 
57
  @app.get("/")
58
  async def root():
@@ -60,74 +108,87 @@ async def root():
60
  "message": "Text Generation API",
61
  "status": "running",
62
  "endpoints": {
63
- "generate": "/generate",
 
64
  "health": "/health",
65
  "docs": "/docs"
66
- }
 
67
  }
68
 
69
  @app.get("/health")
70
  async def health_check():
71
  return {
72
- "status": "healthy",
73
  "model_loaded": generator is not None,
74
- "cuda_available": torch.cuda.is_available()
 
75
  }
76
 
77
  @app.post("/generate", response_model=TextGenerationResponse)
78
- async def generate_text(request: TextGenerationRequest):
79
  if generator is None:
80
- raise HTTPException(status_code=503, detail="Model not loaded yet")
81
-
82
  try:
83
  # Generate text
84
  result = generator(
85
  request.prompt,
86
- max_length=min(request.max_length, 200), # Limit max length for safety
87
  num_return_sequences=request.num_return_sequences,
88
  temperature=request.temperature,
89
  do_sample=request.do_sample,
90
- pad_token_id=generator.tokenizer.eos_token_id
 
 
 
 
91
  )
92
 
93
  generated_text = result[0]['generated_text']
94
 
95
  return TextGenerationResponse(
96
  generated_text=generated_text,
97
- prompt=request.prompt
 
98
  )
99
-
100
  except Exception as e:
101
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
 
102
 
103
- @app.get("/generate")
104
  async def generate_text_get(
105
  prompt: str,
106
- max_length: int = 50,
107
- temperature: float = 1.0
108
  ):
109
  """GET endpoint for simple text generation"""
110
  if generator is None:
111
- raise HTTPException(status_code=503, detail="Model not loaded yet")
112
-
113
  try:
114
  result = generator(
115
  prompt,
116
- max_length=min(max_length, 200),
117
  num_return_sequences=1,
118
  temperature=temperature,
119
  do_sample=True,
120
- pad_token_id=generator.tokenizer.eos_token_id
 
 
121
  )
122
 
123
  return {
124
  "generated_text": result[0]['generated_text'],
125
- "prompt": prompt
 
126
  }
127
-
128
  except Exception as e:
129
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
 
130
 
131
  if __name__ == "__main__":
132
- port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
133
  uvicorn.run(app, host="0.0.0.0", port=port)
 
1
+ # app.py
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from typing import Optional
5
+ import torch
6
+ from transformers import pipeline
7
  import os
8
+ from contextlib import asynccontextmanager # Import this!
9
+ import sys # Import sys for sys.exit()
10
+
11
+ # Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN
12
+ # from huggingface_hub import login
13
+
14
+ # --- Global variable to store the pipeline ---
15
+ generator = None
16
+ # Choose a model appropriate for free tier (e.g., 7B-8B parameters)
17
+ # For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big.
18
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" # Recommended for free tier
19
+
20
+ # --- Lifespan Event Handler ---
21
+ @asynccontextmanager
22
+ async def lifespan(app: FastAPI):
23
+ """
24
+ Handles startup and shutdown events for the FastAPI application.
25
+ Loads the model on startup and can optionally clean up on shutdown.
26
+ """
27
+ global generator
28
+ try:
29
+ # --- Optional: Login to Hugging Face Hub for gated models ---
30
+ # If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct),
31
+ # uncomment the following lines and ensure HF_TOKEN is set as a Space Secret.
32
+ # hf_token = os.getenv("HF_TOKEN")
33
+ # if hf_token:
34
+ # login(token=hf_token)
35
+ # print("Logged into Hugging Face Hub.")
36
+ # else:
37
+ # print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.")
38
+
39
+ # --- Startup Code: Load the model ---
40
+ if torch.cuda.is_available():
41
+ print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
42
+ device = 0 # Use GPU
43
+ # For larger models, use device_map="auto" and torch_dtype
44
+ # device_map = "auto"
45
+ # torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it
46
+ else:
47
+ print("CUDA not available, using CPU. Inference will be very slow for this model size.")
48
+ device = -1 # Use CPU
49
+ # device_map = None
50
+ # torch_dtype = torch.float32 # Default for CPU
51
+
52
+ print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}")
53
+
54
+ # The pipeline automatically handles AutoModel and AutoTokenizer.
55
+ # For better memory management with larger models, directly load with model_kwargs:
56
+ generator = pipeline(
57
+ 'text-generation',
58
+ model=MODEL_NAME,
59
+ device=device,
60
+ # Pass your HF token to the model loading for gated models
61
+ # token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model
62
+ # For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported
63
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
64
+ # For more fine-grained control and auto device mapping for multiple GPUs:
65
+ # model_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
66
+ )
67
+ print("Model loaded successfully!")
68
+
69
+ # 'yield' signifies that the startup code has completed, and the application
70
+ # can now start processing requests.
71
+ yield
72
+
73
+ except Exception as e:
74
+ print(f"CRITICAL ERROR: Failed to load model during startup: {e}")
75
+ # Exit with a non-zero code to indicate failure if model loading fails
76
+ sys.exit(1)
77
+
78
+ finally:
79
+ # --- Shutdown Code (Optional): Clean up resources ---
80
+ print("Application shutting down. Any cleanup can go here.")
81
 
82
+
83
+ # --- Initialize FastAPI application with the lifespan handler ---
84
+ app = FastAPI(lifespan=lifespan, # Use the lifespan context manager
85
  title="Text Generation API",
86
  description="A simple text generation API using Hugging Face transformers",
87
  version="1.0.0"
 
90
  # Request model
91
  class TextGenerationRequest(BaseModel):
92
  prompt: str
93
+ max_new_tokens: Optional[int] = 250 # Changed from max_length for better control
94
  num_return_sequences: Optional[int] = 1
95
+ temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output
96
  do_sample: Optional[bool] = True
97
+ top_p: Optional[float] = 0.9 # Added top_p for more control
98
 
99
  # Response model
100
  class TextGenerationResponse(BaseModel):
101
  generated_text: str
102
  prompt: str
103
+ model_name: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  @app.get("/")
106
  async def root():
 
108
  "message": "Text Generation API",
109
  "status": "running",
110
  "endpoints": {
111
+ "generate_post": "/generate", # Renamed for clarity
112
+ "generate_get": "/generate_simple", # Renamed for clarity
113
  "health": "/health",
114
  "docs": "/docs"
115
+ },
116
+ "current_model": MODEL_NAME
117
  }
118
 
119
  @app.get("/health")
120
  async def health_check():
121
  return {
122
+ "status": "healthy" if generator else "unhealthy",
123
  "model_loaded": generator is not None,
124
+ "cuda_available": torch.cuda.is_available(),
125
+ "model_name": MODEL_NAME
126
  }
127
 
128
  @app.post("/generate", response_model=TextGenerationResponse)
129
+ async def generate_text_post(request: TextGenerationRequest):
130
  if generator is None:
131
+ raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
132
+
133
  try:
134
  # Generate text
135
  result = generator(
136
  request.prompt,
137
+ max_new_tokens=request.max_new_tokens, # Use max_new_tokens
138
  num_return_sequences=request.num_return_sequences,
139
  temperature=request.temperature,
140
  do_sample=request.do_sample,
141
+ top_p=request.top_p, # Pass top_p
142
+ pad_token_id=generator.tokenizer.eos_token_id,
143
+ eos_token_id=generator.tokenizer.eos_token_id,
144
+ # Add stop sequences relevant to your instruction-tuned model format
145
+ # stop_sequences=["\nUser:", "\n###", "\n\n"]
146
  )
147
 
148
  generated_text = result[0]['generated_text']
149
 
150
  return TextGenerationResponse(
151
  generated_text=generated_text,
152
+ prompt=request.prompt,
153
+ model_name=MODEL_NAME
154
  )
155
+
156
  except Exception as e:
157
+ print(f"Generation failed: {str(e)}")
158
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
159
 
160
+ @app.get("/generate_simple") # Changed endpoint name to avoid conflict with POST
161
  async def generate_text_get(
162
  prompt: str,
163
+ max_new_tokens: int = 250, # Changed from max_length
164
+ temperature: float = 0.7
165
  ):
166
  """GET endpoint for simple text generation"""
167
  if generator is None:
168
+ raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
169
+
170
  try:
171
  result = generator(
172
  prompt,
173
+ max_new_tokens=max_new_tokens,
174
  num_return_sequences=1,
175
  temperature=temperature,
176
  do_sample=True,
177
+ top_p=0.9, # Default top_p for simple GET
178
+ pad_token_id=generator.tokenizer.eos_token_id,
179
+ eos_token_id=generator.tokenizer.eos_token_id,
180
  )
181
 
182
  return {
183
  "generated_text": result[0]['generated_text'],
184
+ "prompt": prompt,
185
+ "model_name": MODEL_NAME
186
  }
187
+
188
  except Exception as e:
189
+ print(f"Generation failed: {str(e)}")
190
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
191
 
192
  if __name__ == "__main__":
193
+ port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
194
  uvicorn.run(app, host="0.0.0.0", port=port)