Spaces:
Build error
Build error
File size: 4,367 Bytes
5b2ab1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from typing import List, Optional, Tuple, Union
from PIL import Image
import numpy as np
import torch
from diffusers import (
ControlNetModel,
StableDiffusionControlNetInpaintPipeline,
UniPCMultistepScheduler,
)
from .controlnet import StableDiffusionControlNet, MODEL_DICT
class StableDiffusionControlNetInpaint(StableDiffusionControlNet):
"""StableDiffusion with ControlNet model for inpainting images based on prompts.
Args:
control_model_name (str):
Name of the controlnet processor.
sd_model_name (str):
Name of the StableDiffusion model.
"""
def __init__(
self,
control_model_name: str,
sd_model_name: Optional[str] = "runwayml/stable-diffusion-inpainting",
) -> None:
super().__init__(
control_model_name=control_model_name,
sd_model_name=sd_model_name,
)
def create_pipe(
self, sd_model_name: str, control_model_name: str
) -> StableDiffusionControlNetInpaintPipeline:
"""Create a StableDiffusionControlNetInpaintPipeline.
Args:
sd_model_name (str): StableDiffusion model name.
control_model_name (str): Name of the ControlNet module.
Returns:
StableDiffusionControlNetInpaintPipeline
"""
controlnet = ControlNetModel.from_pretrained(
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
return pipe
def process(
self,
images: List[Image.Image],
prompts: List[str],
mask_images: List[Image.Image],
control_images: Optional[List[Image.Image]] = None,
negative_prompt: Optional[str] = None,
n_outputs: Optional[int] = 1,
num_inference_steps: Optional[int] = 30,
strength: Optional[float] = 1.0,
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
) -> List[List[Image.Image]]:
"""Inpaint images based on `prompts` using `control_images` and `mask_images`.
Args:
images (List[Image.Image]): Input images.
prompts (List[str]): List of prompts.
mask_images (List[Image.Image]): List of mask images.
control_images (Optional[List[Image.Image]], optional): List of control images. Defaults to None.
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.
Returns:
List[List[Image.Image]]
"""
if control_images is None:
control_images = self.generate_control_images(images)
assert len(prompts) == len(
control_images
), "Number of prompts and input images must be equal."
if n_outputs > 1:
prompts = self._repeat(prompts, n=n_outputs)
images = self._repeat(images, n=n_outputs)
control_images = self._repeat(control_images, n=n_outputs)
mask_images = self._repeat(mask_images, n=n_outputs)
generator = [
torch.Generator(device="cuda").manual_seed(int(i))
for i in np.random.randint(max(len(prompts), 16), size=len(prompts))
]
output = self.pipe(
prompts,
image=images,
control_image=control_images,
mask_image=mask_images,
negative_prompt=[negative_prompt] * len(prompts),
num_inference_steps=num_inference_steps,
generator=generator,
)
output_images = [
output.images[idx * n_outputs : (idx + 1) * n_outputs]
for idx in range(len(images) // n_outputs)
]
return {
"output_images": output_images,
"control_images": control_images,
"mask_images": mask_images,
}
|