File size: 3,555 Bytes
215c4b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from typing import Callable, Literal
import torch
import torch.nn as nn
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
def register_general_hook(pipe, position, hook, with_kwargs=False, is_pre_hook=False):
"""Registers a forward hook in a module of the pipeline specified with 'position'
Args:
pipe (_type_): _description_
position (_type_): _description_
hook (_type_): _description_
with_kwargs (bool, optional): _description_. Defaults to False.
is_pre_hook (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
block: nn.Module = locate_block(pipe, position)
if is_pre_hook:
return block.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
else:
return block.register_forward_hook(hook, with_kwargs=with_kwargs)
def locate_block(pipe, position: str) -> nn.Module:
'''
Locate the block at the specified position in the pipeline.
'''
block = pipe
for step in position.split('.'):
if step.isdigit():
step = int(step)
block = block[step]
else:
block = getattr(block, step)
return block
def _safe_clip(x: torch.Tensor):
if x.dtype == torch.float16:
x[torch.isposinf(x)] = 65504
x[torch.isneginf(x)] = -65504
return x
@torch.no_grad()
def fix_inf_values_hook(*args):
# Case 1: no kwards are passed to the module
if len(args) == 3:
module, input, output = args
# Case 2: when kwargs are passed to the model as input
elif len(args) == 4:
module, input, kwinput, output = args
if isinstance(module, FluxTransformerBlock):
return _safe_clip(output[0]), _safe_clip(output[1])
elif isinstance(module, FluxSingleTransformerBlock):
return _safe_clip(output)
@torch.no_grad()
def edit_streams_hook(*args,
recompute_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
stream: Literal["text", "image", "both"]):
"""
recompute_fn will get as input the input tensor and the output tensor for such stream
and returns what should be the new modified output
"""
# Case 1: no kwards are passed to the module
if len(args) == 3:
module, input, output = args
# Case 2: when kwargs are passed to the model as input
elif len(args) == 4:
module, input, kwinput, output = args
else:
raise AssertionError(f'Weird len(args):{len(args)}')
if isinstance(module, FluxTransformerBlock):
if stream == 'text':
output_text = recompute_fn(kwinput["encoder_hidden_states"], output[0])
output_image = output[1]
elif stream == 'image':
output_image = recompute_fn(kwinput["hidden_states"], output[1])
output_text = output[0]
else:
raise AssertionError("Branch not supported for this layer.")
return _safe_clip(output_text), _safe_clip(output_image)
elif isinstance(module, FluxSingleTransformerBlock):
if stream == 'text':
output[:, :512] = recompute_fn(kwinput["hidden_states"][:, :512], output[:, :512])
elif stream == 'image':
output[:, 512:] = recompute_fn(kwinput["hidden_states"][:, 512:], output[:, 512:])
else:
output = recompute_fn(kwinput["hidden_states"], output)
return _safe_clip(output)
|