import glob
import logging
import os
import re
from functools import partial
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Dict, List, Union

import numpy as np
import torch
from controlnet_aux import LineartAnimeDetector
from controlnet_aux.processor import MODELS
from controlnet_aux.processor import Processor as ControlnetPreProcessor
from controlnet_aux.util import HWC3, ade_palette
from controlnet_aux.util import resize_image as aux_resize_image
from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline,
                       EulerDiscreteScheduler,
                       StableDiffusionControlNetImg2ImgPipeline,
                       StableDiffusionPipeline, StableDiffusionXLPipeline)
from PIL import Image
from torchvision.datasets.folder import IMG_EXTENSIONS
from tqdm.rich import tqdm
from transformers import (AutoImageProcessor, CLIPImageProcessor,
                          CLIPTextConfig, CLIPTextModel,
                          CLIPTextModelWithProjection, CLIPTokenizer,
                          UperNetForSemanticSegmentation)

from animatediff import get_dir
from animatediff.dwpose import DWposeDetector
from animatediff.models.clip import CLIPSkipTextModel
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines import AnimationPipeline, load_text_embeddings
from animatediff.pipelines.lora import load_lcm_lora, load_lora_map
from animatediff.pipelines.pipeline_controlnet_img2img_reference import \
    StableDiffusionControlNetImg2ImgReferencePipeline
from animatediff.schedulers import DiffusionScheduler, get_scheduler
from animatediff.settings import InferenceConfig, ModelConfig
from animatediff.utils.control_net_lllite import (ControlNetLLLite,
                                                  load_controlnet_lllite)
from animatediff.utils.convert_from_ckpt import convert_ldm_vae_checkpoint
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
from animatediff.utils.model import (ensure_motion_modules,
                                     get_checkpoint_weights,
                                     get_checkpoint_weights_sdxl)
from animatediff.utils.util import (get_resized_image, get_resized_image2,
                                    get_resized_images,
                                    get_tensor_interpolation_method,
                                    prepare_dwpose, prepare_extra_controlnet,
                                    prepare_ip_adapter,
                                    prepare_ip_adapter_sdxl, prepare_lcm_lora,
                                    prepare_lllite, prepare_motion_module,
                                    save_frames, save_imgs, save_video)

controlnet_address_table={
    "controlnet_tile" : ['lllyasviel/control_v11f1e_sd15_tile'],
    "controlnet_lineart_anime" : ['lllyasviel/control_v11p_sd15s2_lineart_anime'],
    "controlnet_ip2p" : ['lllyasviel/control_v11e_sd15_ip2p'],
    "controlnet_openpose" : ['lllyasviel/control_v11p_sd15_openpose'],
    "controlnet_softedge" : ['lllyasviel/control_v11p_sd15_softedge'],
    "controlnet_shuffle" : ['lllyasviel/control_v11e_sd15_shuffle'],
    "controlnet_depth" : ['lllyasviel/control_v11f1p_sd15_depth'],
    "controlnet_canny" : ['lllyasviel/control_v11p_sd15_canny'],
    "controlnet_inpaint" : ['lllyasviel/control_v11p_sd15_inpaint'],
    "controlnet_lineart" : ['lllyasviel/control_v11p_sd15_lineart'],
    "controlnet_mlsd" : ['lllyasviel/control_v11p_sd15_mlsd'],
    "controlnet_normalbae" : ['lllyasviel/control_v11p_sd15_normalbae'],
    "controlnet_scribble" : ['lllyasviel/control_v11p_sd15_scribble'],
    "controlnet_seg" : ['lllyasviel/control_v11p_sd15_seg'],
    "qr_code_monster_v1" : ['monster-labs/control_v1p_sd15_qrcode_monster'],
    "qr_code_monster_v2" : ['monster-labs/control_v1p_sd15_qrcode_monster', 'v2'],
    "controlnet_mediapipe_face" : ['CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"],
    "animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"]
}

# Edit this table if you want to change to another controlnet checkpoint
controlnet_address_table_sdxl={
#    "controlnet_openpose" : ['thibaud/controlnet-openpose-sdxl-1.0'],
#    "controlnet_softedge" : ['SargeZT/controlnet-sd-xl-1.0-softedge-dexined'],
#    "controlnet_depth" : ['diffusers/controlnet-depth-sdxl-1.0-small'],
#    "controlnet_canny" : ['diffusers/controlnet-canny-sdxl-1.0-small'],
#    "controlnet_seg" : ['SargeZT/sdxl-controlnet-seg'],
    "qr_code_monster_v1" : ['monster-labs/control_v1p_sdxl_qrcode_monster'],
}

# Edit this table if you want to change to another lllite checkpoint
lllite_address_table_sdxl={
    "controlnet_tile" : ['models/lllite/bdsqlsz_controlllite_xl_tile_anime_β.safetensors'],
    "controlnet_lineart_anime" : ['models/lllite/bdsqlsz_controlllite_xl_lineart_anime_denoise.safetensors'],
#    "controlnet_ip2p" : ('lllyasviel/control_v11e_sd15_ip2p'),
    "controlnet_openpose" : ['models/lllite/bdsqlsz_controlllite_xl_dw_openpose.safetensors'],
#    "controlnet_openpose" : ['models/lllite/controllllite_v01032064e_sdxl_pose_anime.safetensors'],
    "controlnet_softedge" : ['models/lllite/bdsqlsz_controlllite_xl_softedge.safetensors'],
    "controlnet_shuffle" : ['models/lllite/bdsqlsz_controlllite_xl_t2i-adapter_color_shuffle.safetensors'],
    "controlnet_depth" : ['models/lllite/bdsqlsz_controlllite_xl_depth.safetensors'],
    "controlnet_canny" : ['models/lllite/bdsqlsz_controlllite_xl_canny.safetensors'],
#    "controlnet_canny" : ['models/lllite/controllllite_v01032064e_sdxl_canny.safetensors'],
#    "controlnet_inpaint" : ('lllyasviel/control_v11p_sd15_inpaint'),
#    "controlnet_lineart" : ('lllyasviel/control_v11p_sd15_lineart'),
    "controlnet_mlsd" : ['models/lllite/bdsqlsz_controlllite_xl_mlsd_V2.safetensors'],
    "controlnet_normalbae" : ['models/lllite/bdsqlsz_controlllite_xl_normal.safetensors'],
    "controlnet_scribble" : ['models/lllite/bdsqlsz_controlllite_xl_sketch.safetensors'],
    "controlnet_seg" : ['models/lllite/bdsqlsz_controlllite_xl_segment_animeface_V2.safetensors'],
#    "qr_code_monster_v1" : ['monster-labs/control_v1p_sdxl_qrcode_monster'],
#    "qr_code_monster_v2" : ('monster-labs/control_v1p_sd15_qrcode_monster', 'v2'),
#    "controlnet_mediapipe_face" : ('CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"),
}





try:
    import onnxruntime
    onnxruntime_installed = True
except:
    onnxruntime_installed = False




logger = logging.getLogger(__name__)

data_dir = get_dir("data")
default_base_path = data_dir.joinpath("models/huggingface/stable-diffusion-v1-5")

re_clean_prompt = re.compile(r"[^\w\-, ]")

controlnet_preprocessor = {}

def load_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True):
    from safetensors.torch import load_file

    from animatediff.utils.lora_diffusers import (LoRANetwork,
                                                  create_network_from_weights)

    sd = load_file(lora_path)

    print(f"create LoRA network")
    lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff)
    print(f"load LoRA network weights")
    lora_network.load_state_dict(sd, False)
    #lora_network.merge_to(alpha)
    lora_network.apply_to(alpha)
    return lora_network

def load_safetensors_lora2(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True):
    from safetensors.torch import load_file

    from animatediff.utils.lora_diffusers import (LoRANetwork,
                                                  create_network_from_weights)

    sd = load_file(lora_path)

    print(f"create LoRA network")
    lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff)
    print(f"load LoRA network weights")
    lora_network.load_state_dict(sd, False)
    lora_network.merge_to(alpha)


