# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch import abc from abc import ABC import gc from deepspeed.ops.op_builder import FPQuantizerBuilder from deepspeed.accelerator import get_accelerator fp_quant_module = None class Quantizer(ABC): """ Abstract Quantizer class that implements quantize/dequantize methods. Arguments: group_size (int, optional): number of values or elements that are grouped together for the quantization process. """ def __init__(self, group_size=512) -> None: self.group_size = group_size @abc.abstractmethod def quantize(self, input, q_bits=8, q_mantisa_bits=3, stochastic_mode=False, return_meta_tensor=False) -> torch.Tensor: ... @abc.abstractmethod def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: ... class FP_Quantize(Quantizer): def __init__(self, quantization_config) -> None: global fp_quant_module super().__init__(group_size=quantization_config.group_size) if fp_quant_module is None: fp_quant_module = FPQuantizerBuilder().load() self.cuda_impl = getattr(fp_quant_module, "CUDA_IMPL", True) self.q_config = quantization_config self.orig_dtype = None self.num_groups = None self.input_q = None self.scale = None def quantize(self, input, q_bits=8, q_mantisa_bits=3, stochastic_mode=False, return_meta_tensor=False) -> torch.Tensor: assert input.dtype == torch.bfloat16, "only support bf16 for now" if return_meta_tensor: assert q_bits == 8, "meta tensor is only supported with q_bit=8" self.orig_dtype = input.dtype self.orig_shape = input.shape if q_bits == 8: pass elif q_bits == 12: q_mantisa_bits = 4 elif q_bits == 6: q_mantisa_bits = 2 elif q_bits == 4: q_mantisa_bits = 1 else: assert (0), \ f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" self.num_groups = input.numel() // self.group_size self.input_q = torch.ones(self.num_groups, int(self.group_size * q_bits) // 8 + 4, dtype=torch.uint8, device=input.device) out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) if return_meta_tensor: data, self.scale = out.split(self.group_size, dim=-1) data = data.contiguous().reshape(input.shape) self.scale = self.scale.contiguous() del self.input_q del out gc.collect() get_accelerator().empty_cache() return data, self.scale return out def to(self, *args, **kwargs): # Intermediate tensors may need to be moved to different devices if hasattr(self, 'input_q'): self.input_q = self.input_q.to(*args, **kwargs) if hasattr(self, 'scale'): self.scale = self.scale.to(*args, **kwargs) def get_scales(self): return fp_quant_module.get_scales(self.scale, self.num_groups) def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: assert (self.orig_dtype is not None), \ "[De-quantization Error]: you need to call quantize before dequantizing!" fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out if q_bits == 8: pass elif q_bits == 12: q_mantisa_bits = 4 elif q_bits == 6: q_mantisa_bits = 2 elif q_bits == 4: q_mantisa_bits = 1 else: assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" if scale is not None: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out def selective_dequantize(self, input_q, indexes, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \ "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function." assert (self.orig_dtype is not None), \ "[De-quantization Error]: you need to call quantize before dequantizing!" fp_out = torch.empty( (indexes.shape[0], *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out if q_bits == 8: pass elif q_bits == 12: q_mantisa_bits = 4 elif q_bits == 6: q_mantisa_bits = 2 elif q_bits == 4: q_mantisa_bits = 1 else: assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" if scale is not None: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out