|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import math |
|
from torch import nn |
|
from torch.nn import init |
|
import deepspeed.comm as dist |
|
from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer |
|
from deepspeed.utils import logger |
|
|
|
g_mpu = None |
|
|
|
|
|
class QuantAct(nn.Module): |
|
""" |
|
Class to quantize given activations. Note that when using this function, the input activation quantization range will be fixed for all |
|
tokens/images for inference. This generally will affect some accuracy but achieve better latency performance. |
|
Parameters: |
|
---------- |
|
act_range_momentum : float, default 0.95 |
|
Momentum for updating the activation quantization range. |
|
quant_mode : str, default 'symmetric' |
|
""" |
|
|
|
def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'): |
|
super(QuantAct, self).__init__() |
|
|
|
self.act_range_momentum = act_range_momentum |
|
self.quant_mode = quant_mode |
|
if quant_mode == 'symmetric': |
|
self.act_function = SymQuantizer.apply |
|
else: |
|
self.act_function = AsymQuantizer.apply |
|
|
|
self.register_buffer('x_min_max', torch.zeros(2)) |
|
|
|
def forward(self, x, num_bits, *args): |
|
""" |
|
x: the activation that we need to quantize |
|
num_bits: the number of bits we need to quantize the activation to |
|
*args: some extra arguments that are useless but needed for align with the interface of other quantization functions |
|
""" |
|
|
|
if self.training: |
|
x_min = x.data.min() |
|
x_max = x.data.max() |
|
|
|
|
|
if self.x_min_max[0] == self.x_min_max[1]: |
|
self.x_min_max[0] = x_min |
|
self.x_min_max[1] = x_max |
|
|
|
|
|
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum) |
|
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum) |
|
|
|
x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1]) |
|
|
|
return x_q |
|
|
|
|
|
class Embedding_Compress(nn.Embedding): |
|
|
|
def __init__(self, *kargs): |
|
super(Embedding_Compress, self).__init__(*kargs) |
|
self.weight.start_bits = None |
|
self.weight.target_bits = None |
|
self.weight.q_period = None |
|
self.weight_quantization_enabled_in_forward = False |
|
self.weight_quantization_enabled = False |
|
|
|
def extra_repr(self): |
|
return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format( |
|
self.num_embeddings, self.embedding_dim, self.weight.target_bits) |
|
|
|
def enable_weight_quantization(self, start_bits, target_bits, quantization_period, |
|
weight_quantization_enabled_in_forward, quantization_type, num_groups): |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = quantization_period |
|
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward |
|
if self.weight_quantization_enabled_in_forward: |
|
logger.warning( |
|
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************" |
|
) |
|
if self.weight.target_bits >= 3: |
|
if quantization_type == 'symmetric': |
|
self.weight_quantizer = SymQuantizer.apply |
|
else: |
|
self.weight_quantizer = AsymQuantizer.apply |
|
elif self.weight.target_bits == 2: |
|
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization' |
|
self.weight_quantizer = TernaryQuantizer.apply |
|
elif self.weight.target_bits == 1: |
|
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization' |
|
self.weight_quantizer = BinaryQuantizer.apply |
|
|
|
self.weight_quantize_num_groups = self.weight.size(0) |
|
|
|
def fix_weight_quantization(self): |
|
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None, |
|
self.weight_quantize_num_groups).data |
|
self.weight_quantization_enabled_in_forward = False |
|
return None |
|
|
|
def forward(self, input): |
|
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled: |
|
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None, |
|
self.weight_quantize_num_groups) |
|
else: |
|
weight = self.weight |
|
|
|
out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, |
|
self.scale_grad_by_freq, self.sparse) |
|
return out |
|
|
|
|
|
class LinearLayer_Compress(nn.Linear): |
|
""" |
|
Linear layer with compression. |
|
""" |
|
|
|
def __init__(self, *kargs, bias=True): |
|
super(LinearLayer_Compress, self).__init__(*kargs, bias=bias) |
|
self.sparse_pruning_method = None |
|
self.row_pruning_method = None |
|
self.head_pruning_method = None |
|
self.activation_quantization_method = None |
|
self.weight.start_bits = None |
|
self.weight.target_bits = None |
|
self.weight.q_period = None |
|
self.weight_quantization_enabled_in_forward = False |
|
self.weight_quantization_enabled = False |
|
self.sparse_pruning_enabled = False |
|
self.row_pruning_enabled = False |
|
self.head_pruning_enabled = False |
|
self.activation_quantization_enabled = False |
|
|
|
def extra_repr(self): |
|
return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format( |
|
self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \ |
|
self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits) |
|
|
|
def enable_sparse_pruning(self, ratio, method): |
|
|
|
self.sparse_pruning_ratio = ratio |
|
self.sparse_pruning_method = method |
|
if method == 'l1': |
|
weight_norm = torch.abs(self.weight.data) |
|
mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False) |
|
mask = mask.view(self.weight.size()) |
|
mask = mask.to(self.weight.device) |
|
elif method == 'topk': |
|
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size())) |
|
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device) |
|
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5)) |
|
mask = None |
|
else: |
|
raise NotImplementedError |
|
|
|
self.register_buffer('sparse_pruning_mask', mask) |
|
|
|
def enable_row_pruning(self, ratio, method): |
|
|
|
self.row_pruning_ratio = ratio |
|
self.row_pruning_method = method |
|
|
|
if method == 'l1': |
|
|
|
weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=1) |
|
mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False) |
|
mask = mask.view(-1, 1) |
|
mask = mask.to(self.weight.device) |
|
elif method == 'topk': |
|
self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1)) |
|
self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device) |
|
init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5)) |
|
mask = None |
|
else: |
|
raise NotImplementedError |
|
|
|
self.register_buffer('row_pruning_mask', mask) |
|
|
|
def enable_head_pruning(self, ratio, method, num_heads): |
|
|
|
self.num_heads = num_heads |
|
self.head_pruning_ratio = ratio |
|
self.head_pruning_method = method |
|
|
|
if method not in ['topk']: |
|
raise NotImplementedError |
|
else: |
|
self.head_pruning_ratio = ratio |
|
self.head_pruning_scores = nn.Parameter(torch.Tensor(1, |
|
self.num_heads)) |
|
self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device) |
|
init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5)) |
|
|
|
def fix_sparse_pruning_helper(self): |
|
mask = self.get_mask(pruning_type='sparse') |
|
self.weight.data = self.weight.data * mask |
|
del self.sparse_pruning_mask |
|
if self.sparse_pruning_method == 'topk': |
|
del self.sparse_mask_scores |
|
self.sparse_pruning_method = None |
|
self.sparse_pruning_enabled = False |
|
return None |
|
|
|
def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False): |
|
|
|
|
|
|
|
|
|
|
|
if mask is None: |
|
mask = self.get_mask(pruning_type='row').bool() |
|
if dim_reduction: |
|
start_bits = self.weight.start_bits |
|
target_bits = self.weight.target_bits |
|
q_period = self.weight.q_period |
|
self.weight = nn.Parameter(self.weight.data[mask.view(-1), :]) |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = q_period |
|
if self.bias is not None: |
|
self.bias = nn.Parameter(self.bias.data[mask.view(-1)]) |
|
self.out_features = self.weight.size(0) |
|
else: |
|
self.weight.data = self.weight.data * mask.view(-1, 1) |
|
if self.bias is not None: |
|
self.bias.data = self.bias.data * mask.view(-1) |
|
|
|
del self.row_pruning_mask |
|
if self.row_pruning_method == 'topk': |
|
del self.row_mask_scores |
|
self.row_pruning_method = None |
|
else: |
|
|
|
start_bits = self.weight.start_bits |
|
target_bits = self.weight.target_bits |
|
q_period = self.weight.q_period |
|
self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)]) |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = q_period |
|
self.in_features = self.weight.size(1) |
|
mask = None |
|
self.row_pruning_enabled = False |
|
return mask |
|
|
|
def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False): |
|
|
|
num_heads = num_heads if num_heads else self.num_heads |
|
if mask is None: |
|
if self.head_pruning_method == 'topk': |
|
mask = self.get_mask(pruning_type='head').bool() |
|
if dim_reduction: |
|
shape = self.weight.size(0) |
|
start_bits = self.weight.start_bits |
|
target_bits = self.weight.target_bits |
|
q_period = self.weight.q_period |
|
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads, |
|
-1)[mask.view(-1), :].reshape(-1, |
|
shape).t()) |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = q_period |
|
else: |
|
|
|
shape = self.weight.size() |
|
self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape( |
|
shape[1], shape[0]).t() |
|
|
|
if self.head_pruning_method == 'topk': |
|
del self.head_pruning_scores |
|
self.head_pruning_method = None |
|
else: |
|
raise NotImplementedError |
|
else: |
|
start_bits = self.weight.start_bits |
|
target_bits = self.weight.target_bits |
|
q_period = self.weight.q_period |
|
shape = self.weight.size(1) |
|
self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape)) |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = q_period |
|
if self.bias is not None: |
|
self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1)) |
|
self.head_pruning_enabled = False |
|
return mask |
|
|
|
def get_mask(self, pruning_type='row'): |
|
if pruning_type == 'sparse': |
|
if self.sparse_pruning_method == 'l1': |
|
return self.sparse_pruning_mask.to(self.weight.device) |
|
elif self.sparse_pruning_method == 'topk': |
|
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False) |
|
else: |
|
raise NotImplementedError |
|
if pruning_type == 'row': |
|
if self.row_pruning_method == 'l1': |
|
return self.row_pruning_mask.to(self.weight.device) |
|
elif self.row_pruning_method == 'topk': |
|
return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False) |
|
else: |
|
raise NotImplementedError |
|
elif pruning_type == 'head': |
|
if self.head_pruning_method == 'topk': |
|
return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False) |
|
else: |
|
raise NotImplementedError |
|
else: |
|
raise NotImplementedError |
|
|
|
def enable_weight_quantization(self, start_bits, target_bits, quantization_period, |
|
weight_quantization_enabled_in_forward, quantization_type, num_groups): |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = quantization_period |
|
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward |
|
if self.weight_quantization_enabled_in_forward: |
|
logger.warning( |
|
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************" |
|
) |
|
if self.weight.target_bits >= 3: |
|
if quantization_type == 'symmetric': |
|
self.weight_quantizer = SymQuantizer.apply |
|
else: |
|
self.weight_quantizer = AsymQuantizer.apply |
|
elif self.weight.target_bits == 2: |
|
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization' |
|
self.weight_quantizer = TernaryQuantizer.apply |
|
elif self.weight.target_bits == 1: |
|
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization' |
|
self.weight_quantizer = BinaryQuantizer.apply |
|
self.weight_quantize_num_groups = num_groups |
|
|
|
def fix_weight_quantization(self): |
|
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None, |
|
self.weight_quantize_num_groups).data |
|
self.weight_quantization_enabled_in_forward = False |
|
return None |
|
|
|
def enable_activation_quantization(self, bits, quantization_type, range_calibration): |
|
assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now' |
|
self.activation_quantization_bits = bits |
|
self.activation_quantization_method = f"{quantization_type}_{range_calibration}" |
|
if range_calibration == 'static': |
|
self.activation_quantizer = QuantAct(quant_mode=quantization_type) |
|
else: |
|
if quantization_type == 'symmetric': |
|
self.activation_quantizer = SymQuantizer.apply |
|
else: |
|
self.activation_quantizer = AsymQuantizer.apply |
|
|
|
def head_pruning_reshape(self, w, mask): |
|
shape = w.shape |
|
return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t() |
|
|
|
def forward(self, input, skip_bias_add=False): |
|
|
|
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled: |
|
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None, |
|
self.weight_quantize_num_groups) |
|
bias = self.bias |
|
else: |
|
weight = self.weight |
|
bias = self.bias |
|
|
|
if self.sparse_pruning_enabled and self.sparse_pruning_method: |
|
mask = self.get_mask(pruning_type='sparse') |
|
weight = weight * mask.view(self.weight.size()) |
|
|
|
if self.row_pruning_enabled and self.row_pruning_method: |
|
mask = self.get_mask(pruning_type='row') |
|
weight = weight * mask.view(-1, 1) |
|
if bias is not None: |
|
bias = bias * mask.view(-1) |
|
|
|
if self.head_pruning_enabled and self.head_pruning_method: |
|
mask = self.get_mask(pruning_type='head') |
|
weight = self.head_pruning_reshape(weight, mask) |
|
|
|
if self.activation_quantization_enabled: |
|
if 'dynamic' in self.activation_quantization_method: |
|
num_groups = input.numel() // input.size(-1) |
|
else: |
|
num_groups = 1 |
|
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups) |
|
|
|
if skip_bias_add: |
|
|
|
output = nn.functional.linear(input, weight, None) |
|
return output, bias |
|
else: |
|
output = nn.functional.linear(input, weight, bias) |
|
return output |
|
|
|
|
|
class Conv2dLayer_Compress(nn.Conv2d): |
|
""" |
|
Conv2D layer with compression. |
|
""" |
|
|
|
def __init__(self, *kargs): |
|
super(Conv2dLayer_Compress, self).__init__(*kargs) |
|
self.sparse_pruning_method = None |
|
self.channel_pruning_method = None |
|
self.activation_quantization_method = None |
|
self.weight.start_bits = None |
|
self.weight.target_bits = None |
|
self.weight.q_period = None |
|
self.weight_quantization_enabled_in_forward = False |
|
self.sparse_pruning_enabled = False |
|
self.channel_pruning_enabled = False |
|
self.activation_quantization_enabled = False |
|
|
|
def __repr__(self): |
|
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' |
|
', stride={stride}') |
|
if self.padding != (0, ) * len(self.padding): |
|
s += ', padding={padding}' |
|
if self.dilation != (1, ) * len(self.dilation): |
|
s += ', dilation={dilation}' |
|
if self.output_padding != (0, ) * len(self.output_padding): |
|
s += ', output_padding={output_padding}' |
|
if self.groups != 1: |
|
s += ', groups={groups}' |
|
if self.bias is None: |
|
s += ', bias=False' |
|
if self.padding_mode != 'zeros': |
|
s += ', padding_mode={padding_mode}' |
|
output = s.format(**self.__dict__) |
|
|
|
return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format( |
|
self.sparse_pruning_method is not None, self.channel_pruning_method is not None, |
|
self.activation_quantization_method is not None, self.weight.target_bits) |
|
|
|
def enable_sparse_pruning(self, ratio, method): |
|
self.sparse_pruning_ratio = ratio |
|
self.sparse_pruning_method = method |
|
if method == 'l1': |
|
weight_norm = torch.abs(self.weight.data) |
|
mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False) |
|
mask = mask.view(self.weight.size()) |
|
mask = mask.to(self.weight.device) |
|
elif method == 'topk': |
|
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size())) |
|
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device) |
|
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5)) |
|
mask = None |
|
else: |
|
raise NotImplementedError |
|
|
|
self.register_buffer('sparse_pruning_mask', mask) |
|
|
|
def enable_channel_pruning(self, ratio, method): |
|
|
|
self.channel_pruning_ratio = ratio |
|
self.channel_pruning_method = method |
|
|
|
if method == 'l1': |
|
|
|
weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=[1, 2, 3]) |
|
mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False) |
|
mask = mask.view(-1, 1, 1, 1) |
|
mask = mask.to(self.weight.device) |
|
elif method == 'topk': |
|
self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1)) |
|
self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device) |
|
init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5)) |
|
mask = None |
|
else: |
|
raise NotImplementedError |
|
|
|
self.register_buffer('channel_pruning_mask', mask) |
|
|
|
def fix_sparse_pruning_helper(self): |
|
mask = self.get_mask(pruning_type='sparse') |
|
self.weight.data = self.weight.data * mask |
|
del self.sparse_pruning_mask |
|
if self.sparse_pruning_method == 'topk': |
|
del self.sparse_mask_scores |
|
self.sparse_pruning_method = None |
|
self.sparse_pruning_enabled = False |
|
return None |
|
|
|
def fix_channel_pruning_helper(self, mask=None, dim_reduction=False): |
|
if mask is None: |
|
if self.channel_pruning_method in ['l1', 'topk']: |
|
mask = self.get_mask(pruning_type='channel').bool() |
|
if dim_reduction: |
|
start_bits = self.weight.start_bits |
|
target_bits = self.weight.target_bits |
|
q_period = self.weight.q_period |
|
self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...]) |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = q_period |
|
if self.bias is not None: |
|
self.bias = nn.Parameter(self.bias.data[mask.view(-1)]) |
|
else: |
|
self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1) |
|
if self.bias is not None: |
|
self.bias.data = self.bias.data * mask.view(-1) |
|
del self.channel_pruning_mask |
|
if self.channel_pruning_method == 'topk': |
|
del self.channel_mask_scores |
|
self.channel_pruning_method = None |
|
else: |
|
raise NotImplementedError |
|
else: |
|
start_bits = self.weight.start_bits |
|
target_bits = self.weight.target_bits |
|
q_period = self.weight.q_period |
|
self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...]) |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = q_period |
|
mask = None |
|
self.channel_pruning_enabled = False |
|
return mask |
|
|
|
def get_mask(self, pruning_type='sparse'): |
|
if pruning_type == 'sparse': |
|
if self.sparse_pruning_method == 'l1': |
|
return self.sparse_pruning_mask.to(self.weight.device) |
|
elif self.sparse_pruning_method == 'topk': |
|
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False) |
|
else: |
|
raise NotImplementedError |
|
elif pruning_type == 'channel': |
|
if self.channel_pruning_method == 'l1': |
|
return self.channel_pruning_mask.to(self.weight.device) |
|
elif self.channel_pruning_method == 'topk': |
|
return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False) |
|
else: |
|
raise NotImplementedError |
|
else: |
|
raise NotImplementedError |
|
|
|
def fix_weight_quantization(self): |
|
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None, |
|
self.weight_quantize_num_groups).data |
|
self.weight_quantization_enabled_in_forward = False |
|
return None |
|
|
|
def enable_weight_quantization(self, start_bits, target_bits, quantization_period, |
|
weight_quantization_enabled_in_forward, quantization_type, num_groups): |
|
self.weight.start_bits = start_bits |
|
self.weight.target_bits = target_bits |
|
self.weight.q_period = quantization_period |
|
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward |
|
if self.weight_quantization_enabled_in_forward: |
|
assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now' |
|
logger.warning( |
|
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************" |
|
) |
|
if quantization_type == 'symmetric': |
|
self.weight_quantizer = SymQuantizer.apply |
|
else: |
|
self.weight_quantizer = AsymQuantizer.apply |
|
self.weight_quantize_num_groups = num_groups |
|
|
|
def enable_activation_quantization(self, bits, quantization_type, range_calibration): |
|
assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now' |
|
self.activation_quantization_bits = bits |
|
self.activation_quantization_method = f"{quantization_type}_{range_calibration}" |
|
if range_calibration == 'static': |
|
self.activation_quantizer = QuantAct(quant_mode=quantization_type) |
|
else: |
|
if quantization_type == 'symmetric': |
|
self.activation_quantizer = SymQuantizer.apply |
|
else: |
|
self.activation_quantizer = AsymQuantizer.apply |
|
|
|
def forward(self, input): |
|
|
|
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled: |
|
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None, |
|
self.weight_quantize_num_groups) |
|
bias = self.bias |
|
else: |
|
weight = self.weight |
|
bias = self.bias |
|
|
|
if self.sparse_pruning_enabled and self.sparse_pruning_method: |
|
mask = self.get_mask(pruning_type='sparse') |
|
weight = weight * mask.view(self.weight.size()) |
|
|
|
if self.channel_pruning_enabled: |
|
mask = self.get_mask(pruning_type='channel') |
|
weight = weight * mask.view(-1, 1, 1, 1) |
|
if bias is not None: |
|
bias = bias * mask.view(-1) |
|
|
|
if self.activation_quantization_enabled: |
|
if 'dynamic' in self.activation_quantization_method: |
|
num_groups = input.numel() // input[0].numel() |
|
else: |
|
num_groups = 1 |
|
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups) |
|
|
|
return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
|
|
|
class BNLayer_Compress(nn.BatchNorm2d): |
|
|
|
def fix_channel_pruning_helper(self, mask, dim_reduction=True): |
|
self.weight = nn.Parameter(self.weight.data[mask.view(-1)]) |
|
self.bias = nn.Parameter(self.bias.data[mask.view(-1)]) |
|
self.running_mean = self.running_mean[mask.view(-1)] |
|
self.running_var = self.running_var[mask.view(-1)] |
|
|
|
|
|
def _reduce(input_): |
|
"""All-reduce the input tensor across model parallel group.""" |
|
group = g_mpu.get_model_parallel_group() |
|
|
|
|
|
if dist.get_world_size(group=group) == 1: |
|
return input_ |
|
|
|
|
|
dist.all_reduce(input_, group=group) |
|
|
|
return input_ |
|
|
|
|
|
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): |
|
"""Split a tensor along its last dimension. |
|
Arguments: |
|
tensor: input tensor. |
|
num_partitions: number of partitions to split the tensor |
|
contiguous_split_chunks: If True, make each chunk contiguous |
|
in memory. |
|
""" |
|
|
|
last_dim = tensor.dim() - 1 |
|
assert tensor.size()[last_dim] % num_partitions == 0 |
|
last_dim_size = tensor.size()[last_dim] // num_partitions |
|
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) |
|
|
|
if contiguous_split_chunks: |
|
return tuple(chunk.contiguous() for chunk in tensor_list) |
|
|
|
return tensor_list |
|
|
|
|
|
def _split(input_): |
|
"""Split the tensor along its last dimension and keep the |
|
corresponding slice.""" |
|
group = g_mpu.get_model_parallel_group() |
|
|
|
|
|
if dist.get_world_size(group=group) == 1: |
|
return input_ |
|
|
|
|
|
world_size = dist.get_world_size(group=group) |
|
input_list = split_tensor_along_last_dim(input_, world_size) |
|
|
|
|
|
rank = dist.get_rank(group=group) |
|
output = input_list[rank].contiguous() |
|
|
|
return output |
|
|
|
|
|
def _gather(input_): |
|
"""Gather tensors and concatenate along the last dimension.""" |
|
group = g_mpu.get_model_parallel_group() |
|
|
|
|
|
if dist.get_world_size(group=group) == 1: |
|
return input_ |
|
|
|
|
|
last_dim = input_.dim() - 1 |
|
rank = dist.get_rank(group=group) |
|
world_size = dist.get_world_size(group=group) |
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
tensor_list[rank] = input_ |
|
dist.all_gather(tensor_list, input_, group=group) |
|
|
|
|
|
output = torch.cat(tensor_list, dim=last_dim).contiguous() |
|
|
|
return output |
|
|
|
|
|
class _CopyToModelParallelRegion(torch.autograd.Function): |
|
"""Pass the input to the model parallel region.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return input_ |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _reduce(grad_output) |
|
|
|
|
|
class _ReduceFromModelParallelRegion(torch.autograd.Function): |
|
"""All-reduce the input from the model parallel region.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _reduce(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return grad_output |
|
|
|
|
|
class _ScatterToModelParallelRegion(torch.autograd.Function): |
|
"""Split the input and keep only the corresponding chuck to the rank.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _split(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _gather(grad_output) |
|
|
|
|
|
class _GatherFromModelParallelRegion(torch.autograd.Function): |
|
"""Gather the input from model parallel region and concatenate.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _gather(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _split(grad_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_to_model_parallel_region(input_): |
|
return _CopyToModelParallelRegion.apply(input_) |
|
|
|
|
|
def reduce_from_model_parallel_region(input_): |
|
return _ReduceFromModelParallelRegion.apply(input_) |
|
|
|
|
|
def scatter_to_model_parallel_region(input_): |
|
return _ScatterToModelParallelRegion.apply(input_) |
|
|
|
|
|
def gather_from_model_parallel_region(input_): |
|
return _GatherFromModelParallelRegion.apply(input_) |
|
|
|
|
|
class ColumnParallelLinear_Compress(LinearLayer_Compress): |
|
|
|
def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False): |
|
|
|
global g_mpu |
|
g_mpu = mpu |
|
self.input_size = input_size |
|
self.output_size = output_size |
|
self.gather_output = gather_output |
|
self.skip_bias_add = skip_bias_add |
|
|
|
|
|
world_size = mpu.get_model_parallel_world_size() |
|
assert output_size % world_size == 0 |
|
self.output_size_per_partition = output_size // world_size |
|
|
|
super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias) |
|
|
|
def forward(self, input_): |
|
|
|
input_parallel = copy_to_model_parallel_region(input_) |
|
|
|
if self.skip_bias_add: |
|
output_parallel, bias = super().forward(input_parallel, True) |
|
else: |
|
output_parallel = super().forward(input_parallel) |
|
bias = None |
|
if self.gather_output: |
|
|
|
output = gather_from_model_parallel_region(output_parallel) |
|
else: |
|
output = output_parallel |
|
return output, bias |
|
|
|
|
|
class RowParallelLinear_Compress(LinearLayer_Compress): |
|
|
|
def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False): |
|
|
|
global g_mpu |
|
g_mpu = mpu |
|
self.input_size = input_size |
|
self.output_size = output_size |
|
self.input_is_parallel = input_is_parallel |
|
self.skip_bias_add = skip_bias_add |
|
|
|
|
|
world_size = mpu.get_model_parallel_world_size() |
|
assert input_size % world_size == 0 |
|
self.input_size_per_partition = input_size // world_size |
|
|
|
super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias) |
|
|
|
def forward(self, input_): |
|
|
|
if self.input_is_parallel: |
|
input_parallel = input_ |
|
else: |
|
input_parallel = scatter_to_model_parallel_region(input_) |
|
|
|
output_parallel, bias = super().forward(input_parallel, True) |
|
|
|
|
|
output_ = reduce_from_model_parallel_region(output_parallel) |
|
if not self.skip_bias_add: |
|
if bias is not None: |
|
output = output_ + bias |
|
else: |
|
output = output_ |
|
output_bias = None |
|
else: |
|
output = output_ |
|
output_bias = bias |
|
return output, output_bias |
|
|