def load_tensors(path:Path,framework="pt",device="cpu"):
    tensors = {}
    if path.suffix == ".safetensors":
        from safetensors import safe_open
        with safe_open(path, framework=framework, device=device) as f:
            for k in f.keys():
                tensors[k] = f.get_tensor(k) # loads the full tensor given a key
    else:
        from torch import load
        tensors = load(path, device)
        if "state_dict" in tensors:
            tensors = tensors["state_dict"]
    return tensors

def load_motion_lora(unet, lora_path:Path, alpha=1.0):
    state_dict = load_tensors(lora_path)

    # directly update weight in diffusers model
    for key in state_dict:
        # only process lora down key
        if "up." in key: continue

        up_key    = key.replace(".down.", ".up.")
        model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
        model_key = model_key.replace("to_out.", "to_out.0.")
        layer_infos = model_key.split(".")[:-1]

        curr_layer = unet
        try:
            while len(layer_infos) > 0:
                temp_name = layer_infos.pop(0)
                curr_layer = curr_layer.__getattr__(temp_name)
        except:
            logger.info(f"{model_key} not found")
            continue


        weight_down = state_dict[key]
        weight_up   = state_dict[up_key]
        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)


class SegPreProcessor:

    def __init__(self):
        self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
        self.processor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")

    def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):

        input_array = np.array(input_image, dtype=np.uint8)
        input_array = HWC3(input_array)
        input_array = aux_resize_image(input_array, detect_resolution)

        pixel_values = self.image_processor(input_array, return_tensors="pt").pixel_values

        with torch.no_grad():
            outputs = self.processor(pixel_values.to(self.processor.device))

        outputs.loss = outputs.loss.to("cpu") if outputs.loss is not None else outputs.loss
        outputs.logits = outputs.logits.to("cpu") if outputs.logits is not None else outputs.logits
        outputs.hidden_states = outputs.hidden_states.to("cpu") if outputs.hidden_states is not None else outputs.hidden_states
        outputs.attentions = outputs.attentions.to("cpu") if outputs.attentions is not None else outputs.attentions

        seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3

        for label, color in enumerate(ade_palette()):
            color_seg[seg == label, :] = color

        color_seg = color_seg.astype(np.uint8)
        color_seg = aux_resize_image(color_seg, image_resolution)
        color_seg = Image.fromarray(color_seg)

        return color_seg

class NullPreProcessor:
    def __call__(self, input_image, **kwargs):
        return input_image

class BlurPreProcessor:
    def __call__(self, input_image, sigma=5.0, **kwargs):
        import cv2

        input_array = np.array(input_image, dtype=np.uint8)
        input_array = HWC3(input_array)

        dst = cv2.GaussianBlur(input_array, (0, 0), sigma)

        return Image.fromarray(dst)

class TileResamplePreProcessor:

    def resize(self, input_image, resolution):
        import cv2

        H, W, C = input_image.shape
        H = float(H)
        W = float(W)
        k = float(resolution) / min(H, W)
        H *= k
        W *= k
        img = cv2.resize(input_image, (int(W), int(H)), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
        return img

    def __call__(self, input_image, down_sampling_rate = 1.0, **kwargs):

        input_array = np.array(input_image, dtype=np.uint8)
        input_array = HWC3(input_array)

        H, W, C = input_array.shape

        target_res = min(H,W) / down_sampling_rate

        dst = self.resize(input_array, target_res)

        return Image.fromarray(dst)



def is_valid_controlnet_type(type_str, is_sdxl):
    if not is_sdxl:
        return type_str in controlnet_address_table
    else:
        return (type_str in controlnet_address_table_sdxl) or (type_str in lllite_address_table_sdxl)

def load_controlnet_from_file(file_path, torch_dtype):
    from safetensors.torch import load_file

    prepare_extra_controlnet()

    file_path = Path(file_path)

    if file_path.exists() and file_path.is_file():
        if file_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
            controlnet_state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
        elif file_path.suffix.lower() == ".safetensors":
            controlnet_state_dict = load_file(file_path, device="cpu")
        else:
            raise RuntimeError(
                f"unknown file format for controlnet weights: {file_path.suffix}"
            )
    else:
        raise FileNotFoundError(f"no controlnet weights found in {file_path}")

    if file_path.parent.name == "animatediff_controlnet":
        model = ControlNetModel(cross_attention_dim=768)
    else:
        model = ControlNetModel()

    missing, _ = model.load_state_dict(controlnet_state_dict["state_dict"], strict=False)
    if len(missing) > 0:
        logger.info(f"ControlNetModel has missing keys: {missing}")

    return model.to(dtype=torch_dtype)

def create_controlnet_model(pipe, type_str, is_sdxl):
    if not is_sdxl:
        if type_str in controlnet_address_table:
            addr = controlnet_address_table[type_str]
            if addr[0] != None:
                if len(addr) == 1:
                    return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16)
                else:
                    return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16)
            else:
                return load_controlnet_from_file(addr[1],torch_dtype=torch.float16)
        else:
            raise ValueError(f"unknown controlnet type {type_str}")
    else:

        if type_str in controlnet_address_table_sdxl:
            addr = controlnet_address_table_sdxl[type_str]
            if len(addr) == 1:
                return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16)
            else:
                return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16)
        elif type_str in lllite_address_table_sdxl:
            addr = lllite_address_table_sdxl[type_str]
            model_path = data_dir.joinpath(addr[0])
            return load_controlnet_lllite(model_path, pipe, torch_dtype=torch.float16)
        else:
            raise ValueError(f"unknown controlnet type {type_str}")



default_preprocessor_table={
    "controlnet_lineart_anime":"lineart_anime",
    "controlnet_openpose": "openpose_full" if onnxruntime_installed==False else "dwpose",
    "controlnet_softedge":"softedge_hedsafe",
    "controlnet_shuffle":"shuffle",
    "controlnet_depth":"depth_midas",
    "controlnet_canny":"canny",
    "controlnet_lineart":"lineart_realistic",
    "controlnet_mlsd":"mlsd",
    "controlnet_normalbae":"normal_bae",
    "controlnet_scribble":"scribble_pidsafe",
    "controlnet_seg":"upernet_seg",
    "controlnet_mediapipe_face":"mediapipe_face",
    "qr_code_monster_v1":"depth_midas",
    "qr_code_monster_v2":"depth_midas",
}

def create_preprocessor_from_name(pre_type):
    if pre_type == "dwpose":
        prepare_dwpose()
        return DWposeDetector()
    elif pre_type == "upernet_seg":
        return SegPreProcessor()
    elif pre_type == "blur":
        return BlurPreProcessor()
    elif pre_type == "tile_resample":
        return TileResamplePreProcessor()
    elif pre_type == "none":
        return NullPreProcessor()
    elif pre_type in MODELS:
        return ControlnetPreProcessor(pre_type)
    else:
        raise ValueError(f"unknown controlnet preprocessor type {pre_type}")


def create_default_preprocessor(type_str):
    if type_str in default_preprocessor_table:
        pre_type = default_preprocessor_table[type_str]
    else:
        pre_type = "none"

    return create_preprocessor_from_name(pre_type)


def get_preprocessor(type_str, device_str, preprocessor_map):
    if type_str not in controlnet_preprocessor:
        if preprocessor_map:
            controlnet_preprocessor[type_str] = create_preprocessor_from_name(preprocessor_map["type"])

        if type_str not in controlnet_preprocessor:
            controlnet_preprocessor[type_str] = create_default_preprocessor(type_str)

        if hasattr(controlnet_preprocessor[type_str], "processor"):
            if hasattr(controlnet_preprocessor[type_str].processor, "to"):
                if device_str:
                    controlnet_preprocessor[type_str].processor.to(device_str)
        elif hasattr(controlnet_preprocessor[type_str], "to"):
            if device_str:
                controlnet_preprocessor[type_str].to(device_str)


    return controlnet_preprocessor[type_str]

