surokpro2's picture
Upload 47 files
215c4b7 verified
from typing import Callable, Literal
import torch
import torch.nn as nn
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
def register_general_hook(pipe, position, hook, with_kwargs=False, is_pre_hook=False):
"""Registers a forward hook in a module of the pipeline specified with 'position'
Args:
pipe (_type_): _description_
position (_type_): _description_
hook (_type_): _description_
with_kwargs (bool, optional): _description_. Defaults to False.
is_pre_hook (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
block: nn.Module = locate_block(pipe, position)
if is_pre_hook:
return block.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
else:
return block.register_forward_hook(hook, with_kwargs=with_kwargs)
def locate_block(pipe, position: str) -> nn.Module:
'''
Locate the block at the specified position in the pipeline.
'''
block = pipe
for step in position.split('.'):
if step.isdigit():
step = int(step)
block = block[step]
else:
block = getattr(block, step)
return block
def _safe_clip(x: torch.Tensor):
if x.dtype == torch.float16:
x[torch.isposinf(x)] = 65504
x[torch.isneginf(x)] = -65504
return x
@torch.no_grad()
def fix_inf_values_hook(*args):
# Case 1: no kwards are passed to the module
if len(args) == 3:
module, input, output = args
# Case 2: when kwargs are passed to the model as input
elif len(args) == 4:
module, input, kwinput, output = args
if isinstance(module, FluxTransformerBlock):
return _safe_clip(output[0]), _safe_clip(output[1])
elif isinstance(module, FluxSingleTransformerBlock):
return _safe_clip(output)
@torch.no_grad()
def edit_streams_hook(*args,
recompute_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
stream: Literal["text", "image", "both"]):
"""
recompute_fn will get as input the input tensor and the output tensor for such stream
and returns what should be the new modified output
"""
# Case 1: no kwards are passed to the module
if len(args) == 3:
module, input, output = args
# Case 2: when kwargs are passed to the model as input
elif len(args) == 4:
module, input, kwinput, output = args
else:
raise AssertionError(f'Weird len(args):{len(args)}')
if isinstance(module, FluxTransformerBlock):
if stream == 'text':
output_text = recompute_fn(kwinput["encoder_hidden_states"], output[0])
output_image = output[1]
elif stream == 'image':
output_image = recompute_fn(kwinput["hidden_states"], output[1])
output_text = output[0]
else:
raise AssertionError("Branch not supported for this layer.")
return _safe_clip(output_text), _safe_clip(output_image)
elif isinstance(module, FluxSingleTransformerBlock):
if stream == 'text':
output[:, :512] = recompute_fn(kwinput["hidden_states"][:, :512], output[:, :512])
elif stream == 'image':
output[:, 512:] = recompute_fn(kwinput["hidden_states"][:, 512:], output[:, 512:])
else:
output = recompute_fn(kwinput["hidden_states"], output)
return _safe_clip(output)