from typing import Any, List, Optional, Tuple, Union import itertools from PIL import Image import numpy as np import torch from controlnet_aux import MLSDdetector, PidiNetDetector, HEDdetector from diffusers import ( ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) MODEL_DICT = { "mlsd": { "name": "lllyasviel/Annotators", "detector": MLSDdetector, "model": "lllyasviel/control_v11p_sd15_mlsd", }, "soft_edge": { "name": "lllyasviel/Annotators", "detector": PidiNetDetector, "model": "lllyasviel/control_v11p_sd15_softedge", }, "hed": { "name": "lllyasviel/Annotators", "detector": HEDdetector, "model": "lllyasviel/sd-controlnet-hed", }, "scribble": { "name": "lllyasviel/Annotators", "detector": HEDdetector, "model": "lllyasviel/control_v11p_sd15_scribble", }, } class StableDiffusionControlNet: """ControlNet pipeline for generating images from prompts. Args: control_model_name (str): Name of the controlnet processor. sd_model_name (str): Name of the StableDiffusion model. """ def __init__( self, control_model_name: str, sd_model_name: Optional[str] = "runwayml/stable-diffusion-v1-5", ) -> None: self.processor = MODEL_DICT[control_model_name]["detector"].from_pretrained( MODEL_DICT[control_model_name]["name"] ) self.pipe = self.create_pipe( sd_model_name=sd_model_name, control_model_name=control_model_name ) def _repeat(self, items: List[Any], n: int) -> List[Any]: """Repeat items in a list n times. Args: items (List[Any]): List of items to be repeated. n (int): Number of repetitions. Returns: List[Any]: List of repeated items. """ return list( itertools.chain.from_iterable(itertools.repeat(item, n) for item in items) ) def generate_control_images(self, images: List[Image.Image]) -> List[Image.Image]: """Generate control images from input images. Args: images (List[Image.Image]): Input images. Returns: List[Image.Image]: Control images. """ return [self.processor(image) for image in images] def create_pipe( self, sd_model_name: str, control_model_name: str ) -> StableDiffusionControlNetPipeline: """Create a StableDiffusionControlNetPipeline. Args: sd_model_name (str): StableDiffusion model name. control_model_name (str): Name of the ControlNet module. Returns: StableDiffusionControlNetPipeline """ controlnet = ControlNetModel.from_pretrained( MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16 ) pipe = StableDiffusionControlNetPipeline.from_pretrained( sd_model_name, controlnet=controlnet, torch_dtype=torch.float16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() return pipe def process( self, images: List[Image.Image], prompts: List[str], negative_prompt: Optional[str] = None, n_outputs: Optional[int] = 1, num_inference_steps: Optional[int] = 30, ) -> List[List[Image.Image]]: """Generate images from `prompts` using `control_images` and `negative_prompt`. Args: images (List[Image.Image]): Input images. prompts (List[str]): List of prompts. negative_prompt (Optional[str], optional): Negative prompt. Defaults to None. n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1. num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30. Returns: List[List[Image.Image]] """ control_images = self.generate_control_images(images) assert len(prompts) == len( control_images ), "Number of prompts and input images must be equal." if n_outputs > 1: prompts = self._repeat(prompts, n=n_outputs) control_images = self._repeat(control_images, n=n_outputs) generator = [ torch.Generator(device="cuda").manual_seed(int(i)) for i in np.random.randint(len(prompts), size=len(prompts)) ] output = self.pipe( prompts, image=control_images, negative_prompt=[negative_prompt] * len(prompts), num_inference_steps=num_inference_steps, generator=generator, ) output_images = [ output.images[idx * n_outputs : (idx + 1) * n_outputs] for idx in range(len(images)) ] return output_images