def clear_controlnet_preprocessor(type_str = None):
    global controlnet_preprocessor
    if type_str == None:
        for t in controlnet_preprocessor:
            controlnet_preprocessor[t] = None
        controlnet_preprocessor={}
        torch.cuda.empty_cache()
    else:
        controlnet_preprocessor[type_str] = None
        torch.cuda.empty_cache()


def get_preprocessed_img(type_str, img, use_preprocessor, device_str, preprocessor_map):
    if use_preprocessor:
        param = {}
        if preprocessor_map:
            param = preprocessor_map["param"] if "param" in preprocessor_map else {}
        return get_preprocessor(type_str, device_str, preprocessor_map)(img, **param)
    else:
        return img


def create_pipeline_sdxl(
    base_model: Union[str, PathLike] = default_base_path,
    model_config: ModelConfig = ...,
    infer_config: InferenceConfig = ...,
    use_xformers: bool = True,
    video_length: int = 16,
    motion_module_path = ...,
):
    from animatediff.pipelines.sdxl_animation import AnimationPipeline
    from animatediff.sdxl_models.unet import UNet3DConditionModel

    logger.info("Loading tokenizer...")
    tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer")
    logger.info("Loading text encoder...")
    text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder", torch_dtype=torch.float16)
    logger.info("Loading VAE...")
    vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae")
    logger.info("Loading tokenizer two...")
    tokenizer_two = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer_2")
    logger.info("Loading text encoder two...")
    text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_model, subfolder="text_encoder_2", torch_dtype=torch.float16)


    logger.info("Loading UNet...")
    unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d(
        pretrained_model_path=base_model,
        motion_module_path=motion_module_path,
        subfolder="unet",
        unet_additional_kwargs=infer_config.unet_additional_kwargs,
    )

    # set up scheduler
    sched_kwargs = infer_config.noise_scheduler_kwargs
    scheduler = get_scheduler(model_config.scheduler, sched_kwargs)
    logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})')

    if model_config.gradual_latent_hires_fix_map:
        if "enable" in model_config.gradual_latent_hires_fix_map:
            if model_config.gradual_latent_hires_fix_map["enable"]:
                if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm):
                    logger.warn("gradual_latent_hires_fix enable")
                    logger.warn(f"{model_config.scheduler=}")
                    logger.warn("If you are forced to exit with an error, change to euler_a or lcm")



    # Load the checkpoint weights into the pipeline
    if model_config.path is not None:
        model_path = data_dir.joinpath(model_config.path)
        logger.info(f"Loading weights from {model_path}")
        if model_path.is_file():
            logger.debug("Loading from single checkpoint file")
            unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = get_checkpoint_weights_sdxl(model_path)
        elif model_path.is_dir():
            logger.debug("Loading from Diffusers model directory")
            temp_pipeline = StableDiffusionXLPipeline.from_pretrained(model_path)
            unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = (
                temp_pipeline.unet.state_dict(),
                temp_pipeline.text_encoder.state_dict(),
                temp_pipeline.text_encoder_2.state_dict(),
                temp_pipeline.vae.state_dict(),
            )
            del temp_pipeline
        else:
            raise FileNotFoundError(f"model_path {model_path} is not a file or directory")

        # Load into the unet, TE, and VAE
        logger.info("Merging weights into UNet...")
        _, unet_unex = unet.load_state_dict(unet_state_dict, strict=False)
        if len(unet_unex) > 0:
            raise ValueError(f"UNet has unexpected keys: {unet_unex}")
        tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False)
        if len(tenc_missing) > 0:
            raise ValueError(f"TextEncoder has missing keys: {tenc_missing}")
        tenc2_missing, _ = text_encoder_two.load_state_dict(tenc2_state_dict, strict=False)
        if len(tenc2_missing) > 0:
            raise ValueError(f"TextEncoder2 has missing keys: {tenc2_missing}")
        vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False)
        if len(vae_missing) > 0:
            raise ValueError(f"VAE has missing keys: {vae_missing}")
    else:
        logger.info("Using base model weights (no checkpoint/LoRA)")

    if model_config.vae_path:
        vae_path = data_dir.joinpath(model_config.vae_path)
        logger.info(f"Loading vae from {vae_path}")

        if vae_path.is_dir():
            vae = AutoencoderKL.from_pretrained(vae_path)
        else:
            tensors = load_tensors(vae_path)
            tensors = convert_ldm_vae_checkpoint(tensors, vae.config)
            vae.load_state_dict(tensors)

    unet.to(torch.float16)
    text_encoder.to(torch.float16)
    text_encoder_two.to(torch.float16)

    del unet_state_dict
    del tenc_state_dict
    del tenc2_state_dict
    del vae_state_dict

    # enable xformers if available
    if use_xformers:
        logger.info("Enabling xformers memory-efficient attention")
        unet.enable_xformers_memory_efficient_attention()

    # motion lora
    for l in model_config.motion_lora_map:
        lora_path = data_dir.joinpath(l)
        logger.info(f"loading motion lora {lora_path=}")
        if lora_path.is_file():
            logger.info(f"Loading motion lora {lora_path}")
            logger.info(f"alpha = {model_config.motion_lora_map[l]}")
            load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l])
        else:
            raise ValueError(f"{lora_path=} not found")

    logger.info("Creating AnimationPipeline...")
    pipeline = AnimationPipeline(
        vae=vae,
        text_encoder=text_encoder,
        text_encoder_2=text_encoder_two,
        tokenizer=tokenizer,
        tokenizer_2=tokenizer_two,
        unet=unet,
        scheduler=scheduler,
        controlnet_map=None,
    )

    del vae
    del text_encoder
    del text_encoder_two
    del tokenizer
    del tokenizer_two
    del unet

    torch.cuda.empty_cache()

    pipeline.lcm = None
    if model_config.lcm_map:
        if model_config.lcm_map["enable"]:
            prepare_lcm_lora()
            load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=True)

    load_lora_map(pipeline, model_config.lora_map, video_length, is_sdxl=True)

    pipeline.unet = pipeline.unet.half()
    pipeline.text_encoder = pipeline.text_encoder.half()
    pipeline.text_encoder_2 = pipeline.text_encoder_2.half()

    # Load TI embeddings
    pipeline.text_encoder = pipeline.text_encoder.to("cuda")
    pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda")

    load_text_embeddings(pipeline, is_sdxl=True)

    pipeline.text_encoder = pipeline.text_encoder.to("cpu")
    pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cpu")

    return pipeline


