File size: 3,384 Bytes
7316751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
import types
import json
from pydantic import validator
from endpoint_handler import EndpointHandler  # your handler file
import base64

app = FastAPI()

handler = None

@app.on_event("startup")
async def load_handler():
    global handler
    handler = EndpointHandler()

class PredictInput(BaseModel):
    image: str       # base64-encoded image string
    question: str
    stream: bool = False

    @validator("question")
    def question_not_empty(cls, v):
        if not v.strip():
            raise ValueError("Question must not be empty")
        return v

    @validator("image")
    def valid_base64_and_size(cls, v):
        try:
            decoded = base64.b64decode(v, validate=True)
        except Exception:
            raise ValueError("`image` must be valid base64")
        if len(decoded) > 10 * 1024 * 1024:  # 10 MB limit
            raise ValueError("Image exceeds 10 MB after decoding")
        return v

class PredictRequest(BaseModel):
    inputs: PredictInput

@app.get("/")
async def root():
    return {"message": "FastAPI app is running on Hugging Face"}

@app.post("/predict")
async def predict_endpoint(payload: PredictRequest):
    """

    Handles prediction requests by processing the input payload and returning the prediction result.

    Args:

        payload (PredictRequest): The request payload containing the input data for prediction, including image, question, and stream flag.

    Returns:

        JSONResponse: If a ValueError occurs, returns a JSON response with an error message and status code 400.

        JSONResponse: If any other exception occurs, returns a JSON response with a generic error message and status code 500.

        StreamingResponse: If the prediction result is a generator (streaming), returns a streaming response with event-stream media type, yielding prediction chunks as JSON.

    Notes:

        - Logs the received question for debugging purposes.

        - Handles both standard and streaming prediction results.

        - Structured JSON messages are sent to indicate the end of the stream or errors during streaming.

    """
    print(f"[Request] Received question: {payload.inputs.question}")
    
    data = {
        "inputs": {
            "image": payload.inputs.image,
            "question": payload.inputs.question,
            "stream": payload.inputs.stream
        }
    }
    
    try:
        result = handler.predict(data)
    except ValueError as ve:
        return JSONResponse({"error": str(ve)}, status_code=400)
    except Exception as e:
        return JSONResponse({"error": "Internal server error"}, status_code=500)

    if isinstance(result, types.GeneratorType):
        def event_stream():
            try:
                for chunk in result:
                    yield f"data: {json.dumps(chunk)}\n\n"
                # Return structured JSON to indicate end of stream
                yield f"data: {json.dumps({'end': True})}\n\n"
            except Exception as e:
                # Return structured JSON to indicate error
                yield f"data: {json.dumps({'error': str(e)})}\n\n"
    return StreamingResponse(event_stream(), media_type="text/event-stream")