|
from typing import Optional, Tuple |
|
import torch |
|
import torchvision.transforms.functional as TF |
|
from PIL import Image |
|
from cache_and_edit import CachedPipeline |
|
import numpy as np |
|
from IPython.display import display |
|
|
|
from cache_and_edit.flux_pipeline import EditedFluxPipeline |
|
|
|
def image2latent(pipe, image, latent_nudging_scalar = 1.15): |
|
image = pipe.image_processor.preprocess(image).type(pipe.vae.dtype).to("cuda") |
|
latents = pipe.vae.encode(image)["latent_dist"].mean |
|
latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor |
|
latents = latents * latent_nudging_scalar |
|
|
|
latents = pipe._pack_latents( |
|
latents=latents, |
|
batch_size=1, |
|
num_channels_latents=16, |
|
height=image.size(2) // 8, |
|
width= image.size(3) // 8 |
|
) |
|
|
|
return latents |
|
|
|
|
|
def get_inverted_input_noise(pipe: CachedPipeline, |
|
image, |
|
prompt: str = "", |
|
num_steps: int = 28, |
|
latent_nudging_scalar: int = 1.15): |
|
"""_summary_ |
|
|
|
Args: |
|
pipe (CachedPipeline): _description_ |
|
image (_type_): _description_ |
|
num_steps (int, optional): _description_. Defaults to 28. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
width, height = image.size |
|
inverted_latents_list = [] |
|
|
|
if isinstance(pipe.pipe, EditedFluxPipeline): |
|
|
|
_ = pipe.run( |
|
prompt, |
|
num_inference_steps=num_steps, |
|
seed=42, |
|
guidance_scale=1, |
|
output_type="latent", |
|
latents=image2latent(pipe.pipe, image, latent_nudging_scalar=latent_nudging_scalar), |
|
empty_clip_embeddings=False, |
|
inverse=True, |
|
width=width, |
|
height=height, |
|
is_inverted_generation=True, |
|
inverted_latents_list=inverted_latents_list |
|
).images[0] |
|
|
|
return inverted_latents_list |
|
|
|
|
|
else: |
|
noise = pipe.run( |
|
prompt, |
|
num_inference_steps=num_steps, |
|
seed=42, |
|
guidance_scale=1, |
|
output_type="latent", |
|
latents=image2latent(pipe.pipe, image, latent_nudging_scalar=latent_nudging_scalar), |
|
empty_clip_embeddings=False, |
|
inverse=True, |
|
width=width, |
|
height=height |
|
).images[0] |
|
|
|
return noise |
|
|
|
|
|
|
|
|
|
def resize_bounding_box( |
|
bb_mask: torch.Tensor, |
|
target_size: Tuple[int, int] = (64, 64), |
|
) -> torch.Tensor: |
|
""" |
|
Given a bounding box mask, patches it into a mask with the target size. |
|
The mask is a 2D tensor of shape (H, W) where each element is either 0 or 1. |
|
Any patch that contains at least one 1 in the original mask will be set to 1 in the output mask. |
|
|
|
Args: |
|
bb_mask (torch.Tensor): The bounding box mask as a boolean tensor of shape (H, W). |
|
target_size (Tuple[int, int]): The size of the target mask as a tuple (H, W). |
|
|
|
Returns: |
|
torch.Tensor: The resized bounding box mask as a boolean tensor of shape (H, W). |
|
""" |
|
|
|
w_mask, h_mask = bb_mask.shape[-2:] |
|
w_target, h_target = target_size |
|
|
|
|
|
if w_mask % w_target != 0 or h_mask % h_target != 0: |
|
raise ValueError( |
|
f"Mask size {bb_mask.shape[-2:]} is not compatible with target size {target_size}" |
|
) |
|
|
|
|
|
patch_size = (w_mask // w_target, h_mask // h_target) |
|
|
|
|
|
out_mask = torch.zeros((w_target, h_target), dtype=bb_mask.dtype, device=bb_mask.device) |
|
for i in range(w_target): |
|
for j in range(h_target): |
|
patch = bb_mask[ |
|
i * patch_size[0] : (i + 1) * patch_size[0], |
|
j * patch_size[1] : (j + 1) * patch_size[1], |
|
] |
|
if torch.sum(patch) > 0: |
|
out_mask[i, j] = 1 |
|
else: |
|
out_mask[i, j] = 0 |
|
|
|
return out_mask |
|
|
|
|
|
def place_image_in_bounding_box( |
|
image_tensor_whc: torch.Tensor, |
|
mask_tensor_wh: torch.Tensor |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Resizes an input image to fit within a bounding box (from a mask) |
|
preserving aspect ratio, and places it centered on a new canvas. |
|
|
|
Args: |
|
image_tensor_whc: Input image tensor, shape [width, height, channels]. |
|
mask_tensor_wh: Bounding box mask, shape [width, height]. Defines canvas size |
|
and contains a rectangle of 1s for the BB. |
|
|
|
Returns: |
|
A tuple: |
|
- output_image_whc (torch.Tensor): Canvas with the resized image placed. |
|
Shape [canvas_width, canvas_height, channels]. |
|
- new_mask_wh (torch.Tensor): Mask showing the actual placement of the image. |
|
Shape [canvas_width, canvas_height]. |
|
""" |
|
|
|
|
|
if not (image_tensor_whc.ndim == 3 and image_tensor_whc.shape[0] > 0 and image_tensor_whc.shape[1] > 0): |
|
raise ValueError( |
|
"Input image_tensor_whc must be a 3D tensor [width, height, channels] " |
|
"with width > 0 and height > 0." |
|
) |
|
img_orig_w, img_orig_h, num_channels = image_tensor_whc.shape |
|
|
|
|
|
if not (mask_tensor_wh.ndim == 2): |
|
raise ValueError("Input mask_tensor_wh must be a 2D tensor [width, height].") |
|
canvas_w, canvas_h = mask_tensor_wh.shape |
|
|
|
|
|
empty_output_image = torch.zeros( |
|
canvas_w, canvas_h, num_channels, |
|
dtype=image_tensor_whc.dtype, device=image_tensor_whc.device |
|
) |
|
empty_new_mask = torch.zeros( |
|
canvas_w, canvas_h, |
|
dtype=mask_tensor_wh.dtype, device=mask_tensor_wh.device |
|
) |
|
|
|
|
|
|
|
fg_coords = torch.nonzero(mask_tensor_wh, as_tuple=False) |
|
|
|
if fg_coords.numel() == 0: |
|
return empty_output_image, empty_new_mask |
|
|
|
|
|
x_min_bb, y_min_bb = fg_coords[:, 0].min(), fg_coords[:, 1].min() |
|
x_max_bb, y_max_bb = fg_coords[:, 0].max(), fg_coords[:, 1].max() |
|
|
|
bb_target_w = x_max_bb - x_min_bb + 1 |
|
bb_target_h = y_max_bb - y_min_bb + 1 |
|
|
|
if bb_target_w <= 0 or bb_target_h <= 0: |
|
return empty_output_image, empty_new_mask |
|
|
|
|
|
|
|
image_tensor_chw = image_tensor_whc.permute(2, 1, 0) |
|
|
|
|
|
scale_factor_w = bb_target_w / img_orig_w |
|
scale_factor_h = bb_target_h / img_orig_h |
|
scale = min(scale_factor_w, scale_factor_h) |
|
|
|
resized_img_w = int(img_orig_w * scale) |
|
resized_img_h = int(img_orig_h * scale) |
|
|
|
if resized_img_w == 0 or resized_img_h == 0: |
|
return empty_output_image, empty_new_mask |
|
|
|
|
|
try: |
|
|
|
resized_image_chw = TF.resize(image_tensor_chw, [resized_img_h, resized_img_w], antialias=True) |
|
except TypeError: |
|
resized_image_chw = TF.resize(image_tensor_chw, [resized_img_h, resized_img_w]) |
|
|
|
|
|
resized_image_whc = resized_image_chw.permute(2, 1, 0) |
|
|
|
|
|
output_image_whc = torch.zeros( |
|
canvas_w, canvas_h, num_channels, |
|
dtype=image_tensor_whc.dtype, device=image_tensor_whc.device |
|
) |
|
|
|
|
|
offset_x = (bb_target_w - resized_img_w) // 2 |
|
offset_y = (bb_target_h - resized_img_h) // 2 |
|
|
|
paste_x_start = x_min_bb + offset_x |
|
paste_y_start = y_min_bb + offset_y |
|
|
|
paste_x_end = paste_x_start + resized_img_w |
|
paste_y_end = paste_y_start + resized_img_h |
|
|
|
|
|
output_image_whc[paste_x_start:paste_x_end, paste_y_start:paste_y_end, :] = resized_image_whc |
|
|
|
|
|
new_mask_wh = torch.zeros( |
|
canvas_w, canvas_h, |
|
dtype=mask_tensor_wh.dtype, device=mask_tensor_wh.device |
|
) |
|
new_mask_wh[paste_x_start:paste_x_end, paste_y_start:paste_y_end] = 1 |
|
|
|
return output_image_whc, new_mask_wh |
|
|
|
|
|
|
|
|
|
def compose_noise_masks(cached_pipe, |
|
foreground_image: Image, |
|
background_image: Image, |
|
target_mask: torch.Tensor, |
|
foreground_mask: torch.Tensor, |
|
option: str = "bg", |
|
photoshop_fg_noise: bool = False, |
|
num_inversion_steps: int = 100, |
|
): |
|
|
|
""" |
|
Composes noise masks for image generation using different strategies. |
|
This function composes noise masks for stable diffusion inversion, with several composition strategies: |
|
- "bg": Uses only background noise |
|
- "bg_fg": Combines background and foreground noise using a target mask |
|
- "segmentation1": Uses segmentation mask to compose foreground and background noise |
|
- "segmentation2": Implements advanced composition with additional boundary noise |
|
Parameters: |
|
---------- |
|
cached_pipe : object |
|
The cached stable diffusion pipeline used for noise inversion |
|
foreground_image : PIL.Image |
|
The foreground image to be placed in the background |
|
background_image : PIL.Image |
|
The background image |
|
target_mask : torch.Tensor |
|
Target mask indicating the position where the foreground should be placed |
|
foreground_mask : torch.Tensor |
|
Segmentation mask of the foreground object |
|
option : str, default="bg" |
|
Composition strategy: "bg", "bg_fg", "segmentation1", or "segmentation2" |
|
photoshop_fg_noise : bool, default=False |
|
Whether to generate noise from a photoshopped composition of foreground and background |
|
num_inversion_steps : int, default=100 |
|
Number of steps for the inversion process |
|
Returns: |
|
------- |
|
dict |
|
A dictionary containing: |
|
- "noise": Dictionary of generated noises (composed_noise, foreground_noise, background_noise) |
|
- "latent_masks": Dictionary of latent masks used for composition |
|
""" |
|
|
|
|
|
assert option in ["bg", "bg_fg", "segmentation1", "segmentation2"], f"Invalid option: {option}" |
|
|
|
|
|
PATCH_SIZE = 16 |
|
latent_size = background_image.size[0] // PATCH_SIZE |
|
latents = (latent_size, latent_size) |
|
|
|
|
|
if option == "bg": |
|
|
|
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps) |
|
composed_noise = bg_noise |
|
|
|
all_noise = { |
|
"composed_noise": composed_noise, |
|
"background_noise": bg_noise, |
|
} |
|
all_latent_masks = {} |
|
|
|
|
|
elif option == "bg_fg": |
|
|
|
|
|
reframed_fg_img, resized_mask = place_image_in_bounding_box( |
|
torch.from_numpy(np.array(foreground_image)), |
|
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool) |
|
) |
|
|
|
|
|
reframed_fg_img = Image.fromarray(reframed_fg_img.numpy()) |
|
|
|
|
|
|
|
resized_mask_img = Image.fromarray((resized_mask.numpy() * 255).astype(np.uint8)) |
|
|
|
|
|
|
|
if photoshop_fg_noise: |
|
|
|
photoshop_img = Image.fromarray( |
|
(torch.tensor(np.array(background_image)) * ~resized_mask.cpu().unsqueeze(-1) + torch.tensor(np.array(reframed_fg_img)) * resized_mask.cpu().unsqueeze(-1)).numpy() |
|
) |
|
|
|
fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps) |
|
else: |
|
fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps) |
|
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps) |
|
|
|
|
|
latent_mask = resize_bounding_box( |
|
resized_mask, |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
|
|
|
|
composed_noise = bg_noise * (~latent_mask) + fg_noise * latent_mask |
|
all_latent_masks = { |
|
"latent_mask": latent_mask, |
|
} |
|
all_noise = { |
|
"composed_noise": composed_noise, |
|
"foreground_noise": fg_noise, |
|
"background_noise": bg_noise, |
|
} |
|
|
|
elif option == "segmentation1": |
|
|
|
|
|
|
|
segmented_fg_image = torch.tensor( |
|
np.array( |
|
foreground_mask.resize(foreground_image.size) |
|
)).to(torch.bool).unsqueeze(-1) * torch.tensor( |
|
np.array(foreground_image) |
|
) |
|
|
|
|
|
reframed_fg_img, resized_mask = place_image_in_bounding_box( |
|
segmented_fg_image, |
|
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool) |
|
) |
|
|
|
reframed_fg_img = Image.fromarray(reframed_fg_img.numpy()) |
|
|
|
|
|
resized_mask_img = Image.fromarray((resized_mask.numpy() * 255).astype(np.uint8)) |
|
|
|
|
|
foreground_mask = foreground_mask.convert("RGB") |
|
reframed_segmentation_mask, resized_mask = place_image_in_bounding_box( |
|
torch.from_numpy(np.array(foreground_mask)), |
|
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool) |
|
) |
|
|
|
reframed_segmentation_mask = reframed_segmentation_mask.numpy() |
|
reframed_segmentation_mask_img = Image.fromarray(reframed_segmentation_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if photoshop_fg_noise: |
|
|
|
|
|
seg_mask_temp = torch.from_numpy(reframed_segmentation_mask).bool() |
|
bg_temp = torch.tensor(np.array(background_image)) |
|
fg_temp = torch.tensor(np.array(reframed_fg_img)) |
|
|
|
photoshop_img = Image.fromarray( |
|
(bg_temp * (~seg_mask_temp) + fg_temp * seg_mask_temp).numpy() |
|
).convert("RGB") |
|
|
|
fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps) |
|
else: |
|
fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps) |
|
|
|
|
|
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps) |
|
bg_noise_init = bg_noise[-1].squeeze(0) if isinstance(bg_noise, list) else bg_noise |
|
fg_noise_init = fg_noise[-1].squeeze(0) if isinstance(fg_noise, list) else fg_noise |
|
|
|
|
|
|
|
reframed_segmentation_mask = reframed_segmentation_mask[:, :, 0] |
|
reframed_segmentation_mask = torch.from_numpy(reframed_segmentation_mask).to(dtype=bool) |
|
latent_mask = resize_bounding_box( |
|
reframed_segmentation_mask, |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
bb_mask = resize_bounding_box( |
|
resized_mask, |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
|
|
|
|
composed_noise = bg_noise_init * (~latent_mask) + fg_noise_init * latent_mask |
|
|
|
all_latent_masks = { |
|
"latent_segmentation_mask": latent_mask, |
|
|
|
"bb_mask": bb_mask, |
|
} |
|
all_noise = { |
|
"composed_noise": composed_noise, |
|
"foreground_noise": fg_noise_init, |
|
"background_noise": bg_noise_init, |
|
"foreground_noise_list": fg_noise if isinstance(fg_noise, list) else None, |
|
"background_noise_list": bg_noise if isinstance(bg_noise, list) else None, |
|
} |
|
|
|
|
|
elif option == "segmentation2": |
|
|
|
|
|
|
|
segmented_fg_image = torch.tensor( |
|
np.array( |
|
foreground_mask.resize(foreground_image.size) |
|
)).to(torch.bool).unsqueeze(-1) * torch.tensor( |
|
np.array(foreground_image) |
|
) |
|
|
|
|
|
reframed_fg_img, resized_mask = place_image_in_bounding_box( |
|
segmented_fg_image, |
|
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool) |
|
) |
|
|
|
|
|
reframed_fg_img = Image.fromarray(reframed_fg_img.numpy()) |
|
|
|
|
|
|
|
foreground_mask = foreground_mask.convert("RGB") |
|
reframed_segmentation_mask, resized_mask = place_image_in_bounding_box( |
|
torch.from_numpy(np.array(foreground_mask)), |
|
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool) |
|
) |
|
|
|
reframed_segmentation_mask = reframed_segmentation_mask.numpy() |
|
reframed_segmentation_mask_img = Image.fromarray(reframed_segmentation_mask) |
|
|
|
|
|
|
|
xor_mask = target_mask ^ np.array(reframed_segmentation_mask_img.convert("L")) |
|
|
|
|
|
|
|
|
|
|
|
if photoshop_fg_noise: |
|
|
|
|
|
seg_mask_temp = torch.from_numpy(reframed_segmentation_mask).bool() |
|
bg_temp = torch.tensor(np.array(background_image)) |
|
fg_temp = torch.tensor(np.array(reframed_fg_img)) |
|
|
|
photoshop_img = Image.fromarray( |
|
(bg_temp * (~seg_mask_temp) + fg_temp * seg_mask_temp).numpy() |
|
).convert("RGB") |
|
|
|
fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps) |
|
else: |
|
fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps) |
|
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps) |
|
|
|
|
|
|
|
reframed_segmentation_mask = reframed_segmentation_mask[:, :, 0] |
|
reframed_segmentation_mask = torch.from_numpy(reframed_segmentation_mask).to(dtype=bool) |
|
|
|
|
|
latent_seg_mask = resize_bounding_box( |
|
reframed_segmentation_mask, |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
print(latent_seg_mask.shape) |
|
|
|
|
|
latent_xor_mask = resize_bounding_box( |
|
torch.from_numpy(xor_mask), |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
|
|
|
|
print(resized_mask.shape) |
|
latent_target_mask = resize_bounding_box( |
|
resized_mask, |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
|
|
|
|
bg_noise_init = bg_noise[-1].squeeze(0) if isinstance(bg_noise, list) else bg_noise |
|
fg_noise_init = fg_noise[-1].squeeze(0) if isinstance(fg_noise, list) else fg_noise |
|
|
|
bg = bg_noise_init[-1] * (~latent_target_mask) |
|
fg = fg_noise_init[-1] * latent_seg_mask |
|
boundary = latent_xor_mask * torch.randn(latent_xor_mask.shape).to("cuda") |
|
composed_noise = bg + fg + boundary |
|
|
|
all_latent_masks = { |
|
"latent_target_mask": latent_target_mask, |
|
"latent_segmentation_mask": latent_seg_mask, |
|
"latent_xor_mask": latent_xor_mask, |
|
} |
|
all_noise = { |
|
"composed_noise": composed_noise, |
|
"foreground_noise": fg_noise_init, |
|
"background_noise": bg_noise_init, |
|
"foreground_noise_list": fg_noise if isinstance(fg_noise, list) else None, |
|
"background_noise_list": bg_noise if isinstance(bg_noise, list) else None, |
|
} |
|
|
|
|
|
latent_bbox_mask = resize_bounding_box( |
|
torch.from_numpy(np.array(target_mask.resize(background_image.size))), |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
all_latent_masks["latent_bbox_mask"] = latent_bbox_mask |
|
|
|
|
|
reframed_fg_img, resized_mask = place_image_in_bounding_box( |
|
torch.from_numpy(np.array(foreground_image)), |
|
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool) |
|
) |
|
bb_mask = resize_bounding_box( |
|
resized_mask, |
|
target_size=latents, |
|
).flatten().unsqueeze(-1).to("cuda") |
|
all_latent_masks["latent_segmentation_mask"] = bb_mask |
|
|
|
|
|
return { |
|
"noise": all_noise, |
|
"latent_masks": all_latent_masks, |
|
} |