|
import einops |
|
from diffusers import StableDiffusionXLPipeline, IFPipeline |
|
from typing import List, Dict, Callable, Union |
|
import torch |
|
from .hooked_scheduler import HookedNoiseScheduler |
|
|
|
def retrieve(io): |
|
if isinstance(io, tuple): |
|
if len(io) == 1: |
|
return io[0] |
|
else: |
|
raise ValueError("A tuple should have length of 1") |
|
elif isinstance(io, torch.Tensor): |
|
return io |
|
else: |
|
raise ValueError("Input/Output must be a tensor, or 1-element tuple") |
|
|
|
|
|
class HookedDiffusionAbstractPipeline: |
|
parent_cls = None |
|
pipe = None |
|
|
|
def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False): |
|
if use_hooked_scheduler: |
|
pipe.scheduler = HookedNoiseScheduler(pipe.scheduler) |
|
self.__dict__['pipe'] = pipe |
|
self.use_hooked_scheduler = use_hooked_scheduler |
|
|
|
@classmethod |
|
def from_pretrained(cls, *args, **kwargs): |
|
return cls(cls.parent_cls.from_pretrained(*args, **kwargs)) |
|
|
|
|
|
def run_with_hooks(self, |
|
*args, |
|
position_hook_dict: Dict[str, Union[Callable, List[Callable]]], |
|
**kwargs |
|
): |
|
''' |
|
Run the pipeline with hooks at specified positions. |
|
Returns the final output. |
|
|
|
Args: |
|
*args: Arguments to pass to the pipeline. |
|
position_hook_dict: A dictionary mapping positions to hooks. |
|
The keys are positions in the pipeline where the hooks should be registered. |
|
The values are either a single hook or a list of hooks to be registered at the specified position. |
|
Each hook should be a callable that takes three arguments: (module, input, output). |
|
**kwargs: Keyword arguments to pass to the pipeline. |
|
''' |
|
hooks = [] |
|
for position, hook in position_hook_dict.items(): |
|
if isinstance(hook, list): |
|
for h in hook: |
|
hooks.append(self._register_general_hook(position, h)) |
|
else: |
|
hooks.append(self._register_general_hook(position, hook)) |
|
|
|
hooks = [hook for hook in hooks if hook is not None] |
|
|
|
try: |
|
output = self.pipe(*args, **kwargs) |
|
finally: |
|
for hook in hooks: |
|
hook.remove() |
|
if self.use_hooked_scheduler: |
|
self.pipe.scheduler.pre_hooks = [] |
|
self.pipe.scheduler.post_hooks = [] |
|
|
|
return output |
|
|
|
def run_with_cache(self, |
|
*args, |
|
positions_to_cache: List[str], |
|
save_input: bool = False, |
|
save_output: bool = True, |
|
**kwargs |
|
): |
|
''' |
|
Run the pipeline with caching at specified positions. |
|
|
|
This method allows you to cache the intermediate inputs and/or outputs of the pipeline |
|
at certain positions. The final output of the pipeline and a dictionary of cached values |
|
are returned. |
|
|
|
Args: |
|
*args: Arguments to pass to the pipeline. |
|
positions_to_cache (List[str]): A list of positions in the pipeline where intermediate |
|
inputs/outputs should be cached. |
|
save_input (bool, optional): If True, caches the input at each specified position. |
|
Defaults to False. |
|
save_output (bool, optional): If True, caches the output at each specified position. |
|
Defaults to True. |
|
**kwargs: Keyword arguments to pass to the pipeline. |
|
|
|
Returns: |
|
final_output: The final output of the pipeline after execution. |
|
cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions |
|
and values are dictionaries containing the cached 'input' and/or 'output' at each position, |
|
depending on the flags `save_input` and `save_output`. |
|
''' |
|
cache_input, cache_output = dict() if save_input else None, dict() if save_output else None |
|
hooks = [ |
|
self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache |
|
] |
|
hooks = [hook for hook in hooks if hook is not None] |
|
output = self.pipe(*args, **kwargs) |
|
for hook in hooks: |
|
hook.remove() |
|
if self.use_hooked_scheduler: |
|
self.pipe.scheduler.pre_hooks = [] |
|
self.pipe.scheduler.post_hooks = [] |
|
|
|
cache_dict = {} |
|
if save_input: |
|
for position, block in cache_input.items(): |
|
cache_input[position] = torch.stack(block, dim=1) |
|
cache_dict['input'] = cache_input |
|
|
|
if save_output: |
|
for position, block in cache_output.items(): |
|
cache_output[position] = torch.stack(block, dim=1) |
|
cache_dict['output'] = cache_output |
|
return output, cache_dict |
|
|
|
def run_with_hooks_and_cache(self, |
|
*args, |
|
position_hook_dict: Dict[str, Union[Callable, List[Callable]]], |
|
positions_to_cache: List[str] = [], |
|
save_input: bool = False, |
|
save_output: bool = True, |
|
**kwargs |
|
): |
|
''' |
|
Run the pipeline with hooks and caching at specified positions. |
|
|
|
This method allows you to register hooks at certain positions in the pipeline and |
|
cache intermediate inputs and/or outputs at specified positions. Hooks can be used |
|
for inspecting or modifying the pipeline's execution, and caching stores intermediate |
|
values for later inspection or use. |
|
|
|
Args: |
|
*args: Arguments to pass to the pipeline. |
|
position_hook_dict Dict[str, Union[Callable, List[Callable]]]: |
|
A dictionary where the keys are the positions in the pipeline, and the values |
|
are hooks (either a single hook or a list of hooks) to be registered at those positions. |
|
Each hook should be a callable that accepts three arguments: (module, input, output). |
|
positions_to_cache (List[str], optional): A list of positions in the pipeline where |
|
intermediate inputs/outputs should be cached. Defaults to an empty list. |
|
save_input (bool, optional): If True, caches the input at each specified position. |
|
Defaults to False. |
|
save_output (bool, optional): If True, caches the output at each specified position. |
|
Defaults to True. |
|
**kwargs: Additional keyword arguments to pass to the pipeline. |
|
|
|
Returns: |
|
final_output: The final output of the pipeline after execution. |
|
cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions |
|
and values are dictionaries containing the cached 'input' and/or 'output' at each position, |
|
depending on the flags `save_input` and `save_output`. |
|
''' |
|
cache_input, cache_output = dict() if save_input else None, dict() if save_output else None |
|
hooks = [ |
|
self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache |
|
] |
|
|
|
for position, hook in position_hook_dict.items(): |
|
if isinstance(hook, list): |
|
for h in hook: |
|
hooks.append(self._register_general_hook(position, h)) |
|
else: |
|
hooks.append(self._register_general_hook(position, hook)) |
|
|
|
hooks = [hook for hook in hooks if hook is not None] |
|
output = self.pipe(*args, **kwargs) |
|
for hook in hooks: |
|
hook.remove() |
|
if self.use_hooked_scheduler: |
|
self.pipe.scheduler.pre_hooks = [] |
|
self.pipe.scheduler.post_hooks = [] |
|
|
|
cache_dict = {} |
|
if save_input: |
|
for position, block in cache_input.items(): |
|
cache_input[position] = torch.stack(block, dim=1) |
|
cache_dict['input'] = cache_input |
|
|
|
if save_output: |
|
for position, block in cache_output.items(): |
|
cache_output[position] = torch.stack(block, dim=1) |
|
cache_dict['output'] = cache_output |
|
|
|
return output, cache_dict |
|
|
|
|
|
def _locate_block(self, position: str): |
|
''' |
|
Locate the block at the specified position in the pipeline. |
|
''' |
|
block = self.pipe |
|
for step in position.split('.'): |
|
if step.isdigit(): |
|
step = int(step) |
|
block = block[step] |
|
else: |
|
block = getattr(block, step) |
|
return block |
|
|
|
|
|
def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict): |
|
|
|
if position.endswith('$self_attention') or position.endswith('$cross_attention'): |
|
return self._register_cache_attention_hook(position, cache_output) |
|
|
|
if position == 'noise': |
|
def hook(model_output, timestep, sample, generator): |
|
if position not in cache_output: |
|
cache_output[position] = [] |
|
cache_output[position].append(sample) |
|
|
|
if self.use_hooked_scheduler: |
|
self.pipe.scheduler.post_hooks.append(hook) |
|
else: |
|
raise ValueError('Cannot cache noise without using hooked scheduler') |
|
return |
|
|
|
block = self._locate_block(position) |
|
|
|
def hook(module, input, kwargs, output): |
|
if cache_input is not None: |
|
if position not in cache_input: |
|
cache_input[position] = [] |
|
cache_input[position].append(retrieve(input)) |
|
|
|
if cache_output is not None: |
|
if position not in cache_output: |
|
cache_output[position] = [] |
|
cache_output[position].append(retrieve(output)) |
|
|
|
return block.register_forward_hook(hook, with_kwargs=True) |
|
|
|
def _register_cache_attention_hook(self, position, cache): |
|
attn_block = self._locate_block(position.split('$')[0]) |
|
if position.endswith('$self_attention'): |
|
attn_block = attn_block.attn1 |
|
elif position.endswith('$cross_attention'): |
|
attn_block = attn_block.attn2 |
|
else: |
|
raise ValueError('Wrong attention type') |
|
|
|
def hook(module, args, kwargs, output): |
|
hidden_states = args[0] |
|
encoder_hidden_states = kwargs['encoder_hidden_states'] |
|
attention_mask = kwargs['attention_mask'] |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
query = attn_block.to_q(hidden_states) |
|
|
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn_block.norm_cross is not None: |
|
encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states) |
|
|
|
key = attn_block.to_k(encoder_hidden_states) |
|
value = attn_block.to_v(encoder_hidden_states) |
|
|
|
query = attn_block.head_to_batch_dim(query) |
|
key = attn_block.head_to_batch_dim(key) |
|
value = attn_block.head_to_batch_dim(value) |
|
|
|
attention_probs = attn_block.get_attention_scores(query, key, attention_mask) |
|
attention_probs = attention_probs.view( |
|
batch_size, |
|
attention_probs.shape[0] // batch_size, |
|
attention_probs.shape[1], |
|
attention_probs.shape[2] |
|
) |
|
if position not in cache: |
|
cache[position] = [] |
|
cache[position].append(attention_probs) |
|
|
|
return attn_block.register_forward_hook(hook, with_kwargs=True) |
|
|
|
def _register_general_hook(self, position, hook): |
|
if position == 'scheduler_pre': |
|
if not self.use_hooked_scheduler: |
|
raise ValueError('Cannot register hooks on scheduler without using hooked scheduler') |
|
self.pipe.scheduler.pre_hooks.append(hook) |
|
return |
|
elif position == 'scheduler_post': |
|
if not self.use_hooked_scheduler: |
|
raise ValueError('Cannot register hooks on scheduler without using hooked scheduler') |
|
self.pipe.scheduler.post_hooks.append(hook) |
|
return |
|
|
|
block = self._locate_block(position) |
|
return block.register_forward_hook(hook) |
|
|
|
def to(self, *args, **kwargs): |
|
self.pipe = self.pipe.to(*args, **kwargs) |
|
return self |
|
|
|
def __getattr__(self, name): |
|
return getattr(self.pipe, name) |
|
|
|
def __setattr__(self, name, value): |
|
return setattr(self.pipe, name, value) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.pipe(*args, **kwargs) |
|
|
|
|
|
class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline): |
|
parent_cls = StableDiffusionXLPipeline |
|
|
|
|
|
class HookedIFPipeline(HookedDiffusionAbstractPipeline): |
|
parent_cls = IFPipeline |
|
|