def create_pipeline(
    base_model: Union[str, PathLike] = default_base_path,
    model_config: ModelConfig = ...,
    infer_config: InferenceConfig = ...,
    use_xformers: bool = True,
    video_length: int = 16,
    is_sdxl:bool = False,
) -> DiffusionPipeline:
    """Create an AnimationPipeline from a pretrained model.
    Uses the base_model argument to load or download the pretrained reference pipeline model."""

    # make sure motion_module is a Path and exists
    logger.info("Checking motion module...")
    motion_module = data_dir.joinpath(model_config.motion_module)
    if not (motion_module.exists() and motion_module.is_file()):
        prepare_motion_module()
        if not (motion_module.exists() and motion_module.is_file()):
            # check for safetensors version
            motion_module = motion_module.with_suffix(".safetensors")
            if not (motion_module.exists() and motion_module.is_file()):
                # download from HuggingFace Hub if not found
                ensure_motion_modules()
            if not (motion_module.exists() and motion_module.is_file()):
                # this should never happen, but just in case...
                raise FileNotFoundError(f"Motion module {motion_module} does not exist or is not a file!")

    if is_sdxl:
        return create_pipeline_sdxl(
            base_model=base_model,
            model_config=model_config,
            infer_config=infer_config,
            use_xformers=use_xformers,
            video_length=video_length,
            motion_module_path=motion_module,
        )

    logger.info("Loading tokenizer...")
    tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer")
    logger.info("Loading text encoder...")
    text_encoder: CLIPSkipTextModel = CLIPSkipTextModel.from_pretrained(base_model, subfolder="text_encoder")
    logger.info("Loading VAE...")
    vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae")
    logger.info("Loading UNet...")
    unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d(
        pretrained_model_path=base_model,
        motion_module_path=motion_module,
        subfolder="unet",
        unet_additional_kwargs=infer_config.unet_additional_kwargs,
    )
    feature_extractor = CLIPImageProcessor.from_pretrained(base_model, subfolder="feature_extractor")

    # set up scheduler
    if model_config.gradual_latent_hires_fix_map:
        if "enable" in model_config.gradual_latent_hires_fix_map:
            if model_config.gradual_latent_hires_fix_map["enable"]:
                if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm):
                    logger.warn("gradual_latent_hires_fix enable")
                    logger.warn(f"{model_config.scheduler=}")
                    logger.warn("If you are forced to exit with an error, change to euler_a or lcm")

    sched_kwargs = infer_config.noise_scheduler_kwargs
    scheduler = get_scheduler(model_config.scheduler, sched_kwargs)
    logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})')

    # Load the checkpoint weights into the pipeline
    if model_config.path is not None:
        model_path = data_dir.joinpath(model_config.path)
        logger.info(f"Loading weights from {model_path}")
        if model_path.is_file():
            logger.debug("Loading from single checkpoint file")
            unet_state_dict, tenc_state_dict, vae_state_dict = get_checkpoint_weights(model_path)
        elif model_path.is_dir():
            logger.debug("Loading from Diffusers model directory")
            temp_pipeline = StableDiffusionPipeline.from_pretrained(model_path)
            unet_state_dict, tenc_state_dict, vae_state_dict = (
                temp_pipeline.unet.state_dict(),
                temp_pipeline.text_encoder.state_dict(),
                temp_pipeline.vae.state_dict(),
            )
            del temp_pipeline
        else:
            raise FileNotFoundError(f"model_path {model_path} is not a file or directory")

        # Load into the unet, TE, and VAE
        logger.info("Merging weights into UNet...")
        _, unet_unex = unet.load_state_dict(unet_state_dict, strict=False)
        if len(unet_unex) > 0:
            raise ValueError(f"UNet has unexpected keys: {unet_unex}")
        tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False)
        if len(tenc_missing) > 0:
            raise ValueError(f"TextEncoder has missing keys: {tenc_missing}")
        vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False)
        if len(vae_missing) > 0:
            raise ValueError(f"VAE has missing keys: {vae_missing}")
    else:
        logger.info("Using base model weights (no checkpoint/LoRA)")

    if model_config.vae_path:
        vae_path = data_dir.joinpath(model_config.vae_path)
        logger.info(f"Loading vae from {vae_path}")

        if vae_path.is_dir():
            vae = AutoencoderKL.from_pretrained(vae_path)
        else:
            tensors = load_tensors(vae_path)
            tensors = convert_ldm_vae_checkpoint(tensors, vae.config)
            vae.load_state_dict(tensors)


    # enable xformers if available
    if use_xformers:
        logger.info("Enabling xformers memory-efficient attention")
        unet.enable_xformers_memory_efficient_attention()

    if False:
        # lora
        for l in model_config.lora_map:
            lora_path = data_dir.joinpath(l)
            if lora_path.is_file():
                logger.info(f"Loading lora {lora_path}")
                logger.info(f"alpha = {model_config.lora_map[l]}")
                load_safetensors_lora(text_encoder, unet, lora_path, alpha=model_config.lora_map[l])

    # motion lora
    for l in model_config.motion_lora_map:
        lora_path = data_dir.joinpath(l)
        logger.info(f"loading motion lora {lora_path=}")
        if lora_path.is_file():
            logger.info(f"Loading motion lora {lora_path}")
            logger.info(f"alpha = {model_config.motion_lora_map[l]}")
            load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l])
        else:
            raise ValueError(f"{lora_path=} not found")

    logger.info("Creating AnimationPipeline...")
    pipeline = AnimationPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        feature_extractor=feature_extractor,
        controlnet_map=None,
    )

    pipeline.lcm = None
    if model_config.lcm_map:
        if model_config.lcm_map["enable"]:
            prepare_lcm_lora()
            load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=False)

    load_lora_map(pipeline, model_config.lora_map, video_length)

    # Load TI embeddings
    pipeline.unet = pipeline.unet.half()
    pipeline.text_encoder = pipeline.text_encoder.half()

    pipeline.text_encoder = pipeline.text_encoder.to("cuda")

    load_text_embeddings(pipeline)

    pipeline.text_encoder = pipeline.text_encoder.to("cpu")

    return pipeline

def load_controlnet_models(pipe: DiffusionPipeline, model_config: ModelConfig = ..., is_sdxl:bool = False):
    # controlnet

    if is_sdxl:
        prepare_lllite()

    controlnet_map={}
    if model_config.controlnet_map:
        c_image_dir = data_dir.joinpath( model_config.controlnet_map["input_image_dir"] )

        for c in model_config.controlnet_map:
            item = model_config.controlnet_map[c]
            if type(item) is dict:
                if item["enable"] == True:
                    if is_valid_controlnet_type(c, is_sdxl):
                        img_dir = c_image_dir.joinpath( c )
                        cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False))
                        if len(cond_imgs) > 0:
                            logger.info(f"loading {c=} model")
                            controlnet_map[c] = create_controlnet_model(pipe, c , is_sdxl)
                    else:
                        logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}")

    if not controlnet_map:
        controlnet_map = None

    pipe.controlnet_map = controlnet_map

def unload_controlnet_models(pipe: AnimationPipeline):
    from animatediff.utils.util import show_gpu

    if pipe.controlnet_map:
        for c in pipe.controlnet_map:
            controlnet = pipe.controlnet_map[c]
            if isinstance(controlnet, ControlNetLLLite):
                controlnet.unapply_to()
                del controlnet

    #show_gpu("before uload controlnet")
    pipe.controlnet_map = None
    torch.cuda.empty_cache()
    #show_gpu("after unload controlnet")


def create_us_pipeline(
    model_config: ModelConfig = ...,
    infer_config: InferenceConfig = ...,
    use_xformers: bool = True,
    use_controlnet_ref: bool = False,
    use_controlnet_tile: bool = False,
    use_controlnet_line_anime: bool = False,
    use_controlnet_ip2p: bool = False,
) -> DiffusionPipeline:

    # set up scheduler
    sched_kwargs = infer_config.noise_scheduler_kwargs
    scheduler = get_scheduler(model_config.scheduler, sched_kwargs)
    logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})')

    controlnet = []
    if use_controlnet_tile:
        controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11f1e_sd15_tile') )
    if use_controlnet_line_anime:
        controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15s2_lineart_anime') )
    if use_controlnet_ip2p:
        controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11e_sd15_ip2p') )

    if len(controlnet) == 1:
        controlnet = controlnet[0]
    elif len(controlnet) == 0:
        controlnet = None

    # Load the checkpoint weights into the pipeline
    pipeline:DiffusionPipeline

    if model_config.path is not None:
        model_path = data_dir.joinpath(model_config.path)
        logger.info(f"Loading weights from {model_path}")
        if model_path.is_file():

            def is_empty_dir(path):
                import os
                return len(os.listdir(path)) == 0

            save_path = data_dir.joinpath("models/huggingface/" + model_path.stem + "_" + str(model_path.stat().st_size))
            save_path.mkdir(exist_ok=True)
            if save_path.is_dir() and is_empty_dir(save_path):
                # StableDiffusionControlNetImg2ImgPipeline.from_single_file does not exist in version 18.2
                logger.debug("Loading from single checkpoint file")
                tmp_pipeline = StableDiffusionPipeline.from_single_file(
                    pretrained_model_link_or_path=str(model_path.absolute())
                )
                tmp_pipeline.save_pretrained(save_path, safe_serialization=True)
                del tmp_pipeline

            if use_controlnet_ref:
                pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained(
                    save_path,
                    controlnet=controlnet,
                    local_files_only=False,
                    load_safety_checker=False,
                    safety_checker=None,
                )
            else:
                pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
                    save_path,
                    controlnet=controlnet,
                    local_files_only=False,
                    load_safety_checker=False,
                    safety_checker=None,
                )

        elif model_path.is_dir():
            logger.debug("Loading from Diffusers model directory")
            if use_controlnet_ref:
                pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained(
                    model_path,
                    controlnet=controlnet,
                    local_files_only=True,
                    load_safety_checker=False,
                    safety_checker=None,
                )
            else:
                pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
                    model_path,
                    controlnet=controlnet,
                    local_files_only=True,
                    load_safety_checker=False,
                    safety_checker=None,
                )
        else:
            raise FileNotFoundError(f"model_path {model_path} is not a file or directory")
    else:
        raise ValueError("model_config.path is invalid")

    pipeline.scheduler = scheduler

    # enable xformers if available
    if use_xformers:
        logger.info("Enabling xformers memory-efficient attention")
        pipeline.enable_xformers_memory_efficient_attention()

    # lora
    for l in model_config.lora_map:
        lora_path = data_dir.joinpath(l)
        if lora_path.is_file():
            alpha = model_config.lora_map[l]
            if isinstance(alpha, dict):
                alpha = 0.75

            logger.info(f"Loading lora {lora_path}")
            logger.info(f"alpha = {alpha}")
            load_safetensors_lora2(pipeline.text_encoder, pipeline.unet, lora_path, alpha=alpha,is_animatediff=False)

    # Load TI embeddings
    pipeline.unet = pipeline.unet.half()
    pipeline.text_encoder = pipeline.text_encoder.half()

    pipeline.text_encoder = pipeline.text_encoder.to("cuda")

    load_text_embeddings(pipeline)

    pipeline.text_encoder = pipeline.text_encoder.to("cpu")

    return pipeline


