File size: 10,474 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import is_dataclass
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist

from .config import LoRAConfig, QuantizationConfig
from .quantization import QuantizedParameter, QuantizedLinear


class OptimizedLinear(nn.Module):
    """
    Optimized version of nn.Linear that adds features such as:
      * LoRA w. base weight sharding
      * FP [6,8,12] quantization

    Arguments:
        input_dim: Required: size of each input sample
        output_dim: Required: size of each output sample
        bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
        lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
        quantization_config: Optional: QuantizationConfig defining quantization features
        dtype: Optional: parameter dtype, only supports bfloat16 currently

    Returns:
        Returns a new nn.Module depending on the input config. Either native
        torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
    """

    def __new__(self,
                input_dim: int,
                output_dim: int,
                bias: bool = False,
                lora_config: LoRAConfig = None,
                quantization_config: QuantizationConfig = None,
                device=None,
                dtype=torch.bfloat16,
                linear_cls=nn.Linear):

        if quantization_config is not None and not is_dataclass(quantization_config):
            raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
        if lora_config is not None and not is_dataclass(lora_config):
            raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
        if lora_config is None and quantization_config is None:
            # Everything disabled, fall back to normal nn.Linear
            self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device)

        elif lora_config:
            # lora enabled, quantization may or may not be
            self = LoRAOptimizedLinear(input_dim=input_dim,
                                       output_dim=output_dim,
                                       bias=bias,
                                       lora_config=lora_config,
                                       quantization_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       linear_cls=linear_cls)

        elif quantization_config:
            # only quantization enabled, no lora
            self = QuantizedLinear(input_dim=input_dim,
                                   output_dim=output_dim,
                                   bias=bias,
                                   quantization_config=quantization_config,
                                   dtype=dtype)
        return self


class LoRAOptimizedLinear(nn.Module):

    def __init__(self,
                 input_dim: int,
                 output_dim: int,
                 bias: bool = False,
                 lora_config: LoRAConfig = None,
                 quantization_config: QuantizationConfig = None,
                 device=None,
                 dtype=torch.bfloat16,
                 linear_cls=nn.Linear):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.bias = bias
        self.lora_config = lora_config
        self.quantization_config = quantization_config
        self.device = get_accelerator().current_device_name() if device is None else device
        self.linear_cls = linear_cls
        self.dtype = dtype
        assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"
        assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear"
        self.zero_shards = self.lora_config.base_weight_sharding
        self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
        if self.zero_shards > 1:
            assert self.zero_shards == dist.get_world_size(
            ), "base weight sharding is only supported across world size"
            w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype),
                                   requires_grad=False)
        else:
            w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False)
        torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim))

        if self.quantization_config is not None:
            assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
            self.weight = QuantizedParameter(w, quantization_config=quantization_config)
        else:
            self.weight = w

        self.disabled = False
        self._initialized = False
        if not self.lora_config.delay_lora_init:
            self.init_lora()

    def disable(self):
        self.disabled = True
        self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype),
                                         requires_grad=False)

    def init_lora(self):
        if self.disabled:
            return

        if self.quantization_config is not None:
            # ensure quant-param wasn't stripped, in some cases transformers will do this during model init
            if not isinstance(self.weight, QuantizedParameter):
                self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config)

        self._initialized = True
        self.weight.requires_grad = False

        # Mark base weight to prevent broadcast and ensure proper offload behavior
        self.weight.ds_optim_param = True

        self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r

        # Keeping lora weights in bf16 precision for ease of training.
        self.lora_weight_1 = self.linear_cls(self.input_dim,
                                             self.lora_config.lora_r,
                                             bias=self.bias,
                                             device=self.device,
                                             dtype=self.dtype)
        self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r,
                                             self.output_dim,
                                             bias=self.bias,
                                             device=self.device,
                                             dtype=self.dtype)

        # initialize "A" with kaiming uniform and "B" with zeros following this
        # https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155
        nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_weight_2.weight)
        self.lora_weight_1.weight.requires_grad = True
        self.lora_weight_2.weight.requires_grad = True

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        if not any([target in prefix for target in self.lora_config.target_mods]):
            # module does not match any target_mods, we must revert to normal nn.Linear via disable
            self.disable()
            return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
                                                 unexpected_keys, error_msgs)

        if self.zero_shards > 1:
            if not dist.is_initialized():
                raise RuntimeError(
                    "attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first."
                )
            rank = dist.get_rank()
            shape_local = self.output_dim * self.sharded_weight_size
            base_weight_name = f"{prefix}weight"
            incoming_param = state_dict[base_weight_name]
            state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local)

        return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                                             error_msgs)

    def full_weight(self):
        base_weight = self.weight
        if getattr(base_weight, 'ds_offload', False):
            # move to gpu so we can dequant and all-gather
            assert base_weight.device == torch.device('cpu'), \
                f"expected base weight on cpu but found {base_weight.device}"
            base_weight.offload(revert=True)
            local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
            base_weight.offload()
        else:
            local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight

        tensor_out = torch.empty(self.output_dim * self.input_dim,
                                 dtype=local_weight.dtype,
                                 device=local_weight.device)
        dist.all_gather_into_tensor(tensor_out, local_weight)
        return tensor_out.reshape(self.output_dim, self.input_dim)

    def linear_without_F_linear(self, input, weight):
        output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
        output = output.view(*input.shape[:-1], weight.shape[1])
        return output

    def forward(self, input_tensor):
        if self.disabled:
            return F.linear(input_tensor, self.weight)
        assert self._initialized, "init_lora was never called, please initialize before proceeding"

        # Gather the sharded base weight
        if self.zero_shards > 1:
            with torch.no_grad():
                base_weight = self.full_weight()
        elif self.quantization_config:
            base_weight = self.weight.dequantized()
        else:
            base_weight = self.weight

        base_weight_output = F.linear(input_tensor, base_weight)
        lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
        return base_weight_output + self.lora_scaling_factor * lora_output