|
from collections import defaultdict |
|
from functools import partial |
|
import gc |
|
from typing import Callable, Dict, List, Literal, Union, Optional, Type, Union |
|
import torch |
|
from SDLens.cache_and_edit.activation_cache import FluxActivationCache, ModelActivationCache, PixartActivationCache, ActivationCacheHandler |
|
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock |
|
from SDLens.cache_and_edit.hooks import locate_block, register_general_hook, fix_inf_values_hook, edit_streams_hook |
|
from SDLens.cache_and_edit.qkv_cache import QKVCacheFluxHandler, QKVCache, CachedFluxAttnProcessor3_0 |
|
from SDLens.cache_and_edit.scheduler_inversion import FlowMatchEulerDiscreteSchedulerForInversion |
|
from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline |
|
|
|
from diffusers.pipelines import FluxPipeline |
|
|
|
|
|
|
|
class CachedPipeline: |
|
|
|
def __init__(self, pipe: EditedFluxPipeline, text_seq_length: int = 512): |
|
|
|
assert isinstance(pipe, EditedFluxPipeline) or isinstance(pipe, FluxPipeline), "Use EditedFluxPipeline class in `cache_and_edit/flux_pipeline.py`" |
|
self.pipe = pipe |
|
self.text_seq_length = text_seq_length |
|
|
|
|
|
self.activation_cache_handler = None |
|
self.qkv_cache_handler = None |
|
|
|
self.registered_hooks = [] |
|
|
|
|
|
def setup_cache(self, use_activation_cache = True, |
|
use_qkv_cache = False, |
|
positions_to_cache: List[str] = None, |
|
positions_to_cache_foreground: List[str] = None, |
|
qkv_to_inject: QKVCache = None, |
|
inject_kv_mode: Literal["image", "text", "both"] = None, |
|
q_mask=None, |
|
processor_class: Optional[Type] = CachedFluxAttnProcessor3_0 |
|
) -> None: |
|
""" |
|
Sets up activation_cache and/or qkv_cache, setting the required hooks. |
|
If positions_to_cache is None, then all modules will be cached. |
|
If inject_kv_mode is None, then qkv cache will be stored, otherwise qkv_to_inject will be injected. |
|
""" |
|
|
|
if use_activation_cache: |
|
if isinstance(self.pipe, EditedFluxPipeline) or isinstance(self.pipe, FluxPipeline): |
|
activation_cache = FluxActivationCache() |
|
else: |
|
raise AssertionError(f"activation cache not implemented for {type(self.pipe)}") |
|
|
|
self.activation_cache_handler = ActivationCacheHandler(activation_cache, positions_to_cache) |
|
|
|
self._set_hooks(position_hook_dict=self.activation_cache_handler.forward_hooks_dict, |
|
with_kwargs=True) |
|
|
|
if use_qkv_cache: |
|
if isinstance(self.pipe, EditedFluxPipeline) or isinstance(self.pipe, FluxPipeline): |
|
self.qkv_cache_handler = QKVCacheFluxHandler(self.pipe, |
|
positions_to_cache, |
|
positions_to_cache_foreground, |
|
inject_kv=inject_kv_mode, |
|
text_seq_length=self.text_seq_length, |
|
q_mask=q_mask, |
|
processor_class=processor_class, |
|
) |
|
else: |
|
raise AssertionError(f"QKV cache not implemented for {type(self.pipe)}") |
|
|
|
|
|
|
|
|
|
@property |
|
def activation_cache(self) -> ModelActivationCache: |
|
return self.activation_cache_handler.cache if hasattr(self, "activation_cache_handler") and self.activation_cache_handler else None |
|
|
|
|
|
@property |
|
def qkv_cache(self) -> QKVCache: |
|
return self.qkv_cache_handler.cache if hasattr(self, "qkv_cache_handler") and self.qkv_cache_handler else None |
|
|
|
|
|
@torch.no_grad |
|
def run(self, |
|
prompt: Union[str, List[str]], |
|
num_inference_steps: int = 1, |
|
seed: int = 42, |
|
width=1024, |
|
height=1024, |
|
cache_activations: bool = False, |
|
cache_qkv: bool = False, |
|
guidance_scale: float = 0.0, |
|
positions_to_cache: List[str] = None, |
|
empty_clip_embeddings: bool = True, |
|
inverse: bool = False, |
|
**kwargs): |
|
"""run the pipeline, possibly cachine activations or QKV. |
|
|
|
Args: |
|
prompt (str): Prompt to run the pipeline (NOTE: for Flux, parameters passed are prompt='' and prompt2=prompt) |
|
num_inference_steps (int, optional): Num steps for inference. Defaults to 1. |
|
seed (int, optional): seed for generators. Defaults to 42. |
|
cache_activations (bool, optional): Whether to cache activations. Defaults to True. |
|
cache_qkv (bool, optional): Whether to cache queries, keys, values. Defaults to False. |
|
positions_to_cache (List[str], optional): list of blocks to cache. |
|
If None, all transformer blocks will be cached. Defaults to None. |
|
|
|
Returns: |
|
_type_: same output as wrapped pipeline. |
|
""" |
|
|
|
|
|
self.clear_all_hooks() |
|
|
|
|
|
if self.activation_cache or self.qkv_cache: |
|
|
|
if self.activation_cache: |
|
del(self.activation_cache_handler.cache) |
|
del(self.activation_cache_handler) |
|
|
|
if self.qkv_cache: |
|
|
|
self.qkv_cache_handler.clear_cache() |
|
del(self.qkv_cache_handler) |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.setup_cache(cache_activations, cache_qkv, positions_to_cache, inject_kv_mode=None) |
|
|
|
assert isinstance(seed, int) |
|
|
|
if isinstance(prompt, str): |
|
empty_prompt = [""] |
|
prompt = [prompt] |
|
else: |
|
empty_prompt = [""] * len(prompt) |
|
|
|
gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))] |
|
|
|
if inverse: |
|
|
|
if not hasattr(self, "inversion_scheduler"): |
|
self.inversion_scheduler = FlowMatchEulerDiscreteSchedulerForInversion.from_config( |
|
self.pipe.scheduler.config, |
|
inverse=True |
|
) |
|
self.og_scheduler = self.pipe.scheduler |
|
|
|
self.pipe.scheduler = self.inversion_scheduler |
|
|
|
output = self.pipe( |
|
prompt=empty_prompt if empty_clip_embeddings else prompt, |
|
prompt_2=prompt, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=gen, |
|
width=width, |
|
height=height, |
|
**kwargs |
|
) |
|
|
|
|
|
if inverse: |
|
self.pipe.scheduler = self.og_scheduler |
|
|
|
return output |
|
|
|
@torch.no_grad |
|
def run_inject_qkv(self, |
|
prompt: Union[str, List[str]], |
|
positions_to_inject: List[str] = None, |
|
positions_to_inject_foreground: List[str] = None, |
|
inject_kv_mode: Literal["image", "text", "both"] = "image", |
|
num_inference_steps: int = 1, |
|
guidance_scale: float = 0.0, |
|
seed: int = 42, |
|
empty_clip_embeddings: bool = True, |
|
q_mask=None, |
|
width: int = 1024, |
|
height: int = 1024, |
|
processor_class: Optional[Type] = CachedFluxAttnProcessor3_0, |
|
**kwargs): |
|
"""run the pipeline, possibly cachine activations or QKV. |
|
|
|
Args: |
|
prompt (str): Prompt to run the pipeline (NOTE: for Flux, parameters passed are prompt='' and prompt2=prompt) |
|
num_inference_steps (int, optional): Num steps for inference. Defaults to 1. |
|
seed (int, optional): seed for generators. Defaults to 42. |
|
cache_activations (bool, optional): Whether to cache activations. Defaults to True. |
|
cache_qkv (bool, optional): Whether to cache queries, keys, values. Defaults to False. |
|
positions_to_cache (List[str], optional): list of blocks to cache. |
|
If None, all transformer blocks will be cached. Defaults to None. |
|
|
|
Returns: |
|
_type_: same output as wrapped pipeline. |
|
""" |
|
|
|
|
|
self.clear_all_hooks() |
|
|
|
|
|
if hasattr(self, "qkv_cache_handler") and self.qkv_cache_handler is not None: |
|
self.qkv_cache_handler.clear_cache() |
|
del(self.qkv_cache_handler) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.setup_cache(use_activation_cache=False, |
|
use_qkv_cache=True, |
|
positions_to_cache=positions_to_inject, |
|
positions_to_cache_foreground=positions_to_inject_foreground, |
|
inject_kv_mode=inject_kv_mode, |
|
q_mask=q_mask, |
|
processor_class=processor_class, |
|
) |
|
|
|
self.qkv_cache_handler |
|
|
|
assert isinstance(seed, int) |
|
|
|
if isinstance(prompt, str): |
|
empty_prompt = [""] |
|
prompt = [prompt] |
|
else: |
|
empty_prompt = [""] * len(prompt) |
|
|
|
gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))] |
|
|
|
output = self.pipe( |
|
prompt=empty_prompt if empty_clip_embeddings else prompt, |
|
prompt_2=prompt, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=gen, |
|
width=width, |
|
height=height, |
|
**kwargs |
|
) |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
def clear_all_hooks(self): |
|
|
|
|
|
for hook in self.registered_hooks: |
|
hook.remove() |
|
self.registered_hooks = [] |
|
|
|
|
|
|
|
for i in range(len(locate_block(self.pipe, "transformer.transformer_blocks"))): |
|
locate_block(self.pipe, f"transformer.transformer_blocks.{i}")._forward_hooks.clear() |
|
|
|
for i in range(len(locate_block(self.pipe, "transformer.single_transformer_blocks"))): |
|
locate_block(self.pipe, f"transformer.single_transformer_blocks.{i}")._forward_hooks.clear() |
|
|
|
|
|
def _set_hooks(self, |
|
position_hook_dict: Dict[str, List[Callable]] = {}, |
|
position_pre_hook_dict: Dict[str, List[Callable]] = {}, |
|
with_kwargs=False |
|
): |
|
''' |
|
Set hooks at specified positions and register them. |
|
Args: |
|
position_hook_dict: A dictionary mapping positions to hooks. |
|
The keys are positions in the pipeline where the hooks should be registered. |
|
The values are either a single hook or a list of hooks to be registered at the specified position. |
|
Each hook should be a callable that takes three arguments: (module, input, output). |
|
**kwargs: Keyword arguments to pass to the pipeline. |
|
''' |
|
|
|
|
|
for is_pre_hook, hook_dict in [(True, position_pre_hook_dict), (False, position_hook_dict)]: |
|
for position, hook in hook_dict.items(): |
|
assert isinstance(hook, list) |
|
for h in hook: |
|
self.registered_hooks.append(register_general_hook(self.pipe, position, h, with_kwargs, is_pre_hook)) |
|
|
|
|
|
def run_with_edit(self, |
|
prompt: str, |
|
edit_fn: callable, |
|
layers_for_edit_fn: List[int], |
|
stream: Literal['text', 'image', 'both'], |
|
guidance_scale: float = 0.0, |
|
seed=42, |
|
num_inference_steps=1, |
|
empty_clip_embeddings: bool = True, |
|
width: int = 1024, |
|
height: int = 1024, |
|
**kwargs, |
|
): |
|
|
|
assert isinstance(seed, int) |
|
|
|
self.clear_all_hooks() |
|
|
|
|
|
|
|
|
|
edit_fn_hooks = {f"transformer.transformer_blocks.{layer}": [lambda *args: edit_streams_hook(*args, recompute_fn=edit_fn, stream=stream)] |
|
for layer in layers_for_edit_fn if layer < 19} |
|
edit_fn_hooks.update({f"transformer.single_transformer_blocks.{layer - 19}": [lambda *args: edit_streams_hook(*args, recompute_fn=edit_fn, stream=stream)] |
|
for layer in layers_for_edit_fn if layer >= 19}) |
|
|
|
|
|
|
|
self._set_hooks(position_hook_dict=edit_fn_hooks, with_kwargs=True) |
|
|
|
|
|
|
|
if isinstance(prompt, str): |
|
empty_prompt = [""] |
|
prompt = [prompt] |
|
else: |
|
empty_prompt = [""] * len(prompt) |
|
|
|
gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))] |
|
|
|
with torch.no_grad(): |
|
output = self.pipe( |
|
prompt=empty_prompt if empty_clip_embeddings else prompt, |
|
prompt_2=prompt, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=gen, |
|
width=width, |
|
height=height, |
|
**kwargs |
|
) |
|
|
|
return output |
|
|