|
import inspect |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence |
|
|
|
import wandb |
|
from wandb.util import get_module |
|
|
|
if TYPE_CHECKING: |
|
np_array = get_module("numpy.array") |
|
torch_float_tensor = get_module("torch.FloatTensor") |
|
|
|
|
|
def chunkify(input_list, chunk_size) -> List: |
|
chunk_size = max(1, chunk_size) |
|
return [ |
|
input_list[i : i + chunk_size] for i in range(0, len(input_list), chunk_size) |
|
] |
|
|
|
|
|
def get_updated_kwargs( |
|
pipeline: Any, args: Sequence[Any], kwargs: Dict[str, Any] |
|
) -> Dict[str, Any]: |
|
pipeline_call_parameters = list( |
|
inspect.signature(pipeline.__call__).parameters.items() |
|
) |
|
for idx, arg in enumerate(args): |
|
kwargs[pipeline_call_parameters[idx][0]] = arg |
|
for pipeline_parameter in pipeline_call_parameters: |
|
if pipeline_parameter[0] not in kwargs: |
|
kwargs[pipeline_parameter[0]] = pipeline_parameter[1].default |
|
if "generator" in kwargs: |
|
generator = kwargs["generator"] |
|
kwargs["generator"] = ( |
|
{ |
|
"seed": generator.initial_seed(), |
|
"device": generator.device, |
|
"random_state": generator.get_state().cpu().numpy().tolist(), |
|
} |
|
if generator is not None |
|
else None |
|
) |
|
if "ip_adapter_image" in kwargs: |
|
if kwargs["ip_adapter_image"] is not None: |
|
wandb.log({"IP-Adapter-Image": wandb.Image(kwargs["ip_adapter_image"])}) |
|
return kwargs |
|
|
|
|
|
def postprocess_pils_to_np(image: List) -> "np_array": |
|
np = get_module( |
|
"numpy", |
|
required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.", |
|
) |
|
return np.stack( |
|
[np.transpose(np.array(img).astype("uint8"), axes=(2, 0, 1)) for img in image], |
|
axis=0, |
|
) |
|
|
|
|
|
def postprocess_np_arrays_for_video( |
|
images: List["np_array"], normalize: Optional[bool] = False |
|
) -> "np_array": |
|
np = get_module( |
|
"numpy", |
|
required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.", |
|
) |
|
images = [(img * 255).astype("uint8") for img in images] if normalize else images |
|
return np.transpose(np.stack((images), axis=0), axes=(0, 3, 1, 2)) |
|
|
|
|
|
def decode_sdxl_t2i_latents(pipeline: Any, latents: "torch_float_tensor") -> List: |
|
"""Decode latents generated by [`diffusers.StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#stable-diffusion-xl). |
|
|
|
Args: |
|
pipeline: (diffusers.DiffusionPipeline) The Diffusion Pipeline from |
|
[`diffusers`](https://huggingface.co/docs/diffusers). |
|
latents (torch.FloatTensor): The generated latents. |
|
|
|
Returns: |
|
List of `PIL` images corresponding to the generated latents. |
|
""" |
|
torch = get_module( |
|
"torch", |
|
required="Please ensure PyTorch is installed. You can check out https://pytorch.org/get-started/locally/#start-locally for installation instructions.", |
|
) |
|
with torch.no_grad(): |
|
needs_upcasting = ( |
|
pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast |
|
) |
|
if needs_upcasting: |
|
pipeline.upcast_vae() |
|
latents = latents.to( |
|
next(iter(pipeline.vae.post_quant_conv.parameters())).dtype |
|
) |
|
images = pipeline.vae.decode( |
|
latents / pipeline.vae.config.scaling_factor, return_dict=False |
|
)[0] |
|
if needs_upcasting: |
|
pipeline.vae.to(dtype=torch.float16) |
|
if pipeline.watermark is not None: |
|
images = pipeline.watermark.apply_watermark(images) |
|
images = pipeline.image_processor.postprocess(images, output_type="pil") |
|
pipeline.maybe_free_model_hooks() |
|
return images |
|
|