Spaces:
Build error
Build error
from typing import List, Optional, Tuple, Union | |
from PIL import Image | |
import numpy as np | |
import torch | |
from diffusers import ( | |
ControlNetModel, | |
StableDiffusionControlNetInpaintPipeline, | |
UniPCMultistepScheduler, | |
) | |
from .controlnet import StableDiffusionControlNet, MODEL_DICT | |
class StableDiffusionControlNetInpaint(StableDiffusionControlNet): | |
"""StableDiffusion with ControlNet model for inpainting images based on prompts. | |
Args: | |
control_model_name (str): | |
Name of the controlnet processor. | |
sd_model_name (str): | |
Name of the StableDiffusion model. | |
""" | |
def __init__( | |
self, | |
control_model_name: str, | |
sd_model_name: Optional[str] = "runwayml/stable-diffusion-inpainting", | |
) -> None: | |
super().__init__( | |
control_model_name=control_model_name, | |
sd_model_name=sd_model_name, | |
) | |
def create_pipe( | |
self, sd_model_name: str, control_model_name: str | |
) -> StableDiffusionControlNetInpaintPipeline: | |
"""Create a StableDiffusionControlNetInpaintPipeline. | |
Args: | |
sd_model_name (str): StableDiffusion model name. | |
control_model_name (str): Name of the ControlNet module. | |
Returns: | |
StableDiffusionControlNetInpaintPipeline | |
""" | |
controlnet = ControlNetModel.from_pretrained( | |
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16 | |
) | |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16 | |
) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload() | |
pipe.enable_xformers_memory_efficient_attention() | |
return pipe | |
def process( | |
self, | |
images: List[Image.Image], | |
prompts: List[str], | |
mask_images: List[Image.Image], | |
control_images: Optional[List[Image.Image]] = None, | |
negative_prompt: Optional[str] = None, | |
n_outputs: Optional[int] = 1, | |
num_inference_steps: Optional[int] = 30, | |
strength: Optional[float] = 1.0, | |
guidance_scale: Optional[float] = 7.5, | |
eta: Optional[float] = 0.0, | |
) -> List[List[Image.Image]]: | |
"""Inpaint images based on `prompts` using `control_images` and `mask_images`. | |
Args: | |
images (List[Image.Image]): Input images. | |
prompts (List[str]): List of prompts. | |
mask_images (List[Image.Image]): List of mask images. | |
control_images (Optional[List[Image.Image]], optional): List of control images. Defaults to None. | |
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None. | |
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1. | |
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30. | |
Returns: | |
List[List[Image.Image]] | |
""" | |
if control_images is None: | |
control_images = self.generate_control_images(images) | |
assert len(prompts) == len( | |
control_images | |
), "Number of prompts and input images must be equal." | |
if n_outputs > 1: | |
prompts = self._repeat(prompts, n=n_outputs) | |
images = self._repeat(images, n=n_outputs) | |
control_images = self._repeat(control_images, n=n_outputs) | |
mask_images = self._repeat(mask_images, n=n_outputs) | |
generator = [ | |
torch.Generator(device="cuda").manual_seed(int(i)) | |
for i in np.random.randint(max(len(prompts), 16), size=len(prompts)) | |
] | |
output = self.pipe( | |
prompts, | |
image=images, | |
control_image=control_images, | |
mask_image=mask_images, | |
negative_prompt=[negative_prompt] * len(prompts), | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
) | |
output_images = [ | |
output.images[idx * n_outputs : (idx + 1) * n_outputs] | |
for idx in range(len(images) // n_outputs) | |
] | |
return { | |
"output_images": output_images, | |
"control_images": control_images, | |
"mask_images": mask_images, | |
} | |