def seed_everything(seed):
    import random

    import numpy as np
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed % (2**32))
    random.seed(seed)

def controlnet_preprocess(
        controlnet_map: Dict[str, Any] = None,
        width: int = 512,
        height: int = 512,
        duration: int = 16,
        out_dir: PathLike = ...,
        device_str:str=None,
        is_sdxl:bool = False,
        ):

    if not controlnet_map:
        return None, None, None, None

    out_dir = Path(out_dir)  # ensure out_dir is a Path

    # { 0 : { "type_str" : IMAGE, "type_str2" : IMAGE }  }
    controlnet_image_map={}

    controlnet_type_map={}

    c_image_dir = data_dir.joinpath( controlnet_map["input_image_dir"] )
    save_detectmap = controlnet_map["save_detectmap"] if "save_detectmap" in controlnet_map else True

    preprocess_on_gpu = controlnet_map["preprocess_on_gpu"] if "preprocess_on_gpu" in controlnet_map else True
    device_str = device_str if preprocess_on_gpu else None

    for c in controlnet_map:
        if c == "controlnet_ref":
            continue

        item = controlnet_map[c]

        processed = False

        if type(item) is dict:
            if item["enable"] == True:

                if is_valid_controlnet_type(c, is_sdxl):
                    preprocessor_map = item["preprocessor"] if "preprocessor" in item else {}

                    img_dir = c_image_dir.joinpath( c )
                    cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False))
                    if len(cond_imgs) > 0:

                        controlnet_type_map[c] = {
                            "controlnet_conditioning_scale" : item["controlnet_conditioning_scale"],
                            "control_guidance_start" : item["control_guidance_start"],
                            "control_guidance_end" : item["control_guidance_end"],
                            "control_scale_list" : item["control_scale_list"],
                            "guess_mode" : item["guess_mode"] if "guess_mode" in item else False,
                            "control_region_list" : item["control_region_list"] if "control_region_list" in item else []
                        }

                        use_preprocessor = item["use_preprocessor"] if "use_preprocessor" in item else True

                        for img_path in tqdm(cond_imgs, desc=f"Preprocessing images ({c})"):
                            frame_no = int(Path(img_path).stem)
                            if frame_no < duration:
                                if frame_no not in controlnet_image_map:
                                    controlnet_image_map[frame_no] = {}
                                controlnet_image_map[frame_no][c] = get_preprocessed_img( c, get_resized_image2(img_path, 512) , use_preprocessor, device_str, preprocessor_map)
                                processed = True
                else:
                    logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}")


        if save_detectmap and processed:
            det_dir = out_dir.joinpath(f"{0:02d}_detectmap/{c}")
            det_dir.mkdir(parents=True, exist_ok=True)
            for frame_no in tqdm(controlnet_image_map, desc=f"Saving Preprocessed images ({c})"):
                save_path = det_dir.joinpath(f"{frame_no:08d}.png")
                if c in controlnet_image_map[frame_no]:
                    controlnet_image_map[frame_no][c].save(save_path)

        clear_controlnet_preprocessor(c)

    clear_controlnet_preprocessor()

    controlnet_ref_map = None

    if "controlnet_ref" in controlnet_map:
        r = controlnet_map["controlnet_ref"]
        if r["enable"] == True:
            org_name = data_dir.joinpath( r["ref_image"]).stem
#            ref_image = get_resized_image( data_dir.joinpath( r["ref_image"] ) , width, height)
            ref_image = get_resized_image2( data_dir.joinpath( r["ref_image"] ) , 512)

            if ref_image is not None:
                controlnet_ref_map = {
                    "ref_image" : ref_image,
                    "style_fidelity" : r["style_fidelity"],
                    "attention_auto_machine_weight" : r["attention_auto_machine_weight"],
                    "gn_auto_machine_weight" : r["gn_auto_machine_weight"],
                    "reference_attn" : r["reference_attn"],
                    "reference_adain" : r["reference_adain"],
                    "scale_pattern" : r["scale_pattern"]
                }

                if save_detectmap:
                    det_dir = out_dir.joinpath(f"{0:02d}_detectmap/controlnet_ref")
                    det_dir.mkdir(parents=True, exist_ok=True)
                    save_path = det_dir.joinpath(f"{org_name}.png")
                    ref_image.save(save_path)

    controlnet_no_shrink = ["controlnet_tile","animatediff_controlnet","controlnet_canny","controlnet_normalbae","controlnet_depth","controlnet_lineart","controlnet_lineart_anime","controlnet_scribble","controlnet_seg","controlnet_softedge","controlnet_mlsd"]
    if "no_shrink_list" in controlnet_map:
        controlnet_no_shrink = controlnet_map["no_shrink_list"]

    return controlnet_image_map, controlnet_type_map, controlnet_ref_map, controlnet_no_shrink


def ip_adapter_preprocess(
        ip_adapter_config_map: Dict[str, Any] = None,
        width: int = 512,
        height: int = 512,
        duration: int = 16,
        out_dir: PathLike = ...,
        is_sdxl: bool = False,
        ):

    ip_adapter_map={}

    processed = False

    if ip_adapter_config_map:
        if ip_adapter_config_map["enable"] == True:
            resized_to_square = ip_adapter_config_map["resized_to_square"] if "resized_to_square" in ip_adapter_config_map else False
            image_dir = data_dir.joinpath( ip_adapter_config_map["input_image_dir"] )
            imgs = sorted(chain.from_iterable([glob.glob(os.path.join(image_dir, f"[0-9]*{ext}")) for ext in IMG_EXTENSIONS]))
            if len(imgs) > 0:
                prepare_ip_adapter_sdxl() if is_sdxl else prepare_ip_adapter()
                ip_adapter_map["images"] = {}
                for img_path in tqdm(imgs, desc=f"Preprocessing images (ip_adapter)"):
                    frame_no = int(Path(img_path).stem)
                    if frame_no < duration:
                        if resized_to_square:
                            ip_adapter_map["images"][frame_no] = get_resized_image(img_path, 256, 256)
                        else:
                            ip_adapter_map["images"][frame_no] = get_resized_image2(img_path, 256)
                        processed = True

            if processed:
                ip_adapter_config_map["prompt_fixed_ratio"] = max(min(1.0, ip_adapter_config_map["prompt_fixed_ratio"]),0)

                prompt_fixed_ratio = ip_adapter_config_map["prompt_fixed_ratio"]
                prompt_map = ip_adapter_map["images"]
                prompt_map = dict(sorted(prompt_map.items()))
                key_list = list(prompt_map.keys())
                for k0,k1 in zip(key_list,key_list[1:]+[duration]):
                    k05 = k0 + round((k1-k0) * prompt_fixed_ratio)
                    if k05 == k1:
                        k05 -= 1
                    if k05 != k0:
                        prompt_map[k05] = prompt_map[k0]
                ip_adapter_map["images"] = prompt_map

            if (ip_adapter_config_map["save_input_image"] == True) and processed:
                det_dir = out_dir.joinpath(f"{0:02d}_ip_adapter/")
                det_dir.mkdir(parents=True, exist_ok=True)
                for frame_no in tqdm(ip_adapter_map["images"], desc=f"Saving Preprocessed images (ip_adapter)"):
                    save_path = det_dir.joinpath(f"{frame_no:08d}.png")
                    ip_adapter_map["images"][frame_no].save(save_path)

    return ip_adapter_map if processed else None

