|
|
|
|
|
from torch import nn |
|
|
|
|
|
class QuantStub(nn.Module): |
|
r"""Quantize stub module, before calibration, this is same as an observer, |
|
it will be swapped as `nnq.Quantize` in `convert`. |
|
|
|
Args: |
|
qconfig: quantization configuration for the tensor, |
|
if qconfig is not provided, we will get qconfig from parent modules |
|
""" |
|
|
|
def __init__(self, qconfig=None): |
|
super().__init__() |
|
if qconfig: |
|
self.qconfig = qconfig |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
class DeQuantStub(nn.Module): |
|
r"""Dequantize stub module, before calibration, this is same as identity, |
|
this will be swapped as `nnq.DeQuantize` in `convert`. |
|
|
|
Args: |
|
qconfig: quantization configuration for the tensor, |
|
if qconfig is not provided, we will get qconfig from parent modules |
|
""" |
|
|
|
def __init__(self, qconfig=None): |
|
super().__init__() |
|
if qconfig: |
|
self.qconfig = qconfig |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
class QuantWrapper(nn.Module): |
|
r"""A wrapper class that wraps the input module, adds QuantStub and |
|
DeQuantStub and surround the call to module with call to quant and dequant |
|
modules. |
|
|
|
This is used by the `quantization` utility functions to add the quant and |
|
dequant modules, before `convert` function `QuantStub` will just be observer, |
|
it observes the input tensor, after `convert`, `QuantStub` |
|
will be swapped to `nnq.Quantize` which does actual quantization. Similarly |
|
for `DeQuantStub`. |
|
""" |
|
quant: QuantStub |
|
dequant: DeQuantStub |
|
module: nn.Module |
|
|
|
def __init__(self, module): |
|
super().__init__() |
|
qconfig = getattr(module, "qconfig", None) |
|
self.add_module("quant", QuantStub(qconfig)) |
|
self.add_module("dequant", DeQuantStub(qconfig)) |
|
self.add_module("module", module) |
|
self.train(module.training) |
|
|
|
def forward(self, X): |
|
X = self.quant(X) |
|
X = self.module(X) |
|
return self.dequant(X) |
|
|