Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import torch | |
from transformers import SamModel, SamProcessor | |
from diffusers import StableDiffusionInpaintPipeline | |
# Constants | |
IMG_SIZE = 512 | |
# Global variables to store points and the original image | |
input_points = [] | |
input_image = None | |
def generate_mask(image, points): | |
""" | |
Generates a mask using SAM based on input points. | |
""" | |
if not points: | |
return None | |
image = image.convert("RGB") | |
points = [tuple(point) for point in points] | |
# Initialize SAM model and processor on CPU | |
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu") | |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu") | |
with torch.no_grad(): | |
outputs = sam_model(**inputs) | |
masks = sam_processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
) | |
if len(masks) == 0: | |
return None | |
best_mask = masks[0][0][outputs.iou_scores.argmax()] | |
binary_mask = ~best_mask.numpy().astype(bool).astype(int) | |
return binary_mask | |
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale): | |
""" | |
Replaces the object in the image based on the mask and prompt. | |
""" | |
if mask is None: | |
return image | |
# Initialize Inpainting pipeline on CPU with a compatible model | |
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float32 | |
).to("cpu") | |
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) | |
generator = torch.Generator("cpu").manual_seed(seed) | |
try: | |
result = inpaint_pipeline( | |
prompt=prompt, | |
image=image, | |
mask_image=mask_image, | |
negative_prompt=negative_prompt if negative_prompt else None, | |
generator=generator, | |
guidance_scale=guidance_scale | |
).images[0] | |
return result | |
except Exception as e: | |
print(f"Inpainting error: {e}") | |
return image | |
def visualize_mask(image, mask): | |
""" | |
Overlays the mask on the image for visualization. | |
""" | |
if mask is None: | |
return image | |
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8) | |
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency | |
mask_rgba = Image.fromarray(bg_transparent) | |
overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba) | |
return overlay.convert("RGB") | |
def get_points(img, evt: gr.SelectData): | |
""" | |
Captures points selected by the user on the image. | |
""" | |
global input_points | |
global input_image | |
if len(input_points) == 0: | |
input_image = img.copy() | |
x = evt.index[0] | |
y = evt.index[1] | |
input_points.append([x, y]) | |
# Generate mask based on selected points | |
mask = generate_mask(input_image, input_points) | |
# Mark selected points with a green crossmark | |
draw = ImageDraw.Draw(img) | |
size = 10 | |
for point in input_points: | |
px, py = point | |
draw.line((px - size, py, px + size, py), fill="green", width=5) | |
draw.line((px, py - size, px, py + size), fill="green", width=5) | |
# Visualize the mask overlay | |
masked_image = visualize_mask(input_image, mask) | |
return masked_image, img | |
def run_inpaint(prompt, negative_prompt, cfg, seed, invert): | |
""" | |
Runs the inpainting process based on user inputs. | |
""" | |
global input_image | |
global input_points | |
if input_image is None or len(input_points) == 0: | |
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.") | |
mask = generate_mask(input_image, input_points) | |
if invert: | |
what = 'subject' | |
mask = ~mask | |
else: | |
what = 'background' | |
try: | |
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg) | |
except Exception as e: | |
raise gr.Error(str(e)) | |
return inpainted.resize((IMG_SIZE, IMG_SIZE)) | |
def reset_points_func(): | |
""" | |
Resets the selected points and the input image. | |
""" | |
global input_points | |
global input_image | |
input_points = [] | |
input_image = None | |
return None, None, None | |
def preprocess(input_img): | |
""" | |
Preprocesses the uploaded image to ensure it is square and resized. | |
""" | |
if input_img is None: | |
return None | |
width, height = input_img.size | |
if width != height: | |
# Add white padding to make the image square | |
new_size = max(width, height) | |
new_image = Image.new("RGB", (new_size, new_size), 'white') | |
left = (new_size - width) // 2 | |
top = (new_size - height) // 2 | |
new_image.paste(input_img, (left, top)) | |
input_img = new_image | |
return input_img.resize((IMG_SIZE, IMG_SIZE)) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Object Replacement App | |
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image. | |
**Instructions:** | |
1. **Upload Image:** Click on the first image box to upload your image. | |
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy. | |
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output. | |
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed. | |
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image. | |
6. **Reset:** Click the "Reset" button to clear selections and start over. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Image upload and point selection | |
upload_image = gr.Image( | |
label="Upload Image", | |
type="pil", | |
interactive=True, | |
height=IMG_SIZE, | |
width=IMG_SIZE | |
) | |
mask_visualization = gr.Image( | |
label="Selected Object Mask Overlay", | |
interactive=False, | |
height=IMG_SIZE, | |
width=IMG_SIZE | |
) | |
selected_image = gr.Image( | |
label="Image with Selected Points", | |
type="pil", | |
interactive=False, | |
height=IMG_SIZE, | |
width=IMG_SIZE, | |
) | |
# Capture points using the select event | |
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image]) | |
# Preprocess image on change | |
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image]) | |
# Text inputs and settings | |
prompt = gr.Textbox( | |
label="Replacement Prompt", | |
placeholder="e.g., a red sports car", | |
lines=2 | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="e.g., blurry, low quality", | |
lines=2 | |
) | |
cfg = gr.Slider( | |
label="Classifier-Free Guidance Scale", | |
minimum=1.0, | |
maximum=20.0, | |
value=7.5, | |
step=0.5 | |
) | |
seed = gr.Number( | |
label="Seed", | |
value=42, | |
precision=0 | |
) | |
invert = gr.Checkbox( | |
label="Infill subject instead of background" | |
) | |
# Buttons | |
replace_button = gr.Button("Replace Object") | |
reset_button = gr.Button("Reset") | |
with gr.Column(): | |
# Output images | |
augmented_image = gr.Image( | |
label="Augmented Image", | |
type="pil", | |
interactive=False, | |
height=IMG_SIZE, | |
width=IMG_SIZE, | |
) | |
# Define button actions | |
replace_button.click( | |
fn=run_inpaint, | |
inputs=[prompt, negative_prompt, cfg, seed, invert], | |
outputs=[augmented_image] | |
) | |
reset_button.click( | |
fn=reset_points_func, | |
inputs=[], | |
outputs=[mask_visualization, selected_image, augmented_image] | |
) | |
# Examples (optional) | |
gr.Markdown( | |
""" | |
## EXAMPLES | |
Click on an example to load it. Then, follow the instructions above. | |
""") | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=[ | |
[ | |
"car.png", | |
"a red sports car", | |
"blurry, low quality", | |
42 | |
], | |
[ | |
"monalisa.png", | |
"a rockstar", | |
"dark, overexposed", | |
123 | |
], | |
], | |
inputs=[ | |
upload_image, | |
prompt, | |
negative_prompt, | |
seed | |
], | |
label="Click to load examples", | |
cache_examples=False # Set to False to avoid the error | |
) | |
demo.queue(max_size=10).launch() |