|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import math |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from dataclasses import is_dataclass |
|
from deepspeed.accelerator import get_accelerator |
|
import deepspeed.comm as dist |
|
|
|
from .config import LoRAConfig, QuantizationConfig |
|
from .quantization import QuantizedParameter, QuantizedLinear |
|
|
|
|
|
class OptimizedLinear(nn.Module): |
|
""" |
|
Optimized version of nn.Linear that adds features such as: |
|
* LoRA w. base weight sharding |
|
* FP [6,8,12] quantization |
|
|
|
Arguments: |
|
input_dim: Required: size of each input sample |
|
output_dim: Required: size of each output sample |
|
bias: Optional: If set to False, the layer will not learn an additive bias. Default: False |
|
lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree |
|
quantization_config: Optional: QuantizationConfig defining quantization features |
|
dtype: Optional: parameter dtype, only supports bfloat16 currently |
|
|
|
Returns: |
|
Returns a new nn.Module depending on the input config. Either native |
|
torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. |
|
""" |
|
|
|
def __new__(self, |
|
input_dim: int, |
|
output_dim: int, |
|
bias: bool = False, |
|
lora_config: LoRAConfig = None, |
|
quantization_config: QuantizationConfig = None, |
|
device=None, |
|
dtype=torch.bfloat16, |
|
linear_cls=nn.Linear): |
|
|
|
if quantization_config is not None and not is_dataclass(quantization_config): |
|
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") |
|
if lora_config is not None and not is_dataclass(lora_config): |
|
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") |
|
if lora_config is None and quantization_config is None: |
|
|
|
self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device) |
|
|
|
elif lora_config: |
|
|
|
self = LoRAOptimizedLinear(input_dim=input_dim, |
|
output_dim=output_dim, |
|
bias=bias, |
|
lora_config=lora_config, |
|
quantization_config=quantization_config, |
|
dtype=dtype, |
|
device=device, |
|
linear_cls=linear_cls) |
|
|
|
elif quantization_config: |
|
|
|
self = QuantizedLinear(input_dim=input_dim, |
|
output_dim=output_dim, |
|
bias=bias, |
|
quantization_config=quantization_config, |
|
dtype=dtype) |
|
return self |
|
|
|
|
|
class LoRAOptimizedLinear(nn.Module): |
|
|
|
def __init__(self, |
|
input_dim: int, |
|
output_dim: int, |
|
bias: bool = False, |
|
lora_config: LoRAConfig = None, |
|
quantization_config: QuantizationConfig = None, |
|
device=None, |
|
dtype=torch.bfloat16, |
|
linear_cls=nn.Linear): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
self.bias = bias |
|
self.lora_config = lora_config |
|
self.quantization_config = quantization_config |
|
self.device = get_accelerator().current_device_name() if device is None else device |
|
self.linear_cls = linear_cls |
|
self.dtype = dtype |
|
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" |
|
assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear" |
|
self.zero_shards = self.lora_config.base_weight_sharding |
|
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) |
|
if self.zero_shards > 1: |
|
assert self.zero_shards == dist.get_world_size( |
|
), "base weight sharding is only supported across world size" |
|
w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype), |
|
requires_grad=False) |
|
else: |
|
w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False) |
|
torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim)) |
|
|
|
if self.quantization_config is not None: |
|
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" |
|
self.weight = QuantizedParameter(w, quantization_config=quantization_config) |
|
else: |
|
self.weight = w |
|
|
|
self.disabled = False |
|
self._initialized = False |
|
if not self.lora_config.delay_lora_init: |
|
self.init_lora() |
|
|
|
def disable(self): |
|
self.disabled = True |
|
self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype), |
|
requires_grad=False) |
|
|
|
def init_lora(self): |
|
if self.disabled: |
|
return |
|
|
|
if self.quantization_config is not None: |
|
|
|
if not isinstance(self.weight, QuantizedParameter): |
|
self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config) |
|
|
|
self._initialized = True |
|
self.weight.requires_grad = False |
|
|
|
|
|
self.weight.ds_optim_param = True |
|
|
|
self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r |
|
|
|
|
|
self.lora_weight_1 = self.linear_cls(self.input_dim, |
|
self.lora_config.lora_r, |
|
bias=self.bias, |
|
device=self.device, |
|
dtype=self.dtype) |
|
self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r, |
|
self.output_dim, |
|
bias=self.bias, |
|
device=self.device, |
|
dtype=self.dtype) |
|
|
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_weight_2.weight) |
|
self.lora_weight_1.weight.requires_grad = True |
|
self.lora_weight_2.weight.requires_grad = True |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, |
|
error_msgs): |
|
if not any([target in prefix for target in self.lora_config.target_mods]): |
|
|
|
self.disable() |
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, |
|
unexpected_keys, error_msgs) |
|
|
|
if self.zero_shards > 1: |
|
if not dist.is_initialized(): |
|
raise RuntimeError( |
|
"attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first." |
|
) |
|
rank = dist.get_rank() |
|
shape_local = self.output_dim * self.sharded_weight_size |
|
base_weight_name = f"{prefix}weight" |
|
incoming_param = state_dict[base_weight_name] |
|
state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local) |
|
|
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, |
|
error_msgs) |
|
|
|
def full_weight(self): |
|
base_weight = self.weight |
|
if getattr(base_weight, 'ds_offload', False): |
|
|
|
assert base_weight.device == torch.device('cpu'), \ |
|
f"expected base weight on cpu but found {base_weight.device}" |
|
base_weight.offload(revert=True) |
|
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight |
|
base_weight.offload() |
|
else: |
|
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight |
|
|
|
tensor_out = torch.empty(self.output_dim * self.input_dim, |
|
dtype=local_weight.dtype, |
|
device=local_weight.device) |
|
dist.all_gather_into_tensor(tensor_out, local_weight) |
|
return tensor_out.reshape(self.output_dim, self.input_dim) |
|
|
|
def linear_without_F_linear(self, input, weight): |
|
output = torch.mm(input.reshape(-1, input.shape[-1]), weight) |
|
output = output.view(*input.shape[:-1], weight.shape[1]) |
|
return output |
|
|
|
def forward(self, input_tensor): |
|
if self.disabled: |
|
return F.linear(input_tensor, self.weight) |
|
assert self._initialized, "init_lora was never called, please initialize before proceeding" |
|
|
|
|
|
if self.zero_shards > 1: |
|
with torch.no_grad(): |
|
base_weight = self.full_weight() |
|
elif self.quantization_config: |
|
base_weight = self.weight.dequantized() |
|
else: |
|
base_weight = self.weight |
|
|
|
base_weight_output = F.linear(input_tensor, base_weight) |
|
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) |
|
return base_weight_output + self.lora_scaling_factor * lora_output |
|
|