Spaces:
Paused
Paused
File size: 3,384 Bytes
7316751 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
import types
import json
from pydantic import validator
from endpoint_handler import EndpointHandler # your handler file
import base64
app = FastAPI()
handler = None
@app.on_event("startup")
async def load_handler():
global handler
handler = EndpointHandler()
class PredictInput(BaseModel):
image: str # base64-encoded image string
question: str
stream: bool = False
@validator("question")
def question_not_empty(cls, v):
if not v.strip():
raise ValueError("Question must not be empty")
return v
@validator("image")
def valid_base64_and_size(cls, v):
try:
decoded = base64.b64decode(v, validate=True)
except Exception:
raise ValueError("`image` must be valid base64")
if len(decoded) > 10 * 1024 * 1024: # 10 MB limit
raise ValueError("Image exceeds 10 MB after decoding")
return v
class PredictRequest(BaseModel):
inputs: PredictInput
@app.get("/")
async def root():
return {"message": "FastAPI app is running on Hugging Face"}
@app.post("/predict")
async def predict_endpoint(payload: PredictRequest):
"""
Handles prediction requests by processing the input payload and returning the prediction result.
Args:
payload (PredictRequest): The request payload containing the input data for prediction, including image, question, and stream flag.
Returns:
JSONResponse: If a ValueError occurs, returns a JSON response with an error message and status code 400.
JSONResponse: If any other exception occurs, returns a JSON response with a generic error message and status code 500.
StreamingResponse: If the prediction result is a generator (streaming), returns a streaming response with event-stream media type, yielding prediction chunks as JSON.
Notes:
- Logs the received question for debugging purposes.
- Handles both standard and streaming prediction results.
- Structured JSON messages are sent to indicate the end of the stream or errors during streaming.
"""
print(f"[Request] Received question: {payload.inputs.question}")
data = {
"inputs": {
"image": payload.inputs.image,
"question": payload.inputs.question,
"stream": payload.inputs.stream
}
}
try:
result = handler.predict(data)
except ValueError as ve:
return JSONResponse({"error": str(ve)}, status_code=400)
except Exception as e:
return JSONResponse({"error": "Internal server error"}, status_code=500)
if isinstance(result, types.GeneratorType):
def event_stream():
try:
for chunk in result:
yield f"data: {json.dumps(chunk)}\n\n"
# Return structured JSON to indicate end of stream
yield f"data: {json.dumps({'end': True})}\n\n"
except Exception as e:
# Return structured JSON to indicate error
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
|