|
|
|
"""This file exports ONNX ops for opset 16. |
|
|
|
Note [ONNX Operators that are added/updated in opset 16] |
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set |
|
New operators: |
|
GridSample https://github.com/onnx/onnx/pull/3557 |
|
|
|
Updated operators: |
|
Identity |
|
If |
|
LeakyRelu |
|
Loop |
|
PRelu |
|
RoiAlign |
|
Scan |
|
ScatterElements |
|
ScatterND |
|
Where |
|
GreaterOrEqual |
|
LessOrEqual |
|
""" |
|
|
|
|
|
|
|
|
|
import functools |
|
|
|
import torch |
|
from torch.nn.functional import ( |
|
GRID_SAMPLE_INTERPOLATION_MODES, |
|
GRID_SAMPLE_PADDING_MODES, |
|
) |
|
from torch.onnx import _type_utils, errors, symbolic_helper, utils |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) |
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("aten::grid_sampler") |
|
@symbolic_helper.parse_args("v", "v", "i", "i", "b") |
|
def grid_sampler( |
|
g: jit_utils.GraphContext, |
|
input, |
|
grid, |
|
mode_enum, |
|
padding_mode_enum, |
|
align_corners, |
|
): |
|
|
|
if symbolic_helper._get_tensor_rank(input) == 5: |
|
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") |
|
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] |
|
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ |
|
padding_mode_enum |
|
] |
|
return g.op( |
|
"GridSample", |
|
input, |
|
grid, |
|
align_corners_i=int(align_corners), |
|
mode_s=mode_s, |
|
padding_mode_s=padding_mode_s, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::scatter_add") |
|
@symbolic_helper.parse_args("v", "i", "v", "v") |
|
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): |
|
src_type = _type_utils.JitScalarType.from_value( |
|
src, _type_utils.JitScalarType.UNDEFINED |
|
) |
|
src_sizes = symbolic_helper._get_tensor_sizes(src) |
|
index_sizes = symbolic_helper._get_tensor_sizes(index) |
|
|
|
if len(src_sizes) != len(index_sizes): |
|
return symbolic_helper._unimplemented( |
|
"scatter_add", |
|
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if src_sizes != index_sizes or None in index_sizes: |
|
adjusted_shape = g.op("Shape", index) |
|
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) |
|
src = g.op("Slice", src, starts, adjusted_shape) |
|
|
|
src = symbolic_helper._maybe_get_scalar(src) |
|
if symbolic_helper._is_value(src): |
|
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") |
|
else: |
|
|
|
|
|
if _type_utils.JitScalarType.from_value(self) != src_type: |
|
src = g.op( |
|
"Cast", |
|
src, |
|
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), |
|
) |
|
|
|
return g.op( |
|
"ScatterElements", |
|
self, |
|
index, |
|
src, |
|
axis_i=dim, |
|
reduction_s="add", |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::scatter_reduce") |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b") |
|
def scatter_reduce( |
|
g: jit_utils.GraphContext, |
|
self: torch._C.Value, |
|
dim: int, |
|
index: torch._C.Value, |
|
src: torch._C.Value, |
|
reduce: str, |
|
include_self: bool, |
|
): |
|
if reduce == "mean": |
|
raise errors.OnnxExporterError( |
|
"ONNX does not support mean reduction for scatter_reduce" |
|
) |
|
if not include_self: |
|
raise errors.OnnxExporterError( |
|
"ONNX does not support include_self=False for scatter_reduce" |
|
) |
|
|
|
reduce_mode = { |
|
"mean": "none", |
|
"sum": "add", |
|
"prod": "mul", |
|
"amin": "min", |
|
"amax": "max", |
|
} |
|
onnx_reduce = reduce_mode[reduce] |
|
|
|
self_rank = g.op("Size", g.op("Shape", self)) |
|
|
|
|
|
self_rank_is_zero = g.op( |
|
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) |
|
) |
|
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( |
|
g, "If", self_rank_is_zero, n_blocks=2, outputs=3 |
|
) |
|
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) |
|
|
|
self_reshape = if_context.op("Reshape", self, neg_1) |
|
utils._add_output_to_block(if_context.block, self_reshape) |
|
index_reshape = if_context.op("Reshape", index, neg_1) |
|
utils._add_output_to_block(if_context.block, index_reshape) |
|
src_reshape = if_context.op("Reshape", src, neg_1) |
|
utils._add_output_to_block(if_context.block, src_reshape) |
|
|
|
self_identity = else_context.op("Identity", self) |
|
utils._add_output_to_block(else_context.block, self_identity) |
|
index_identitye = else_context.op("Identity", index) |
|
utils._add_output_to_block(else_context.block, index_identitye) |
|
src_identity = else_context.op("Identity", src) |
|
utils._add_output_to_block(else_context.block, src_identity) |
|
|
|
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) |
|
|
|
|
|
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( |
|
g, "If", self_rank_is_zero, n_blocks=2, outputs=1 |
|
) |
|
result_squeezed = if_context.op("Squeeze", result) |
|
utils._add_output_to_block(if_context.block, result_squeezed) |
|
result_identity = else_context.op("Identity", result) |
|
utils._add_output_to_block(else_context.block, result_identity) |
|
result_final = if_op.node().output() |
|
|
|
return result_final |
|
|