sreejith8100 commited on
Commit
7316751
·
verified ·
1 Parent(s): e598ee7

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +22 -0
  2. endpoint_handler.py +91 -0
  3. main.py +90 -0
  4. requirements.txt +19 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
2
+
3
+ RUN apt-get update && apt-get install -y wget
4
+ RUN useradd -m -u 1000 user
5
+
6
+ USER user
7
+ WORKDIR /app
8
+
9
+ ENV PATH="/home/user/.local/bin:$PATH"
10
+ ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
11
+ ENV TORCH_CUDA_ARCH_LIST="8.0+PTX"
12
+
13
+ RUN wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.4/flash_attn-2.7.3+cu121torch2.3-cp310-cp310-linux_x86_64.whl
14
+ RUN pip install ./flash_attn-2.7.3+cu121torch2.3-cp310-cp310-linux_x86_64.whl && rm flash_attn-2.7.3+cu121torch2.3-cp310-cp310-linux_x86_64.whl
15
+
16
+ COPY --chown=user requirements.txt .
17
+ RUN pip install --upgrade pip setuptools wheel
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ COPY --chown=user . .
21
+
22
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
endpoint_handler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from io import BytesIO
5
+ import base64
6
+ from huggingface_hub import login
7
+ from huggingface_hub import login
8
+ import os
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, model_dir=None):
12
+ print("[Init] Initializing EndpointHandler...")
13
+ self.load_model()
14
+
15
+ def load_model(self):
16
+ hf_token = os.getenv("HF_TOKEN")
17
+ model_path = "openbmb/MiniCPM-o-2_6" # use model repo name directly
18
+
19
+ if hf_token:
20
+ print("[Auth] Logging into Hugging Face Hub with token...")
21
+ login(token=hf_token)
22
+
23
+ print(f"[Model Load] Loading model from: {model_path}")
24
+ try:
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
26
+ self.model = AutoModel.from_pretrained(
27
+ model_path,
28
+ trust_remote_code=True,
29
+ attn_implementation='sdpa',
30
+ torch_dtype='auto', # safer on Spaces
31
+ init_vision=True,
32
+ init_audio=False,
33
+ init_tts=False
34
+ ).eval().cuda()
35
+ print("[Model Load] Model successfully loaded and moved to CUDA.")
36
+ except Exception as e:
37
+ print(f"[Model Load Error] {e}")
38
+ raise RuntimeError(f"Failed to load model: {e}")
39
+
40
+ def load_image(self, image_base64):
41
+ try:
42
+ print("[Image Load] Decoding base64 image...")
43
+ image_bytes = base64.b64decode(image_base64)
44
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
45
+ print("[Image Load] Image successfully decoded and converted to RGB.")
46
+ return image
47
+ except Exception as e:
48
+ print(f"[Image Load Error] {e}")
49
+ raise ValueError(f"Failed to open image from base64 string: {e}")
50
+
51
+ def predict(self, request):
52
+ print(f"[Predict] Received request: {request}")
53
+
54
+ image_base64 = request.get("inputs", {}).get("image")
55
+ question = request.get("inputs", {}).get("question")
56
+ stream = request.get("inputs", {}).get("stream", False)
57
+
58
+ if not image_base64 or not question:
59
+ print("[Predict Error] Missing 'image' or 'question' in the request.")
60
+ return {"error": "Missing 'image' or 'question' in inputs."}
61
+
62
+ try:
63
+ image = self.load_image(image_base64)
64
+ msgs = [{"role": "user", "content": [image, question]}]
65
+
66
+ print(f"[Predict] Asking model with question: {question}")
67
+ print("[Predict] Starting chat inference...")
68
+
69
+ res = self.model.chat(
70
+ image=None,
71
+ msgs=msgs,
72
+ tokenizer=self.tokenizer,
73
+ sampling=True,
74
+ stream=stream
75
+ )
76
+
77
+ if stream:
78
+ for new_text in res:
79
+ yield {"output": new_text}
80
+ else:
81
+ generated_text = "".join(res)
82
+ print("[Predict] Inference complete.")
83
+ return {"output": generated_text}
84
+
85
+ except Exception as e:
86
+ print(f"[Predict Error] {e}")
87
+ return {"error": str(e)}
88
+
89
+ def __call__(self, data):
90
+ print("[__call__] Invoked handler with data.")
91
+ return self.predict(data)
main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from pydantic import BaseModel
4
+ import types
5
+ import json
6
+ from pydantic import validator
7
+ from endpoint_handler import EndpointHandler # your handler file
8
+ import base64
9
+
10
+ app = FastAPI()
11
+
12
+ handler = None
13
+
14
+ @app.on_event("startup")
15
+ async def load_handler():
16
+ global handler
17
+ handler = EndpointHandler()
18
+
19
+ class PredictInput(BaseModel):
20
+ image: str # base64-encoded image string
21
+ question: str
22
+ stream: bool = False
23
+
24
+ @validator("question")
25
+ def question_not_empty(cls, v):
26
+ if not v.strip():
27
+ raise ValueError("Question must not be empty")
28
+ return v
29
+
30
+ @validator("image")
31
+ def valid_base64_and_size(cls, v):
32
+ try:
33
+ decoded = base64.b64decode(v, validate=True)
34
+ except Exception:
35
+ raise ValueError("`image` must be valid base64")
36
+ if len(decoded) > 10 * 1024 * 1024: # 10 MB limit
37
+ raise ValueError("Image exceeds 10 MB after decoding")
38
+ return v
39
+
40
+ class PredictRequest(BaseModel):
41
+ inputs: PredictInput
42
+
43
+ @app.get("/")
44
+ async def root():
45
+ return {"message": "FastAPI app is running on Hugging Face"}
46
+
47
+ @app.post("/predict")
48
+ async def predict_endpoint(payload: PredictRequest):
49
+ """
50
+ Handles prediction requests by processing the input payload and returning the prediction result.
51
+ Args:
52
+ payload (PredictRequest): The request payload containing the input data for prediction, including image, question, and stream flag.
53
+ Returns:
54
+ JSONResponse: If a ValueError occurs, returns a JSON response with an error message and status code 400.
55
+ JSONResponse: If any other exception occurs, returns a JSON response with a generic error message and status code 500.
56
+ StreamingResponse: If the prediction result is a generator (streaming), returns a streaming response with event-stream media type, yielding prediction chunks as JSON.
57
+ Notes:
58
+ - Logs the received question for debugging purposes.
59
+ - Handles both standard and streaming prediction results.
60
+ - Structured JSON messages are sent to indicate the end of the stream or errors during streaming.
61
+ """
62
+ print(f"[Request] Received question: {payload.inputs.question}")
63
+
64
+ data = {
65
+ "inputs": {
66
+ "image": payload.inputs.image,
67
+ "question": payload.inputs.question,
68
+ "stream": payload.inputs.stream
69
+ }
70
+ }
71
+
72
+ try:
73
+ result = handler.predict(data)
74
+ except ValueError as ve:
75
+ return JSONResponse({"error": str(ve)}, status_code=400)
76
+ except Exception as e:
77
+ return JSONResponse({"error": "Internal server error"}, status_code=500)
78
+
79
+ if isinstance(result, types.GeneratorType):
80
+ def event_stream():
81
+ try:
82
+ for chunk in result:
83
+ yield f"data: {json.dumps(chunk)}\n\n"
84
+ # Return structured JSON to indicate end of stream
85
+ yield f"data: {json.dumps({'end': True})}\n\n"
86
+ except Exception as e:
87
+ # Return structured JSON to indicate error
88
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
89
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
90
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow==10.1.0
2
+ torch==2.3.1
3
+ torchaudio==2.3.1
4
+ torchvision==0.18.1
5
+ transformers==4.44.2
6
+ librosa==0.9.0
7
+ soundfile==0.12.1
8
+ vector-quantize-pytorch==1.18.5
9
+ vocos==0.1.0
10
+ decord
11
+ moviepy
12
+ einops
13
+ accelerate
14
+ openbmb
15
+ fastapi
16
+ uvicorn[standard]
17
+ timm>=0.6.13
18
+ sentencepiece>=0.1.99
19
+ python-multipart