Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| import base64 | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import normalize | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
| from fastapi.responses import Response, JSONResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| from briarmbg import BriaRMBG | |
| app = FastAPI(title="BRIA RMBG Service", version="0.1.0") | |
| # Mount static frontend directory | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| net: Optional[BriaRMBG] = None | |
| _device: Optional[torch.device] = None | |
| # Enable CORS for external API usage (adjust origins as needed) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def _resize_image(image: Image.Image) -> Image.Image: | |
| image = image.convert("RGB") | |
| model_input_size = (1024, 1024) | |
| image = image.resize(model_input_size, Image.BILINEAR) | |
| return image | |
| def remove_background(pil_image: Image.Image) -> Image.Image: | |
| """Apply RMBG model to add alpha channel mask to the original image.""" | |
| if net is None: | |
| raise RuntimeError("Model not loaded") | |
| orig_image = pil_image.convert("RGB") | |
| w, h = orig_image.size | |
| image = _resize_image(orig_image) | |
| im_np = np.array(image) | |
| im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) | |
| im_tensor = torch.unsqueeze(im_tensor, 0) | |
| im_tensor = torch.divide(im_tensor, 255.0) | |
| im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) | |
| if _device and _device.type == "cuda": | |
| im_tensor = im_tensor.to(_device, non_blocking=True) | |
| else: | |
| im_tensor = im_tensor.to("cpu") | |
| # Inference | |
| result = net(im_tensor) | |
| # Post-process first side output | |
| result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0) | |
| ma = torch.max(result) | |
| mi = torch.min(result) | |
| result = (result - mi) / (ma - mi + 1e-8) | |
| # Convert to PIL alpha mask | |
| result_array = (result * 255).clamp(0, 255).byte().cpu().numpy() | |
| pil_mask = Image.fromarray(np.squeeze(result_array)) | |
| # Compose RGBA image | |
| new_im = orig_image.copy() | |
| new_im.putalpha(pil_mask) | |
| return new_im | |
| async def _load_model(): | |
| global net, _device | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # If HF_TOKEN env is provided, use it; otherwise pass token=None to avoid sending bad headers | |
| hf_token = os.environ.get("HF_TOKEN") or None | |
| net = BriaRMBG.from_pretrained("briaai/RMBG-1.4", token=hf_token) | |
| net.to(_device) | |
| net.eval() | |
| async def index(): | |
| # Simple redirect loader to static UI | |
| html = ( | |
| """ | |
| <!doctype html> | |
| <html><head><meta http-equiv=refresh content="0; url=/static/index.html"></head> | |
| <body>Loading UI...</body></html> | |
| """ | |
| ) | |
| return HTMLResponse(content=html) | |
| async def health(): | |
| return {"status": "ok", "device": str(_device), "cuda": torch.cuda.is_available()} | |
| async def api_remove_bg( | |
| file: UploadFile = File(..., description="Image file to process"), | |
| output: str = Query("image", enum=["image", "json"], description="Return raw PNG image or JSON(base64)") | |
| ): | |
| if file.content_type is None or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Uploaded file must be an image") | |
| data = await file.read() | |
| try: | |
| pil = Image.open(io.BytesIO(data)) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Failed to read image") | |
| try: | |
| out_img = remove_background(pil) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Processing error: {e}") | |
| buf = io.BytesIO() | |
| out_img.save(buf, format="PNG") | |
| png_bytes = buf.getvalue() | |
| if output == "json": | |
| b64 = base64.b64encode(png_bytes).decode("utf-8") | |
| return JSONResponse({"image_base64": b64, "format": "PNG"}) | |
| return Response(content=png_bytes, media_type="image/png") | |
| # Optional endpoint to just return the alpha matte | |
| async def api_matte(file: UploadFile = File(...)): | |
| data = await file.read() | |
| try: | |
| pil = Image.open(io.BytesIO(data)) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Failed to read image") | |
| if net is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| orig = pil.convert("RGB") | |
| w, h = orig.size | |
| image = _resize_image(orig) | |
| im_np = np.array(image) | |
| im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0 | |
| im_tensor = torch.nn.functional.normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) | |
| if torch.cuda.is_available(): | |
| im_tensor = im_tensor.cuda() | |
| result = net(im_tensor) | |
| result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0) | |
| ma = torch.max(result) | |
| mi = torch.min(result) | |
| result = (result - mi) / (ma - mi + 1e-8) | |
| result_array = (result * 255).clamp(0, 255).byte().cpu().numpy() | |
| matte = Image.fromarray(np.squeeze(result_array)) | |
| buf = io.BytesIO() | |
| matte.save(buf, format="PNG") | |
| return Response(content=buf.getvalue(), media_type="image/png") | |
| # Dev entrypoint: uvicorn server:app --host 0.0.0.0 --port 7860 | |