|
|
|
"""This file exports ONNX ops for opset 20. |
|
|
|
Note [ONNX Operators that are added/updated in opset 20] |
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set |
|
New operators: |
|
AffineGrid |
|
ConstantOfShape |
|
DFT |
|
Gelu |
|
GridSample |
|
ImageDecoder |
|
IsInf |
|
IsNaN |
|
ReduceMax |
|
ReduceMin |
|
RegexFullMatch |
|
StringConcat |
|
StringSplit |
|
""" |
|
|
|
import functools |
|
|
|
import torch.nn.functional as F |
|
from torch import _C |
|
from torch.onnx import symbolic_helper |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] |
|
|
|
|
|
def convert_grid_sample_mode(mode_s): |
|
return ( |
|
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s |
|
) |
|
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) |
|
|
|
|
|
@_onnx_symbolic("aten::grid_sampler") |
|
@symbolic_helper.parse_args("v", "v", "i", "i", "b") |
|
def _grid_sampler( |
|
g: jit_utils.GraphContext, |
|
input: _C.Value, |
|
grid: _C.Value, |
|
mode_enum: int, |
|
padding_mode_enum: int, |
|
align_corners: bool, |
|
): |
|
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] |
|
|
|
mode_s = convert_grid_sample_mode(mode_s) |
|
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ |
|
padding_mode_enum |
|
] |
|
return g.op( |
|
"GridSample", |
|
input, |
|
grid, |
|
align_corners_i=int(align_corners), |
|
mode_s=mode_s, |
|
padding_mode_s=padding_mode_s, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::affine_grid_generator") |
|
@symbolic_helper.parse_args("v", "v", "b") |
|
def _affine_grid_generator( |
|
g: jit_utils.GraphContext, |
|
theta: _C.Value, |
|
size: _C.Value, |
|
align_corners: bool, |
|
): |
|
return g.op( |
|
"AffineGrid", |
|
theta, |
|
size, |
|
align_corners_i=int(align_corners), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::gelu") |
|
@symbolic_helper.parse_args("v", "s") |
|
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): |
|
return g.op("Gelu", self, approximate_s=approximate) |
|
|