|
import json |
|
from typing import Any |
|
|
|
import torch |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer |
|
from transformers.utils.quantization_config import QuantizationConfigMixin |
|
|
|
|
|
@register_quantization_config("custom") |
|
class CustomConfig(QuantizationConfigMixin): |
|
def __init__(self): |
|
self.quant_method = "custom" |
|
self.bits = 8 |
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
output = { |
|
"num_bits": self.bits, |
|
} |
|
return output |
|
|
|
def __repr__(self): |
|
config_dict = self.to_dict() |
|
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" |
|
|
|
def to_diff_dict(self) -> dict[str, Any]: |
|
config_dict = self.to_dict() |
|
|
|
default_config_dict = CustomConfig().to_dict() |
|
|
|
serializable_config_dict = {} |
|
|
|
for key, value in config_dict.items(): |
|
if value != default_config_dict[key]: |
|
serializable_config_dict[key] = value |
|
|
|
return serializable_config_dict |
|
|
|
|
|
@register_quantizer("custom") |
|
class CustomQuantizer(HfQuantizer): |
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): |
|
super().__init__(quantization_config, **kwargs) |
|
self.quantization_config = quantization_config |
|
self.scale_map = {} |
|
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") |
|
self.torch_dtype = kwargs.get("torch_dtype", torch.float32) |
|
|
|
def _process_model_before_weight_loading(self, model, **kwargs): |
|
return True |
|
|
|
def _process_model_after_weight_loading(self, model, **kwargs): |
|
return True |
|
|
|
def is_serializable(self) -> bool: |
|
return True |
|
|
|
def is_trainable(self) -> bool: |
|
return False |
|
|
|
|
|
model_8bit = AutoModelForCausalLM.from_pretrained( |
|
"facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") |
|
input_text = "once there is" |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
output = model_8bit.generate( |
|
**inputs, |
|
max_length=100, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
print(generated_text) |
|
|