import random

import torch
import sys

from diffusers import Transformer2DModel
from torch import nn
from torch.nn import Parameter
from torch.nn.modules.module import T
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.zipper_resampler import ZipperResampler
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
from toolkit.util.inverse_cfg import inverse_classifier_guidance

from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
from collections import OrderedDict
from toolkit.util.ip_adapter_utils import AttnProcessor2_0, IPAttnProcessor2_0, ImageProjModel
from toolkit.resampler import Resampler
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref
from diffusers import FluxTransformer2DModel

if TYPE_CHECKING:
    from toolkit.stable_diffusion_model import StableDiffusion

from transformers import (
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    AutoImageProcessor,
    ConvNextV2ForImageClassification,
    ConvNextForImageClassification,
    ConvNextImageProcessor
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel

from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification

from transformers import ViTFeatureExtractor, ViTForImageClassification

import torch.nn.functional as F


class MLPProjModelClipFace(torch.nn.Module):
    def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
        super().__init__()

        self.cross_attention_dim = cross_attention_dim
        self.num_tokens = num_tokens
        self.norm = torch.nn.LayerNorm(id_embeddings_dim)

        self.proj = torch.nn.Sequential(
            torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
            torch.nn.GELU(),
            torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
        )
        # Initialize the last linear layer weights near zero
        torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
        torch.nn.init.zeros_(self.proj[2].bias)
        # # Custom initialization for LayerNorm to output near zero
        # torch.nn.init.constant_(self.norm.weight, 0.1)  # Small weights near zero
        # torch.nn.init.zeros_(self.norm.bias)  # Bias to zero

    def forward(self, x):
        x = self.norm(x)
        x = self.proj(x)
        x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
        return x


class CustomIPAttentionProcessor(IPAttnProcessor2_0):
    def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False):
        super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
        self.adapter_ref: weakref.ref = weakref.ref(adapter)
        self.train_scaler = train_scaler
        if train_scaler:
            if full_token_scaler:
                self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
            else:
                self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
            # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
            self.ip_scaler.requires_grad_(True)

    def __call__(
            self,
            attn,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            temb=None,
    ):
        is_active = self.adapter_ref().is_active
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if is_active:
            # since we are removing tokens, we need to adjust the sequence length
            sequence_length = sequence_length - self.num_tokens

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states

        # will be none if disabled
        if not is_active:
            ip_hidden_states = None
            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            elif attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        try:
            hidden_states = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            )
        except Exception as e:
            print(e)
            raise e

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # will be none if disabled
        if ip_hidden_states is not None:
            # apply scaler
            if self.train_scaler:
                weight = self.ip_scaler
                # reshape to (1, self.num_tokens, 1)
                weight = weight.view(1, -1, 1)
                ip_hidden_states = ip_hidden_states * weight

            # for ip-adapter
            ip_key = self.to_k_ip(ip_hidden_states)
            ip_value = self.to_v_ip(ip_hidden_states)

            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

            # the output of sdp = (batch, num_heads, seq_len, head_dim)
            # TODO: add support for attn.scale when we move to Torch 2.1
            ip_hidden_states = F.scaled_dot_product_attention(
                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
            )

            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            ip_hidden_states = ip_hidden_states.to(query.dtype)

            scale = self.scale
            hidden_states = hidden_states + scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

    # this ensures that the ip_scaler is not changed when we load the model
    # def _apply(self, fn):
    #     if hasattr(self, "ip_scaler"):
    #         # Overriding the _apply method to prevent the special_parameter from changing dtype
    #         self.ip_scaler = fn(self.ip_scaler)
    #         # Temporarily set the special_parameter to None to exclude it from default _apply processing
    #         ip_scaler = self.ip_scaler
    #         self.ip_scaler = None
    #         super(CustomIPAttentionProcessor, self)._apply(fn)
    #         # Restore the special_parameter after the default _apply processing
    #         self.ip_scaler = ip_scaler
    #         return self
    #     else:
    #         return super(CustomIPAttentionProcessor, self)._apply(fn)


