File size: 6,481 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import abc
from abc import ABC
import gc
from deepspeed.ops.op_builder import FPQuantizerBuilder
from deepspeed.accelerator import get_accelerator
fp_quant_module = None
class Quantizer(ABC):
"""
Abstract Quantizer class that implements quantize/dequantize methods.
Arguments:
group_size (int, optional): number of values or elements that are grouped
together for the quantization process.
"""
def __init__(self, group_size=512) -> None:
self.group_size = group_size
@abc.abstractmethod
def quantize(self,
input,
q_bits=8,
q_mantisa_bits=3,
stochastic_mode=False,
return_meta_tensor=False) -> torch.Tensor:
...
@abc.abstractmethod
def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
...
class FP_Quantize(Quantizer):
def __init__(self, quantization_config) -> None:
global fp_quant_module
super().__init__(group_size=quantization_config.group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()
self.cuda_impl = getattr(fp_quant_module, "CUDA_IMPL", True)
self.q_config = quantization_config
self.orig_dtype = None
self.num_groups = None
self.input_q = None
self.scale = None
def quantize(self,
input,
q_bits=8,
q_mantisa_bits=3,
stochastic_mode=False,
return_meta_tensor=False) -> torch.Tensor:
assert input.dtype == torch.bfloat16, "only support bf16 for now"
if return_meta_tensor:
assert q_bits == 8, "meta tensor is only supported with q_bit=8"
self.orig_dtype = input.dtype
self.orig_shape = input.shape
if q_bits == 8:
pass
elif q_bits == 12:
q_mantisa_bits = 4
elif q_bits == 6:
q_mantisa_bits = 2
elif q_bits == 4:
q_mantisa_bits = 1
else:
assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
self.num_groups = input.numel() // self.group_size
self.input_q = torch.ones(self.num_groups,
int(self.group_size * q_bits) // 8 + 4,
dtype=torch.uint8,
device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
if return_meta_tensor:
data, self.scale = out.split(self.group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
self.scale = self.scale.contiguous()
del self.input_q
del out
gc.collect()
get_accelerator().empty_cache()
return data, self.scale
return out
def to(self, *args, **kwargs):
# Intermediate tensors may need to be moved to different devices
if hasattr(self, 'input_q'):
self.input_q = self.input_q.to(*args, **kwargs)
if hasattr(self, 'scale'):
self.scale = self.scale.to(*args, **kwargs)
def get_scales(self):
return fp_quant_module.get_scales(self.scale, self.num_groups)
def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
assert (self.orig_dtype is not None), \
"[De-quantization Error]: you need to call quantize before dequantizing!"
fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype,
device=input_q.device) if fp_out is None else fp_out
if q_bits == 8:
pass
elif q_bits == 12:
q_mantisa_bits = 4
elif q_bits == 6:
q_mantisa_bits = 2
elif q_bits == 4:
q_mantisa_bits = 1
else:
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
if scale is not None:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
return fp_out
def selective_dequantize(self,
input_q,
indexes,
fp_out=None,
q_bits=8,
q_mantisa_bits=3,
scale=None) -> torch.Tensor:
assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \
"Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function."
assert (self.orig_dtype is not None), \
"[De-quantization Error]: you need to call quantize before dequantizing!"
fp_out = torch.empty(
(indexes.shape[0],
*self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out
if q_bits == 8:
pass
elif q_bits == 12:
q_mantisa_bits = 4
elif q_bits == 6:
q_mantisa_bits = 2
elif q_bits == 4:
q_mantisa_bits = 1
else:
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
if scale is not None:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out
|