|
from typing import Any, List, Optional, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
def _replace_relu(module: nn.Module) -> None: |
|
reassign = {} |
|
for name, mod in module.named_children(): |
|
_replace_relu(mod) |
|
|
|
|
|
|
|
if type(mod) is nn.ReLU or type(mod) is nn.ReLU6: |
|
reassign[name] = nn.ReLU(inplace=False) |
|
|
|
for key, value in reassign.items(): |
|
module._modules[key] = value |
|
|
|
|
|
def quantize_model(model: nn.Module, backend: str) -> None: |
|
_dummy_input_data = torch.rand(1, 3, 299, 299) |
|
if backend not in torch.backends.quantized.supported_engines: |
|
raise RuntimeError("Quantized backend not supported ") |
|
torch.backends.quantized.engine = backend |
|
model.eval() |
|
|
|
if backend == "fbgemm": |
|
model.qconfig = torch.ao.quantization.QConfig( |
|
activation=torch.ao.quantization.default_observer, |
|
weight=torch.ao.quantization.default_per_channel_weight_observer, |
|
) |
|
elif backend == "qnnpack": |
|
model.qconfig = torch.ao.quantization.QConfig( |
|
activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer |
|
) |
|
|
|
|
|
model.fuse_model() |
|
torch.ao.quantization.prepare(model, inplace=True) |
|
model(_dummy_input_data) |
|
torch.ao.quantization.convert(model, inplace=True) |
|
|
|
|
|
def _fuse_modules( |
|
model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any |
|
): |
|
if is_qat is None: |
|
is_qat = model.training |
|
method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules |
|
return method(model, modules_to_fuse, **kwargs) |
|
|