naderasadi's picture
Initial commit
5b2ab1c
from typing import List, Optional, Tuple, Union
from PIL import Image
import numpy as np
import torch
from diffusers import (
ControlNetModel,
StableDiffusionControlNetInpaintPipeline,
UniPCMultistepScheduler,
)
from .controlnet import StableDiffusionControlNet, MODEL_DICT
class StableDiffusionControlNetInpaint(StableDiffusionControlNet):
"""StableDiffusion with ControlNet model for inpainting images based on prompts.
Args:
control_model_name (str):
Name of the controlnet processor.
sd_model_name (str):
Name of the StableDiffusion model.
"""
def __init__(
self,
control_model_name: str,
sd_model_name: Optional[str] = "runwayml/stable-diffusion-inpainting",
) -> None:
super().__init__(
control_model_name=control_model_name,
sd_model_name=sd_model_name,
)
def create_pipe(
self, sd_model_name: str, control_model_name: str
) -> StableDiffusionControlNetInpaintPipeline:
"""Create a StableDiffusionControlNetInpaintPipeline.
Args:
sd_model_name (str): StableDiffusion model name.
control_model_name (str): Name of the ControlNet module.
Returns:
StableDiffusionControlNetInpaintPipeline
"""
controlnet = ControlNetModel.from_pretrained(
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
return pipe
def process(
self,
images: List[Image.Image],
prompts: List[str],
mask_images: List[Image.Image],
control_images: Optional[List[Image.Image]] = None,
negative_prompt: Optional[str] = None,
n_outputs: Optional[int] = 1,
num_inference_steps: Optional[int] = 30,
strength: Optional[float] = 1.0,
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
) -> List[List[Image.Image]]:
"""Inpaint images based on `prompts` using `control_images` and `mask_images`.
Args:
images (List[Image.Image]): Input images.
prompts (List[str]): List of prompts.
mask_images (List[Image.Image]): List of mask images.
control_images (Optional[List[Image.Image]], optional): List of control images. Defaults to None.
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.
Returns:
List[List[Image.Image]]
"""
if control_images is None:
control_images = self.generate_control_images(images)
assert len(prompts) == len(
control_images
), "Number of prompts and input images must be equal."
if n_outputs > 1:
prompts = self._repeat(prompts, n=n_outputs)
images = self._repeat(images, n=n_outputs)
control_images = self._repeat(control_images, n=n_outputs)
mask_images = self._repeat(mask_images, n=n_outputs)
generator = [
torch.Generator(device="cuda").manual_seed(int(i))
for i in np.random.randint(max(len(prompts), 16), size=len(prompts))
]
output = self.pipe(
prompts,
image=images,
control_image=control_images,
mask_image=mask_images,
negative_prompt=[negative_prompt] * len(prompts),
num_inference_steps=num_inference_steps,
generator=generator,
)
output_images = [
output.images[idx * n_outputs : (idx + 1) * n_outputs]
for idx in range(len(images) // n_outputs)
]
return {
"output_images": output_images,
"control_images": control_images,
"mask_images": mask_images,
}