File size: 10,474 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import is_dataclass
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from .config import LoRAConfig, QuantizationConfig
from .quantization import QuantizedParameter, QuantizedLinear
class OptimizedLinear(nn.Module):
"""
Optimized version of nn.Linear that adds features such as:
* LoRA w. base weight sharding
* FP [6,8,12] quantization
Arguments:
input_dim: Required: size of each input sample
output_dim: Required: size of each output sample
bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
quantization_config: Optional: QuantizationConfig defining quantization features
dtype: Optional: parameter dtype, only supports bfloat16 currently
Returns:
Returns a new nn.Module depending on the input config. Either native
torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
"""
def __new__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
dtype=torch.bfloat16,
linear_cls=nn.Linear):
if quantization_config is not None and not is_dataclass(quantization_config):
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
if lora_config is not None and not is_dataclass(lora_config):
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
if lora_config is None and quantization_config is None:
# Everything disabled, fall back to normal nn.Linear
self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device)
elif lora_config:
# lora enabled, quantization may or may not be
self = LoRAOptimizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
lora_config=lora_config,
quantization_config=quantization_config,
dtype=dtype,
device=device,
linear_cls=linear_cls)
elif quantization_config:
# only quantization enabled, no lora
self = QuantizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
quantization_config=quantization_config,
dtype=dtype)
return self
class LoRAOptimizedLinear(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
dtype=torch.bfloat16,
linear_cls=nn.Linear):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
self.device = get_accelerator().current_device_name() if device is None else device
self.linear_cls = linear_cls
self.dtype = dtype
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"
assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear"
self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
if self.zero_shards > 1:
assert self.zero_shards == dist.get_world_size(
), "base weight sharding is only supported across world size"
w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype),
requires_grad=False)
else:
w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False)
torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim))
if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
self.weight = QuantizedParameter(w, quantization_config=quantization_config)
else:
self.weight = w
self.disabled = False
self._initialized = False
if not self.lora_config.delay_lora_init:
self.init_lora()
def disable(self):
self.disabled = True
self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype),
requires_grad=False)
def init_lora(self):
if self.disabled:
return
if self.quantization_config is not None:
# ensure quant-param wasn't stripped, in some cases transformers will do this during model init
if not isinstance(self.weight, QuantizedParameter):
self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config)
self._initialized = True
self.weight.requires_grad = False
# Mark base weight to prevent broadcast and ensure proper offload behavior
self.weight.ds_optim_param = True
self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r
# Keeping lora weights in bf16 precision for ease of training.
self.lora_weight_1 = self.linear_cls(self.input_dim,
self.lora_config.lora_r,
bias=self.bias,
device=self.device,
dtype=self.dtype)
self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r,
self.output_dim,
bias=self.bias,
device=self.device,
dtype=self.dtype)
# initialize "A" with kaiming uniform and "B" with zeros following this
# https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155
nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_weight_2.weight)
self.lora_weight_1.weight.requires_grad = True
self.lora_weight_2.weight.requires_grad = True
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
if not any([target in prefix for target in self.lora_config.target_mods]):
# module does not match any target_mods, we must revert to normal nn.Linear via disable
self.disable()
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
if self.zero_shards > 1:
if not dist.is_initialized():
raise RuntimeError(
"attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first."
)
rank = dist.get_rank()
shape_local = self.output_dim * self.sharded_weight_size
base_weight_name = f"{prefix}weight"
incoming_param = state_dict[base_weight_name]
state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local)
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
def full_weight(self):
base_weight = self.weight
if getattr(base_weight, 'ds_offload', False):
# move to gpu so we can dequant and all-gather
assert base_weight.device == torch.device('cpu'), \
f"expected base weight on cpu but found {base_weight.device}"
base_weight.offload(revert=True)
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
base_weight.offload()
else:
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
tensor_out = torch.empty(self.output_dim * self.input_dim,
dtype=local_weight.dtype,
device=local_weight.device)
dist.all_gather_into_tensor(tensor_out, local_weight)
return tensor_out.reshape(self.output_dim, self.input_dim)
def linear_without_F_linear(self, input, weight):
output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
output = output.view(*input.shape[:-1], weight.shape[1])
return output
def forward(self, input_tensor):
if self.disabled:
return F.linear(input_tensor, self.weight)
assert self._initialized, "init_lora was never called, please initialize before proceeding"
# Gather the sharded base weight
if self.zero_shards > 1:
with torch.no_grad():
base_weight = self.full_weight()
elif self.quantization_config:
base_weight = self.weight.dequantized()
else:
base_weight = self.weight
base_weight_output = F.linear(input_tensor, base_weight)
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
return base_weight_output + self.lora_scaling_factor * lora_output
|