cover-removal / server.py
kalhdrawi's picture
إضافة ملفات جديدة
92d3191
raw
history blame
5.6 kB
import io
import os
import base64
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import Response, JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from briarmbg import BriaRMBG
app = FastAPI(title="BRIA RMBG Service", version="0.1.0")
# Mount static frontend directory
app.mount("/static", StaticFiles(directory="static"), name="static")
net: Optional[BriaRMBG] = None
_device: Optional[torch.device] = None
# Enable CORS for external API usage (adjust origins as needed)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _resize_image(image: Image.Image) -> Image.Image:
image = image.convert("RGB")
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def remove_background(pil_image: Image.Image) -> Image.Image:
"""Apply RMBG model to add alpha channel mask to the original image."""
if net is None:
raise RuntimeError("Model not loaded")
orig_image = pil_image.convert("RGB")
w, h = orig_image.size
image = _resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if _device and _device.type == "cuda":
im_tensor = im_tensor.to(_device, non_blocking=True)
else:
im_tensor = im_tensor.to("cpu")
# Inference
result = net(im_tensor)
# Post-process first side output
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi + 1e-8)
# Convert to PIL alpha mask
result_array = (result * 255).clamp(0, 255).byte().cpu().numpy()
pil_mask = Image.fromarray(np.squeeze(result_array))
# Compose RGBA image
new_im = orig_image.copy()
new_im.putalpha(pil_mask)
return new_im
@app.on_event("startup")
async def _load_model():
global net, _device
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# If HF_TOKEN env is provided, use it; otherwise pass token=None to avoid sending bad headers
hf_token = os.environ.get("HF_TOKEN") or None
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4", token=hf_token)
net.to(_device)
net.eval()
@app.get("/", response_class=HTMLResponse)
async def index():
# Simple redirect loader to static UI
html = (
"""
<!doctype html>
<html><head><meta http-equiv=refresh content="0; url=/static/index.html"></head>
<body>Loading UI...</body></html>
"""
)
return HTMLResponse(content=html)
@app.get("/health")
async def health():
return {"status": "ok", "device": str(_device), "cuda": torch.cuda.is_available()}
@app.post("/api/remove_bg")
async def api_remove_bg(
file: UploadFile = File(..., description="Image file to process"),
output: str = Query("image", enum=["image", "json"], description="Return raw PNG image or JSON(base64)")
):
if file.content_type is None or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Uploaded file must be an image")
data = await file.read()
try:
pil = Image.open(io.BytesIO(data))
except Exception:
raise HTTPException(status_code=400, detail="Failed to read image")
try:
out_img = remove_background(pil)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing error: {e}")
buf = io.BytesIO()
out_img.save(buf, format="PNG")
png_bytes = buf.getvalue()
if output == "json":
b64 = base64.b64encode(png_bytes).decode("utf-8")
return JSONResponse({"image_base64": b64, "format": "PNG"})
return Response(content=png_bytes, media_type="image/png")
# Optional endpoint to just return the alpha matte
@app.post("/api/matte")
async def api_matte(file: UploadFile = File(...)):
data = await file.read()
try:
pil = Image.open(io.BytesIO(data))
except Exception:
raise HTTPException(status_code=400, detail="Failed to read image")
if net is None:
raise HTTPException(status_code=503, detail="Model not loaded")
orig = pil.convert("RGB")
w, h = orig.size
image = _resize_image(orig)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
im_tensor = torch.nn.functional.normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
result = net(im_tensor)
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi + 1e-8)
result_array = (result * 255).clamp(0, 255).byte().cpu().numpy()
matte = Image.fromarray(np.squeeze(result_array))
buf = io.BytesIO()
matte.save(buf, format="PNG")
return Response(content=buf.getvalue(), media_type="image/png")
# Dev entrypoint: uvicorn server:app --host 0.0.0.0 --port 7860