sae_flux / SDLens /cache_and_edit /activation_cache.py
surokpro2's picture
Upload 47 files
215c4b7 verified
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import List
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
from SDLens.cache_and_edit.hooks import fix_inf_values_hook, register_general_hook
import torch
class ModelActivationCache(ABC):
"""
Cache for inference pass of a Diffusion Transformer.
Used to cache residual-streams and activations.
"""
def __init__(self):
# Initialize caches for "double transformer" blocks using the subclass-defined NUM_TRANSFORMER_BLOCKS
if hasattr(self, 'NUM_TRANSFORMER_BLOCKS'):
self.image_residual = []
self.image_activation = []
self.text_residual = []
self.text_activation = []
# Initialize caches for "single transformer" blocks if defined (using NUM_SINGLE_TRANSFORMER_BLOCKS)
if hasattr(self, 'NUM_SINGLE_TRANSFORMER_BLOCKS'):
self.text_image_residual = []
self.text_image_activation = []
def __str__(self):
lines = [f"{self.__class__.__name__}:"]
for attr_name, value in self.__dict__.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value):
shapes = value[0].shape
lines.append(f" {attr_name}: len={len(value)}, shapes={shapes}")
else:
lines.append(f" {attr_name}: {type(value)}")
return "\n".join(lines)
def _repr_pretty_(self, p, cycle):
p.text(str(self))
@abstractmethod
def get_cache_info(self):
"""
Return details about the cache configuration.
Subclasses must implement this to provide info on their transformer block counts.
"""
pass
class FluxActivationCache(ModelActivationCache):
# Define number of blocks for double and single transformer caches
NUM_TRANSFORMER_BLOCKS = 19
NUM_SINGLE_TRANSFORMER_BLOCKS = 38
def __init__(self):
super().__init__()
def get_cache_info(self):
return {
"transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
"single_transformer_blocks": self.NUM_SINGLE_TRANSFORMER_BLOCKS,
}
def __getitem__(self, key):
return getattr(self, key)
class PixartActivationCache(ModelActivationCache):
# Define number of blocks for the double transformer cache only
NUM_TRANSFORMER_BLOCKS = 28
def __init__(self):
super().__init__()
def get_cache_info(self):
return {
"double_transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
}
class ActivationCacheHandler:
""" Used to manage ModelActivationCache of a Diffusion Transformer.
"""
def __init__(self, cache: ModelActivationCache, positions_to_cache: List[str] = None):
"""Constructor.
Args:
cache (ModelActivationCache): cache to be used to store tensors.
positions_to_cache (List[str], optional): name of modules to cached.
If None, all modules as specified in `cache.get_cache_info()` will be cached. Defaults to None.
Raises:
NotImplementedError: _description_
Returns:
_type_: _description_
"""
self.cache = cache
self.positions_to_cache = positions_to_cache
@torch.no_grad()
def cache_residual_and_activation_hook(self, *args):
"""
To be used as a forward hook on a Transformer Block.
It caches both residual_stream and activation (defined as output - residual_stream).
"""
if len(args) == 3:
module, input, output = args
elif len(args) == 4:
module, input, kwinput, output = args
if isinstance(module, FluxTransformerBlock):
encoder_hidden_states = output[0]
hidden_states = output[1]
self.cache.image_activation.append(hidden_states - kwinput["hidden_states"])
self.cache.text_activation.append(encoder_hidden_states - kwinput["encoder_hidden_states"])
self.cache.image_residual.append(kwinput["hidden_states"])
self.cache.text_residual.append(kwinput["encoder_hidden_states"])
elif isinstance(module, FluxSingleTransformerBlock):
self.cache.text_image_activation.append(output - kwinput["hidden_states"])
self.cache.text_image_residual.append(kwinput["hidden_states"])
else:
raise NotImplementedError(f"Caching not implemented for {type(module)}")
@property
def forward_hooks_dict(self):
# insert cache storing in dict
hooks = defaultdict(list)
if self.positions_to_cache is None:
for block_type, num_layers in self.cache.get_cache_info().items():
for i in range(num_layers):
module_name: str = f"transformer.{block_type}.{i}"
hooks[module_name].append(fix_inf_values_hook)
hooks[module_name].append(self.cache_residual_and_activation_hook)
else:
for module_name in self.positions_to_cache:
hooks[module_name].append(fix_inf_values_hook)
hooks[module_name].append(self.cache_residual_and_activation_hook)
return hooks