|
|
|
""" |
|
Utils shared by different modes of quantization (eager/graph) |
|
""" |
|
import functools |
|
import warnings |
|
from collections import OrderedDict |
|
from inspect import getfullargspec, signature |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import torch |
|
from torch.ao.quantization.quant_type import QuantType |
|
from torch.fx import Node |
|
from torch.nn.utils.parametrize import is_parametrized |
|
|
|
|
|
NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] |
|
NodePattern.__module__ = "torch.ao.quantization.utils" |
|
|
|
|
|
|
|
|
|
|
|
|
|
QuantizerCls = Any |
|
|
|
|
|
|
|
|
|
Pattern = Union[ |
|
Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any |
|
] |
|
Pattern.__module__ = "torch.ao.quantization.utils" |
|
|
|
|
|
|
|
class MatchAllNode: |
|
"""A node pattern that matches all nodes, used in defining |
|
fusion patterns in FX Graph Mode Quantization |
|
""" |
|
|
|
|
|
module_type_list = { |
|
torch.nn.ReLU, |
|
torch.nn.ReLU6, |
|
torch.nn.AdaptiveAvgPool1d, |
|
torch.nn.AdaptiveAvgPool2d, |
|
torch.nn.AdaptiveAvgPool3d, |
|
torch.nn.AvgPool1d, |
|
torch.nn.AvgPool2d, |
|
torch.nn.AvgPool3d, |
|
torch.nn.MaxPool1d, |
|
torch.nn.MaxPool2d, |
|
torch.nn.MaxPool3d, |
|
torch.nn.Identity, |
|
torch.nn.Hardsigmoid, |
|
torch.nn.Sigmoid, |
|
torch.nn.Tanh, |
|
} |
|
func_list = { |
|
torch.nn.functional.adaptive_avg_pool1d, |
|
torch.nn.functional.adaptive_avg_pool2d, |
|
torch.nn.functional.adaptive_avg_pool3d, |
|
torch.nn.functional.elu, |
|
torch.nn.functional.hardswish, |
|
torch.nn.functional.instance_norm, |
|
torch.nn.functional.layer_norm, |
|
torch.nn.functional.leaky_relu, |
|
torch.nn.functional.silu, |
|
torch.nn.functional.mish, |
|
torch.nn.functional.dropout, |
|
torch.nn.functional.max_pool1d, |
|
torch.nn.functional.max_pool2d, |
|
torch.nn.functional.max_pool3d, |
|
torch.nn.functional.relu, |
|
torch.nn.functional.hardtanh, |
|
torch.nn.functional.hardtanh_, |
|
torch.nn.functional.hardsigmoid, |
|
torch.nn.functional.sigmoid, |
|
torch.transpose, |
|
torch.repeat_interleave, |
|
torch.sigmoid, |
|
torch.squeeze, |
|
torch.stack, |
|
torch.sum, |
|
torch.tanh, |
|
torch.unsqueeze, |
|
torch.cat, |
|
} |
|
method_list = { |
|
torch.mean, |
|
"relu", |
|
"relu_", |
|
"contiguous", |
|
"detach", |
|
"detach_", |
|
"hardsigmoid", |
|
"hardsigmoid_", |
|
"permute", |
|
"repeat", |
|
"repeat_interleave", |
|
"reshape", |
|
"resize_", |
|
"shape", |
|
"sigmoid", |
|
"sigmoid_", |
|
"size", |
|
"squeeze", |
|
"squeeze_", |
|
"tanh", |
|
"tanh_", |
|
"transpose", |
|
"unsqueeze", |
|
"unsqueeze_", |
|
"view", |
|
} |
|
|
|
|
|
|
|
def check_node(node, modules): |
|
|
|
is_call_function = node.op == "call_function" and node.target in func_list |
|
is_call_method = node.op == "call_method" and node.target in method_list |
|
is_call_module = ( |
|
node.op == "call_module" and type(modules[str(node.target)]) in module_type_list |
|
) |
|
return is_call_function, is_call_method, is_call_module |
|
|
|
|
|
def get_combined_dict(default_dict, additional_dict): |
|
""" |
|
Combines two dictionaries. |
|
|
|
This function takes two dictionaries as input and returns a new dictionary |
|
that contains all the key-value pairs from both input dictionaries. |
|
If there are any duplicate keys in the `additional_dict`, the values |
|
from the `additional_dict` will overwrite those in the `default_dict`. |
|
Args: |
|
default_dict (dict): The main dictionary that will be used as the base |
|
additional_dict (dict): The dictionary used to update `default_dict` |
|
|
|
Returns: |
|
dict: The resulting dictionary |
|
Example: |
|
>>> x = dict(a=1, b=1) |
|
>>> y = dict(b=2, c=3) |
|
>>> get_combined_dict(x, y) |
|
{'a': 1, 'b': 2, 'c': 3} |
|
""" |
|
d = default_dict.copy() |
|
d.update(additional_dict) |
|
return d |
|
|
|
|
|
def is_per_tensor(qscheme): |
|
return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric |
|
|
|
|
|
def is_per_channel(qscheme): |
|
return qscheme in [ |
|
torch.per_channel_affine, |
|
torch.per_channel_affine_float_qparams, |
|
torch.per_channel_symmetric, |
|
] |
|
|
|
|
|
def getattr_from_fqn(obj: Any, fqn: str) -> Any: |
|
""" |
|
Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. |
|
""" |
|
return functools.reduce(getattr, fqn.split("."), obj) |
|
|
|
|
|
def to_underlying_dtype(qdtype): |
|
DTYPE_MAPPING = { |
|
torch.quint8: torch.uint8, |
|
torch.qint8: torch.int8, |
|
torch.qint32: torch.int32, |
|
torch.quint4x2: torch.uint8, |
|
torch.quint2x4: torch.uint8, |
|
torch.uint8: torch.uint8, |
|
torch.int8: torch.int8, |
|
torch.uint16: torch.uint16, |
|
torch.int16: torch.int16, |
|
torch.int32: torch.int32, |
|
torch.float8_e5m2: torch.float8_e5m2, |
|
torch.float8_e4m3fn: torch.float8_e4m3fn, |
|
} |
|
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype) |
|
return DTYPE_MAPPING[qdtype] |
|
|
|
|
|
def get_qparam_dict(observer_or_fake_quant): |
|
from torch.ao.quantization.observer import PlaceholderObserver |
|
|
|
qscheme = getattr(observer_or_fake_quant, "qscheme", None) |
|
dtype = observer_or_fake_quant.dtype |
|
qparams = {"qscheme": qscheme, "dtype": dtype} |
|
|
|
if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver): |
|
return {"qscheme": None, "dtype": dtype} |
|
|
|
if is_per_tensor(qscheme): |
|
qscheme = torch.per_tensor_affine |
|
elif is_per_channel(qscheme): |
|
|
|
|
|
if qscheme == torch.per_channel_symmetric: |
|
qscheme = torch.per_channel_affine |
|
qparams["axis"] = observer_or_fake_quant.ch_axis |
|
else: |
|
raise RuntimeError(f"Unrecognized qscheme: {qscheme}") |
|
|
|
|
|
qparams["qscheme"] = qscheme |
|
|
|
scale, zero_point = observer_or_fake_quant.calculate_qparams() |
|
qparams["scale"] = scale |
|
qparams["zero_point"] = zero_point |
|
|
|
if hasattr(observer_or_fake_quant, "quant_min"): |
|
qparams["quant_min"] = observer_or_fake_quant.quant_min |
|
if hasattr(observer_or_fake_quant, "quant_max"): |
|
qparams["quant_max"] = observer_or_fake_quant.quant_max |
|
|
|
return qparams |
|
|
|
|
|
def get_swapped_custom_module_class( |
|
custom_module, custom_module_class_mapping, qconfig |
|
): |
|
"""Get the observed/quantized custom module class that we need |
|
to swap `custom_module` to |
|
Input: |
|
custom_module: input, can be an instance of either a float or observed custom module |
|
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping |
|
qconfig: qconfig configured for the custom module |
|
|
|
Output: |
|
corresponding observed/quantized custom module class for input custom module instance |
|
""" |
|
quant_type = get_quant_type(qconfig) |
|
class_mapping = custom_module_class_mapping.get(quant_type, {}) |
|
assert type(custom_module) in class_mapping, ( |
|
"did not find corresponding observed " |
|
f"module class for {type(custom_module)} in mapping: {class_mapping}" |
|
) |
|
return class_mapping[type(custom_module)] |
|
|
|
|
|
def activation_dtype(qconfig): |
|
assert qconfig is not None |
|
activation = qconfig.activation() |
|
return activation.dtype |
|
|
|
|
|
def weight_dtype(qconfig): |
|
assert qconfig is not None |
|
weight = qconfig.weight() |
|
return weight.dtype |
|
|
|
|
|
def activation_is_statically_quantized(qconfig): |
|
"""Given a qconfig, decide if the activation needs to be |
|
quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16 |
|
""" |
|
return activation_dtype(qconfig) in [ |
|
torch.quint8, |
|
torch.qint8, |
|
torch.qint32, |
|
torch.float16, |
|
torch.uint8, |
|
torch.int8, |
|
torch.int16, |
|
torch.int32, |
|
torch.float8_e5m2, |
|
torch.float8_e4m3fn, |
|
] and (not activation_is_dynamically_quantized(qconfig)) |
|
|
|
|
|
def activation_is_dynamically_quantized(qconfig): |
|
"""Given a qconfig, decide if the activation needs to be |
|
dynamically quantized or not, this includes dynamically quantizing to |
|
quint8, qint8 and float16 |
|
""" |
|
_activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig) |
|
return activation_is_dynamic |
|
|
|
|
|
def activation_is_int8_quantized(qconfig): |
|
"""Given a qconfig, decide if the activation needs to be |
|
quantized to int8 or not, this includes quantizing to quint8, qint8 |
|
""" |
|
return activation_dtype(qconfig) in [ |
|
torch.quint8, |
|
torch.qint8, |
|
torch.uint8, |
|
torch.int8, |
|
] |
|
|
|
|
|
def activation_is_int32_quantized(qconfig): |
|
"""Given a qconfig, decide if the activation needs to be |
|
quantized to int32 or not |
|
""" |
|
return activation_dtype(qconfig) in [torch.qint32, torch.int32] |
|
|
|
|
|
def weight_is_quantized(qconfig): |
|
"""Given a qconfig, decide if the weight needs to be |
|
quantized or not |
|
""" |
|
return weight_dtype(qconfig) in [ |
|
torch.quint8, |
|
torch.qint8, |
|
torch.float16, |
|
torch.quint4x2, |
|
torch.uint8, |
|
torch.int8, |
|
torch.int16, |
|
torch.int32, |
|
torch.float8_e5m2, |
|
torch.float8_e4m3fn, |
|
] |
|
|
|
|
|
def weight_is_statically_quantized(qconfig): |
|
"""Given a qconfig, decide if the weight needs to be statically |
|
quantized or not |
|
""" |
|
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] |
|
|
|
|
|
def op_is_int8_dynamically_quantized(qconfig) -> bool: |
|
"""Given a qconfig, returns True if this op is using int8 dynamic |
|
quantization |
|
""" |
|
activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig) |
|
return ( |
|
activation_dtype in [torch.quint8, torch.uint8] |
|
and |
|
|
|
weight_dtype in [torch.qint8, torch.int8] |
|
and activation_is_dynamic |
|
) |
|
|
|
|
|
def get_qconfig_dtypes(qconfig): |
|
r"""returns the qconfig tuple for qconfig: |
|
(activation_dtype, weight_dtype, activation_is_dynamic) |
|
""" |
|
assert qconfig is not None |
|
activation = qconfig.activation() |
|
weight = qconfig.weight() |
|
act_is_dynamic = getattr(activation, "is_dynamic", False) |
|
return (activation.dtype, weight.dtype, act_is_dynamic) |
|
|
|
|
|
def get_quant_type(qconfig): |
|
assert qconfig is not None |
|
activation = qconfig.activation() |
|
weight = qconfig.weight() |
|
static_dtypes = [ |
|
torch.quint8, |
|
torch.qint8, |
|
torch.quint4x2, |
|
torch.qint32, |
|
torch.uint8, |
|
torch.int8, |
|
torch.int16, |
|
torch.int32, |
|
torch.float8_e5m2, |
|
torch.float8_e4m3fn, |
|
] |
|
if weight.dtype in static_dtypes: |
|
if hasattr(activation, "is_dynamic") and activation.is_dynamic: |
|
return QuantType.DYNAMIC |
|
elif activation.dtype in static_dtypes: |
|
return QuantType.STATIC |
|
else: |
|
return QuantType.WEIGHT_ONLY |
|
|
|
if weight.dtype == torch.float16: |
|
if hasattr(activation, "is_dynamic") and activation.is_dynamic: |
|
return QuantType.DYNAMIC |
|
elif activation.dtype == torch.float16: |
|
return QuantType.STATIC |
|
|
|
raise Exception( |
|
f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype})," |
|
f"weight({weight.dtype})" |
|
) |
|
|
|
|
|
def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: |
|
"""Checks if the given minimum and maximum values are valid, meaning that |
|
they exist and the min value is less than the max value. |
|
""" |
|
if min_val.numel() == 0 or max_val.numel() == 0: |
|
warnings.warn( |
|
"must run observer before calling calculate_qparams. " |
|
+ "Returning default values." |
|
) |
|
return False |
|
|
|
if min_val.dim() == 0 or max_val.dim() == 0: |
|
if min_val == float("inf") and max_val == float("-inf"): |
|
warnings.warn( |
|
"must run observer before calling calculate_qparams. " |
|
+ "Returning default values." |
|
) |
|
|
|
return False |
|
|
|
assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" |
|
else: |
|
assert torch.all( |
|
min_val <= max_val |
|
), f"min {min_val} should be less than max {max_val}" |
|
|
|
return True |
|
|
|
|
|
def calculate_qmin_qmax( |
|
quant_min: int, |
|
quant_max: int, |
|
has_customized_qrange: bool, |
|
dtype: torch.dtype, |
|
reduce_range: bool, |
|
) -> tuple[int, int]: |
|
r"""Calculates actual qmin and qmax based on the quantization range, |
|
observer datatype and if range is reduced. |
|
""" |
|
|
|
if has_customized_qrange: |
|
|
|
|
|
|
|
if dtype in [torch.qint32, torch.int32]: |
|
initial_quant_min, initial_quant_max = 0, 2**32 - 1 |
|
else: |
|
initial_quant_min, initial_quant_max = 0, 255 |
|
|
|
|
|
custom_quant_min, custom_quant_max = quant_min, quant_max |
|
if custom_quant_min is not None and custom_quant_max is not None: |
|
initial_quant_min, initial_quant_max = ( |
|
custom_quant_min, |
|
custom_quant_max, |
|
) |
|
|
|
qrange_len = initial_quant_max - initial_quant_min + 1 |
|
if dtype in [torch.qint8, torch.int8]: |
|
assert ( |
|
0 < qrange_len <= 256 |
|
), "quantization range should be positive and not exceed the maximum bit range (=256)." |
|
elif dtype in [torch.qint32, torch.int32]: |
|
assert ( |
|
0 < qrange_len <= 2**32 |
|
), "quantization range should be positive and not exceed the maximum bit range (=4294967296)." |
|
if reduce_range: |
|
quant_min, quant_max = quant_min // 2, quant_max // 2 |
|
else: |
|
|
|
if dtype in [torch.qint8, torch.int8]: |
|
if reduce_range: |
|
quant_min, quant_max = -64, 63 |
|
else: |
|
quant_min, quant_max = -128, 127 |
|
elif dtype in [torch.quint8, torch.uint8]: |
|
if reduce_range: |
|
quant_min, quant_max = 0, 127 |
|
else: |
|
quant_min, quant_max = 0, 255 |
|
elif dtype in [torch.qint32, torch.int32]: |
|
quant_min, quant_max = -1 * (2**31), (2**31) - 1 |
|
elif dtype in [torch.uint16]: |
|
quant_min, quant_max = 0, 2**16 - 1 |
|
elif dtype in [torch.int16]: |
|
quant_min, quant_max = -(2**15), 2**15 - 1 |
|
else: |
|
quant_min, quant_max = 0, 15 |
|
return quant_min, quant_max |
|
|
|
|
|
def _parent_name(target): |
|
""" |
|
Turn 'foo.bar' into ['foo', 'bar'] |
|
""" |
|
r = target.rsplit(".", 1) |
|
if len(r) == 1: |
|
return "", r[0] |
|
else: |
|
return r[0], r[1] |
|
|
|
|
|
def has_no_children_ignoring_parametrizations(module): |
|
""" |
|
Checks if module._modules is empty or |
|
if module is a parametrization, checks that module._modules only has |
|
the 'parametrizations' module |
|
""" |
|
if len(module._modules) == 0: |
|
return True |
|
elif is_parametrized(module): |
|
return len(module._modules) == 1 and "parametrizations" in module._modules |
|
else: |
|
return False |
|
|
|
|
|
def _get_path_of_module( |
|
root: torch.nn.Module, submodule: torch.nn.Module |
|
) -> Optional[str]: |
|
"""Get the path (fully qualified name) of a submodule |
|
|
|
Example:: |
|
|
|
>> class M(torch.nn.Module): |
|
def __init__(self) -> None: |
|
self.linear = torch.nn.Linear(5, 5) |
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
>> m = M() |
|
>> l = m.linear |
|
>> _get_path_of_module(m, l) |
|
"linear" |
|
""" |
|
for n, p in root.named_modules(): |
|
if submodule is p: |
|
return n |
|
return None |
|
|
|
|
|
def _get_signature_locals(f: Callable, loc: dict[str, Any]) -> dict[str, Any]: |
|
"""Get local keyword arguments |
|
|
|
Example:: |
|
|
|
>> def f(self, a, b=9): |
|
pass |
|
>> loc = {"a": 6, "c": 7} |
|
>> _get_signature_locals(f, loc) |
|
{"a": 6} |
|
""" |
|
return {k: v for k, v in loc.items() if k in signature(f).parameters} |
|
|
|
|
|
def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": |
|
"""Get all default keyword arguments from function signature |
|
|
|
Example:: |
|
|
|
>> def f(self, a, b=9): |
|
pass |
|
>> _get_default_kwargs(f) |
|
{"b": 9} |
|
""" |
|
kwargs = {} |
|
for name, param in signature(f).parameters.items(): |
|
if param.default is not param.empty: |
|
kwargs[name] = param.default |
|
elif param.kind is param.VAR_POSITIONAL: |
|
kwargs[name] = () |
|
elif param.kind is param.VAR_KEYWORD: |
|
kwargs[name] = {} |
|
return OrderedDict(kwargs) |
|
|
|
|
|
def _normalize_kwargs(func: Callable, loc: dict[str, Any]) -> "OrderedDict[str, Any]": |
|
"""Given a function and local function arguments, normalize the keyword |
|
arguments by filling in default arguments from function signature |
|
|
|
Example:: |
|
|
|
>> def f(self, key1=3, key2=3): |
|
pass |
|
>> loc = {"key2": 6} |
|
>> _normalize_kwargs(f, loc) |
|
{"key1": 3, "key2": 6} |
|
""" |
|
default_kwargs = _get_default_kwargs(func) |
|
local_kwargs = _get_signature_locals(func, loc) |
|
normalized_kwargs = default_kwargs.copy() |
|
for attr, val in local_kwargs.items(): |
|
if attr in normalized_kwargs: |
|
|
|
normalized_kwargs[attr] = val |
|
return normalized_kwargs |
|
|
|
|
|
def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: |
|
r"""Validates that the user-specified quantization range is properly initialized |
|
and within the given bound supported by the observer dtype. |
|
|
|
To accommodate lower-bit quantization with respect to the existing torch.qint8 and |
|
torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing |
|
in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax |
|
values are used to calculate static estimates of the scale and zero point for aggressive lower-bit |
|
fake quantization. These estimates are compared against parameters learned through backpropagation. |
|
The related literatures for scale and zero point via backpropagation are as follows: |
|
|
|
Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS |
|
Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf |
|
""" |
|
|
|
|
|
assert ( |
|
quant_min <= 0 <= quant_max |
|
), "Used-specified quantization range must include 0." |
|
assert ( |
|
quant_min < quant_max |
|
), "qmin must be strictly less than qmax for user-specified quantization range." |
|
|
|
|
|
|
|
|
|
|
|
|
|
def determine_qparams( |
|
min_val: torch.Tensor, |
|
max_val: torch.Tensor, |
|
quant_min: int, |
|
quant_max: int, |
|
dtype: torch.dtype, |
|
eps: torch.Tensor, |
|
has_customized_qrange: bool, |
|
qscheme: torch.qscheme = torch.per_tensor_affine, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
r"""Calculates the quantization parameters, given min and max |
|
value tensors. Works for both per tensor and per channel cases |
|
|
|
Args: |
|
min_val: Minimum values per channel |
|
max_val: Maximum values per channel |
|
|
|
Returns: |
|
scales: Scales tensor of shape (#channels,) |
|
zero_points: Zero points tensor of shape (#channels,) |
|
""" |
|
if not check_min_max_valid(min_val, max_val): |
|
return torch.tensor([1.0], device=min_val.device.type), torch.tensor( |
|
[0], device=min_val.device.type |
|
) |
|
|
|
min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
|
max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
|
|
|
device = min_val_neg.device |
|
scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) |
|
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) |
|
eps = eps.to(device) |
|
|
|
if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: |
|
max_val_pos = torch.max(-min_val_neg, max_val_pos) |
|
scale = max_val_pos / (float(quant_max - quant_min) / 2) |
|
scale = torch.max(scale, eps) |
|
if dtype in [torch.uint8, torch.quint8]: |
|
if has_customized_qrange: |
|
|
|
zero_point = zero_point.new_full( |
|
zero_point.size(), (quant_min + quant_max) // 2 |
|
) |
|
else: |
|
zero_point = zero_point.new_full(zero_point.size(), 128) |
|
elif qscheme == torch.per_channel_affine_float_qparams: |
|
scale = (max_val - min_val) / float(quant_max - quant_min) |
|
scale = torch.where(scale > eps, scale, torch.ones_like(scale)) |
|
|
|
|
|
|
|
|
|
zero_point = -1 * min_val / scale |
|
else: |
|
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) |
|
scale = torch.max(scale, eps) |
|
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) |
|
zero_point = torch.clamp(zero_point, quant_min, quant_max) |
|
|
|
|
|
|
|
if len(scale.shape) == 0: |
|
|
|
scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) |
|
if len(zero_point.shape) == 0: |
|
|
|
zero_point = torch.tensor( |
|
[int(zero_point)], dtype=zero_point.dtype, device=device |
|
) |
|
if qscheme == torch.per_channel_affine_float_qparams: |
|
zero_point = torch.tensor( |
|
[float(zero_point)], dtype=zero_point.dtype, device=device |
|
) |
|
|
|
return scale.to(torch.double), zero_point.to(torch.int64) |
|
|
|
|
|
def _get_num_pos_args(f: Callable) -> int: |
|
"""Get number of positional args for a function |
|
|
|
Example:: |
|
|
|
>> def f(self, key1=3, key2=3): |
|
pass |
|
>> _get_num_pos_args(f) |
|
3 |
|
""" |
|
return len(getfullargspec(f).args) |
|
|
|
|
|
def get_fqn_to_example_inputs( |
|
model: torch.nn.Module, example_inputs: tuple[Any, ...] |
|
) -> dict[str, tuple[Any, ...]]: |
|
"""Given a model and its example inputs, return a dictionary from |
|
fully qualified name of submodules to example_inputs for that submodule, |
|
e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), |
|
"sub.linear1": (tensor4,), ...} |
|
|
|
Used to make quantizing submodules easier now that FX Graph Mode Quantization requires |
|
example inputs. |
|
|
|
Also works for keyword arguments with default values, we would flatten keyword |
|
arguments as positional arguments and fill in the missing keyword args with default |
|
values, e.g. if we have a forward function: |
|
def forward(self, x, key1=3, key2=3): |
|
... |
|
|
|
and we call it with self.submodule(x, key2=6) |
|
we'll get example_inputs: (x, 3, 6) |
|
|
|
user can also override `key1` with positional arguments as well: |
|
for self.submodule(x, 5, key2=6) |
|
we'll get: (x, 5, 6) |
|
|
|
variable positional arguments and variable positional keyword arguments in forward |
|
function are not supported currently, so please make sure no submodules is using |
|
them. |
|
""" |
|
root = model |
|
fqn_to_example_inputs = {} |
|
|
|
def _patched_module_call(self, *args, **kwargs): |
|
submodule_example_inputs = list(args).copy() |
|
normalized_kwargs = _normalize_kwargs(self.forward, kwargs) |
|
|
|
num_args = _get_num_pos_args(self.forward) - 1 |
|
num_to_pop = num_args - len(submodule_example_inputs) |
|
while num_to_pop and normalized_kwargs: |
|
normalized_kwargs.popitem(last=False) |
|
num_to_pop -= 1 |
|
submodule_example_inputs.extend(normalized_kwargs.values()) |
|
submodule_example_inputs_tuple = tuple(submodule_example_inputs) |
|
fqn = _get_path_of_module(root, self) |
|
if fqn is not None: |
|
fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple |
|
return orig_module_call(self, *args, **kwargs) |
|
|
|
orig_module_call = torch.nn.Module.__call__ |
|
torch.nn.Module.__call__ = _patched_module_call |
|
try: |
|
model(*example_inputs) |
|
finally: |
|
|
|
torch.nn.Module.__call__ = orig_module_call |
|
return fqn_to_example_inputs |
|
|
|
|
|
def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: |
|
""" |
|
Returns the unique device for a module, or None if no device is found. |
|
Throws an error if multiple devices are detected. |
|
""" |
|
devices = {p.device for p in module.parameters()} | { |
|
p.device for p in module.buffers() |
|
} |
|
""" |
|
As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564 |
|
""" |
|
if {torch.device("cpu"), torch.device("meta")} == devices: |
|
warnings.warn( |
|
"Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'." |
|
) |
|
devices = {torch.device("cpu")} |
|
"" |
|
assert len(devices) <= 1, ( |
|
"prepare only works with cpu or single-device CUDA modules, " |
|
f"but got devices {devices}" |
|
) |
|
device = next(iter(devices)) if len(devices) > 0 else None |
|
return device |
|
|
|
|
|
__all__ = [ |
|
"NodePattern", |
|
"Pattern", |
|
"MatchAllNode", |
|
"check_node", |
|
"get_combined_dict", |
|
"is_per_tensor", |
|
"is_per_channel", |
|
"getattr_from_fqn", |
|
"get_qparam_dict", |
|
"get_swapped_custom_module_class", |
|
"activation_dtype", |
|
"weight_dtype", |
|
"activation_is_statically_quantized", |
|
"activation_is_dynamically_quantized", |
|
"activation_is_int8_quantized", |
|
"activation_is_int32_quantized", |
|
"weight_is_quantized", |
|
"weight_is_statically_quantized", |
|
"op_is_int8_dynamically_quantized", |
|
"get_qconfig_dtypes", |
|
"get_quant_type", |
|
"check_min_max_valid", |
|
"calculate_qmin_qmax", |
|
"has_no_children_ignoring_parametrizations", |
|
"get_fqn_to_example_inputs", |
|
"to_underlying_dtype", |
|
"determine_qparams", |
|
"validate_qmin_qmax", |
|
] |
|
|