class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False,
                 full_token_scaler=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.adapter_ref: weakref.ref = weakref.ref(adapter)
        self.train_scaler = train_scaler
        self.num_tokens = num_tokens
        if train_scaler:
            if full_token_scaler:
                self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
            else:
                self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
            # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
            self.ip_scaler.requires_grad_(True)

    def __call__(
        self,
        attn,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        is_active = self.adapter_ref().is_active
        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
        if encoder_hidden_states is not None:
            # `context` projections.
            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)

            if attn.norm_added_q is not None:
                encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
            if attn.norm_added_k is not None:
                encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

            # attention
            query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb

            query = apply_rotary_emb(query, image_rotary_emb)
            key = apply_rotary_emb(key, image_rotary_emb)

        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # begin ip adapter
        if not is_active:
            ip_hidden_states = None
        else:
            # get ip hidden states. Should be stored
            ip_hidden_states = self.adapter_ref().last_conditional
            # add unconditional to front if it exists
            if ip_hidden_states.shape[0] * 2 == batch_size:
                if self.adapter_ref().last_unconditional is None:
                    raise ValueError("Unconditional is None but should not be")
                ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0)

        if ip_hidden_states is not None:
            # apply scaler
            if self.train_scaler:
                weight = self.ip_scaler
                # reshape to (1, self.num_tokens, 1)
                weight = weight.view(1, -1, 1)
                ip_hidden_states = ip_hidden_states * weight

            # for ip-adapter
            ip_key = self.to_k_ip(ip_hidden_states)
            ip_value = self.to_v_ip(ip_hidden_states)

            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

            ip_hidden_states = F.scaled_dot_product_attention(
                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
            )

            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            ip_hidden_states = ip_hidden_states.to(query.dtype)

            scale = self.scale
            hidden_states = hidden_states + scale * ip_hidden_states
        # end ip adapter

        if encoder_hidden_states is not None:
            encoder_hidden_states, hidden_states = (
                hidden_states[:, : encoder_hidden_states.shape[1]],
                hidden_states[:, encoder_hidden_states.shape[1] :],
            )

            # linear proj
            hidden_states = attn.to_out[0](hidden_states)
            # dropout
            hidden_states = attn.to_out[1](hidden_states)
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

            return hidden_states, encoder_hidden_states
        else:
            return hidden_states

# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
class IPAdapter(torch.nn.Module):
    """IP-Adapter"""

    def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
        super().__init__()
        self.config = adapter_config
        self.sd_ref: weakref.ref = weakref.ref(sd)
        self.device = self.sd_ref().unet.device
        self.preprocessor: Optional[CLIPImagePreProcessor] = None
        self.input_size = 224
        self.clip_noise_zero = True
        self.unconditional: torch.Tensor = None

        self.last_conditional: torch.Tensor = None
        self.last_unconditional: torch.Tensor = None

        self.additional_loss = None
        if self.config.image_encoder_arch.startswith("clip"):
            try:
                self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
            except EnvironmentError:
                self.clip_image_processor = CLIPImageProcessor()
            self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
                adapter_config.image_encoder_path,
                ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        elif self.config.image_encoder_arch == 'siglip':
            from transformers import SiglipImageProcessor, SiglipVisionModel
            try:
                self.clip_image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path)
            except EnvironmentError:
                self.clip_image_processor = SiglipImageProcessor()
            self.image_encoder = SiglipVisionModel.from_pretrained(
                adapter_config.image_encoder_path,
                ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        elif self.config.image_encoder_arch == 'vit':
            try:
                self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
            except EnvironmentError:
                self.clip_image_processor = ViTFeatureExtractor()
            self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
                self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        elif self.config.image_encoder_arch == 'safe':
            try:
                self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
            except EnvironmentError:
                self.clip_image_processor = SAFEImageProcessor()
            self.image_encoder = SAFEVisionModel(
                in_channels=3,
                num_tokens=self.config.safe_tokens,
                num_vectors=sd.unet.config['cross_attention_dim'],
                reducer_channels=self.config.safe_reducer_channels,
                channels=self.config.safe_channels,
                downscale_factor=8
            ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        elif self.config.image_encoder_arch == 'convnext':
            try:
                self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
            except EnvironmentError:
                print(f"could not load image processor from {adapter_config.image_encoder_path}")
                self.clip_image_processor = ConvNextImageProcessor(
                    size=320,
                    image_mean=[0.48145466, 0.4578275, 0.40821073],
                    image_std=[0.26862954, 0.26130258, 0.27577711],
                )
            self.image_encoder = ConvNextForImageClassification.from_pretrained(
                adapter_config.image_encoder_path,
                use_safetensors=True,
            ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        elif self.config.image_encoder_arch == 'convnextv2':
            try:
                self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path)
            except EnvironmentError:
                print(f"could not load image processor from {adapter_config.image_encoder_path}")
                self.clip_image_processor = ConvNextImageProcessor(
                    size=512,
                    image_mean=[0.485, 0.456, 0.406],
                    image_std=[0.229, 0.224, 0.225],
                )
            self.image_encoder = ConvNextV2ForImageClassification.from_pretrained(
                adapter_config.image_encoder_path,
                use_safetensors=True,
            ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        elif self.config.image_encoder_arch == 'vit-hybrid':
            try:
                self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
            except EnvironmentError:
                print(f"could not load image processor from {adapter_config.image_encoder_path}")
                self.clip_image_processor = ViTHybridImageProcessor(
                    size=320,
                    image_mean=[0.48145466, 0.4578275, 0.40821073],
                    image_std=[0.26862954, 0.26130258, 0.27577711],
                )
            self.image_encoder = ViTHybridForImageClassification.from_pretrained(
                adapter_config.image_encoder_path,
                use_safetensors=True,
            ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        else:
            raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")

        if not self.config.train_image_encoder:
            # compile it
            print('Compiling image encoder')
            #torch.compile(self.image_encoder, fullgraph=True)

        self.input_size = self.image_encoder.config.image_size

        if self.config.quad_image:  # 4x4 image
            # self.clip_image_processor.config
            # We do a 3x downscale of the image, so we need to adjust the input size
            preprocessor_input_size = self.image_encoder.config.image_size * 2

            # update the preprocessor so images come in at the right size
            if 'height' in self.clip_image_processor.size:
                self.clip_image_processor.size['height'] = preprocessor_input_size
                self.clip_image_processor.size['width'] = preprocessor_input_size
            elif hasattr(self.clip_image_processor, 'crop_size'):
                self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
                self.clip_image_processor.crop_size['height'] = preprocessor_input_size
                self.clip_image_processor.crop_size['width'] = preprocessor_input_size

        if self.config.image_encoder_arch == 'clip+':
            # self.clip_image_processor.config
            # We do a 3x downscale of the image, so we need to adjust the input size
            preprocessor_input_size = self.image_encoder.config.image_size * 4

            # update the preprocessor so images come in at the right size
            self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
            self.clip_image_processor.crop_size['height'] = preprocessor_input_size
            self.clip_image_processor.crop_size['width'] = preprocessor_input_size

            self.preprocessor = CLIPImagePreProcessor(
                input_size=preprocessor_input_size,
                clip_input_size=self.image_encoder.config.image_size,
            )
        if not self.config.image_encoder_arch == 'safe':
            if 'height' in self.clip_image_processor.size:
                self.input_size = self.clip_image_processor.size['height']
            elif hasattr(self.clip_image_processor, 'crop_size'):
                self.input_size = self.clip_image_processor.crop_size['height']
            elif 'shortest_edge' in self.clip_image_processor.size.keys():
                self.input_size = self.clip_image_processor.size['shortest_edge']
            else:
                raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
        self.current_scale = 1.0
        self.is_active = True
        is_pixart = sd.is_pixart
        is_flux = sd.is_flux
        if adapter_config.type == 'ip':
            # ip-adapter
            image_proj_model = ImageProjModel(
                cross_attention_dim=sd.unet.config['cross_attention_dim'],
                clip_embeddings_dim=self.image_encoder.config.projection_dim,
                clip_extra_context_tokens=self.config.num_tokens,  # usually 4
            )
        elif adapter_config.type == 'ip_clip_face':
            cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
            image_proj_model = MLPProjModelClipFace(
                cross_attention_dim=cross_attn_dim,
                id_embeddings_dim=self.image_encoder.config.projection_dim,
                num_tokens=self.config.num_tokens,  # usually 4
            )
        elif adapter_config.type == 'ip+':
            heads = 12 if not sd.is_xl else 20
            if is_flux:
                dim = 1280
            else:
                dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
            embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
                'convnext') else \
                self.image_encoder.config.hidden_sizes[-1]

            image_encoder_state_dict = self.image_encoder.state_dict()
            # max_seq_len = CLIP tokens + CLS token
            max_seq_len = 257
            if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
                # clip
                max_seq_len = int(
                    image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])

            if is_pixart:
                heads = 20
                dim = 1280
                output_dim = 4096
            elif is_flux:
                heads = 20
                dim = 1280
                output_dim = 3072
            else:
                output_dim = sd.unet.config['cross_attention_dim']

            if self.config.image_encoder_arch.startswith('convnext'):
                in_tokens = 16 * 16
                embedding_dim = self.image_encoder.config.hidden_sizes[-1]

            # ip-adapter-plus
            image_proj_model = Resampler(
                dim=dim,
                depth=4,
                dim_head=64,
                heads=heads,
                num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len,
                embedding_dim=embedding_dim,
                max_seq_len=max_seq_len,
                output_dim=output_dim,
                ff_mult=4
            )
        elif adapter_config.type == 'ipz':
            dim = sd.unet.config['cross_attention_dim']
            if hasattr(self.image_encoder.config, 'hidden_sizes'):
                embedding_dim = self.image_encoder.config.hidden_sizes[-1]
            else:
                embedding_dim = self.image_encoder.config.target_hidden_size

            image_encoder_state_dict = self.image_encoder.state_dict()
            # max_seq_len = CLIP tokens + CLS token
            in_tokens = 257
            if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
                # clip
                in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])

            if self.config.image_encoder_arch.startswith('convnext'):
                in_tokens = 16 * 16
                embedding_dim = self.image_encoder.config.hidden_sizes[-1]

            is_conv_next = self.config.image_encoder_arch.startswith('convnext')

            out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens
            # ip-adapter-plus
            image_proj_model = ZipperResampler(
                in_size=embedding_dim,
                in_tokens=in_tokens,
                out_size=dim,
                out_tokens=out_tokens,
                hidden_size=embedding_dim,
                hidden_tokens=in_tokens,
                # num_blocks=1 if not is_conv_next else 2,
                num_blocks=1 if not is_conv_next else 2,
                is_conv_input=is_conv_next
            )
        elif adapter_config.type == 'ilora':
            # we apply the clip encodings to the LoRA
            image_proj_model = None
        else:
            raise ValueError(f"unknown adapter type: {adapter_config.type}")

        # init adapter modules
        attn_procs = {}
        unet_sd = sd.unet.state_dict()
        attn_processor_keys = []
        if is_pixart:
            transformer: Transformer2DModel = sd.unet
            for i, module in transformer.transformer_blocks.named_children():
                attn_processor_keys.append(f"transformer_blocks.{i}.attn1")

                # cross attention
                attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
        elif is_flux:
            transformer: FluxTransformer2DModel = sd.unet
            for i, module in transformer.transformer_blocks.named_children():
                attn_processor_keys.append(f"transformer_blocks.{i}.attn")

            # single transformer blocks do not have cross attn, but we will do them anyway
            for i, module in transformer.single_transformer_blocks.named_children():
                attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
        else:
            attn_processor_keys = list(sd.unet.attn_processors.keys())

        attn_processor_names = []

        blocks = []
        transformer_blocks = []
        for name in attn_processor_keys:
            name_split = name.split(".")
            block_name = f"{name_split[0]}.{name_split[1]}"
            transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1
            if transformer_idx >= 0:
                transformer_name = ".".join(name_split[:2])
                transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2])
                if transformer_name not in transformer_blocks:
                    transformer_blocks.append(transformer_name)


            if block_name not in blocks:
                blocks.append(block_name)
            if is_flux:
                cross_attention_dim = None
            else:
                cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
                    sd.unet.config['cross_attention_dim']
            if name.startswith("mid_block"):
                hidden_size = sd.unet.config['block_out_channels'][-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = sd.unet.config['block_out_channels'][block_id]
            elif name.startswith("transformer") or name.startswith("single_transformer"):
                if is_flux:
                    hidden_size = 3072
                else:
                    hidden_size = sd.unet.config['cross_attention_dim']
            else:
                # they didnt have this, but would lead to undefined below
                raise ValueError(f"unknown attn processor name: {name}")
            if cross_attention_dim is None and not is_flux:
                attn_procs[name] = AttnProcessor2_0()
            else:
                layer_name = name.split(".processor")[0]

                # if quantized, we need to scale the weights
                if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux:
                    # is quantized

                    k_weight = torch.randn(hidden_size, hidden_size) * 0.01
                    v_weight = torch.randn(hidden_size, hidden_size) * 0.01
                    k_weight = k_weight.to(self.sd_ref().torch_dtype)
                    v_weight = v_weight.to(self.sd_ref().torch_dtype)
                else:
                    k_weight = unet_sd[layer_name + ".to_k.weight"]
                    v_weight = unet_sd[layer_name + ".to_v.weight"]

                weights = {
                    "to_k_ip.weight": k_weight,
                    "to_v_ip.weight": v_weight
                }

                if is_flux:
                    attn_procs[name] = CustomIPFluxAttnProcessor2_0(
                        hidden_size=hidden_size,
                        cross_attention_dim=cross_attention_dim,
                        scale=1.0,
                        num_tokens=self.config.num_tokens,
                        adapter=self,
                        train_scaler=self.config.train_scaler or self.config.merge_scaler,
                        full_token_scaler=False
                    )
                else:
                    attn_procs[name] = CustomIPAttentionProcessor(
                        hidden_size=hidden_size,
                        cross_attention_dim=cross_attention_dim,
                        scale=1.0,
                        num_tokens=self.config.num_tokens,
                        adapter=self,
                        train_scaler=self.config.train_scaler or self.config.merge_scaler,
                        # full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
                        full_token_scaler=False
                    )
                if self.sd_ref().is_pixart or self.sd_ref().is_flux:
                    # pixart is much more sensitive
                    weights = {
                        "to_k_ip.weight": weights["to_k_ip.weight"] * 0.01,
                        "to_v_ip.weight": weights["to_v_ip.weight"] * 0.01,
                    }

                attn_procs[name].load_state_dict(weights, strict=False)
                attn_processor_names.append(name)
        print(f"Attn Processors")
        print(attn_processor_names)
        if self.sd_ref().is_pixart:
            # we have to set them ourselves
            transformer: Transformer2DModel = sd.unet
            for i, module in transformer.transformer_blocks.named_children():
                module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
                module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
            self.adapter_modules = torch.nn.ModuleList(
                [
                    transformer.transformer_blocks[i].attn2.processor for i in
                    range(len(transformer.transformer_blocks))
                ])
        elif self.sd_ref().is_flux:
            # we have to set them ourselves
            transformer: FluxTransformer2DModel = sd.unet
            for i, module in transformer.transformer_blocks.named_children():
                module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]

            # do single blocks too even though they dont have cross attn
            for i, module in transformer.single_transformer_blocks.named_children():
                module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]

            self.adapter_modules = torch.nn.ModuleList(
                [
                    transformer.transformer_blocks[i].attn.processor for i in
                    range(len(transformer.transformer_blocks))
                ] + [
                    transformer.single_transformer_blocks[i].attn.processor for i in
                    range(len(transformer.single_transformer_blocks))
                ]
            )
        else:
            sd.unet.set_attn_processor(attn_procs)
            self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())

        sd.adapter = self
        self.unet_ref: weakref.ref = weakref.ref(sd.unet)
        self.image_proj_model = image_proj_model
        # load the weights if we have some
        if self.config.name_or_path:
            loaded_state_dict = load_ip_adapter_model(
                self.config.name_or_path,
                device='cpu',
                dtype=sd.torch_dtype
            )
            self.load_state_dict(loaded_state_dict)

        self.set_scale(1.0)

        if self.config.train_image_encoder:
            self.image_encoder.train()
            self.image_encoder.requires_grad_(True)

        # premake a unconditional
        zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16)
        self.unconditional = self.clip_image_processor(
            images=zerod,
            return_tensors="pt",
            do_resize=True,
            do_rescale=False,
        ).pixel_values

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.image_encoder.to(*args, **kwargs)
        self.image_proj_model.to(*args, **kwargs)
        self.adapter_modules.to(*args, **kwargs)
        if self.preprocessor is not None:
            self.preprocessor.to(*args, **kwargs)
        return self

    # def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
    #     self.image_proj_model.load_state_dict(state_dict["image_proj"])
    #     ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
    #     ip_layers.load_state_dict(state_dict["ip_adapter"])
    #     if self.config.train_image_encoder and 'image_encoder' in state_dict:
    #         self.image_encoder.load_state_dict(state_dict["image_encoder"])
    #     if self.preprocessor is not None and 'preprocessor' in state_dict:
    #         self.preprocessor.load_state_dict(state_dict["preprocessor"])

    # def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
    #     self.load_ip_adapter(state_dict)

    def state_dict(self) -> OrderedDict:
        state_dict = OrderedDict()
        if self.config.train_only_image_encoder:
            return self.image_encoder.state_dict()
        if self.config.train_scaler:
            state_dict["ip_scale"] = self.adapter_modules.state_dict()
            # remove items that are not scalers
            for key in list(state_dict["ip_scale"].keys()):
                if not key.endswith("ip_scaler"):
                    del state_dict["ip_scale"][key]
            return state_dict

        state_dict["image_proj"] = self.image_proj_model.state_dict()
        state_dict["ip_adapter"] = self.adapter_modules.state_dict()
        # handle merge scaler training
        if self.config.merge_scaler:
            for key in list(state_dict["ip_adapter"].keys()):
                if key.endswith("ip_scaler"):
                    # merge in the scaler so we dont have to save it and it will be compatible with other ip adapters
                    scale = state_dict["ip_adapter"][key].clone()

                    key_start = key.split(".")[-2]
                    # reshape to (1, 1)
                    scale = scale.view(1, 1)
                    del state_dict["ip_adapter"][key]
                    # find the to_k_ip and to_v_ip keys
                    for key2 in list(state_dict["ip_adapter"].keys()):
                        if key2.endswith(f"{key_start}.to_k_ip.weight"):
                            state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale
                        if key2.endswith(f"{key_start}.to_v_ip.weight"):
                            state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale

        if self.config.train_image_encoder:
            state_dict["image_encoder"] = self.image_encoder.state_dict()
        if self.preprocessor is not None:
            state_dict["preprocessor"] = self.preprocessor.state_dict()
        return state_dict

    def get_scale(self):
        return self.current_scale

    def set_scale(self, scale):
        self.current_scale = scale
        if not self.sd_ref().is_pixart and not self.sd_ref().is_flux:
            for attn_processor in self.sd_ref().unet.attn_processors.values():
                if isinstance(attn_processor, CustomIPAttentionProcessor):
                    attn_processor.scale = scale

    # @torch.no_grad()
    # def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
    #                                    drop=False) -> torch.Tensor:
    #     # todo: add support for sdxl
    #     if isinstance(pil_image, Image.Image):
    #         pil_image = [pil_image]
    #     clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
    #     clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
    #     if drop:
    #         clip_image = clip_image * 0
    #     clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
    #     return clip_image_embeds

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.image_encoder.to(*args, **kwargs)
        self.image_proj_model.to(*args, **kwargs)
        self.adapter_modules.to(*args, **kwargs)
        if self.preprocessor is not None:
            self.preprocessor.to(*args, **kwargs)
        return self

    def parse_clip_image_embeds_from_cache(
            self,
            image_embeds_list: List[dict],  # has ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
            quad_count=4,
    ):
        with torch.no_grad():
            device = self.sd_ref().unet.device
            clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0)

            if self.config.quad_image:
                # get the outputs of the quat
                chunks = clip_image_embeds.chunk(quad_count, dim=0)
                chunk_sum = torch.zeros_like(chunks[0])
                for chunk in chunks:
                    chunk_sum = chunk_sum + chunk
                # get the mean of them

                clip_image_embeds = chunk_sum / quad_count

            clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
        return clip_image_embeds

    def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
        with torch.no_grad():
            tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
            noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
                                     dtype=get_torch_dtype(self.sd_ref().dtype))
            tensors_0_1 = tensors_0_1 * noise_scale
            # tensors_0_1 = tensors_0_1 * 0
            mean = torch.tensor(self.clip_image_processor.image_mean).to(
                self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
            ).detach()
            std = torch.tensor(self.clip_image_processor.image_std).to(
                self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
            ).detach()
            tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
            clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
        return clip_image.detach()

    def get_clip_image_embeds_from_tensors(
            self,
            tensors_0_1: torch.Tensor,
            drop=False,
            is_training=False,
            has_been_preprocessed=False,
            quad_count=4,
            cfg_embed_strength=None, # perform CFG on embeds with unconditional as negative
    ) -> torch.Tensor:
        if self.sd_ref().unet.device != self.device:
            self.to(self.sd_ref().unet.device)
        if self.sd_ref().unet.device != self.image_encoder.device:
            self.to(self.sd_ref().unet.device)
        if not self.config.train:
            is_training = False
        uncond_clip = None
        with torch.no_grad():
            # on training the clip image is created in the dataloader
            if not has_been_preprocessed:
                # tensors should be 0-1
                if tensors_0_1.ndim == 3:
                    tensors_0_1 = tensors_0_1.unsqueeze(0)
                # training tensors are 0 - 1
                tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)

                # if images are out of this range throw error
                if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
                    raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
                        tensors_0_1.min(), tensors_0_1.max()
                    ))
                # unconditional
                if drop:
                    if self.clip_noise_zero:
                        tensors_0_1 = torch.rand_like(tensors_0_1).detach()
                        noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
                                                 dtype=get_torch_dtype(self.sd_ref().dtype))
                        tensors_0_1 = tensors_0_1 * noise_scale
                    else:
                        tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
                    # tensors_0_1 = tensors_0_1 * 0
                clip_image = self.clip_image_processor(
                    images=tensors_0_1,
                    return_tensors="pt",
                    do_resize=True,
                    do_rescale=False,
                ).pixel_values
            else:
                if drop:
                    # scale the noise down
                    if self.clip_noise_zero:
                        tensors_0_1 = torch.rand_like(tensors_0_1).detach()
                        noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
                                                 dtype=get_torch_dtype(self.sd_ref().dtype))
                        tensors_0_1 = tensors_0_1 * noise_scale
                    else:
                        tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
                    # tensors_0_1 = tensors_0_1 * 0
                    mean = torch.tensor(self.clip_image_processor.image_mean).to(
                        self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
                    ).detach()
                    std = torch.tensor(self.clip_image_processor.image_std).to(
                        self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
                    ).detach()
                    tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
                    clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])

                else:
                    clip_image = tensors_0_1
            clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()

            if self.config.quad_image:
                # split the 4x4 grid and stack on batch
                ci1, ci2 = clip_image.chunk(2, dim=2)
                ci1, ci3 = ci1.chunk(2, dim=3)
                ci2, ci4 = ci2.chunk(2, dim=3)
                to_cat = []
                for i, ci in enumerate([ci1, ci2, ci3, ci4]):
                    if i < quad_count:
                        to_cat.append(ci)
                    else:
                        break

                clip_image = torch.cat(to_cat, dim=0).detach()

            # if drop:
            #     clip_image = clip_image * 0
        with torch.set_grad_enabled(is_training):
            if is_training and self.config.train_image_encoder:
                self.image_encoder.train()
                clip_image = clip_image.requires_grad_(True)
                if self.preprocessor is not None:
                    clip_image = self.preprocessor(clip_image)
                clip_output = self.image_encoder(
                    clip_image,
                    output_hidden_states=True
                )
            else:
                self.image_encoder.eval()
                if self.preprocessor is not None:
                    clip_image = self.preprocessor(clip_image)
                clip_output = self.image_encoder(
                    clip_image, output_hidden_states=True
                )

            if self.config.clip_layer == 'penultimate_hidden_states':
                # they skip last layer for ip+
                # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
                clip_image_embeds = clip_output.hidden_states[-2]
            elif self.config.clip_layer == 'last_hidden_state':
                clip_image_embeds = clip_output.hidden_states[-1]
            else:
                clip_image_embeds = clip_output.image_embeds

                if self.config.adapter_type == "clip_face":
                    l2_norm = torch.norm(clip_image_embeds, p=2)
                    clip_image_embeds = clip_image_embeds / l2_norm

            if self.config.image_encoder_arch.startswith('convnext'):
                # flatten the width height layers to make the token space
                clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
                # rearrange to (batch, tokens, size)
                clip_image_embeds = clip_image_embeds.permute(0, 2, 1)

            # apply unconditional if doing cfg on embeds
            with torch.no_grad():
                if cfg_embed_strength is not None:
                    uncond_clip = self.get_empty_clip_image(tensors_0_1.shape[0]).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
                    if self.config.quad_image:
                        # split the 4x4 grid and stack on batch
                        ci1, ci2 = uncond_clip.chunk(2, dim=2)
                        ci1, ci3 = ci1.chunk(2, dim=3)
                        ci2, ci4 = ci2.chunk(2, dim=3)
                        to_cat = []
                        for i, ci in enumerate([ci1, ci2, ci3, ci4]):
                            if i < quad_count:
                                to_cat.append(ci)
                            else:
                                break

                        uncond_clip = torch.cat(to_cat, dim=0).detach()
                    uncond_clip_output = self.image_encoder(
                        uncond_clip, output_hidden_states=True
                    )

                    if self.config.clip_layer == 'penultimate_hidden_states':
                        uncond_clip_output_embeds = uncond_clip_output.hidden_states[-2]
                    elif self.config.clip_layer == 'last_hidden_state':
                        uncond_clip_output_embeds = uncond_clip_output.hidden_states[-1]
                    else:
                        uncond_clip_output_embeds = uncond_clip_output.image_embeds
                        if self.config.adapter_type == "clip_face":
                            l2_norm = torch.norm(uncond_clip_output_embeds, p=2)
                            uncond_clip_output_embeds = uncond_clip_output_embeds / l2_norm

                    uncond_clip_output_embeds = uncond_clip_output_embeds.detach()


                    # apply inverse cfg
                    clip_image_embeds = inverse_classifier_guidance(
                        clip_image_embeds,
                        uncond_clip_output_embeds,
                        cfg_embed_strength
                    )


            if self.config.quad_image:
                # get the outputs of the quat
                chunks = clip_image_embeds.chunk(quad_count, dim=0)
                if self.config.train_image_encoder and is_training:
                    # perform a loss across all chunks this will teach the vision encoder to
                    # identify similarities in our pairs of images and ignore things that do not make them similar
                    num_losses = 0
                    total_loss = None
                    for chunk in chunks:
                        for chunk2 in chunks:
                            if chunk is not chunk2:
                                loss = F.mse_loss(chunk, chunk2)
                                if total_loss is None:
                                    total_loss = loss
                                else:
                                    total_loss = total_loss + loss
                                num_losses += 1
                    if total_loss is not None:
                        total_loss = total_loss / num_losses
                        total_loss = total_loss * 1e-2
                        if self.additional_loss is not None:
                            total_loss = total_loss + self.additional_loss
                        self.additional_loss = total_loss

                chunk_sum = torch.zeros_like(chunks[0])
                for chunk in chunks:
                    chunk_sum = chunk_sum + chunk
                # get the mean of them

                clip_image_embeds = chunk_sum / quad_count

        if not is_training or not self.config.train_image_encoder:
            clip_image_embeds = clip_image_embeds.detach()

        return clip_image_embeds

    # use drop for prompt dropout, or negatives
    def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor, is_unconditional=False) -> PromptEmbeds:
        clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        if self.sd_ref().is_flux:
            # do not attach to text embeds for flux, we will save and grab them as it messes
            # with the RoPE to have them in the same tensor
            if is_unconditional:
                self.last_unconditional = image_prompt_embeds
            else:
                self.last_conditional = image_prompt_embeds
        else:
            embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
        return embeddings

    def train(self: T, mode: bool = True) -> T:
        if self.config.train_image_encoder:
            self.image_encoder.train(mode)
        if not self.config.train_only_image_encoder:
            for attn_processor in self.adapter_modules:
                attn_processor.train(mode)
        if self.image_proj_model is not None:
            self.image_proj_model.train(mode)
        return super().train(mode)

    def get_parameter_groups(self, adapter_lr):
        param_groups = []
        # when training just scaler, we do not train anything else
        if not self.config.train_scaler:
            param_groups.append({
                "params": list(self.get_non_scaler_parameters()),
                "lr": adapter_lr,
            })
        if self.config.train_scaler or self.config.merge_scaler:
            scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
            param_groups.append({
                "params": list(self.get_scaler_parameters()),
                "lr": scaler_lr,
            })
        return param_groups

    def get_scaler_parameters(self):
        # only get the scalera from the adapter modules
        for attn_processor in self.adapter_modules:
            # only get the scaler
            # check if it has ip_scaler attribute
            if hasattr(attn_processor, "ip_scaler"):
                scaler_param = attn_processor.ip_scaler
                yield scaler_param

    def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        if self.config.train_only_image_encoder:
            if self.config.train_only_image_encoder_positional_embedding:
                yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse)
            else:
                yield from self.image_encoder.parameters(recurse)
            return
        if self.config.train_scaler:
            # no params
            return

        for attn_processor in self.adapter_modules:
            if self.config.train_scaler or self.config.merge_scaler:
                # todo remove scaler
                if hasattr(attn_processor, "to_k_ip"):
                    # yield the linear layer
                    yield from attn_processor.to_k_ip.parameters(recurse)
                if hasattr(attn_processor, "to_v_ip"):
                    # yield the linear layer
                    yield from attn_processor.to_v_ip.parameters(recurse)
            else:
                yield from attn_processor.parameters(recurse)
        yield from self.image_proj_model.parameters(recurse)
        if self.config.train_image_encoder:
            yield from self.image_encoder.parameters(recurse)
        if self.preprocessor is not None:
            yield from self.preprocessor.parameters(recurse)

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        yield from self.get_non_scaler_parameters(recurse)
        if self.config.train_scaler or self.config.merge_scaler:
            yield from self.get_scaler_parameters()

    def merge_in_weights(self, state_dict: Mapping[str, Any]):
        # merge in img_proj weights
        current_img_proj_state_dict = self.image_proj_model.state_dict()
        for key, value in state_dict["image_proj"].items():
            if key in current_img_proj_state_dict:
                current_shape = current_img_proj_state_dict[key].shape
                new_shape = value.shape
                if current_shape != new_shape:
                    try:
                        # merge in what we can and leave the other values as they are
                        if len(current_shape) == 1:
                            current_img_proj_state_dict[key][:new_shape[0]] = value
                        elif len(current_shape) == 2:
                            current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
                        elif len(current_shape) == 3:
                            current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
                        elif len(current_shape) == 4:
                            current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
                            :new_shape[3]] = value
                        else:
                            raise ValueError(f"unknown shape: {current_shape}")
                    except RuntimeError as e:
                        print(e)
                        print(
                            f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")

                        if len(current_shape) == 1:
                            current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
                        elif len(current_shape) == 2:
                            current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[
                                                                                                     :current_shape[0],
                                                                                                     :current_shape[1]]
                        elif len(current_shape) == 3:
                            current_img_proj_state_dict[key][:current_shape[0], :current_shape[1],
                            :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
                        elif len(current_shape) == 4:
                            current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
                            :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
                                                 :current_shape[3]]
                        else:
                            raise ValueError(f"unknown shape: {current_shape}")
                        print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
                else:
                    current_img_proj_state_dict[key] = value
        self.image_proj_model.load_state_dict(current_img_proj_state_dict)

        # merge in ip adapter weights
        current_ip_adapter_state_dict = self.adapter_modules.state_dict()
        for key, value in state_dict["ip_adapter"].items():
            if key in current_ip_adapter_state_dict:
                current_shape = current_ip_adapter_state_dict[key].shape
                new_shape = value.shape
                if current_shape != new_shape:
                    try:
                        # merge in what we can and leave the other values as they are
                        if len(current_shape) == 1:
                            current_ip_adapter_state_dict[key][:new_shape[0]] = value
                        elif len(current_shape) == 2:
                            current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
                        elif len(current_shape) == 3:
                            current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
                        elif len(current_shape) == 4:
                            current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
                            :new_shape[3]] = value
                        else:
                            raise ValueError(f"unknown shape: {current_shape}")
                        print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
                    except RuntimeError as e:
                        print(e)
                        print(
                            f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")

                        if (len(current_shape) == 1):
                            current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
                        elif (len(current_shape) == 2):
                            current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[
                                                                                                       :current_shape[
                                                                                                           0],
                                                                                                       :current_shape[
                                                                                                           1]]
                        elif (len(current_shape) == 3):
                            current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1],
                            :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
                        elif (len(current_shape) == 4):
                            current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
                            :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
                                                 :current_shape[3]]
                        else:
                            raise ValueError(f"unknown shape: {current_shape}")
                        print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")

                else:
                    current_ip_adapter_state_dict[key] = value
        self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        strict = False
        if self.config.train_scaler and 'ip_scale' in state_dict:
            self.adapter_modules.load_state_dict(state_dict["ip_scale"], strict=False)
        if 'ip_adapter' in state_dict:
            try:
                self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
                self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
            except Exception as e:
                print(e)
                print("could not load ip adapter weights, trying to merge in weights")
                self.merge_in_weights(state_dict)
        if self.config.train_image_encoder and 'image_encoder' in state_dict:
            self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
        if self.preprocessor is not None and 'preprocessor' in state_dict:
            self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)

        if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict:
            # we are loading pure clip weights.
            self.image_encoder.load_state_dict(state_dict, strict=strict)

    def enable_gradient_checkpointing(self):
        if hasattr(self.image_encoder, "enable_gradient_checkpointing"):
            self.image_encoder.enable_gradient_checkpointing()
        elif hasattr(self.image_encoder, 'gradient_checkpointing'):
            self.image_encoder.gradient_checkpointing = True