surokpro2's picture
Upload 47 files
215c4b7 verified
# Add parent directory to sys.path
from collections import defaultdict
import gc
import os, sys
from pathlib import Path
from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
parent_dir = Path.cwd().parent.resolve()
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from typing import Dict, List, Literal, Optional, TypedDict, Type, Union
import torch
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers import FluxTransformer2DModel
from diffusers import FluxPipeline
from diffusers.models.embeddings import apply_rotary_emb
from SDLens.cache_and_edit.hooks import locate_block
import torch.nn.functional as F
from diffusers.models.attention_processor import FluxAttnProcessor2_0
class QKVCache(TypedDict):
query: List[torch.Tensor]
key: List[torch.Tensor]
value: List[torch.Tensor]
class CachedFluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, external_cache: QKVCache,
inject_kv: Literal["image", "text", "both"]= None,
text_seq_length: int = 512):
"""Constructor for Cached attention processor.
Args:
external_cache (QKVCache): cache to store/inject values.
inject_kv (Literal["image", "text", "both"], optional): whether to inject image, text or both streams KV.
If None, it does not perform injection but the full cache is stored. Defaults to None.
"""
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.cache = external_cache
self.inject_kv = inject_kv
self.text_seq_length = text_seq_length
assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# Cache Q, K, V
if self.inject_kv == "image":
# NOTE: I am replacing key and values only for the image branch
# NOTE: in default settings, encoder_hidden_states_key_proh.shape[2] == 512
# the first element of the batch is the image whose key and value will be injected into all the other images
key[1:, :, self.text_seq_length:] = key[:1, :, self.text_seq_length:]
value[1:, :, self.text_seq_length:] = value[:1, :, self.text_seq_length:]
elif self.inject_kv == "text":
key[1:, :, :self.text_seq_length] = key[:1, :, :self.text_seq_length]
value[1:, :, :self.text_seq_length] = value[:1, :, :self.text_seq_length]
elif self.inject_kv == "both":
key[1:] = key[:1]
value[1:] = value[:1]
else: # Don't inject, store cache!
self.cache["query"].append(query)
self.cache["key"].append(key)
self.cache["value"].append(value)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class CachedFluxAttnProcessor3_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, external_cache: QKVCache,
inject_kv: Literal["image", "text", "both"]= None,
inject_kv_foreground: bool = False,
text_seq_length: int = 512,
q_mask: Optional[torch.Tensor] = None,):
"""Constructor for Cached attention processor.
Args:
external_cache (QKVCache): cache to store/inject values.
inject_kv (Literal["image", "text", "both"], optional): whether to inject image, text or both streams KV.
If None, it does not perform injection but the full cache is stored. Defaults to None.
"""
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.cache = external_cache
self.inject_kv = inject_kv
self.inject_kv_foreground = inject_kv_foreground
self.text_seq_length = text_seq_length
self.q_mask = q_mask
assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
# # Cache Q, K, V
# if self.inject_kv == "image":
# # NOTE: I am replacing key and values only for the image branch
# # NOTE: in default settings, encoder_hidden_states_key_proh.shape[2] == 512
# # the first element of the batch is the image whose key and value will be injected into all the other images
# key[1:, :, self.text_seq_length:] = key[:1, :, self.text_seq_length:]
# value[1:, :, self.text_seq_length:] = value[:1, :, self.text_seq_length:]
# elif self.inject_kv == "text":
# key[1:, :, :self.text_seq_length] = key[:1, :, :self.text_seq_length]
# value[1:, :, :self.text_seq_length] = value[:1, :, :self.text_seq_length]
# elif self.inject_kv == "both":
# key[1:] = key[:1]
# value[1:] = value[:1]
# else: # Don't inject, store cache!
# self.cache["query"].append(query)
# self.cache["key"].append(key)
# self.cache["value"].append(value)
# extend the mask to match key and values dimension:
# Shape of mask is: (num_image_tokens, 1)
mask = self.q_mask.permute(1, 0).unsqueeze(0).unsqueeze(-1) # Shape: (1, num_image_tokens, 1, 1)
# put mask on gpu
mask = mask.to(key.device)
# first check that we inject only kv in images:
if self.inject_kv is not None and self.inject_kv != "image":
raise NotImplementedError("Injecting is implemented only for images.")
# the second element of the batch is the number of heads
# The first element of the batch represents the background image, the second element of the batch
# represents the foreground image. The third element represents the image where we want to inject
# the key and value of the background image and foreground image according to the query mask.
# Inject from background (element 0) where mask is True
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# Get the index range after the text tokens
start_idx = self.text_seq_length
if self.inject_kv_foreground and self.inject_kv == "image":
key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[:1, :, start_idx:])
value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[:1, :, start_idx:])
elif self.inject_kv == "image" and not self.inject_kv_foreground:
key[2:, :, start_idx:] = torch.where(mask, key[2:, :, start_idx:], key[:1, :, start_idx:])
value[2:, :, start_idx:] = torch.where(mask, value[2:, :, start_idx:], value[:1, :, start_idx:])
elif self.inject_kv is None and self.inject_kv_foreground:
key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[2:, :, start_idx:])
value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[2:, :, start_idx:])
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# mask hidden states from bg:
# hidden_states = hidden_states_fg[:, :, start_idx:] * mask + hidden_states_bg[:, :, start_idx:] * (~mask)
# concatenate the text
#hidden_states = torch.cat([hidden_states_bg[:, :, :start_idx], hidden_states], dim=2)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class QKVCacheFluxHandler:
"""Used to cache queries, keys and values of a FluxPipeline.
"""
def __init__(self, pipe: Union[FluxPipeline, EditedFluxPipeline],
positions_to_cache: List[str] = None,
positions_to_cache_foreground: List[str] = None,
inject_kv: Literal["image", "text", "both"] = None,
text_seq_length: int = 512,
q_mask: Optional[torch.Tensor] = None,
processor_class: Optional[Type] = CachedFluxAttnProcessor3_0
):
print(type(pipe))
if not isinstance(pipe, FluxPipeline) and not isinstance(pipe, EditedFluxPipeline):
raise NotImplementedError(f"QKVCache not yet implemented for {type(pipe)}.")
self.pipe = pipe
if positions_to_cache is not None:
self.positions_to_cache = positions_to_cache
else:
# act on all transformer layers
self.positions_to_cache = []
if positions_to_cache_foreground is not None:
self.positions_to_cache_foreground = positions_to_cache_foreground
else:
self.positions_to_cache_foreground = []
self._cache = {"query": [], "key": [], "value": []}
# Set Cached Processor to perform editing
all_layers = [f"transformer.transformer_blocks.{i}" for i in range(19)] + \
[f"transformer.single_transformer_blocks.{i}" for i in range(38)]
for module_name in all_layers:
inject_kv = "image" if module_name in self.positions_to_cache else None
inject_kv_foreground = module_name in self.positions_to_cache_foreground
module = locate_block(pipe, module_name)
module.attn.set_processor(processor_class(external_cache=self._cache,
inject_kv=inject_kv,
inject_kv_foreground=inject_kv_foreground,
text_seq_length=text_seq_length,
q_mask=q_mask,
))
@property
def cache(self) -> QKVCache:
"""Returns a dictionary initialized as {"query": [], "key": [], "value": []}.
After calling a forward pass for pipe, queries, keys and values will be
appended in the respective list for each layer.
Returns:
Dict[str, List[torch.Tensor]]: cache dictionary containing 'query', 'key' and 'value'
"""
return self._cache
def clear_cache(self) -> None:
# TODO: check if we have to force clean GPU memory
del(self._cache)
gc.collect() # force Python to clean up unreachable objects
torch.cuda.empty_cache() # tell PyTorch to release unused GPU memory from its cache
self._cache = {"query": [], "key": [], "value": []}
for module_name in self.positions_to_cache:
module = locate_block(self.pipe, module_name)
module.attn.set_processor(FluxAttnProcessor2_0())
class TFICONAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self,
external_cache: QKVCache,
inject_kv: Literal["image", "text", "both"]= None,
inject_kv_foreground: bool = False,
text_seq_length: int = 512,
q_mask: Optional[torch.Tensor] = None,
call_max_times = None,
inject_q = True,
inject_k = True,
inject_v = True,
):
"""Constructor for Cached attention processor.
Args:
external_cache (QKVCache): cache to store/inject values.
inject_kv (Literal["image", "text", "both"], optional): whether to inject image, text or both streams KV.
If None, it does not perform injection but the full cache is stored. Defaults to None.
"""
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.cache = external_cache
self.inject_kv = inject_kv
self.inject_kv_foreground = inject_kv_foreground
self.text_seq_length = text_seq_length
self.q_mask = q_mask
self.inject_q = inject_q
self.inject_k = inject_k
self.inject_v = inject_v
self.call_max_times = call_max_times
if self.call_max_times is not None:
self.num_calls = call_max_times
else:
self.num_calls = None
assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# hidden states are the image patches (B, 4096, hidden_dim)
# encoder_hidden_states are the text tokens (B, 512, hidden_dim)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# concat inputs for attention -> (B, num_heads, 512 + 4096, head_dim)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
# TODO: try first witout mask
# Cache Q, K, V
# extend the mask to match key and values dimension:
# Shape of mask is: (num_image_tokens, 1)
mask = self.q_mask.permute(1, 0).unsqueeze(0).unsqueeze(-1) # Shape: (1, num_image_tokens, 1, 1)
# put mask on gpu
mask = mask.to(key.device)
# first check that we inject only kv in images:
if self.inject_kv is not None and self.inject_kv != "image":
raise NotImplementedError("Injecting is implemented only for images.")
# the second element of the batch is the number of heads
# The first element of the batch represents the background image, the second element of the batch
# represents the foreground image. The third element represents the image where we want to inject
# the key and value of the background image and foreground image according to the query mask.
# Inject from background (element 0) where mask is True
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# Get the index range after the text tokens
start_idx = self.text_seq_length
# Batch is formed as follow:
# - background image (0)
# - foreground image (1)
# - composition(s) (2, 3, ...)
# Create the combined attention mask, by forming Q_comp and K_comp, taking the Q and K of the background image
# when outside of the mask, the one of the foreground image when inside the mask
if self.num_calls is None or self.num_calls > 0:
if self.inject_kv_foreground:
if self.inject_k:
key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[0:1, :, start_idx:])
if self.inject_q:
query[2:, :, start_idx:] = torch.where(mask, query[1:2, :, start_idx:], query[0:1, :, start_idx:])
if self.inject_v:
value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[0:1, :, start_idx:])
else:
if self.inject_k:
key[2:, :, start_idx:] = torch.where(mask, key[2:, :, start_idx:], key[0:1, :, start_idx:])
if self.inject_q:
query[2:, :, start_idx:] = torch.where(mask, query[2:, :, start_idx:], query[0:1, :, start_idx:])
if self.inject_v:
value[2:, :, start_idx:] = torch.where(mask, value[2:, :, start_idx:], value[0:1, :, start_idx:])
if self.num_calls is not None:
self.num_calls -= 1
# Use the combined attention map to compute attention using V from the composition image
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# hidden_states[2:, :, start_idx:] = torch.where(mask, weightage * hidden_states[1:2, :, start_idx:] + (1-weightage) * hidden_states[2:, :, start_idx:], hidden_states[2:, :, start_idx:])
# concatenate the text
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states