File size: 4,777 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Needed utilities for torchao FP8 training.
"""

from functools import partial
from typing import TYPE_CHECKING, Callable, Optional

import torch

from .imports import is_torchao_available, torchao_required


if TYPE_CHECKING:
    if is_torchao_available():
        from torchao.float8.float8_linear import Float8LinearConfig


def find_first_last_linear_layers(model: torch.nn.Module):
    """
    Finds the first and last linear layer names in a model.

    This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.

    Ref: https://x.com/xariusrke/status/1826669142604141052
    """
    first_linear, last_linear = None, None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if first_linear is None:
                first_linear = name
            last_linear = name
    return first_linear, last_linear


def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:
    """
    A function which will check if `module` is:
    - a `torch.nn.Linear` layer
    - has in_features and out_features divisible by 16
    - is not part of `layers_to_filter`

    Args:
        module (`torch.nn.Module`):
            The module to check.
        fqn (`str`):
            The fully qualified name of the layer.
        layers_to_filter (`List[str]`):
            The list of layers to filter.
    """
    if isinstance(module, torch.nn.Linear):
        if module.in_features % 16 != 0 or module.out_features % 16 != 0:
            return False
    if fqn in layers_to_filter:
        return False
    return True


def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
    """
    A filter function which will filter out all linear layers except the first and last.

    <Tip>

        For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
        converging properly

    </Tip>

    Args:
        module (`torch.nn.Module`):
            The module to check.
        fqn (`str`):
            The fully qualified name of the layer.
    """
    first_linear, last_linear = find_first_last_linear_layers(module)
    return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])


@torchao_required
def has_ao_layers(model: torch.nn.Module):
    from torchao.float8.float8_linear import Float8Linear

    for name, module in model.named_modules():
        if isinstance(module, Float8Linear):
            return True
    return False


@torchao_required
def convert_model_to_fp8_ao(
    model: torch.nn.Module,
    config: Optional["Float8LinearConfig"] = None,
    module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
):
    """
    Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.

    Args:
        model (`torch.nn.Module`):
            The model to convert.
        config (`torchao.float8.Float8LinearConfig`, *optional*):
            The configuration for the FP8 training. Recommended to utilize
            `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
            sufficient (what is passed when set to `None`).
        module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
            Optional function that must take in a module and layer name, and returns a boolean indicating whether the
            module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.

    Example:

    ```python
    from accelerate.utils.ao import convert_model_to_fp8_ao

    model = MyModel()
    model.to("cuda")
    convert_to_float8_training(model)

    model.train()
    ```
    """
    from torchao.float8 import convert_to_float8_training

    first_linear, last_linear = find_first_last_linear_layers(model)
    if module_filter_func is None:
        module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])
    convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)