naderasadi's picture
Initial commit
5b2ab1c
from typing import Any, List, Optional, Tuple, Union
import itertools
from PIL import Image
import numpy as np
import torch
from controlnet_aux import MLSDdetector, PidiNetDetector, HEDdetector
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
MODEL_DICT = {
"mlsd": {
"name": "lllyasviel/Annotators",
"detector": MLSDdetector,
"model": "lllyasviel/control_v11p_sd15_mlsd",
},
"soft_edge": {
"name": "lllyasviel/Annotators",
"detector": PidiNetDetector,
"model": "lllyasviel/control_v11p_sd15_softedge",
},
"hed": {
"name": "lllyasviel/Annotators",
"detector": HEDdetector,
"model": "lllyasviel/sd-controlnet-hed",
},
"scribble": {
"name": "lllyasviel/Annotators",
"detector": HEDdetector,
"model": "lllyasviel/control_v11p_sd15_scribble",
},
}
class StableDiffusionControlNet:
"""ControlNet pipeline for generating images from prompts.
Args:
control_model_name (str):
Name of the controlnet processor.
sd_model_name (str):
Name of the StableDiffusion model.
"""
def __init__(
self,
control_model_name: str,
sd_model_name: Optional[str] = "runwayml/stable-diffusion-v1-5",
) -> None:
self.processor = MODEL_DICT[control_model_name]["detector"].from_pretrained(
MODEL_DICT[control_model_name]["name"]
)
self.pipe = self.create_pipe(
sd_model_name=sd_model_name, control_model_name=control_model_name
)
def _repeat(self, items: List[Any], n: int) -> List[Any]:
"""Repeat items in a list n times.
Args:
items (List[Any]): List of items to be repeated.
n (int): Number of repetitions.
Returns:
List[Any]: List of repeated items.
"""
return list(
itertools.chain.from_iterable(itertools.repeat(item, n) for item in items)
)
def generate_control_images(self, images: List[Image.Image]) -> List[Image.Image]:
"""Generate control images from input images.
Args:
images (List[Image.Image]): Input images.
Returns:
List[Image.Image]: Control images.
"""
return [self.processor(image) for image in images]
def create_pipe(
self, sd_model_name: str, control_model_name: str
) -> StableDiffusionControlNetPipeline:
"""Create a StableDiffusionControlNetPipeline.
Args:
sd_model_name (str): StableDiffusion model name.
control_model_name (str): Name of the ControlNet module.
Returns:
StableDiffusionControlNetPipeline
"""
controlnet = ControlNetModel.from_pretrained(
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
return pipe
def process(
self,
images: List[Image.Image],
prompts: List[str],
negative_prompt: Optional[str] = None,
n_outputs: Optional[int] = 1,
num_inference_steps: Optional[int] = 30,
) -> List[List[Image.Image]]:
"""Generate images from `prompts` using `control_images` and `negative_prompt`.
Args:
images (List[Image.Image]): Input images.
prompts (List[str]): List of prompts.
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.
Returns:
List[List[Image.Image]]
"""
control_images = self.generate_control_images(images)
assert len(prompts) == len(
control_images
), "Number of prompts and input images must be equal."
if n_outputs > 1:
prompts = self._repeat(prompts, n=n_outputs)
control_images = self._repeat(control_images, n=n_outputs)
generator = [
torch.Generator(device="cuda").manual_seed(int(i))
for i in np.random.randint(len(prompts), size=len(prompts))
]
output = self.pipe(
prompts,
image=control_images,
negative_prompt=[negative_prompt] * len(prompts),
num_inference_steps=num_inference_steps,
generator=generator,
)
output_images = [
output.images[idx * n_outputs : (idx + 1) * n_outputs]
for idx in range(len(images))
]
return output_images