|
|
|
|
|
import importlib |
|
import inspect |
|
|
|
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
def register_quantized_ops(domain: str, version: int): |
|
|
|
module = importlib.import_module("torch.onnx.symbolic_caffe2") |
|
quant_version_ops = inspect.getmembers(module) |
|
aten_q_ops = { |
|
"relu", |
|
"_empty_affine_quantized", |
|
"dequantize", |
|
"quantize_per_tensor", |
|
"upsample_nearest2d", |
|
"avg_pool2d", |
|
"reshape", |
|
"slice", |
|
"cat", |
|
"max_pool2d", |
|
"sigmoid", |
|
} |
|
for op, func in quant_version_ops: |
|
name = f"{domain}::{op}" |
|
if inspect.isfunction(func) and not registration.registry.is_registered_op( |
|
name, version |
|
): |
|
if op in aten_q_ops: |
|
|
|
registration.registry.register( |
|
f"aten::{op}", version, func, custom=True |
|
) |
|
registration.registry.register(name, version, func) |
|
|
|
|
|
def _permute_helper(g: jit_utils.GraphContext, input, axes): |
|
quant_args = { |
|
"axes_i": axes, |
|
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), |
|
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), |
|
} |
|
output = g.op("_caffe2::Int8Transpose", input, **quant_args) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
def nchw2nhwc(g: jit_utils.GraphContext, input): |
|
axes = [0, 2, 3, 1] |
|
return _permute_helper(g, input, axes) |
|
|
|
|
|
def nhwc2nchw(g: jit_utils.GraphContext, input): |
|
axes = [0, 3, 1, 2] |
|
return _permute_helper(g, input, axes) |
|
|
|
|
|
def linear_prepack(g: jit_utils.GraphContext, weight, bias): |
|
|
|
|
|
|
|
output = g.op("_caffe2::WeightPrepack", weight, bias) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "f", "i") |
|
def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): |
|
kwargs = { |
|
"Y_scale_f": scale, |
|
"Y_zero_point_i": zero_point, |
|
} |
|
output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
def conv_prepack( |
|
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups |
|
): |
|
|
|
|
|
|
|
output = g.op("_caffe2::WeightPrepack", input, weight, bias) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") |
|
def conv2d( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
scale, |
|
zero_point, |
|
): |
|
kernel_size = weight.node()["shape"][1:3] |
|
kwargs = { |
|
"strides_i": stride, |
|
"pads_i": padding + padding, |
|
"dilations_i": dilation, |
|
"group_i": groups, |
|
"kernels_i": kernel_size, |
|
"order_s": "NHWC", |
|
"Y_scale_f": scale, |
|
"Y_zero_point_i": zero_point, |
|
} |
|
output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") |
|
def conv2d_relu( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
scale, |
|
zero_point, |
|
): |
|
kernel_size = weight.node()["shape"][1:3] |
|
kwargs = { |
|
"strides_i": stride, |
|
"pads_i": padding + padding, |
|
"dilations_i": dilation, |
|
"group_i": groups, |
|
"kernels_i": kernel_size, |
|
"order_s": "NHWC", |
|
"Y_scale_f": scale, |
|
"Y_zero_point_i": zero_point, |
|
} |
|
output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "f", "i") |
|
def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): |
|
kwargs = { |
|
"Y_scale_f": scale, |
|
"Y_zero_point_i": zero_point, |
|
} |
|
output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v") |
|
def relu(g: jit_utils.GraphContext, input): |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.relu(g, input) |
|
kwargs = { |
|
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), |
|
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), |
|
} |
|
output = g.op("_caffe2::Int8Relu", input, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "f", "i", "t") |
|
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): |
|
kwargs = { |
|
"Y_scale_f": scale, |
|
"Y_zero_point_i": zero_point, |
|
} |
|
output = g.op("_caffe2::Int8Quantize", input, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v") |
|
def dequantize(g: jit_utils.GraphContext, input): |
|
return g.op("_caffe2::Int8Dequantize", input) |
|
|
|
|
|
@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") |
|
def _empty_affine_quantized( |
|
g: jit_utils.GraphContext, |
|
input, |
|
shape, |
|
scale, |
|
zero_point, |
|
dtype, |
|
pin_memory, |
|
memory_format, |
|
layout, |
|
): |
|
return input |
|
|
|
|
|
def upsample_nearest2d( |
|
g: jit_utils.GraphContext, |
|
input, |
|
output_size, |
|
align_corners=None, |
|
scales_h=None, |
|
scales_w=None, |
|
): |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.upsample_nearest2d(g, input, output_size, align_corners) |
|
|
|
output_size = symbolic_helper._parse_arg(output_size, "is") |
|
kwargs = { |
|
"output_size_i": output_size, |
|
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), |
|
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), |
|
} |
|
input = nchw2nhwc(g, input) |
|
output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) |
|
output = nhwc2nchw(g, output) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") |
|
def max_pool2d( |
|
g: jit_utils.GraphContext, |
|
input, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
ceil_mode, |
|
): |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.max_pool2d( |
|
g, input, kernel_size, stride, padding, dilation, ceil_mode |
|
) |
|
kwargs = { |
|
"strides_i": stride, |
|
"pads_i": padding + padding, |
|
"kernel_i": kernel_size[0], |
|
"order_s": "NHWC", |
|
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), |
|
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), |
|
} |
|
input = nchw2nhwc(g, input) |
|
output = g.op("_caffe2::Int8MaxPool", input, **kwargs) |
|
output = nhwc2nchw(g, output) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") |
|
def avg_pool2d( |
|
g: jit_utils.GraphContext, |
|
input, |
|
kernel_size, |
|
stride, |
|
padding, |
|
ceil_mode, |
|
count_include_pad, |
|
divisor_override=None, |
|
): |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.avg_pool2d( |
|
g, |
|
input, |
|
kernel_size, |
|
stride, |
|
padding, |
|
ceil_mode, |
|
count_include_pad, |
|
divisor_override, |
|
) |
|
kwargs = { |
|
"strides_i": stride, |
|
"pads_i": padding + padding, |
|
"kernel_i": kernel_size[0], |
|
"order_s": "NHWC", |
|
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), |
|
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), |
|
} |
|
input = nchw2nhwc(g, input) |
|
output = g.op("_caffe2::Int8AveragePool", input, **kwargs) |
|
output = nhwc2nchw(g, output) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
def reshape(g: jit_utils.GraphContext, input, shape): |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.reshape(g, input, shape) |
|
|
|
kwargs = { |
|
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), |
|
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), |
|
} |
|
output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i") |
|
def slice(g: jit_utils.GraphContext, input, dim, start, end, step): |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.slice(g, input, dim, start, end, step) |
|
|
|
if step != 1: |
|
raise RuntimeError("ONNX quantized slice export only works for step 1.") |
|
start = symbolic_helper._parse_arg(start, "i") |
|
end = symbolic_helper._parse_arg(end, "i") |
|
dim = symbolic_helper._parse_arg(dim, "i") |
|
|
|
kwargs = { |
|
"start_idx_i": start, |
|
"end_idx_i": end, |
|
"dim_i": dim, |
|
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), |
|
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), |
|
} |
|
output = g.op("_caffe2::Int8Slice", input, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): |
|
tensors = symbolic_helper._unpack_list(tensor_list) |
|
input = tensors[0] |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.cat(g, tensor_list, dim) |
|
|
|
dim = symbolic_helper._parse_arg(dim, "i") |
|
kwargs = { |
|
"Y_scale_f": tensors[0].node()["Y_scale"], |
|
"Y_zero_point_i": tensors[0].node()["Y_zero_point"], |
|
} |
|
output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|
|
|
|
@symbolic_helper.parse_args("v") |
|
def sigmoid(g: jit_utils.GraphContext, input): |
|
if input not in symbolic_helper._quantized_ops: |
|
return opset9.sigmoid(g, input) |
|
|
|
|
|
out_scale = 1.0 / 256 |
|
zero_point = 0 |
|
kwargs = { |
|
"Y_scale_f": out_scale, |
|
"Y_zero_point_i": zero_point, |
|
} |
|
output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) |
|
symbolic_helper._quantized_ops.add(output) |
|
return output |
|
|