def prompt_preprocess(
        prompt_config_map: Dict[str, Any],
        head_prompt: str,
        tail_prompt: str,
        prompt_fixed_ratio: float,
        video_length: int,
):
    prompt_map = {}
    for k in prompt_config_map.keys():
        if int(k) < video_length:
            pr = prompt_config_map[k]
            if head_prompt:
                pr = head_prompt + "," + pr
            if tail_prompt:
                pr = pr + "," + tail_prompt

            prompt_map[int(k)]=pr

    prompt_map = dict(sorted(prompt_map.items()))
    key_list = list(prompt_map.keys())
    for k0,k1 in zip(key_list,key_list[1:]+[video_length]):
        k05 = k0 + round((k1-k0) * prompt_fixed_ratio)
        if k05 == k1:
            k05 -= 1
        if k05 != k0:
            prompt_map[k05] = prompt_map[k0]

    return prompt_map


def region_preprocess(
        model_config: ModelConfig = ...,
        width: int = 512,
        height: int = 512,
        duration: int = 16,
        out_dir: PathLike = ...,
        is_init_img_exist: bool = False,
        is_sdxl:bool = False,
        ):

    is_bg_init_img = False
    if is_init_img_exist:
        if model_config.region_map:
            if "background" in model_config.region_map:
                is_bg_init_img = model_config.region_map["background"]["is_init_img"]


    region_condi_list=[]
    region2index={}

    condi_index = 0

    prev_ip_map = None

    if not is_bg_init_img:
        ip_map = ip_adapter_preprocess(
                model_config.ip_adapter_map,
                width,
                height,
                duration,
                out_dir,
                is_sdxl
            )

        if ip_map:
            prev_ip_map = ip_map

        condition_map = {
            "prompt_map": prompt_preprocess(
                model_config.prompt_map,
                model_config.head_prompt,
                model_config.tail_prompt,
                model_config.prompt_fixed_ratio,
                duration
            ),
            "ip_adapter_map": ip_map
        }

        region_condi_list.append( condition_map )

        bg_src = condi_index
        condi_index += 1
    else:
        bg_src = -1

    region_list=[
        {
            "mask_images": None,
            "src" : bg_src,
            "crop_generation_rate" : 0
        }
    ]
    region2index["background"]=bg_src

    if model_config.region_map:
        for r in model_config.region_map:
            if r == "background":
                continue
            if model_config.region_map[r]["enable"] != True:
                continue
            region_dir = out_dir.joinpath(f"region_{int(r):05d}/")
            region_dir.mkdir(parents=True, exist_ok=True)

            mask_map = mask_preprocess(
                model_config.region_map[r],
                width,
                height,
                duration,
                region_dir
            )

            if not mask_map:
                continue

            if model_config.region_map[r]["is_init_img"] == False:
                ip_map = ip_adapter_preprocess(
                        model_config.region_map[r]["condition"]["ip_adapter_map"],
                        width,
                        height,
                        duration,
                        region_dir,
                        is_sdxl
                    )

                if ip_map:
                    prev_ip_map = ip_map

                condition_map={
                    "prompt_map": prompt_preprocess(
                        model_config.region_map[r]["condition"]["prompt_map"],
                        model_config.region_map[r]["condition"]["head_prompt"],
                        model_config.region_map[r]["condition"]["tail_prompt"],
                        model_config.region_map[r]["condition"]["prompt_fixed_ratio"],
                        duration
                    ),
                    "ip_adapter_map": ip_map
                }

                region_condi_list.append( condition_map )

                src = condi_index
                condi_index += 1
            else:
                if is_init_img_exist == False:
                    logger.warn("'is_init_img' : true / BUT init_img is not exist -> ignore region")
                    continue
                src = -1

            region_list.append(
                {
                    "mask_images": mask_map,
                    "src" : src,
                    "crop_generation_rate" : model_config.region_map[r]["crop_generation_rate"] if "crop_generation_rate" in model_config.region_map[r] else 0
                }
            )
            region2index[r]=src

    ip_adapter_config_map = None

    if prev_ip_map is not None:
        ip_adapter_config_map={}
        ip_adapter_config_map["scale"] = model_config.ip_adapter_map["scale"]
        ip_adapter_config_map["is_plus"] = model_config.ip_adapter_map["is_plus"]
        ip_adapter_config_map["is_plus_face"] = model_config.ip_adapter_map["is_plus_face"] if "is_plus_face" in model_config.ip_adapter_map else False
        ip_adapter_config_map["is_light"] = model_config.ip_adapter_map["is_light"] if "is_light" in model_config.ip_adapter_map else False
        ip_adapter_config_map["is_full_face"] = model_config.ip_adapter_map["is_full_face"] if "is_full_face" in model_config.ip_adapter_map else False
        for c in region_condi_list:
            if c["ip_adapter_map"] == None:
                logger.info(f"fill map")
                c["ip_adapter_map"] = prev_ip_map




    #for c in region_condi_list:
    #    logger.info(f"{c['prompt_map']=}")


    if not region_condi_list:
        raise ValueError("erro! There is not a single valid region")

    return region_condi_list, region_list, ip_adapter_config_map, region2index

def img2img_preprocess(
        img2img_config_map: Dict[str, Any] = None,
        width: int = 512,
        height: int = 512,
        duration: int = 16,
        out_dir: PathLike = ...,
        ):

    img2img_map={}

    processed = False

    if img2img_config_map:
        if img2img_config_map["enable"] == True:
            image_dir = data_dir.joinpath( img2img_config_map["init_img_dir"] )
            imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False))
            if len(imgs) > 0:
                img2img_map["images"] = {}
                img2img_map["denoising_strength"] = img2img_config_map["denoising_strength"]
                for img_path in tqdm(imgs, desc=f"Preprocessing images (img2img)"):
                    frame_no = int(Path(img_path).stem)
                    if frame_no < duration:
                        img2img_map["images"][frame_no] = get_resized_image(img_path, width, height)
                        processed = True

            if (img2img_config_map["save_init_image"] == True) and processed:
                det_dir = out_dir.joinpath(f"{0:02d}_img2img_init_img/")
                det_dir.mkdir(parents=True, exist_ok=True)
                for frame_no in tqdm(img2img_map["images"], desc=f"Saving Preprocessed images (img2img)"):
                    save_path = det_dir.joinpath(f"{frame_no:08d}.png")
                    img2img_map["images"][frame_no].save(save_path)

    return img2img_map if processed else None

