jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import is_dataclass
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from .config import LoRAConfig, QuantizationConfig
from .quantization import QuantizedParameter, QuantizedLinear
class OptimizedLinear(nn.Module):
"""
Optimized version of nn.Linear that adds features such as:
* LoRA w. base weight sharding
* FP [6,8,12] quantization
Arguments:
input_dim: Required: size of each input sample
output_dim: Required: size of each output sample
bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
quantization_config: Optional: QuantizationConfig defining quantization features
dtype: Optional: parameter dtype, only supports bfloat16 currently
Returns:
Returns a new nn.Module depending on the input config. Either native
torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
"""
def __new__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
dtype=torch.bfloat16,
linear_cls=nn.Linear):
if quantization_config is not None and not is_dataclass(quantization_config):
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
if lora_config is not None and not is_dataclass(lora_config):
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
if lora_config is None and quantization_config is None:
# Everything disabled, fall back to normal nn.Linear
self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device)
elif lora_config:
# lora enabled, quantization may or may not be
self = LoRAOptimizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
lora_config=lora_config,
quantization_config=quantization_config,
dtype=dtype,
device=device,
linear_cls=linear_cls)
elif quantization_config:
# only quantization enabled, no lora
self = QuantizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
quantization_config=quantization_config,
dtype=dtype)
return self
class LoRAOptimizedLinear(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
dtype=torch.bfloat16,
linear_cls=nn.Linear):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
self.device = get_accelerator().current_device_name() if device is None else device
self.linear_cls = linear_cls
self.dtype = dtype
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"
assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear"
self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
if self.zero_shards > 1:
assert self.zero_shards == dist.get_world_size(
), "base weight sharding is only supported across world size"
w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype),
requires_grad=False)
else:
w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False)
torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim))
if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
self.weight = QuantizedParameter(w, quantization_config=quantization_config)
else:
self.weight = w
self.disabled = False
self._initialized = False
if not self.lora_config.delay_lora_init:
self.init_lora()
def disable(self):
self.disabled = True
self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype),
requires_grad=False)
def init_lora(self):
if self.disabled:
return
if self.quantization_config is not None:
# ensure quant-param wasn't stripped, in some cases transformers will do this during model init
if not isinstance(self.weight, QuantizedParameter):
self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config)
self._initialized = True
self.weight.requires_grad = False
# Mark base weight to prevent broadcast and ensure proper offload behavior
self.weight.ds_optim_param = True
self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r
# Keeping lora weights in bf16 precision for ease of training.
self.lora_weight_1 = self.linear_cls(self.input_dim,
self.lora_config.lora_r,
bias=self.bias,
device=self.device,
dtype=self.dtype)
self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r,
self.output_dim,
bias=self.bias,
device=self.device,
dtype=self.dtype)
# initialize "A" with kaiming uniform and "B" with zeros following this
# https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155
nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_weight_2.weight)
self.lora_weight_1.weight.requires_grad = True
self.lora_weight_2.weight.requires_grad = True
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
if not any([target in prefix for target in self.lora_config.target_mods]):
# module does not match any target_mods, we must revert to normal nn.Linear via disable
self.disable()
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
if self.zero_shards > 1:
if not dist.is_initialized():
raise RuntimeError(
"attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first."
)
rank = dist.get_rank()
shape_local = self.output_dim * self.sharded_weight_size
base_weight_name = f"{prefix}weight"
incoming_param = state_dict[base_weight_name]
state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local)
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
def full_weight(self):
base_weight = self.weight
if getattr(base_weight, 'ds_offload', False):
# move to gpu so we can dequant and all-gather
assert base_weight.device == torch.device('cpu'), \
f"expected base weight on cpu but found {base_weight.device}"
base_weight.offload(revert=True)
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
base_weight.offload()
else:
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
tensor_out = torch.empty(self.output_dim * self.input_dim,
dtype=local_weight.dtype,
device=local_weight.device)
dist.all_gather_into_tensor(tensor_out, local_weight)
return tensor_out.reshape(self.output_dim, self.input_dim)
def linear_without_F_linear(self, input, weight):
output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
output = output.view(*input.shape[:-1], weight.shape[1])
return output
def forward(self, input_tensor):
if self.disabled:
return F.linear(input_tensor, self.weight)
assert self._initialized, "init_lora was never called, please initialize before proceeding"
# Gather the sharded base weight
if self.zero_shards > 1:
with torch.no_grad():
base_weight = self.full_weight()
elif self.quantization_config:
base_weight = self.weight.dequantized()
else:
base_weight = self.weight
base_weight_output = F.linear(input_tensor, base_weight)
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
return base_weight_output + self.lora_scaling_factor * lora_output