MG / app.py
Mkg09's picture
Update app.py
64482a4 verified
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import gradio as gr
# ------------------------------
# Gaussian Blur Setup & Function
# ------------------------------
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# Load the segmentation model for Gaussian blur effect
gaussian_model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
gaussian_processor = SegformerImageProcessor.from_pretrained(gaussian_model_name)
gaussian_model = SegformerForSemanticSegmentation.from_pretrained(gaussian_model_name)
gaussian_model.eval()
def apply_gaussian_blur(input_image):
# Ensure correct orientation
image = ImageOps.exif_transpose(input_image)
# Preprocess image and perform segmentation
inputs = gaussian_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = gaussian_model(**inputs)
logits = outputs.logits
upscaled_logits = torch.nn.functional.interpolate(
logits, size=image.size[::-1], mode="bilinear", align_corners=False
)
predicted = upscaled_logits.argmax(dim=1)[0].cpu().numpy()
# Get the label ID for "person"
id2label = gaussian_model.config.id2label
person_label = None
for key, label in id2label.items():
if label.lower() == "person":
person_label = int(key)
break
if person_label is None:
raise ValueError("No 'person' label found in the model's label mapping.")
# Create binary mask and composite blurred background with original foreground
mask = np.where(predicted == person_label, 255, 0).astype(np.uint8)
mask_image = Image.fromarray(mask, mode="L")
blurred_image = image.filter(ImageFilter.GaussianBlur(15))
blurred_background = Image.composite(image, blurred_image, mask_image)
return blurred_background
# ------------------------------
# Lens Blur Setup & Function
# ------------------------------
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# Load the depth estimation model for Lens blur effect
lens_model_name = "Intel/dpt-large"
lens_feature_extractor = DPTFeatureExtractor.from_pretrained(lens_model_name)
lens_model = DPTForDepthEstimation.from_pretrained(lens_model_name)
lens_model.eval()
# Options for controlling the blur effect
INVERT_DEPTH = True # Invert depth if needed (near=large)
CLAMP_NEAR = True # Force near objects to remain sharp
NEAR_THRESHOLD = 0.2 # Normalized depth threshold for no blur
N = 10 # Number of discrete blur levels
MAX_BLUR = 15 # Maximum Gaussian blur radius for farthest pixels
def apply_lens_blur(input_image):
# Prepare the image: fix orientation, convert to RGB, and resize
image = ImageOps.exif_transpose(input_image)
image = image.convert("RGB")
image = image.resize((512, 512))
# Run depth estimation
inputs = lens_feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = lens_model(**inputs)
predicted_depth = outputs.predicted_depth
if predicted_depth.dim() == 3:
predicted_depth = predicted_depth.unsqueeze(1)
predicted_depth = F.interpolate(
predicted_depth, size=(512, 512), mode="bicubic", align_corners=False
)
predicted_depth = predicted_depth.squeeze().cpu().numpy()
# Normalize depth map to [0, 1]
min_val, max_val = predicted_depth.min(), predicted_depth.max()
depth_normalized = (predicted_depth - min_val) / (max_val - min_val + 1e-8)
if INVERT_DEPTH:
depth_normalized = 1.0 - depth_normalized
if CLAMP_NEAR:
depth_normalized = np.clip(
(depth_normalized - NEAR_THRESHOLD) / (1.0 - NEAR_THRESHOLD),
0.0, 1.0
)
# Create multiple blurred versions of the image
blurred_images = []
for i in range(N):
level = i / (N - 1)
radius = level * MAX_BLUR
blurred_im = image.filter(ImageFilter.GaussianBlur(radius))
blurred_images.append(np.array(blurred_im))
# Build the final image by choosing blur level per pixel based on depth
width, height = image.size
final_image_np = np.zeros((height, width, 3), dtype=np.uint8)
depth_indices = (depth_normalized * (N - 1)).astype(np.int32)
for y in range(height):
for x in range(width):
idx = depth_indices[y, x]
final_image_np[y, x] = blurred_images[idx][y, x, :3]
final_image = Image.fromarray(final_image_np)
return final_image
# ------------------------------
# Gradio App Function
# ------------------------------
def process_image(input_image, effect):
"""
Process the uploaded image using the selected blur effect.
"""
if effect == "Gaussian Blur":
return apply_gaussian_blur(input_image)
elif effect == "Lens Blur":
return apply_lens_blur(input_image)
else:
return input_image
# Create a Gradio interface with image upload and effect selection
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(choices=["Gaussian Blur", "Lens Blur"], label="Select Effect")
],
outputs=gr.Image(label="Output Image"),
title="Blur Effects App",
description="Apply Gaussian Blur (with segmentation) or Depth-based Lens Blur to your image."
)
# Launch the app
demo.launch()