def mask_preprocess(
        region_config_map: Dict[str, Any] = None,
        width: int = 512,
        height: int = 512,
        duration: int = 16,
        out_dir: PathLike = ...,
        ):

    mask_map={}

    processed = False
    size = None
    mode = None

    if region_config_map:
        image_dir = data_dir.joinpath( region_config_map["mask_dir"] )
        imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False))
        if len(imgs) > 0:
            for img_path in tqdm(imgs, desc=f"Preprocessing images (mask)"):
                frame_no = int(Path(img_path).stem)
                if frame_no < duration:
                    mask_map[frame_no] = get_resized_image(img_path, width, height)
                    if size is None:
                        size = mask_map[frame_no].size
                        mode = mask_map[frame_no].mode

                    processed = True

        if processed:
            if 0 in mask_map:
                prev_img = mask_map[0]
            else:
                prev_img = Image.new(mode, size, color=0)

            for i in range(duration):
                if i in mask_map:
                    prev_img = mask_map[i]
                else:
                    mask_map[i] = prev_img

        if (region_config_map["save_mask"] == True) and processed:
            det_dir = out_dir.joinpath(f"mask/")
            det_dir.mkdir(parents=True, exist_ok=True)
            for frame_no in tqdm(mask_map, desc=f"Saving Preprocessed images (mask)"):
                save_path = det_dir.joinpath(f"{frame_no:08d}.png")
                mask_map[frame_no].save(save_path)

    return mask_map if processed else None

def wild_card_conversion(model_config: ModelConfig = ...,):
    from animatediff.utils.wild_card import replace_wild_card

    wild_card_dir = get_dir("wildcards")
    for k in model_config.prompt_map.keys():
        model_config.prompt_map[k] = replace_wild_card(model_config.prompt_map[k], wild_card_dir)

    if model_config.head_prompt:
        model_config.head_prompt = replace_wild_card(model_config.head_prompt, wild_card_dir)
    if model_config.tail_prompt:
        model_config.tail_prompt = replace_wild_card(model_config.tail_prompt, wild_card_dir)

    model_config.prompt_fixed_ratio = max(min(1.0, model_config.prompt_fixed_ratio),0)

    if model_config.region_map:
        for r in model_config.region_map:
            if r == "background":
                continue

            if "condition" in model_config.region_map[r]:
                c = model_config.region_map[r]["condition"]
                for k in c["prompt_map"].keys():
                    c["prompt_map"][k] = replace_wild_card(c["prompt_map"][k], wild_card_dir)

                if "head_prompt" in c:
                    c["head_prompt"] = replace_wild_card(c["head_prompt"], wild_card_dir)
                if "tail_prompt" in c:
                    c["tail_prompt"] = replace_wild_card(c["tail_prompt"], wild_card_dir)
                if "prompt_fixed_ratio" in c:
                    c["prompt_fixed_ratio"] = max(min(1.0, c["prompt_fixed_ratio"]),0)

def save_output(
        pipeline_output,
        frame_dir:str,
        out_file:str,
        output_map : Dict[str,Any] = {},
        no_frames : bool = False,
        save_frames=save_frames,
        save_video=None,
):

    output_format = "gif"
    output_fps = 8
    if output_map:
        output_format = output_map["format"] if "format" in output_map else output_format
        output_fps = output_map["fps"] if "fps" in output_map else output_fps
        if output_format == "mp4":
            output_format = "h264"

    if output_format == "gif":
        out_file = out_file.with_suffix(".gif")
        if no_frames is not True:
            if save_frames:
                save_frames(pipeline_output,frame_dir)

            # generate the output filename and save the video
            if save_video:
                save_video(pipeline_output, out_file, output_fps)
            else:
                pipeline_output[0].save(
                    fp=out_file, format="GIF", append_images=pipeline_output[1:], save_all=True, duration=(1 / output_fps * 1000), loop=0
                )

    else:

        if save_frames:
            save_frames(pipeline_output,frame_dir)

        from animatediff.rife.ffmpeg import (FfmpegEncoder, VideoCodec,
                                             codec_extn)

        out_file = out_file.with_suffix( f".{codec_extn(output_format)}" )

        logger.info("Creating ffmpeg encoder...")
        encoder = FfmpegEncoder(
            frames_dir=frame_dir,
            out_file=out_file,
            codec=output_format,
            in_fps=output_fps,
            out_fps=output_fps,
            lossless=False,
            param= output_map["encode_param"] if "encode_param" in output_map else {}
        )
        logger.info("Encoding interpolated frames with ffmpeg...")
        result = encoder.encode()
        logger.debug(f"ffmpeg result: {result}")



def run_inference(
    pipeline: DiffusionPipeline,
    n_prompt: str = ...,
    seed: int = -1,
    steps: int = 25,
    guidance_scale: float = 7.5,
    unet_batch_size: int = 1,
    width: int = 512,
    height: int = 512,
    duration: int = 16,
    idx: int = 0,
    out_dir: PathLike = ...,
    context_frames: int = -1,
    context_stride: int = 3,
    context_overlap: int = 4,
    context_schedule: str = "uniform",
    clip_skip: int = 1,
    controlnet_map: Dict[str, Any] = None,
    controlnet_image_map: Dict[str,Any] = None,
    controlnet_type_map: Dict[str,Any] = None,
    controlnet_ref_map: Dict[str,Any] = None,
    controlnet_no_shrink:List[str]=None,
    no_frames :bool = False,
    img2img_map: Dict[str,Any] = None,
    ip_adapter_config_map: Dict[str,Any] = None,
    region_list: List[Any] = None,
    region_condi_list: List[Any] = None,
    output_map: Dict[str,Any] = None,
    is_single_prompt_mode: bool = False,
    is_sdxl:bool=False,
    apply_lcm_lora:bool=False,
    gradual_latent_map: Dict[str,Any] = None,
):
    out_dir = Path(out_dir)  # ensure out_dir is a Path

    # Trim and clean up the prompt for filename use
    prompt_map = region_condi_list[0]["prompt_map"]
    prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")]
    prompt_str = "_".join((prompt_tags[:6]))[:50]
    frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}")
    out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}")

    def preview_callback(i: int, video: torch.Tensor, save_fn: Callable[[torch.Tensor], None], out_file: str) -> None:
        save_fn(video, out_file=Path(f"{out_file}_preview@{i}"))

    save_fn = partial(
        save_output,
        frame_dir=frame_dir,
        output_map=output_map,
        no_frames=no_frames,
        save_frames=partial(save_frames, show_progress=False),
        save_video=save_video
    )
    callback = partial(preview_callback, save_fn=save_fn, out_file=out_file)

    seed_everything(seed)

    logger.info(f"{len( region_condi_list )=}")
    logger.info(f"{len( region_list )=}")

    pipeline_output = pipeline(
        negative_prompt=n_prompt,
        num_inference_steps=steps,
        guidance_scale=guidance_scale,
        unet_batch_size=unet_batch_size,
        width=width,
        height=height,
        video_length=duration,
        return_dict=False,
        context_frames=context_frames,
        context_stride=context_stride + 1,
        context_overlap=context_overlap,
        context_schedule=context_schedule,
        clip_skip=clip_skip,
        controlnet_type_map=controlnet_type_map,
        controlnet_image_map=controlnet_image_map,
        controlnet_ref_map=controlnet_ref_map,
        controlnet_no_shrink=controlnet_no_shrink,
        controlnet_max_samples_on_vram=controlnet_map["max_samples_on_vram"] if "max_samples_on_vram" in controlnet_map else 999,
        controlnet_max_models_on_vram=controlnet_map["max_models_on_vram"] if "max_models_on_vram" in controlnet_map else 99,
        controlnet_is_loop = controlnet_map["is_loop"] if "is_loop" in controlnet_map else True,
        img2img_map=img2img_map,
        ip_adapter_config_map=ip_adapter_config_map,
        region_list=region_list,
        region_condi_list=region_condi_list,
        interpolation_factor=1,
        is_single_prompt_mode=is_single_prompt_mode,
        apply_lcm_lora=apply_lcm_lora,
        gradual_latent_map=gradual_latent_map,
        callback=callback,
        callback_steps=output_map.get("preview_steps"),
    )
    logger.info("Generation complete, saving...")

    save_fn(pipeline_output, out_file=out_file)

    logger.info(f"Saved sample to {out_file}")
    return pipeline_output


