import io import os import base64 from typing import Optional import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize from fastapi import FastAPI, File, UploadFile, HTTPException, Query from fastapi.responses import Response, JSONResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from PIL import Image from briarmbg import BriaRMBG app = FastAPI(title="BRIA RMBG Service", version="0.1.0") # Mount static frontend directory app.mount("/static", StaticFiles(directory="static"), name="static") net: Optional[BriaRMBG] = None _device: Optional[torch.device] = None # Enable CORS for external API usage (adjust origins as needed) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def _resize_image(image: Image.Image) -> Image.Image: image = image.convert("RGB") model_input_size = (1024, 1024) image = image.resize(model_input_size, Image.BILINEAR) return image def remove_background(pil_image: Image.Image) -> Image.Image: """Apply RMBG model to add alpha channel mask to the original image.""" if net is None: raise RuntimeError("Model not loaded") orig_image = pil_image.convert("RGB") w, h = orig_image.size image = _resize_image(orig_image) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) im_tensor = torch.unsqueeze(im_tensor, 0) im_tensor = torch.divide(im_tensor, 255.0) im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) if _device and _device.type == "cuda": im_tensor = im_tensor.to(_device, non_blocking=True) else: im_tensor = im_tensor.to("cpu") # Inference result = net(im_tensor) # Post-process first side output result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi + 1e-8) # Convert to PIL alpha mask result_array = (result * 255).clamp(0, 255).byte().cpu().numpy() pil_mask = Image.fromarray(np.squeeze(result_array)) # Compose RGBA image new_im = orig_image.copy() new_im.putalpha(pil_mask) return new_im @app.on_event("startup") async def _load_model(): global net, _device _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # If HF_TOKEN env is provided, use it; otherwise pass token=None to avoid sending bad headers hf_token = os.environ.get("HF_TOKEN") or None net = BriaRMBG.from_pretrained("briaai/RMBG-1.4", token=hf_token) net.to(_device) net.eval() @app.get("/", response_class=HTMLResponse) async def index(): # Simple redirect loader to static UI html = ( """ Loading UI... """ ) return HTMLResponse(content=html) @app.get("/health") async def health(): return {"status": "ok", "device": str(_device), "cuda": torch.cuda.is_available()} @app.post("/api/remove_bg") async def api_remove_bg( file: UploadFile = File(..., description="Image file to process"), output: str = Query("image", enum=["image", "json"], description="Return raw PNG image or JSON(base64)") ): if file.content_type is None or not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Uploaded file must be an image") data = await file.read() try: pil = Image.open(io.BytesIO(data)) except Exception: raise HTTPException(status_code=400, detail="Failed to read image") try: out_img = remove_background(pil) except Exception as e: raise HTTPException(status_code=500, detail=f"Processing error: {e}") buf = io.BytesIO() out_img.save(buf, format="PNG") png_bytes = buf.getvalue() if output == "json": b64 = base64.b64encode(png_bytes).decode("utf-8") return JSONResponse({"image_base64": b64, "format": "PNG"}) return Response(content=png_bytes, media_type="image/png") # Optional endpoint to just return the alpha matte @app.post("/api/matte") async def api_matte(file: UploadFile = File(...)): data = await file.read() try: pil = Image.open(io.BytesIO(data)) except Exception: raise HTTPException(status_code=400, detail="Failed to read image") if net is None: raise HTTPException(status_code=503, detail="Model not loaded") orig = pil.convert("RGB") w, h = orig.size image = _resize_image(orig) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0 im_tensor = torch.nn.functional.normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) if torch.cuda.is_available(): im_tensor = im_tensor.cuda() result = net(im_tensor) result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi + 1e-8) result_array = (result * 255).clamp(0, 255).byte().cpu().numpy() matte = Image.fromarray(np.squeeze(result_array)) buf = io.BytesIO() matte.save(buf, format="PNG") return Response(content=buf.getvalue(), media_type="image/png") # Dev entrypoint: uvicorn server:app --host 0.0.0.0 --port 7860