jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
import wandb
from wandb.util import get_module
if TYPE_CHECKING:
np_array = get_module("numpy.array")
torch_float_tensor = get_module("torch.FloatTensor")
def chunkify(input_list, chunk_size) -> List:
chunk_size = max(1, chunk_size)
return [
input_list[i : i + chunk_size] for i in range(0, len(input_list), chunk_size)
]
def get_updated_kwargs(
pipeline: Any, args: Sequence[Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
pipeline_call_parameters = list(
inspect.signature(pipeline.__call__).parameters.items()
)
for idx, arg in enumerate(args):
kwargs[pipeline_call_parameters[idx][0]] = arg
for pipeline_parameter in pipeline_call_parameters:
if pipeline_parameter[0] not in kwargs:
kwargs[pipeline_parameter[0]] = pipeline_parameter[1].default
if "generator" in kwargs:
generator = kwargs["generator"]
kwargs["generator"] = (
{
"seed": generator.initial_seed(),
"device": generator.device,
"random_state": generator.get_state().cpu().numpy().tolist(),
}
if generator is not None
else None
)
if "ip_adapter_image" in kwargs:
if kwargs["ip_adapter_image"] is not None:
wandb.log({"IP-Adapter-Image": wandb.Image(kwargs["ip_adapter_image"])})
return kwargs
def postprocess_pils_to_np(image: List) -> "np_array":
np = get_module(
"numpy",
required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
)
return np.stack(
[np.transpose(np.array(img).astype("uint8"), axes=(2, 0, 1)) for img in image],
axis=0,
)
def postprocess_np_arrays_for_video(
images: List["np_array"], normalize: Optional[bool] = False
) -> "np_array":
np = get_module(
"numpy",
required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
)
images = [(img * 255).astype("uint8") for img in images] if normalize else images
return np.transpose(np.stack((images), axis=0), axes=(0, 3, 1, 2))
def decode_sdxl_t2i_latents(pipeline: Any, latents: "torch_float_tensor") -> List:
"""Decode latents generated by [`diffusers.StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#stable-diffusion-xl).
Args:
pipeline: (diffusers.DiffusionPipeline) The Diffusion Pipeline from
[`diffusers`](https://huggingface.co/docs/diffusers).
latents (torch.FloatTensor): The generated latents.
Returns:
List of `PIL` images corresponding to the generated latents.
"""
torch = get_module(
"torch",
required="Please ensure PyTorch is installed. You can check out https://pytorch.org/get-started/locally/#start-locally for installation instructions.",
)
with torch.no_grad():
needs_upcasting = (
pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
)
if needs_upcasting:
pipeline.upcast_vae()
latents = latents.to(
next(iter(pipeline.vae.post_quant_conv.parameters())).dtype
)
images = pipeline.vae.decode(
latents / pipeline.vae.config.scaling_factor, return_dict=False
)[0]
if needs_upcasting:
pipeline.vae.to(dtype=torch.float16)
if pipeline.watermark is not None:
images = pipeline.watermark.apply_watermark(images)
images = pipeline.image_processor.postprocess(images, output_type="pil")
pipeline.maybe_free_model_hooks()
return images