def run_upscale(
    org_imgs: List[str],
    pipeline: DiffusionPipeline,
    prompt_map: Dict[int, str] = None,
    n_prompt: str = ...,
    seed: int = -1,
    steps: int = 25,
    strength: float = 0.5,
    guidance_scale: float = 7.5,
    clip_skip: int = 1,
    us_width: int = 512,
    us_height: int = 512,
    idx: int = 0,
    out_dir: PathLike = ...,
    upscale_config:Dict[str, Any]=None,
    use_controlnet_ref: bool = False,
    use_controlnet_tile: bool = False,
    use_controlnet_line_anime: bool = False,
    use_controlnet_ip2p: bool = False,
    no_frames:bool = False,
    output_map: Dict[str,Any] = None,
):
    from animatediff.utils.lpw_stable_diffusion import lpw_encode_prompt

    pipeline.set_progress_bar_config(disable=True)

    images = get_resized_images(org_imgs, us_width, us_height)

    steps = steps if "steps" not in upscale_config else upscale_config["steps"]
    scheduler = scheduler if "scheduler" not in upscale_config else upscale_config["scheduler"]
    guidance_scale = guidance_scale if "guidance_scale" not in upscale_config else upscale_config["guidance_scale"]
    clip_skip = clip_skip if "clip_skip" not in upscale_config else upscale_config["clip_skip"]
    strength = strength if "strength" not in upscale_config else upscale_config["strength"]

    controlnet_conditioning_scale = []
    guess_mode = []
    control_guidance_start = []
    control_guidance_end = []

    # for controlnet tile
    if use_controlnet_tile:
        controlnet_conditioning_scale.append(upscale_config["controlnet_tile"]["controlnet_conditioning_scale"])
        guess_mode.append(upscale_config["controlnet_tile"]["guess_mode"])
        control_guidance_start.append(upscale_config["controlnet_tile"]["control_guidance_start"])
        control_guidance_end.append(upscale_config["controlnet_tile"]["control_guidance_end"])

    # for controlnet line_anime
    if use_controlnet_line_anime:
        controlnet_conditioning_scale.append(upscale_config["controlnet_line_anime"]["controlnet_conditioning_scale"])
        guess_mode.append(upscale_config["controlnet_line_anime"]["guess_mode"])
        control_guidance_start.append(upscale_config["controlnet_line_anime"]["control_guidance_start"])
        control_guidance_end.append(upscale_config["controlnet_line_anime"]["control_guidance_end"])

    # for controlnet ip2p
    if use_controlnet_ip2p:
        controlnet_conditioning_scale.append(upscale_config["controlnet_ip2p"]["controlnet_conditioning_scale"])
        guess_mode.append(upscale_config["controlnet_ip2p"]["guess_mode"])
        control_guidance_start.append(upscale_config["controlnet_ip2p"]["control_guidance_start"])
        control_guidance_end.append(upscale_config["controlnet_ip2p"]["control_guidance_end"])

    # for controlnet ref
    ref_image = None
    if use_controlnet_ref:
        if not upscale_config["controlnet_ref"]["use_frame_as_ref_image"] and not upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]:
            ref_image = get_resized_images([ data_dir.joinpath( upscale_config["controlnet_ref"]["ref_image"] ) ], us_width, us_height)[0]


    generator = torch.manual_seed(seed)

    seed_everything(seed)

    prompt_embeds_map = {}
    prompt_map = dict(sorted(prompt_map.items()))
    negative = None

    do_classifier_free_guidance=guidance_scale > 1.0

    prompt_list = [prompt_map[key_frame] for key_frame in prompt_map.keys()]

    prompt_embeds,neg_embeds = lpw_encode_prompt(
        pipe=pipeline,
        prompt=prompt_list,
        do_classifier_free_guidance=do_classifier_free_guidance,
        negative_prompt=n_prompt,
    )

    if do_classifier_free_guidance:
        negative = neg_embeds.chunk(neg_embeds.shape[0], 0)
        positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0)
    else:
        negative = [None]
        positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0)

    for i, key_frame in enumerate(prompt_map):
        prompt_embeds_map[key_frame] = positive[i]

    key_first =list(prompt_map.keys())[0]
    key_last =list(prompt_map.keys())[-1]

    def get_current_prompt_embeds(
            center_frame: int = 0,
            video_length : int = 0
            ):

        key_prev = key_last
        key_next = key_first

        for p in prompt_map.keys():
            if p > center_frame:
                key_next = p
                break
            key_prev = p

        dist_prev = center_frame - key_prev
        if dist_prev < 0:
            dist_prev += video_length
        dist_next = key_next - center_frame
        if dist_next < 0:
            dist_next += video_length

        if key_prev == key_next or dist_prev + dist_next == 0:
            return prompt_embeds_map[key_prev]

        rate = dist_prev / (dist_prev + dist_next)

        return get_tensor_interpolation_method()(prompt_embeds_map[key_prev],prompt_embeds_map[key_next], rate)


    line_anime_processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")


    out_images=[]

    logger.info(f"{use_controlnet_tile=}")
    logger.info(f"{use_controlnet_line_anime=}")
    logger.info(f"{use_controlnet_ip2p=}")

    logger.info(f"{controlnet_conditioning_scale=}")
    logger.info(f"{guess_mode=}")
    logger.info(f"{control_guidance_start=}")
    logger.info(f"{control_guidance_end=}")


    for i, org_image in enumerate(tqdm(images, desc=f"Upscaling...")):

        cur_positive = get_current_prompt_embeds(i, len(images))

#        logger.info(f"w {condition_image.size[0]}")
#        logger.info(f"h {condition_image.size[1]}")
        condition_image = []

        if use_controlnet_tile:
            condition_image.append( org_image )
        if use_controlnet_line_anime:
            condition_image.append( line_anime_processor(org_image) )
        if use_controlnet_ip2p:
            condition_image.append( org_image )

        if not use_controlnet_ref:
            out_image = pipeline(
                prompt_embeds=cur_positive,
                negative_prompt_embeds=negative[0],
                image=org_image,
                control_image=condition_image,
                width=org_image.size[0],
                height=org_image.size[1],
                strength=strength,
                num_inference_steps=steps,
                guidance_scale=guidance_scale,
                generator=generator,

                controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0],
                guess_mode= guess_mode[0],
                control_guidance_start= control_guidance_start if len(control_guidance_start) > 1 else control_guidance_start[0],
                control_guidance_end= control_guidance_end if len(control_guidance_end) > 1 else control_guidance_end[0],

            ).images[0]
        else:

            if upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]:
                if i == 0:
                    ref_image = org_image
            elif upscale_config["controlnet_ref"]["use_frame_as_ref_image"]:
                ref_image = org_image

            out_image = pipeline(
                prompt_embeds=cur_positive,
                negative_prompt_embeds=negative[0],
                image=org_image,
                control_image=condition_image,
                width=org_image.size[0],
                height=org_image.size[1],
                strength=strength,
                num_inference_steps=steps,
                guidance_scale=guidance_scale,
                generator=generator,

                controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0],
                guess_mode= guess_mode[0],
                # control_guidance_start= control_guidance_start,
                # control_guidance_end= control_guidance_end,

                ### for controlnet ref
                ref_image=ref_image,
                attention_auto_machine_weight = upscale_config["controlnet_ref"]["attention_auto_machine_weight"],
                gn_auto_machine_weight = upscale_config["controlnet_ref"]["gn_auto_machine_weight"],
                style_fidelity = upscale_config["controlnet_ref"]["style_fidelity"],
                reference_attn= upscale_config["controlnet_ref"]["reference_attn"],
                reference_adain= upscale_config["controlnet_ref"]["reference_adain"],

            ).images[0]

        out_images.append(out_image)

    # Trim and clean up the prompt for filename use
    prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")]
    prompt_str = "_".join((prompt_tags[:6]))[:50]

    # generate the output filename and save the video
    out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}")

    frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}-upscaled")

    save_output( out_images, frame_dir, out_file, output_map, no_frames, save_imgs, None )

    logger.info(f"Saved sample to {out_file}")

    return out_images