|
from typing import Callable, Literal |
|
import torch |
|
import torch.nn as nn |
|
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock |
|
|
|
|
|
def register_general_hook(pipe, position, hook, with_kwargs=False, is_pre_hook=False): |
|
"""Registers a forward hook in a module of the pipeline specified with 'position' |
|
|
|
Args: |
|
pipe (_type_): _description_ |
|
position (_type_): _description_ |
|
hook (_type_): _description_ |
|
with_kwargs (bool, optional): _description_. Defaults to False. |
|
is_pre_hook (bool, optional): _description_. Defaults to False. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
block: nn.Module = locate_block(pipe, position) |
|
|
|
if is_pre_hook: |
|
return block.register_forward_pre_hook(hook, with_kwargs=with_kwargs) |
|
else: |
|
return block.register_forward_hook(hook, with_kwargs=with_kwargs) |
|
|
|
|
|
def locate_block(pipe, position: str) -> nn.Module: |
|
''' |
|
Locate the block at the specified position in the pipeline. |
|
''' |
|
block = pipe |
|
for step in position.split('.'): |
|
if step.isdigit(): |
|
step = int(step) |
|
block = block[step] |
|
else: |
|
block = getattr(block, step) |
|
return block |
|
|
|
|
|
def _safe_clip(x: torch.Tensor): |
|
if x.dtype == torch.float16: |
|
x[torch.isposinf(x)] = 65504 |
|
x[torch.isneginf(x)] = -65504 |
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
def fix_inf_values_hook(*args): |
|
|
|
|
|
if len(args) == 3: |
|
module, input, output = args |
|
|
|
elif len(args) == 4: |
|
module, input, kwinput, output = args |
|
|
|
if isinstance(module, FluxTransformerBlock): |
|
return _safe_clip(output[0]), _safe_clip(output[1]) |
|
|
|
elif isinstance(module, FluxSingleTransformerBlock): |
|
return _safe_clip(output) |
|
|
|
|
|
@torch.no_grad() |
|
def edit_streams_hook(*args, |
|
recompute_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
|
stream: Literal["text", "image", "both"]): |
|
""" |
|
recompute_fn will get as input the input tensor and the output tensor for such stream |
|
and returns what should be the new modified output |
|
""" |
|
|
|
|
|
if len(args) == 3: |
|
module, input, output = args |
|
|
|
elif len(args) == 4: |
|
module, input, kwinput, output = args |
|
else: |
|
raise AssertionError(f'Weird len(args):{len(args)}') |
|
|
|
if isinstance(module, FluxTransformerBlock): |
|
|
|
if stream == 'text': |
|
output_text = recompute_fn(kwinput["encoder_hidden_states"], output[0]) |
|
output_image = output[1] |
|
elif stream == 'image': |
|
output_image = recompute_fn(kwinput["hidden_states"], output[1]) |
|
output_text = output[0] |
|
else: |
|
raise AssertionError("Branch not supported for this layer.") |
|
|
|
return _safe_clip(output_text), _safe_clip(output_image) |
|
|
|
elif isinstance(module, FluxSingleTransformerBlock): |
|
|
|
if stream == 'text': |
|
output[:, :512] = recompute_fn(kwinput["hidden_states"][:, :512], output[:, :512]) |
|
elif stream == 'image': |
|
output[:, 512:] = recompute_fn(kwinput["hidden_states"][:, 512:], output[:, 512:]) |
|
else: |
|
output = recompute_fn(kwinput["hidden_states"], output) |
|
|
|
return _safe_clip(output) |
|
|