|
|
|
import copy |
|
import warnings |
|
from collections import namedtuple |
|
from typing import Any, Optional, Union |
|
from typing_extensions import deprecated |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.ao.quantization.fake_quantize import ( |
|
default_dynamic_fake_quant, |
|
default_embedding_fake_quant, |
|
default_embedding_fake_quant_4bit, |
|
default_fake_quant, |
|
default_fused_act_fake_quant, |
|
default_fused_per_channel_wt_fake_quant, |
|
default_fused_wt_fake_quant, |
|
default_per_channel_weight_fake_quant, |
|
default_weight_fake_quant, |
|
FakeQuantize, |
|
FakeQuantizeBase, |
|
fused_per_channel_wt_fake_quant_range_neg_127_to_127, |
|
fused_wt_fake_quant_range_neg_127_to_127, |
|
FusedMovingAvgObsFakeQuantize, |
|
) |
|
|
|
from .observer import ( |
|
_PartialWrapper, |
|
default_debug_observer, |
|
default_dynamic_quant_observer, |
|
default_float_qparams_observer, |
|
default_float_qparams_observer_4bit, |
|
default_observer, |
|
default_per_channel_weight_observer, |
|
default_placeholder_observer, |
|
default_reuse_input_observer, |
|
default_weight_observer, |
|
HistogramObserver, |
|
MinMaxObserver, |
|
MovingAverageMinMaxObserver, |
|
NoopObserver, |
|
ObserverBase, |
|
per_channel_weight_observer_range_neg_127_to_127, |
|
PlaceholderObserver, |
|
ReuseInputObserver, |
|
weight_observer_range_neg_127_to_127, |
|
) |
|
|
|
|
|
__all__ = [ |
|
"QConfig", |
|
|
|
"QConfigDynamic", |
|
"default_qconfig", |
|
"default_debug_qconfig", |
|
"default_per_channel_qconfig", |
|
"default_dynamic_qconfig", |
|
"float16_dynamic_qconfig", |
|
"float16_static_qconfig", |
|
"per_channel_dynamic_qconfig", |
|
"float_qparams_weight_only_qconfig", |
|
"float_qparams_weight_only_qconfig_4bit", |
|
"default_quint8_weight_qconfig", |
|
"default_qat_qconfig", |
|
"default_dynamic_qat_qconfig", |
|
"default_weight_only_qconfig", |
|
"default_activation_only_qconfig", |
|
"default_qat_qconfig_v2", |
|
"default_reuse_input_qconfig", |
|
"default_symmetric_qnnpack_qconfig", |
|
"default_per_channel_symmetric_qnnpack_qconfig", |
|
"default_symmetric_qnnpack_qat_qconfig", |
|
"default_per_channel_symmetric_qnnpack_qat_qconfig", |
|
"default_embedding_qat_qconfig", |
|
"default_embedding_qat_qconfig_4bit", |
|
"get_default_qconfig", |
|
"get_default_qat_qconfig", |
|
"get_default_qconfig_dict", |
|
"get_default_qat_qconfig_dict", |
|
"QConfigAny", |
|
"qconfig_equals", |
|
] |
|
|
|
|
|
class QConfig(namedtuple("QConfig", ["activation", "weight"])): |
|
""" |
|
Describes how to quantize a layer or a part of the network by providing |
|
settings (observer classes) for activations and weights respectively. |
|
|
|
|
|
Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns |
|
instances on invocation, not the concrete observer instances themselves. |
|
Quantization preparation function will instantiate observers multiple times for each of the layers. |
|
|
|
|
|
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` |
|
method (that behaves like functools.partial):: |
|
|
|
my_qconfig = QConfig( |
|
activation=MinMaxObserver.with_args(dtype=torch.qint8), |
|
weight=default_observer.with_args(dtype=torch.qint8)) |
|
|
|
""" |
|
|
|
__slots__ = () |
|
|
|
def __new__(cls, activation, weight): |
|
|
|
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): |
|
raise ValueError( |
|
"QConfig received observer instance, please pass observer class instead. " |
|
+ "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" |
|
) |
|
return super().__new__(cls, activation, weight) |
|
|
|
|
|
@deprecated( |
|
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", |
|
category=FutureWarning, |
|
) |
|
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): |
|
""" |
|
Describes how to dynamically quantize a layer or a part of the network by providing |
|
settings (observer classes) for weights. |
|
|
|
It's like QConfig, but for dynamic quantization. |
|
|
|
Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns |
|
instances on invocation, not the concrete observer instances themselves. |
|
Quantization function will instantiate observers multiple times for each of the layers. |
|
|
|
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` |
|
method (that behaves like functools.partial):: |
|
|
|
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) |
|
""" |
|
|
|
__slots__ = () |
|
|
|
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): |
|
|
|
if isinstance(weight, nn.Module): |
|
raise ValueError( |
|
"QConfigDynamic received observer instance, please pass observer class instead. " |
|
+ "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" |
|
) |
|
return super().__new__(cls, activation, weight) |
|
|
|
|
|
default_qconfig = QConfig(activation=default_observer, weight=default_weight_observer) |
|
""" |
|
Default qconfig configuration. |
|
""" |
|
|
|
default_debug_qconfig = QConfig( |
|
weight=default_weight_observer, activation=default_debug_observer |
|
) |
|
""" |
|
Default qconfig configuration for debugging. |
|
""" |
|
|
|
default_per_channel_qconfig = QConfig( |
|
activation=default_observer, weight=default_per_channel_weight_observer |
|
) |
|
""" |
|
Default qconfig configuration for per channel weight quantization. |
|
""" |
|
|
|
default_dynamic_qconfig = QConfig( |
|
activation=default_dynamic_quant_observer, weight=default_weight_observer |
|
) |
|
""" |
|
Default dynamic qconfig. |
|
""" |
|
|
|
float16_dynamic_qconfig = QConfig( |
|
activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), |
|
weight=PlaceholderObserver.with_args(dtype=torch.float16), |
|
) |
|
""" |
|
Dynamic qconfig with weights quantized to `torch.float16`. |
|
""" |
|
|
|
float16_static_qconfig = QConfig( |
|
activation=PlaceholderObserver.with_args(dtype=torch.float16), |
|
weight=PlaceholderObserver.with_args(dtype=torch.float16), |
|
) |
|
""" |
|
Dynamic qconfig with both activations and weights quantized to `torch.float16`. |
|
""" |
|
|
|
per_channel_dynamic_qconfig = QConfig( |
|
activation=default_dynamic_quant_observer, |
|
weight=default_per_channel_weight_observer, |
|
) |
|
""" |
|
Dynamic qconfig with weights quantized per channel. |
|
""" |
|
|
|
float_qparams_weight_only_qconfig = QConfig( |
|
activation=default_placeholder_observer, weight=default_float_qparams_observer |
|
) |
|
""" |
|
Dynamic qconfig with weights quantized with a floating point zero_point. |
|
""" |
|
|
|
float_qparams_weight_only_qconfig_4bit = QConfig( |
|
activation=default_placeholder_observer, weight=default_float_qparams_observer_4bit |
|
) |
|
|
|
default_qat_qconfig = QConfig( |
|
activation=default_fake_quant, weight=default_weight_fake_quant |
|
) |
|
""" |
|
Default qconfig for QAT. |
|
""" |
|
|
|
default_dynamic_qat_qconfig = QConfig( |
|
activation=default_dynamic_fake_quant, weight=default_weight_fake_quant |
|
) |
|
""" |
|
Default qconfig for dynamic QAT. |
|
""" |
|
|
|
default_weight_only_qconfig = QConfig( |
|
activation=torch.nn.Identity, weight=default_weight_fake_quant |
|
) |
|
""" |
|
Default qconfig for quantizing weights only. |
|
""" |
|
|
|
default_activation_only_qconfig = QConfig( |
|
activation=default_fake_quant, weight=torch.nn.Identity |
|
) |
|
""" |
|
Default qconfig for quantizing activations only. |
|
""" |
|
|
|
|
|
|
|
default_qat_qconfig_v2 = QConfig( |
|
activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant |
|
) |
|
""" |
|
Fused version of `default_qat_config`, has performance benefits. |
|
""" |
|
|
|
default_reuse_input_qconfig = QConfig( |
|
activation=default_reuse_input_observer, weight=NoopObserver |
|
) |
|
""" |
|
Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape |
|
""" |
|
|
|
|
|
def get_default_qconfig(backend="x86", version=0): |
|
""" |
|
Returns the default PTQ qconfig for the specified backend. |
|
|
|
Args: |
|
* `backend` (str): a string representing the target backend. Currently supports |
|
`x86` (default), `fbgemm`, `qnnpack` and `onednn`. |
|
|
|
Return: |
|
qconfig |
|
""" |
|
supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] |
|
if backend not in supported_backends: |
|
raise AssertionError( |
|
"backend: " |
|
+ str(backend) |
|
+ f" not supported. backend must be one of {supported_backends}" |
|
) |
|
|
|
if version == 0: |
|
if backend == "fbgemm": |
|
qconfig = QConfig( |
|
activation=HistogramObserver.with_args(reduce_range=True), |
|
weight=default_per_channel_weight_observer, |
|
) |
|
elif backend == "qnnpack": |
|
|
|
qconfig = QConfig( |
|
activation=HistogramObserver.with_args(reduce_range=False), |
|
weight=default_weight_observer, |
|
) |
|
elif backend == "onednn": |
|
if not torch.cpu._is_vnni_supported(): |
|
warnings.warn( |
|
"Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " |
|
"on CPU without Vector Neural Network Instruction support." |
|
) |
|
qconfig = QConfig( |
|
activation=HistogramObserver.with_args(reduce_range=False), |
|
weight=default_per_channel_weight_observer, |
|
) |
|
elif backend == "x86": |
|
qconfig = QConfig( |
|
activation=HistogramObserver.with_args(reduce_range=True), |
|
weight=default_per_channel_weight_observer, |
|
) |
|
else: |
|
|
|
qconfig = default_qconfig |
|
else: |
|
raise AssertionError( |
|
"Version number: " |
|
+ str(version) |
|
+ " in get_default_qconfig is not supported. Version number must be 0" |
|
) |
|
|
|
return qconfig |
|
|
|
|
|
""" |
|
Default, symmetric PTQ qconfig for the specified backend. And a per_channel |
|
variant of the same. |
|
|
|
Symmetric here applies to signed weights with zero point = 0, and additional |
|
value restrictions. The activations are also signed 8-bit integers with this |
|
qconfig. |
|
|
|
* Once this change is merged [as of 3/17/22], with backend or qengine = |
|
'qnnpack', some quantized operators with this symmetric qconfig may use |
|
operators from xnnpack library. |
|
|
|
** Support to use xnnpack ops with `qnnpack` backed for asymmetric |
|
qconfig (returned by get_default_qconfig()) is not available yet. |
|
|
|
* This qconfig uses signed activations and weights. Weights have added |
|
restrictions such as zero point is forced to be 0, making the weights |
|
symmetric, hence the name. And the 8-bit quantized values are |
|
restricting to to [-127, +127], excluding -128. |
|
|
|
* xnnpack has a requantization scale value restriction, 0x1p-32 <= |
|
requantization_scale < 256.0 where, `requantization_scale = (input_scale |
|
* kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value |
|
of 256) is to prevent requantization_scale to go below xnnpack lower |
|
threshold. |
|
""" |
|
default_symmetric_qnnpack_qconfig = QConfig( |
|
activation=HistogramObserver.with_args( |
|
dtype=torch.qint8, reduce_range=False, eps=2**-12 |
|
), |
|
weight=weight_observer_range_neg_127_to_127, |
|
) |
|
|
|
default_per_channel_symmetric_qnnpack_qconfig = QConfig( |
|
activation=HistogramObserver.with_args( |
|
dtype=torch.qint8, reduce_range=False, eps=2**-12 |
|
), |
|
weight=per_channel_weight_observer_range_neg_127_to_127, |
|
) |
|
|
|
default_embedding_qat_qconfig = QConfig( |
|
activation=NoopObserver.with_args(dtype=torch.float32), |
|
weight=default_embedding_fake_quant, |
|
) |
|
|
|
default_embedding_qat_qconfig_4bit = QConfig( |
|
activation=NoopObserver.with_args(dtype=torch.float32), |
|
weight=default_embedding_fake_quant_4bit, |
|
) |
|
|
|
default_quint8_weight_qconfig = QConfig( |
|
activation=HistogramObserver, weight=MinMaxObserver |
|
) |
|
|
|
|
|
def get_default_qat_qconfig(backend="x86", version=1): |
|
""" |
|
Returns the default QAT qconfig for the specified backend. |
|
|
|
Args: |
|
* `backend` (str): a string representing the target backend. Currently supports |
|
`x86` (default), `fbgemm`, `qnnpack` and `onednn`. |
|
* `version`: version, for backwards compatibility. Can be `None` or `1`. |
|
|
|
Return: |
|
qconfig |
|
""" |
|
supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] |
|
if backend not in supported_backends: |
|
raise AssertionError( |
|
"backend: " |
|
+ str(backend) |
|
+ f" not supported. backend must be one of {supported_backends}" |
|
) |
|
|
|
|
|
if version == 0: |
|
if backend == "fbgemm": |
|
qconfig = QConfig( |
|
activation=FakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=0, |
|
quant_max=255, |
|
reduce_range=True, |
|
), |
|
weight=default_per_channel_weight_fake_quant, |
|
) |
|
elif backend == "qnnpack": |
|
qconfig = QConfig( |
|
activation=FakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=0, |
|
quant_max=255, |
|
reduce_range=False, |
|
), |
|
weight=default_weight_fake_quant, |
|
) |
|
elif backend == "onednn": |
|
qconfig = QConfig( |
|
activation=FakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 |
|
), |
|
weight=default_per_channel_weight_fake_quant, |
|
) |
|
elif backend == "x86": |
|
qconfig = QConfig( |
|
activation=FakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=0, |
|
quant_max=255, |
|
reduce_range=True, |
|
), |
|
weight=default_per_channel_weight_fake_quant, |
|
) |
|
else: |
|
qconfig = default_qat_qconfig |
|
|
|
elif version == 1: |
|
if backend == "fbgemm": |
|
qconfig = QConfig( |
|
activation=FusedMovingAvgObsFakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=0, |
|
quant_max=255, |
|
reduce_range=True, |
|
), |
|
weight=default_fused_per_channel_wt_fake_quant, |
|
) |
|
elif backend == "qnnpack": |
|
|
|
qconfig = QConfig( |
|
activation=FusedMovingAvgObsFakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=0, |
|
quant_max=255, |
|
reduce_range=False, |
|
), |
|
weight=default_fused_wt_fake_quant, |
|
) |
|
elif backend == "onednn": |
|
qconfig = QConfig( |
|
activation=FusedMovingAvgObsFakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 |
|
), |
|
weight=default_fused_per_channel_wt_fake_quant, |
|
) |
|
elif backend == "x86": |
|
qconfig = QConfig( |
|
activation=FusedMovingAvgObsFakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=0, |
|
quant_max=255, |
|
reduce_range=True, |
|
), |
|
weight=default_fused_per_channel_wt_fake_quant, |
|
) |
|
else: |
|
qconfig = default_qat_qconfig_v2 |
|
else: |
|
raise AssertionError( |
|
"Version number: " |
|
+ str(version) |
|
+ "in get_default_qat_qconfig is not supported. Version number must be 0 or 1" |
|
) |
|
|
|
return qconfig |
|
|
|
|
|
""" |
|
Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. |
|
""" |
|
default_symmetric_qnnpack_qat_qconfig = QConfig( |
|
activation=FusedMovingAvgObsFakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=-128, |
|
quant_max=127, |
|
dtype=torch.qint8, |
|
reduce_range=False, |
|
eps=2**-12, |
|
), |
|
weight=fused_wt_fake_quant_range_neg_127_to_127, |
|
) |
|
|
|
default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( |
|
activation=FusedMovingAvgObsFakeQuantize.with_args( |
|
observer=MovingAverageMinMaxObserver, |
|
quant_min=-128, |
|
quant_max=127, |
|
dtype=torch.qint8, |
|
reduce_range=False, |
|
eps=2**-12, |
|
), |
|
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127, |
|
) |
|
|
|
_default_fp32_placeholder_qconfig = QConfig( |
|
activation=PlaceholderObserver.with_args(dtype=torch.float32), |
|
weight=PlaceholderObserver.with_args(dtype=torch.float32), |
|
) |
|
|
|
_default_quint8_placeholder_qconfig = QConfig( |
|
activation=PlaceholderObserver.with_args(dtype=torch.quint8), |
|
|
|
weight=None, |
|
) |
|
|
|
|
|
@deprecated( |
|
"`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " |
|
"a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", |
|
category=FutureWarning, |
|
) |
|
def get_default_qconfig_dict(backend="x86", version=0): |
|
return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() |
|
|
|
|
|
@deprecated( |
|
"`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " |
|
"a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", |
|
category=FutureWarning, |
|
) |
|
def get_default_qat_qconfig_dict(backend="x86", version=1): |
|
return torch.ao.quantization.get_default_qat_qconfig_mapping( |
|
backend, version |
|
).to_dict() |
|
|
|
|
|
def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> None: |
|
""" |
|
Verifies that this `qconfig` is valid. |
|
""" |
|
if qconfig is None: |
|
return |
|
is_conv_transpose_mod = isinstance( |
|
mod, |
|
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d), |
|
) |
|
if is_conv_transpose_mod: |
|
if qconfig.weight is None: |
|
|
|
return |
|
example_observer = qconfig.weight() |
|
is_per_channel = isinstance( |
|
example_observer, |
|
( |
|
torch.ao.quantization.PerChannelMinMaxObserver, |
|
torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, |
|
), |
|
) |
|
assert ( |
|
not is_per_channel |
|
), "Per channel weight observer is not supported yet for ConvTranspose{n}d." |
|
|
|
|
|
QConfigAny = Optional[QConfig] |
|
QConfigAny.__module__ = "torch.ao.quantization.qconfig" |
|
|
|
|
|
def _add_module_to_qconfig_obs_ctr( |
|
qconfig: QConfigAny, module: Optional[nn.Module] |
|
) -> Any: |
|
r"""This is a helper function for use in quantization prepare that updates a qconfig so that |
|
the constructors stored in the qconfig will create observers on the same device that |
|
'module' is on. This is intended to be used when the qconfigs are propagated to each |
|
module in order to avoid potential device alignment issues. |
|
|
|
Args: |
|
qconfig: QConfig with obs constructors stored in activation and weight |
|
module: module which the qconfig is related to |
|
|
|
Return: |
|
qconfig: configured so that obs constructors set to construct on the same device as module |
|
""" |
|
|
|
if module is None or qconfig is None or qconfig._fields != ("activation", "weight"): |
|
return qconfig |
|
|
|
def get_factory_kwargs_based_on_module_device(): |
|
assert isinstance(module, torch.nn.Module) |
|
devices = {p.device for p in module.parameters()} | { |
|
p.device for p in module.buffers() |
|
} |
|
device = next(iter(devices)) if len(devices) > 0 else None |
|
return None if device is None else {"device": device} |
|
|
|
def configure_constructor_to_put_obs_on_module_device(original_constructor): |
|
try: |
|
|
|
check = original_constructor.with_args(factory_kwargs=None) |
|
check() |
|
return original_constructor.with_callable_args( |
|
factory_kwargs=get_factory_kwargs_based_on_module_device |
|
) |
|
except AttributeError: |
|
return original_constructor |
|
except TypeError: |
|
return original_constructor |
|
|
|
activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation) |
|
weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight) |
|
|
|
return QConfig(activation, weight) |
|
|
|
|
|
_ObserverOrFakeQuantizeConstructor = Union[ |
|
_PartialWrapper, type[ObserverBase], type[FakeQuantizeBase] |
|
] |
|
|
|
|
|
def _obs_or_fq_ctr_equals( |
|
obs_or_fq1: _ObserverOrFakeQuantizeConstructor, |
|
obs_or_fq2: _ObserverOrFakeQuantizeConstructor, |
|
): |
|
if isinstance(obs_or_fq1, _PartialWrapper) and isinstance( |
|
obs_or_fq2, _PartialWrapper |
|
): |
|
return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) |
|
return obs_or_fq1 == obs_or_fq2 |
|
|
|
|
|
def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): |
|
""" |
|
Return whether the two partial wrappers are equal, |
|
""" |
|
|
|
obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) |
|
obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) |
|
keywords_equal = True |
|
|
|
if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: |
|
keywords_equal = keywords_equal and _obs_or_fq_ctr_equals( |
|
obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"] |
|
) |
|
obs_or_fq1_keywords.pop("observer") |
|
obs_or_fq2_keywords.pop("observer") |
|
keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords |
|
return ( |
|
obs_or_fq1.p.func == obs_or_fq2.p.func |
|
and obs_or_fq1.p.args == obs_or_fq2.p.args |
|
and keywords_equal |
|
) |
|
|
|
|
|
def qconfig_equals(q1: QConfigAny, q2: QConfigAny): |
|
""" |
|
Returns `True` if `q1` equals `q2`, and `False` otherwise. |
|
""" |
|
if q1 is None or q2 is None: |
|
return q1 == q2 |
|
else: |
|
assert q1 is not None and q2 is not None |
|
try: |
|
|
|
|
|
|
|
activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) |
|
weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) |
|
return activation_same and weight_same |
|
except AttributeError: |
|
return q1 == q2 |
|
|
|
|
|
def _activation_is_memoryless(qconfig: QConfig): |
|
""" |
|
Return whether the observer for activations defined in the given QConfig is memoryless. |
|
This means a MovingAverage observer with averaging constant equal to 1. |
|
""" |
|
|
|
def _is_memoryless(observer): |
|
return ( |
|
hasattr(observer, "averaging_constant") and observer.averaging_constant == 1 |
|
) |
|
|
|
act = qconfig.activation() |
|
if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"): |
|
return _is_memoryless(act.activation_post_process) |
|
else: |
|
return _is_memoryless(act) |
|
|
|
|
|
def _is_reuse_input_qconfig(qconfig: Optional[QConfig]): |
|
return ( |
|
qconfig is not None |
|
and isinstance(qconfig.activation(), ReuseInputObserver) |
|
and isinstance(qconfig.weight(), NoopObserver) |
|
) |
|
|