File size: 3,555 Bytes
215c4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from typing import Callable, Literal
import torch
import torch.nn as nn    
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock


def register_general_hook(pipe, position, hook, with_kwargs=False, is_pre_hook=False):
    """Registers a forward hook in a module of the pipeline specified with 'position'

    Args:
        pipe (_type_): _description_
        position (_type_): _description_
        hook (_type_): _description_
        with_kwargs (bool, optional): _description_. Defaults to False.
        is_pre_hook (bool, optional): _description_. Defaults to False.

    Returns:
        _type_: _description_
    """

    block: nn.Module = locate_block(pipe, position)

    if is_pre_hook:
        return block.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
    else:
        return block.register_forward_hook(hook, with_kwargs=with_kwargs)


def locate_block(pipe, position: str) -> nn.Module:
    '''
    Locate the block at the specified position in the pipeline.
    '''
    block = pipe
    for step in position.split('.'):
        if step.isdigit():
            step = int(step)
            block = block[step]
        else:
            block = getattr(block, step)
    return block


def _safe_clip(x: torch.Tensor):
    if x.dtype == torch.float16:
        x[torch.isposinf(x)] = 65504
        x[torch.isneginf(x)] = -65504
    return x
    

@torch.no_grad()
def fix_inf_values_hook(*args):

    # Case 1: no kwards are passed to the module
    if len(args) == 3:
        module, input, output = args
    # Case 2: when kwargs are passed to the model as input
    elif len(args) == 4:
        module, input, kwinput, output = args

    if isinstance(module, FluxTransformerBlock):
        return _safe_clip(output[0]), _safe_clip(output[1])

    elif isinstance(module, FluxSingleTransformerBlock):
        return _safe_clip(output)
    

@torch.no_grad()
def edit_streams_hook(*args, 
                      recompute_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 
                      stream: Literal["text", "image", "both"]):
    """ 
        recompute_fn will get as input the input tensor and the output tensor for such stream
        and returns what should be the new modified output
    """

    # Case 1: no kwards are passed to the module
    if len(args) == 3:
        module, input, output = args
    # Case 2: when kwargs are passed to the model as input
    elif len(args) == 4:
        module, input, kwinput, output = args
    else: 
        raise AssertionError(f'Weird len(args):{len(args)}')

    if isinstance(module, FluxTransformerBlock):

        if stream == 'text':
            output_text = recompute_fn(kwinput["encoder_hidden_states"], output[0])
            output_image = output[1]
        elif stream == 'image':
            output_image = recompute_fn(kwinput["hidden_states"], output[1])
            output_text = output[0]
        else:
            raise AssertionError("Branch not supported for this layer.")

        return _safe_clip(output_text), _safe_clip(output_image)

    elif isinstance(module, FluxSingleTransformerBlock):
        
        if stream == 'text':
            output[:, :512] = recompute_fn(kwinput["hidden_states"][:, :512], output[:, :512])
        elif stream == 'image':
            output[:, 512:] = recompute_fn(kwinput["hidden_states"][:, 512:], output[:, 512:])
        else:
            output = recompute_fn(kwinput["hidden_states"], output)
        
        return _safe_clip(output)