|
|
|
|
|
|
|
"""This file exports ONNX ops for opset 9. |
|
|
|
Opset 9 is supported by ONNX release 1.4.1 |
|
release on 01/23/19 |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import builtins |
|
import functools |
|
import math |
|
import sys |
|
import warnings |
|
from typing import Callable, TYPE_CHECKING |
|
from typing_extensions import deprecated |
|
|
|
import torch |
|
import torch._C._onnx as _C_onnx |
|
import torch.nn.modules.utils |
|
import torch.onnx |
|
from torch import _C |
|
|
|
|
|
from torch.onnx import _constants, _type_utils, errors, symbolic_helper |
|
from torch.onnx._globals import GLOBALS |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Sequence |
|
|
|
from torch.types import Number |
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
"abs", |
|
"acos", |
|
"add", |
|
"addcmul", |
|
"addmm", |
|
"alias", |
|
"amax", |
|
"amin", |
|
"aminmax", |
|
"arange", |
|
"argmax", |
|
"argmin", |
|
"as_strided", |
|
"as_tensor", |
|
"asin", |
|
"atan", |
|
"atan2", |
|
"baddbmm", |
|
"batch_norm", |
|
"bernoulli", |
|
"bitwise_not", |
|
"bitwise_or", |
|
"bmm", |
|
"broadcast_tensors", |
|
"broadcast_to", |
|
"bucketize", |
|
"cat", |
|
"cdist", |
|
"ceil", |
|
"clamp_max", |
|
"clamp_min", |
|
"clamp", |
|
"clone", |
|
"constant_pad_nd", |
|
"contiguous", |
|
"conv_tbc", |
|
"conv_transpose1d", |
|
"conv_transpose2d", |
|
"conv_transpose3d", |
|
"conv1d", |
|
"conv2d", |
|
"conv3d", |
|
"convert_element_type", |
|
"convolution", |
|
"cos", |
|
"cosine_similarity", |
|
"cross", |
|
"cumsum", |
|
"detach", |
|
"dim", |
|
"div", |
|
"dot", |
|
"dropout", |
|
"elu", |
|
"embedding_bag", |
|
"embedding", |
|
"empty_like", |
|
"empty", |
|
"eq", |
|
"erf", |
|
"exp", |
|
"expand_as", |
|
"expand", |
|
"eye", |
|
"fill", |
|
"flatten", |
|
"floor_divide", |
|
"floor", |
|
"floordiv", |
|
"frobenius_norm", |
|
"full_like", |
|
"full", |
|
"gather", |
|
"ge", |
|
"gelu", |
|
"get_pool_ceil_padding", |
|
"glu", |
|
"group_norm", |
|
"gt", |
|
"hann_window", |
|
"hardshrink", |
|
"hardsigmoid", |
|
"hardswish", |
|
"hardtanh", |
|
"index_add", |
|
"index_copy", |
|
"index_fill", |
|
"index_put", |
|
"index_select", |
|
"index", |
|
"instance_norm", |
|
"is_floating_point", |
|
"is_pinned", |
|
"isnan", |
|
"item", |
|
"kl_div", |
|
"layer_norm", |
|
"le", |
|
"leaky_relu", |
|
"lerp", |
|
"lift", |
|
"linalg_cross", |
|
"linalg_matrix_norm", |
|
"linalg_norm", |
|
"linalg_vector_norm", |
|
"linear", |
|
"linspace", |
|
"log_sigmoid", |
|
"log_softmax", |
|
"log", |
|
"log10", |
|
"log1p", |
|
"log2", |
|
"logical_and", |
|
"logical_not", |
|
"logical_or", |
|
"logical_xor", |
|
"logit", |
|
"logsumexp", |
|
"lstm_cell", |
|
"lstm", |
|
"lt", |
|
"masked_fill", |
|
"masked_fill_", |
|
"matmul", |
|
"max_pool1d_with_indices", |
|
"max_pool2d_with_indices", |
|
"max_pool3d_with_indices", |
|
"max", |
|
"maximum", |
|
"meshgrid", |
|
"min", |
|
"minimum", |
|
"mish", |
|
"mm", |
|
"movedim", |
|
"mse_loss", |
|
"mul", |
|
"multinomial", |
|
"mv", |
|
"narrow", |
|
"native_layer_norm", |
|
"ne", |
|
"neg", |
|
"new_empty", |
|
"new_full", |
|
"new_ones", |
|
"new_zeros", |
|
"nonzero_numpy", |
|
"nonzero", |
|
"norm", |
|
"numel", |
|
"numpy_T", |
|
"one_hot", |
|
"ones_like", |
|
"ones", |
|
"onnx_placeholder", |
|
"pad", |
|
"pairwise_distance", |
|
"permute", |
|
"pixel_shuffle", |
|
"pixel_unshuffle", |
|
"pow", |
|
"prelu", |
|
"prim_constant_chunk", |
|
"prim_constant_split", |
|
"prim_constant", |
|
"prim_data", |
|
"prim_device", |
|
"prim_dtype", |
|
"prim_if", |
|
"prim_layout", |
|
"prim_list_construct", |
|
"prim_list_unpack", |
|
"prim_loop", |
|
"prim_max", |
|
"prim_min", |
|
"prim_shape", |
|
"prim_tolist", |
|
"prim_tuple_construct", |
|
"prim_type", |
|
"prim_unchecked_cast", |
|
"prim_uninitialized", |
|
"rand_like", |
|
"rand", |
|
"randint_like", |
|
"randint", |
|
"randn_like", |
|
"randn", |
|
"reciprocal", |
|
"reflection_pad", |
|
"relu", |
|
"relu6", |
|
"remainder", |
|
"repeat_interleave", |
|
"repeat", |
|
"replication_pad", |
|
"reshape_as", |
|
"reshape", |
|
"roll", |
|
"rrelu", |
|
"rsqrt", |
|
"rsub", |
|
"scalar_tensor", |
|
"scatter_add", |
|
"scatter", |
|
"select", |
|
"selu", |
|
"sigmoid", |
|
"sign", |
|
"silu", |
|
"sin", |
|
"size", |
|
"slice", |
|
"softmax", |
|
"softplus", |
|
"softshrink", |
|
"sort", |
|
"split_with_sizes", |
|
"split", |
|
"sqrt", |
|
"square", |
|
"squeeze", |
|
"stack", |
|
"std_mean", |
|
"std", |
|
"sub", |
|
"t", |
|
"take", |
|
"tan", |
|
"tanh", |
|
"tanhshrink", |
|
"tensor", |
|
"threshold", |
|
"to", |
|
"topk", |
|
"transpose", |
|
"true_divide", |
|
"type_as", |
|
"unbind", |
|
"unfold", |
|
"unsafe_chunk", |
|
"unsafe_split_with_sizes", |
|
"unsafe_split", |
|
"unsqueeze", |
|
"unsupported_complex_operators", |
|
"noop_complex_operators", |
|
"unused", |
|
"var_mean", |
|
"var", |
|
"view_as", |
|
"view", |
|
"where", |
|
"wrap_logical_op_with_cast_to", |
|
"wrap_logical_op_with_negation", |
|
"zeros_like", |
|
"zeros", |
|
"zero", |
|
] |
|
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) |
|
|
|
|
|
def _export(name: str): |
|
"""Exports the function in the current global namespace.""" |
|
|
|
def wrapper(func): |
|
globals()[name] = func |
|
__all__.append(name) |
|
return func |
|
|
|
return wrapper |
|
|
|
|
|
def unused(g): |
|
"""Represents "missing" optional inputs.""" |
|
n = g.op("prim::Constant") |
|
n.setType(_C.OptionalType.ofTensor()) |
|
return n |
|
|
|
|
|
@_onnx_symbolic("aten::_shape_as_tensor") |
|
def _shape_as_tensor(g: jit_utils.GraphContext, input): |
|
return g.op("Shape", input) |
|
|
|
|
|
@_onnx_symbolic("aten::_reshape_from_tensor") |
|
def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): |
|
if isinstance(shape, list): |
|
shape = g.op("Concat", *shape, axis_i=0) |
|
return reshape(g, input, shape) |
|
|
|
|
|
@_onnx_symbolic("aten::reshape") |
|
@symbolic_helper.quantized_args(True) |
|
def reshape(g: jit_utils.GraphContext, self, shape): |
|
return symbolic_helper._reshape_helper(g, self, shape) |
|
|
|
|
|
@_onnx_symbolic("aten::reshape_as") |
|
@symbolic_helper.quantized_args(True) |
|
def reshape_as(g: jit_utils.GraphContext, self, other): |
|
shape = g.op("Shape", other) |
|
return reshape(g, self, shape) |
|
|
|
|
|
@_onnx_symbolic("aten::add") |
|
def add(g: jit_utils.GraphContext, self, other, alpha=None): |
|
""" |
|
This function takes the add function and returns the corresponding ONNX operator. |
|
|
|
This function is not meant to be called directly by the user. |
|
|
|
Args: |
|
g (GraphContext): The graph context. |
|
self (Tensor): The first operand. |
|
other (Tensor): The second operand. |
|
alpha (float, optional): The scaling factor for the second operand. Defaults to None. |
|
|
|
Returns: |
|
ONNX operator. |
|
""" |
|
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"Add", 9, 11, "Add between list of tensors not supported", self |
|
) |
|
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: |
|
other = g.op("Mul", other, alpha) |
|
return g.op("Add", self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::sub") |
|
def sub(g: jit_utils.GraphContext, self, other, alpha=None): |
|
""" |
|
Consumes sub function and returns the corresponding ONNX operator. |
|
|
|
This function is not meant to be called directly by the user. |
|
|
|
Args: |
|
g (GraphContext): The graph context. |
|
self (Tensor): The first operand. |
|
other (Tensor): The second operand. |
|
alpha (Optional[Tensor]): A scaling factor to apply to the second operand. |
|
If `alpha` is not provided, it defaults to 1. |
|
|
|
Returns: |
|
ONNX operator |
|
""" |
|
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: |
|
other = g.op("Mul", other, alpha) |
|
return g.op("Sub", self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::rsub") |
|
def rsub(g: jit_utils.GraphContext, self, other, alpha=None): |
|
return sub(g, other, self, alpha=alpha) |
|
|
|
|
|
@_onnx_symbolic("aten::mul") |
|
def mul(g: jit_utils.GraphContext, self, other): |
|
if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): |
|
|
|
return g.op("And", self, other) |
|
else: |
|
return g.op("Mul", self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::div") |
|
def div(g: jit_utils.GraphContext, self, other, *args): |
|
if len(args) == 0: |
|
return true_divide(g, self, other) |
|
else: |
|
return _div_rounding_mode(g, self, other, *args) |
|
|
|
|
|
@_onnx_symbolic("aten::addcmul") |
|
@symbolic_helper.parse_args("v", "v", "v", "f") |
|
def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): |
|
value_tens = g.op("Constant", value_t=torch.tensor([value])) |
|
return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "s") |
|
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): |
|
if rounding_mode is None: |
|
return true_divide(g, self, other) |
|
elif rounding_mode == "floor": |
|
return _floor_divide(g, self, other) |
|
elif rounding_mode == "trunc": |
|
return _trunc_divide(g, self, other) |
|
else: |
|
raise errors.SymbolicValueError( |
|
f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', |
|
self, |
|
) |
|
|
|
|
|
def _trunc_divide(g: jit_utils.GraphContext, self, other): |
|
out = g.op("Div", self, other) |
|
|
|
|
|
|
|
|
|
|
|
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
self, _type_utils.JitScalarType.UNDEFINED |
|
) |
|
if scalar_type != _type_utils.JitScalarType.UNDEFINED: |
|
if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): |
|
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
else: |
|
out = g.op( |
|
"Cast", |
|
out, |
|
to_i=scalar_type.onnx_type(), |
|
) |
|
else: |
|
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
return out |
|
|
|
|
|
def _floor_divide(g: jit_utils.GraphContext, self, other): |
|
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): |
|
out = true_divide(g, self, other) |
|
return g.op("Floor", out) |
|
else: |
|
|
|
div = g.op("Div", self, other) |
|
|
|
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) |
|
negative = g.op( |
|
"Xor", |
|
symbolic_helper._lt_helper(g, self, zero), |
|
symbolic_helper._lt_helper(g, other, zero), |
|
) |
|
|
|
|
|
mod = g.op("Sub", self, g.op("Mul", div, other)) |
|
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) |
|
|
|
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
|
fixup = g.op("Mul", fixup_mask, one) |
|
return g.op("Sub", div, fixup) |
|
|
|
|
|
@_onnx_symbolic("aten::floor_divide") |
|
def floor_divide(g: jit_utils.GraphContext, self, other): |
|
|
|
return _trunc_divide(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::floordiv") |
|
def floordiv(g: jit_utils.GraphContext, self, other): |
|
return floor_divide(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::true_divide") |
|
def true_divide(g: jit_utils.GraphContext, self, other): |
|
"""Division where both inputs are cast to floating types |
|
|
|
If both inputs are floating, performs div as usual |
|
If only one input is a floating type, the other input is cast to its type |
|
If neither input is a floating type, both inputs are cast to the default scalar type |
|
""" |
|
|
|
|
|
|
|
|
|
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): |
|
return g.op("Div", self, other) |
|
|
|
|
|
|
|
scalar_type = torch.get_default_dtype() |
|
onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT |
|
assert scalar_type is torch.float or scalar_type is torch.double |
|
if torch.get_default_dtype() is torch.double: |
|
onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE |
|
|
|
self = g.op("Cast", self, to_i=onnx_scalar_type) |
|
other = g.op("Cast", other, to_i=onnx_scalar_type) |
|
return g.op("Div", self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::reciprocal") |
|
def reciprocal(g: jit_utils.GraphContext, self): |
|
|
|
if not symbolic_helper._is_fp(self): |
|
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
return g.op("Reciprocal", self) |
|
|
|
|
|
@_onnx_symbolic("aten::cat") |
|
@symbolic_helper.parse_args("v", "i") |
|
def cat(g: jit_utils.GraphContext, tensor_list, dim): |
|
"""Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. |
|
|
|
Parameters: |
|
g (jit_utils.GraphContext): Graph context. |
|
tensor_list (List[torch.Tensor]): List of tensors to concatenate. |
|
dim (int): Dimension along which to concatenate the tensors. |
|
|
|
Returns: |
|
ONNX graph node representing the concatenated tensor. |
|
""" |
|
tensors = symbolic_helper._unpack_list(tensor_list) |
|
|
|
|
|
|
|
nonempty_tensors = [] |
|
for t in tensors: |
|
if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( |
|
t, 0 |
|
): |
|
continue |
|
nonempty_tensors.append(t) |
|
assert len(nonempty_tensors) > 0 |
|
assert all( |
|
symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None |
|
or symbolic_helper._get_tensor_rank(t) is None |
|
or symbolic_helper._get_tensor_rank(t) |
|
== symbolic_helper._get_tensor_rank(nonempty_tensors[0]) |
|
for t in nonempty_tensors |
|
) |
|
tensor_list.node().removeAllInputs() |
|
for t in nonempty_tensors: |
|
tensor_list.node().addInput(t) |
|
|
|
tensors = symbolic_helper._unpack_list(tensor_list) |
|
return g.op("Concat", *tensors, axis_i=dim) |
|
|
|
|
|
@_onnx_symbolic("aten::stack") |
|
@symbolic_helper.parse_args("v", "i") |
|
def stack(g: jit_utils.GraphContext, tensor_list, dim): |
|
unsqueezed = [ |
|
symbolic_helper._unsqueeze_helper(g, t, [dim]) |
|
for t in symbolic_helper._unpack_list(tensor_list) |
|
] |
|
return g.op("Concat", *unsqueezed, axis_i=dim) |
|
|
|
|
|
@_onnx_symbolic("aten::list") |
|
def _list(g: jit_utils.GraphContext, self): |
|
return self |
|
|
|
|
|
@_onnx_symbolic("aten::mm") |
|
def mm(g: jit_utils.GraphContext, self, other): |
|
|
|
|
|
C = g.op("Constant", value_t=torch.tensor([1])) |
|
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) |
|
|
|
|
|
@_onnx_symbolic("aten::bmm") |
|
def bmm(g: jit_utils.GraphContext, self, other): |
|
return g.op("MatMul", self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::matmul") |
|
def matmul(g: jit_utils.GraphContext, self, other): |
|
return g.op("MatMul", self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::addmm") |
|
@symbolic_helper.parse_args("v", "v", "v", "t", "t") |
|
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): |
|
scalar_type = None |
|
self_scalar_type = symbolic_helper._try_get_scalar_type(self) |
|
mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) |
|
mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) |
|
if self_scalar_type is not None: |
|
scalar_type = self_scalar_type |
|
elif mat1_scalar_type is not None: |
|
scalar_type = mat1_scalar_type |
|
elif mat2_scalar_type is not None: |
|
scalar_type = mat2_scalar_type |
|
|
|
mat1_rank = symbolic_helper._get_tensor_rank(mat1) |
|
mat2_rank = symbolic_helper._get_tensor_rank(mat2) |
|
|
|
def is_not_none_nor(v, u): |
|
return v is not None and v != u |
|
|
|
if scalar_type is not None and ( |
|
is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) |
|
): |
|
res1 = g.op("MatMul", mat1, mat2) |
|
res2 = self |
|
|
|
alpha = symbolic_helper._scalar(alpha) |
|
beta = symbolic_helper._scalar(beta) |
|
|
|
if alpha != 1: |
|
alpha = g.op( |
|
"Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) |
|
) |
|
res1 = g.op("Mul", res1, alpha) |
|
if beta != 1: |
|
beta = g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
symbolic_helper._scalar(beta), dtype=scalar_type.dtype() |
|
), |
|
) |
|
res2 = g.op("Mul", res2, beta) |
|
|
|
return g.op("Add", res1, res2) |
|
|
|
return g.op( |
|
"Gemm", |
|
mat1, |
|
mat2, |
|
self, |
|
beta_f=symbolic_helper._scalar(beta), |
|
alpha_f=symbolic_helper._scalar(alpha), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::neg") |
|
def neg(g: jit_utils.GraphContext, self): |
|
return g.op("Neg", self) |
|
|
|
|
|
@_onnx_symbolic("aten::sqrt") |
|
def sqrt(g: jit_utils.GraphContext, self): |
|
if _type_utils.JitScalarType.from_value( |
|
self, _type_utils.JitScalarType.UNDEFINED |
|
) in { |
|
_type_utils.JitScalarType.UINT8, |
|
_type_utils.JitScalarType.INT8, |
|
_type_utils.JitScalarType.INT16, |
|
_type_utils.JitScalarType.INT, |
|
_type_utils.JitScalarType.INT64, |
|
}: |
|
|
|
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
|
|
return g.op("Sqrt", self) |
|
|
|
|
|
@_onnx_symbolic("aten::rsqrt") |
|
def rsqrt(g: jit_utils.GraphContext, self): |
|
return g.op( |
|
"Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::tanh") |
|
|
|
@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) |
|
def tanh(g: jit_utils.GraphContext, self): |
|
return g.op("Tanh", self) |
|
|
|
|
|
@_onnx_symbolic("aten::sin") |
|
def sin(g: jit_utils.GraphContext, self): |
|
return g.op("Sin", self) |
|
|
|
|
|
@_onnx_symbolic("aten::cos") |
|
def cos(g: jit_utils.GraphContext, self): |
|
return g.op("Cos", self) |
|
|
|
|
|
@_onnx_symbolic("aten::tan") |
|
def tan(g: jit_utils.GraphContext, self): |
|
return g.op("Tan", self) |
|
|
|
|
|
@_onnx_symbolic("aten::asin") |
|
def asin(g: jit_utils.GraphContext, self): |
|
return g.op("Asin", self) |
|
|
|
|
|
@_onnx_symbolic("aten::acos") |
|
def acos(g: jit_utils.GraphContext, self): |
|
return g.op("Acos", self) |
|
|
|
|
|
@_onnx_symbolic("aten::atan") |
|
def atan(g: jit_utils.GraphContext, self): |
|
return g.op("Atan", self) |
|
|
|
|
|
@_onnx_symbolic("aten::atan2") |
|
def atan2(g: jit_utils.GraphContext, self, other): |
|
|
|
slope = g.op("Div", self, other) |
|
atan = g.op("Atan", slope) |
|
const_zero = g.op("Constant", value_t=torch.tensor(0)) |
|
const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) |
|
|
|
condition_second_or_third_quadrant = g.op("Greater", self, const_zero) |
|
second_third_quadrant = g.op( |
|
"Where", |
|
condition_second_or_third_quadrant, |
|
g.op("Add", atan, const_pi), |
|
g.op("Sub", atan, const_pi), |
|
) |
|
|
|
condition_14_or_23_quadrant = g.op("Less", other, const_zero) |
|
result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) |
|
|
|
return result |
|
|
|
|
|
@_onnx_symbolic("aten::sigmoid") |
|
|
|
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) |
|
def sigmoid(g: jit_utils.GraphContext, self): |
|
"""Converts the corresponding PyTorch function into ONNX operators. |
|
|
|
It is not meant to be called directly by a user. |
|
|
|
Args: |
|
g (jit_utils.GraphContext): Graph context. |
|
self (Tensor): the input tensor. |
|
Returns: |
|
ONNX operator |
|
""" |
|
return g.op("Sigmoid", self) |
|
|
|
|
|
@_onnx_symbolic("aten::sign") |
|
def sign(g: jit_utils.GraphContext, self): |
|
return g.op("Sign", self) |
|
|
|
|
|
@symbolic_helper.quantized_args(True) |
|
def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): |
|
assert len(starts) == len(ends) |
|
if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: |
|
return input |
|
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] |
|
) |
|
@_onnx_symbolic( |
|
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] |
|
) |
|
|
|
@_onnx_symbolic( |
|
"aten::prod", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"ReduceProd", "prod", allow_multi_dim_support=False |
|
) |
|
], |
|
) |
|
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): |
|
return symbolic_helper._reduce_with_dtype_helper( |
|
onnx_op, name, allow_multi_dim_support |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::cumsum") |
|
@symbolic_helper.parse_args("v", "i", "none") |
|
def cumsum(g: jit_utils.GraphContext, input, dim, dtype): |
|
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) |
|
|
|
|
|
@_onnx_symbolic("aten::_sample_dirichlet") |
|
def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): |
|
return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) |
|
|
|
|
|
@_onnx_symbolic("aten::_standard_gamma") |
|
def _standard_gamma(g: jit_utils.GraphContext, self, generator): |
|
return symbolic_helper._onnx_unsupported("_standard_gamma", self) |
|
|
|
|
|
@_onnx_symbolic("aten::t") |
|
def t(g: jit_utils.GraphContext, self): |
|
rank = symbolic_helper._get_tensor_rank(self) |
|
if rank is None or rank < 2: |
|
|
|
|
|
|
|
return g.op("Identity", self) |
|
return g.op("Transpose", self, perm_i=(1, 0)) |
|
|
|
|
|
@_onnx_symbolic("aten::numpy_T") |
|
@symbolic_helper.quantized_args(True) |
|
def numpy_T(g: jit_utils.GraphContext, input): |
|
ndim = symbolic_helper._get_tensor_rank(input) |
|
assert ndim is not None |
|
perm = list(reversed(range(0, ndim))) |
|
return g.op("Transpose", input, perm_i=perm) |
|
|
|
|
|
@_onnx_symbolic("aten::expand") |
|
@symbolic_helper.quantized_args(True) |
|
def expand(g: jit_utils.GraphContext, self, size, implicit): |
|
"""Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" |
|
size = symbolic_helper._maybe_get_const(size, "is") |
|
if not symbolic_helper._is_value(size): |
|
size = g.op("Constant", value_t=torch.LongTensor(size)) |
|
elif symbolic_helper._is_packed_list(size): |
|
|
|
|
|
|
|
size = symbolic_helper._reshape_helper( |
|
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) |
|
) |
|
dtype = _type_utils.JitScalarType.INT64 |
|
ones = ones_like(g, size, dtype) |
|
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) |
|
size = where(g, g.op("Equal", size, neg_ones), ones, size) |
|
return g.op("Expand", self, size) |
|
|
|
|
|
@_onnx_symbolic("aten::broadcast_to") |
|
@symbolic_helper.quantized_args(True) |
|
def broadcast_to(g: jit_utils.GraphContext, self, size): |
|
size = symbolic_helper._maybe_get_const(size, "is") |
|
if not symbolic_helper._is_value(size): |
|
size = g.op("Constant", value_t=torch.LongTensor(size)) |
|
elif symbolic_helper._is_packed_list(size): |
|
|
|
|
|
|
|
size = symbolic_helper._reshape_helper( |
|
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) |
|
) |
|
dtype = _type_utils.JitScalarType.INT64 |
|
ones = ones_like(g, size, dtype) |
|
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) |
|
size = where(g, g.op("Equal", size, neg_ones), ones, size) |
|
return g.op("Expand", self, size) |
|
|
|
|
|
@_onnx_symbolic("aten::expand_as") |
|
@symbolic_helper.quantized_args(True, True) |
|
def expand_as(g: jit_utils.GraphContext, self, other): |
|
self_t = symbolic_helper._maybe_get_const(self, "t") |
|
if isinstance(self_t, torch.Tensor): |
|
orig_type = self_t.dtype |
|
self_t = self_t.to(torch.double) |
|
dims = [] |
|
for d in range(self_t.dim()): |
|
if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): |
|
dims.append(d) |
|
self = g.op( |
|
"Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) |
|
) |
|
|
|
shape = g.op("Shape", other) |
|
return g.op("Expand", self, shape) |
|
|
|
|
|
@_onnx_symbolic("aten::embedding") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "v", "i", "b", "v") |
|
def embedding( |
|
g: jit_utils.GraphContext, |
|
weight, |
|
indices, |
|
padding_idx, |
|
scale_grad_by_freq, |
|
sparse, |
|
): |
|
if scale_grad_by_freq and GLOBALS.export_training: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of embedding with scale_grad_by_freq=True " |
|
"for training mode. ONNX does not support scaling the gradients.", |
|
weight, |
|
) |
|
if padding_idx >= 0 and GLOBALS.export_training: |
|
warnings.warn( |
|
"Warning: ONNX export of embedding with padding_idx >= 0 " |
|
"for training mode. " |
|
"ONNX does not support not updating the embedding vector at padding_idx during training." |
|
) |
|
|
|
return g.op("Gather", weight, indices) |
|
|
|
|
|
@_onnx_symbolic("aten::embedding_bag") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") |
|
def embedding_bag( |
|
g: jit_utils.GraphContext, |
|
embedding_matrix, |
|
indices, |
|
offsets, |
|
scale_grad_by_freq, |
|
mode, |
|
sparse, |
|
per_sample_weights, |
|
include_last_offset, |
|
padding_idx, |
|
): |
|
if not symbolic_helper._is_none(per_sample_weights): |
|
return symbolic_helper._onnx_unsupported( |
|
"embedding_bag with per_sample_weights" |
|
) |
|
|
|
return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) |
|
|
|
|
|
@_onnx_symbolic("aten::size") |
|
@symbolic_helper.quantized_args(True, quantize_output=False) |
|
def size(g: jit_utils.GraphContext, self, dim=None): |
|
if dim is None: |
|
return g.op("Shape", self) |
|
if symbolic_helper._maybe_get_const(dim, "i") < 0: |
|
rank = symbolic_helper._get_tensor_rank(self) |
|
if rank is not None: |
|
dim = symbolic_helper._maybe_get_const(dim, "i") + rank |
|
dim = g.op("Constant", value_t=torch.tensor(dim)) |
|
return symbolic_helper._size_helper(g, self, dim) |
|
|
|
|
|
@_onnx_symbolic("aten::transpose") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "i", "i") |
|
def transpose(g: jit_utils.GraphContext, self, dim0, dim1): |
|
if dim0 == dim1: |
|
return self |
|
|
|
|
|
rank = symbolic_helper._get_tensor_rank(self) |
|
if rank is not None: |
|
axes = list(range(rank)) |
|
axes[dim0], axes[dim1] = axes[dim1], axes[dim0] |
|
return g.op("Transpose", self, perm_i=axes) |
|
else: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of transpose for tensor of unknown rank.", |
|
self, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::permute") |
|
@symbolic_helper.parse_args("v", "is") |
|
def permute(g: jit_utils.GraphContext, self, dims): |
|
if dims == list(range(0, len(dims))): |
|
return self |
|
return g.op("Transpose", self, perm_i=dims) |
|
|
|
|
|
@_onnx_symbolic("aten::view") |
|
@symbolic_helper.quantized_args(True) |
|
def view(g: jit_utils.GraphContext, self, size): |
|
return reshape(g, self, size) |
|
|
|
|
|
@_onnx_symbolic("aten::view_as") |
|
def view_as(g: jit_utils.GraphContext, self, other): |
|
shape = g.op("Shape", other) |
|
return reshape(g, self, shape) |
|
|
|
|
|
@_onnx_symbolic("aten::unsafe_chunk") |
|
@symbolic_helper.parse_args("v", "i", "i", "i") |
|
def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): |
|
if _outputs is None: |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self |
|
) |
|
size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
if size is None: |
|
return symbolic_helper._unimplemented( |
|
"unsafe_chunk", "unknown dimension size", self |
|
) |
|
split_size = (size + chunks - 1) // chunks |
|
splits = [split_size] * (size // split_size) |
|
leftover = size % split_size |
|
if leftover: |
|
splits.append(leftover) |
|
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::split") |
|
@symbolic_helper.parse_args("v", "v", "i", "i") |
|
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): |
|
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"split", 9, 11, "Dynamic number of outputs not supported", self |
|
) |
|
split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") |
|
if split_val.dim() > 0: |
|
return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) |
|
split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") |
|
|
|
size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
if size is None: |
|
if _outputs is not None: |
|
size = split_size * _outputs |
|
else: |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"split", 9, 11, "Unknown dimension size not supported", self |
|
) |
|
splits = [split_size] * (size // split_size) |
|
leftover = size % split_size |
|
if leftover: |
|
splits.append(leftover) |
|
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::unsafe_split") |
|
def unsafe_split( |
|
g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None |
|
): |
|
return split(g, self, split_size_or_sizes, dim, _outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::split_with_sizes") |
|
@symbolic_helper.parse_args("v", "is", "i", "i") |
|
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): |
|
if not symbolic_helper._is_split_static(split_sizes, _outputs): |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self |
|
) |
|
return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::unsafe_split_with_sizes") |
|
def unsafe_split_with_sizes( |
|
g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None |
|
): |
|
return split_with_sizes(g, self, split_sizes, dim, _outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::unbind") |
|
@symbolic_helper.parse_args("v", "i", "i") |
|
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): |
|
if _outputs is None: |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"unbind", 9, 11, "Dynamic number of outputs not supported", self |
|
) |
|
|
|
outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) |
|
outputs = [outputs] if _outputs == 1 else outputs |
|
squeezed_outputs = [ |
|
symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs |
|
] |
|
return squeezed_outputs |
|
|
|
|
|
@_onnx_symbolic("aten::select") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "i", "v") |
|
def select(g: jit_utils.GraphContext, self, dim, index): |
|
"""Implement the select functionality for a pytorch tensor in ONNX. |
|
|
|
Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. |
|
""" |
|
index = symbolic_helper._maybe_get_scalar(index) |
|
if (not symbolic_helper._is_value(index)) and (index < 0): |
|
if index == -1: |
|
end_index = _constants.INT64_MAX |
|
else: |
|
end_index = index + 1 |
|
slice_node = symbolic_helper._slice_helper( |
|
g, self, axes=[dim], starts=[index], ends=[end_index] |
|
) |
|
return symbolic_helper._squeeze_helper(g, slice_node, [dim]) |
|
else: |
|
|
|
return g.op("Gather", self, index, axis_i=dim) |
|
|
|
|
|
@_onnx_symbolic("aten::square") |
|
def square(g: jit_utils.GraphContext, self): |
|
return g.op("Mul", self, self) |
|
|
|
|
|
@_onnx_symbolic("aten::squeeze") |
|
def squeeze(g: jit_utils.GraphContext, self, dim=None): |
|
if dim is None: |
|
return g.op("Squeeze", self) |
|
|
|
squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") |
|
|
|
if squeeze_dim < 0: |
|
rank = symbolic_helper._get_tensor_rank(self) |
|
if rank is not None: |
|
warnings.warn( |
|
"ONNX export squeeze with negative axis " |
|
+ str(squeeze_dim) |
|
+ " might cause the onnx model to be incorrect. " |
|
+ "Negative axis is not supported in ONNX. " |
|
+ "Axis is converted to " |
|
+ str(squeeze_dim + rank) |
|
+ " based on input shape at export time. " |
|
+ "Passing an tensor of different rank in execution will be incorrect." |
|
) |
|
squeeze_dim += rank |
|
else: |
|
return symbolic_helper._unimplemented( |
|
"squeeze", "negative axis with unknown input rank", self |
|
) |
|
|
|
dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) |
|
if dim_size is None: |
|
warnings.warn( |
|
"This model contains a squeeze operation on dimension " |
|
+ str(squeeze_dim) |
|
+ " on an input " |
|
+ "with unknown shape. Note that if the size of dimension " |
|
+ str(squeeze_dim) |
|
+ " of the input " |
|
+ "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " |
|
+ "non-singleton dimensions, it is recommended to export this model using opset " |
|
+ "version 11 or higher." |
|
) |
|
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) |
|
if dim_size > 1: |
|
warnings.warn( |
|
"This model contains a squeeze operation on dimension " |
|
+ str(squeeze_dim) |
|
+ ". The size of " |
|
+ "this dimension in the given input is " |
|
+ str(dim_size) |
|
+ ". The model will " |
|
+ "be exported without the squeeze node. If the model is intended to be used with dynamic " |
|
+ "input shapes, please use opset version 11 to " |
|
+ "export the model." |
|
) |
|
return self |
|
|
|
warnings.warn( |
|
"This model contains a squeeze operation on dimension " |
|
+ str(squeeze_dim) |
|
+ ". If the model is " |
|
+ "intended to be used with dynamic input shapes, please use opset version 11 to export the model." |
|
) |
|
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) |
|
|
|
|
|
@_onnx_symbolic("aten::prelu") |
|
def prelu(g: jit_utils.GraphContext, self, weight): |
|
self_rank = symbolic_helper._get_tensor_rank(self) |
|
weight_sizes = symbolic_helper._get_tensor_sizes(weight) |
|
weight_rank = len(weight_sizes) |
|
if self_rank is not None: |
|
if self_rank > 2: |
|
|
|
weight = symbolic_helper._unsqueeze_helper( |
|
g, weight, list(range(1, self_rank - 1)) |
|
) |
|
elif self_rank == 0 and weight_sizes == [1]: |
|
|
|
weight = symbolic_helper._squeeze_helper(g, weight, [0]) |
|
weight_rank = 0 |
|
|
|
if self_rank is not None and weight_rank is not None: |
|
assert self_rank >= weight_rank, ( |
|
f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" |
|
) |
|
return g.op("PRelu", self, weight) |
|
|
|
|
|
@_onnx_symbolic("aten::silu") |
|
def silu(g: jit_utils.GraphContext, input): |
|
return g.op("Mul", input, g.op("Sigmoid", input)) |
|
|
|
|
|
@_onnx_symbolic("aten::mish") |
|
def mish(g: jit_utils.GraphContext, input): |
|
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) |
|
|
|
|
|
@_onnx_symbolic("aten::relu") |
|
@symbolic_helper.quantized_args(True) |
|
def relu(g: jit_utils.GraphContext, input): |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Relu", input, opset_before=14 |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::relu6") |
|
@symbolic_helper.quantized_args(True) |
|
def relu6(g: jit_utils.GraphContext, input): |
|
return clamp(g, input, 0, 6) |
|
|
|
|
|
@_onnx_symbolic("aten::ceil") |
|
def ceil(g: jit_utils.GraphContext, input): |
|
return g.op("Ceil", input) |
|
|
|
|
|
@_onnx_symbolic("aten::floor") |
|
def floor(g: jit_utils.GraphContext, input): |
|
return g.op("Floor", input) |
|
|
|
|
|
@_onnx_symbolic("aten::len") |
|
def _len(g: jit_utils.GraphContext, self): |
|
sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) |
|
return symbolic_helper._squeeze_helper(g, sz_0, [0]) |
|
|
|
|
|
@_onnx_symbolic("aten::threshold") |
|
@symbolic_helper.parse_args("v", "t", "t") |
|
def threshold(g: jit_utils.GraphContext, self, threshold, value): |
|
|
|
if symbolic_helper._scalar(threshold) != 0: |
|
return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) |
|
if symbolic_helper._scalar(value) != 0: |
|
return symbolic_helper._unimplemented("threshold", "non-zero value", self) |
|
return g.op("Relu", self) |
|
|
|
|
|
@_onnx_symbolic("aten::leaky_relu") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "f", "b") |
|
def leaky_relu( |
|
g: jit_utils.GraphContext, |
|
input: _C.Value, |
|
negative_slope: float, |
|
inplace: bool = False, |
|
): |
|
|
|
return g.op("LeakyRelu", input, alpha_f=negative_slope) |
|
|
|
|
|
@_onnx_symbolic("aten::glu") |
|
@symbolic_helper.parse_args("v", "i") |
|
def glu(g: jit_utils.GraphContext, input, dim): |
|
dim_size = symbolic_helper._get_tensor_dim_size(input, dim) |
|
if dim_size is not None: |
|
assert dim_size % 2 == 0 |
|
|
|
first, second = g.op("Split", input, axis_i=dim, outputs=2) |
|
return g.op("Mul", first, g.op("Sigmoid", second)) |
|
|
|
|
|
@_onnx_symbolic("aten::softmax") |
|
@symbolic_helper.parse_args("v", "i", "none") |
|
def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dim = symbolic_helper._get_tensor_rank(input) |
|
if input_dim is not None: |
|
|
|
if dim < 0: |
|
dim = input_dim + dim |
|
|
|
is_transpose_required = input_dim != dim + 1 |
|
|
|
if is_transpose_required: |
|
axes = list(range(input_dim)) |
|
axes[dim], axes[-1] = axes[-1], axes[dim] |
|
input = g.op("Transpose", input, perm_i=axes) |
|
dim = input_dim - 1 |
|
|
|
softmax = g.op("Softmax", input, axis_i=dim) |
|
if dtype and dtype.node().kind() != "prim::Constant": |
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
softmax = g.op( |
|
"Cast", |
|
softmax, |
|
to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), |
|
) |
|
|
|
if is_transpose_required: |
|
softmax = g.op("Transpose", softmax, perm_i=axes) |
|
return softmax |
|
|
|
|
|
input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) |
|
|
|
exp = g.op("Exp", input) |
|
sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) |
|
softmax = g.op("Div", exp, sum) |
|
if dtype and dtype.node().kind() != "prim::Constant": |
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
softmax = g.op( |
|
"Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() |
|
) |
|
return softmax |
|
|
|
|
|
@_onnx_symbolic("aten::softplus") |
|
def softplus(g: jit_utils.GraphContext, self, beta, threshold): |
|
beta_const = symbolic_helper._maybe_get_const(beta, "f") |
|
if beta_const != 1: |
|
return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) |
|
return g.op("Softplus", self) |
|
|
|
|
|
@_onnx_symbolic("aten::get_pool_ceil_padding") |
|
def get_pool_ceil_padding(input, kernel_size, stride, padding): |
|
|
|
sizes = symbolic_helper._get_tensor_sizes(input) |
|
dim = sizes[-len(padding) :] if sizes is not None else None |
|
if dim is None or any(i is None for i in dim): |
|
return symbolic_helper._unimplemented( |
|
"get_pool_ceil_padding", "input size not accessible", input |
|
) |
|
ceiled_output_dim = [ |
|
int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) |
|
+ 1 |
|
for i in range(0, len(padding)) |
|
] |
|
|
|
ceiled_output_dim = [ |
|
( |
|
ceiled_output_dim[i] - 1 |
|
if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) |
|
else ceiled_output_dim[i] |
|
) |
|
for i in range(0, len(ceiled_output_dim)) |
|
] |
|
padding_ceil = [ |
|
( |
|
0 |
|
if (stride[i] == 1) |
|
else ( |
|
kernel_size[i] |
|
- ( |
|
dim[i] |
|
+ 2 * padding[i] |
|
- ((ceiled_output_dim[i] - 1) * stride[i] + 1) |
|
) |
|
) |
|
) |
|
for i in range(0, len(padding)) |
|
] |
|
|
|
padding_ceil = [ |
|
( |
|
( |
|
int(padding_ceil[i]) |
|
if padding_ceil[i] < kernel_size[i] - 1 |
|
else int(kernel_size[i] - 1) |
|
) |
|
if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) |
|
else int(padding_ceil[i]) |
|
) |
|
for i in range(0, len(padding_ceil)) |
|
] |
|
return padding_ceil |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::max_pool1d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False |
|
), |
|
_export("max_pool1d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::max_pool2d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False |
|
), |
|
_export("max_pool2d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::max_pool3d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False |
|
), |
|
_export("max_pool3d"), |
|
], |
|
) |
|
def _max_pool(name, tuple_fn, ndims, return_indices): |
|
@symbolic_helper.quantized_args(True, False, False, False, False, False) |
|
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") |
|
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): |
|
if set(tuple_fn(dilation)) != {1}: |
|
return symbolic_helper._unimplemented(name, "dilation", input) |
|
if not stride: |
|
stride = kernel_size |
|
padding = tuple(tuple_fn(padding)) |
|
if ceil_mode: |
|
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) |
|
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) |
|
else: |
|
padding = padding * 2 |
|
kwargs = { |
|
"kernel_shape_i": tuple_fn(kernel_size), |
|
"pads_i": padding, |
|
"strides_i": tuple_fn(stride), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if return_indices: |
|
r, indices = g.op("MaxPool", input, outputs=2, **kwargs) |
|
_, flattened_indices = g.op( |
|
"MaxPool", |
|
input, |
|
outputs=2, |
|
kernel_shape_i=[1 for _ in range(ndims)], |
|
strides_i=[1 for _ in range(ndims)], |
|
) |
|
|
|
s = symbolic_helper._slice_helper( |
|
g, |
|
flattened_indices, |
|
axes=[2 + i for i in range(ndims)], |
|
starts=list(tuple_fn(0)), |
|
ends=list(tuple_fn(1)), |
|
) |
|
indices = sub(g, indices, s) |
|
return r, indices |
|
else: |
|
r = g.op("MaxPool", input, outputs=1, **kwargs) |
|
return r |
|
|
|
return symbolic_fn |
|
|
|
|
|
max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( |
|
_max_pool( |
|
"max_pool1d_with_indices", |
|
torch.nn.modules.utils._single, |
|
1, |
|
return_indices=True, |
|
) |
|
) |
|
max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( |
|
_max_pool( |
|
"max_pool2d_with_indices", |
|
torch.nn.modules.utils._pair, |
|
2, |
|
return_indices=True, |
|
) |
|
) |
|
max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( |
|
_max_pool( |
|
"max_pool3d_with_indices", |
|
torch.nn.modules.utils._triple, |
|
3, |
|
return_indices=True, |
|
) |
|
) |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::avg_pool1d", |
|
decorate=[ |
|
symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), |
|
_export("avg_pool1d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::avg_pool2d", |
|
decorate=[ |
|
symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), |
|
_export("avg_pool2d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::avg_pool3d", |
|
decorate=[ |
|
symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), |
|
_export("avg_pool3d"), |
|
], |
|
) |
|
def _avg_pool(name, tuple_fn): |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") |
|
def symbolic_fn( |
|
g, |
|
input: _C.Value, |
|
kernel_size: Sequence[int], |
|
stride: Sequence[int], |
|
padding: int | Sequence[int], |
|
ceil_mode: int, |
|
count_include_pad: int, |
|
divisor_override=None, |
|
): |
|
if not stride: |
|
stride = kernel_size |
|
padding = symbolic_helper._avgpool_helper( |
|
tuple_fn, padding, kernel_size, stride, divisor_override, name |
|
) |
|
assert isinstance(padding, tuple) |
|
adjusted_padding = padding |
|
|
|
|
|
|
|
|
|
|
|
if count_include_pad: |
|
input = symbolic_helper._op_with_optional_float_cast( |
|
g, |
|
"Pad", |
|
input, |
|
pads_i=((0,) * 2 + padding) * 2, |
|
mode_s="constant", |
|
value_f=0.0, |
|
opset_before=11, |
|
) |
|
adjusted_padding = (0,) * len(padding) |
|
if ceil_mode: |
|
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) |
|
adjusted_padding = adjusted_padding + tuple( |
|
a + b for (a, b) in zip(padding_ceil, adjusted_padding) |
|
) |
|
else: |
|
adjusted_padding = adjusted_padding * 2 |
|
output = g.op( |
|
"AveragePool", |
|
input, |
|
kernel_shape_i=tuple_fn(kernel_size), |
|
strides_i=tuple_fn(stride), |
|
pads_i=adjusted_padding, |
|
) |
|
return output |
|
|
|
return symbolic_fn |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::adaptive_avg_pool1d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single |
|
), |
|
_export("adaptive_avg_pool1d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::adaptive_avg_pool2d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair |
|
), |
|
_export("adaptive_avg_pool2d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::adaptive_avg_pool3d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple |
|
), |
|
_export("adaptive_avg_pool3d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::adaptive_max_pool1d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"adaptive_max_pool1d", |
|
"MaxPool", |
|
torch.nn.modules.utils._single, |
|
max_pool1d_with_indices, |
|
), |
|
_export("adaptive_max_pool1d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::adaptive_max_pool2d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"adaptive_max_pool2d", |
|
"MaxPool", |
|
torch.nn.modules.utils._pair, |
|
max_pool2d_with_indices, |
|
), |
|
_export("adaptive_max_pool2d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::adaptive_max_pool3d", |
|
decorate=[ |
|
symbolic_helper._apply_params( |
|
"adaptive_max_pool3d", |
|
"MaxPool", |
|
torch.nn.modules.utils._triple, |
|
max_pool3d_with_indices, |
|
), |
|
_export("adaptive_max_pool3d"), |
|
], |
|
) |
|
def _adaptive_pool(name, type, tuple_fn, fn=None): |
|
@symbolic_helper.quantized_args(True, False) |
|
def symbolic_fn(g, input, output_size): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_size_value = output_size |
|
try: |
|
output_size = symbolic_helper._parse_arg(output_size, "is") |
|
except Exception: |
|
|
|
|
|
return symbolic_helper._onnx_unsupported( |
|
"adaptive pooling, since output_size is not constant.", input |
|
) |
|
if output_size == [1] * len(output_size) and type == "AveragePool": |
|
return g.op("GlobalAveragePool", input) |
|
sizes = symbolic_helper._get_tensor_sizes(input) |
|
try: |
|
dim = sizes[2:] |
|
except Exception: |
|
|
|
|
|
dim = None |
|
if dim is None or any(i is None for i in dim): |
|
if output_size == [1] * len(output_size): |
|
return g.op("GlobalMaxPool", input), None |
|
return symbolic_helper._unimplemented( |
|
name, "input size not accessible", input |
|
) |
|
|
|
mod = [dim[i] % output_size[i] for i in range(0, len(dim))] |
|
if mod != [0] * len(mod): |
|
if output_size == [1] * len(output_size): |
|
return g.op("GlobalMaxPool", input), None |
|
return symbolic_helper._unimplemented( |
|
name, "output size that are not factor of input size", output_size_value |
|
) |
|
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] |
|
|
|
if type == "MaxPool": |
|
return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) |
|
output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) |
|
return output |
|
|
|
return symbolic_fn |
|
|
|
|
|
def _prepare_onnx_paddings(dim: int, pad): |
|
"""Generate paddings in ONNX order based on pad in pytorch. |
|
Args: |
|
dim: the dimension of the tensor. |
|
pad: the paddings in pytorch. |
|
The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... |
|
""" |
|
|
|
|
|
|
|
|
|
paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) |
|
|
|
paddings = paddings[-2::-2] + paddings[-1::-2] |
|
return paddings |
|
|
|
|
|
def _convert_padding_node(input): |
|
padding = symbolic_helper._maybe_get_const(input, "is") |
|
if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): |
|
input_list = symbolic_helper._unpack_list(padding) |
|
try: |
|
padding = [ |
|
symbolic_helper._get_const(v, "i", "padding") for v in input_list |
|
] |
|
except Exception: |
|
|
|
|
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"Pad", 9, 11, "The sizes of the padding must be constant", input |
|
) |
|
return padding |
|
|
|
|
|
@_onnx_symbolic("aten::constant_pad_nd") |
|
def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): |
|
mode = "constant" |
|
try: |
|
value = symbolic_helper._get_const(value, "f", "value") |
|
except Exception: |
|
|
|
|
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"Pad", 9, 11, "The value for the padding must be constant", value |
|
) |
|
|
|
padding = _convert_padding_node(padding) |
|
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 |
|
) |
|
|
|
|
|
def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): |
|
padding = _convert_padding_node(pad) |
|
assert len(padding) % 2 == 0 |
|
ndim = len(padding) // 2 |
|
|
|
cur = input |
|
for idx in range(ndim): |
|
pad_r = padding[-(2 * idx + 1)] |
|
pad_l = padding[-(2 * idx + 2)] |
|
tensors = [] |
|
if pad_l > 0: |
|
left = symbolic_helper._slice_helper( |
|
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] |
|
) |
|
tensors.append(left) |
|
|
|
if pad_l < 0 or pad_r < 0: |
|
start = builtins.max(0, -pad_l) |
|
end = -(builtins.max(0, -pad_r)) |
|
middle = symbolic_helper._slice_helper( |
|
g, |
|
cur, |
|
axes=[2 + idx], |
|
starts=[start], |
|
ends=[end], |
|
) |
|
tensors.append(middle) |
|
else: |
|
tensors.append(cur) |
|
|
|
if pad_r > 0: |
|
right = symbolic_helper._slice_helper( |
|
g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] |
|
) |
|
tensors.append(right) |
|
|
|
cur = g.op("Concat", *tensors, axis_i=(2 + idx)) |
|
|
|
return cur |
|
|
|
|
|
@_onnx_symbolic("aten::reflection_pad1d") |
|
@_onnx_symbolic("aten::reflection_pad2d") |
|
@_onnx_symbolic("aten::reflection_pad3d") |
|
def reflection_pad(g: jit_utils.GraphContext, input, padding): |
|
mode = "reflect" |
|
padding = _convert_padding_node(padding) |
|
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::replication_pad1d") |
|
@_onnx_symbolic("aten::replication_pad2d") |
|
@_onnx_symbolic("aten::replication_pad3d") |
|
def replication_pad(g: jit_utils.GraphContext, input, padding): |
|
mode = "edge" |
|
padding = _convert_padding_node(padding) |
|
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::pad") |
|
def pad( |
|
g: jit_utils.GraphContext, |
|
input: _C.Value, |
|
pad: _C.Value, |
|
mode: _C.Value, |
|
value: _C.Value, |
|
): |
|
mode = symbolic_helper._parse_arg(mode, "s") |
|
if mode == "replicate": |
|
return replication_pad(g, input, pad) |
|
elif mode == "reflect": |
|
return reflection_pad(g, input, pad) |
|
elif mode == "constant": |
|
return constant_pad_nd(g, input, pad, value) |
|
elif mode == "circular": |
|
return _pad_circular(g, input, pad) |
|
else: |
|
raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::upsample_nearest1d", |
|
decorate=[ |
|
symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), |
|
_export("upsample_nearest1d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_nearest2d", |
|
decorate=[ |
|
symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), |
|
_export("upsample_nearest2d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_nearest3d", |
|
decorate=[ |
|
symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), |
|
_export("upsample_nearest3d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_linear1d", |
|
decorate=[ |
|
symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), |
|
_export("upsample_linear1d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_bilinear2d", |
|
decorate=[ |
|
symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), |
|
_export("upsample_bilinear2d"), |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_trilinear3d", |
|
decorate=[ |
|
symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), |
|
_export("upsample_trilinear3d"), |
|
], |
|
) |
|
def _interpolate(name: str, dim: int, interpolate_mode: str): |
|
def symbolic_fn(g, input, output_size, *args): |
|
scales, align_corners = symbolic_helper._get_interpolate_attributes( |
|
g, interpolate_mode, args |
|
) |
|
symbolic_helper._interpolate_warning(interpolate_mode) |
|
align_corners = symbolic_helper._maybe_get_scalar(align_corners) |
|
if align_corners: |
|
return symbolic_helper._unimplemented(name, "align_corners == True", input) |
|
if scales is None: |
|
scales = symbolic_helper._interpolate_size_to_scales( |
|
g, input, output_size, dim |
|
) |
|
return g.op("Upsample", input, scales, mode_s=interpolate_mode) |
|
|
|
return symbolic_fn |
|
|
|
|
|
@_onnx_symbolic("aten::__interpolate") |
|
def __interpolate( |
|
g: jit_utils.GraphContext, |
|
input, |
|
size, |
|
scale_factor, |
|
mode, |
|
align_corners, |
|
recompute_scale_factor, |
|
antialias, |
|
): |
|
scales, mode = symbolic_helper._interpolate_get_scales_and_mode( |
|
g, input, size, scale_factor, mode, align_corners |
|
) |
|
return g.op("Upsample", input, scales, mode_s=mode) |
|
|
|
|
|
@_onnx_symbolic("aten::bitwise_not") |
|
def bitwise_not(g: jit_utils.GraphContext, input): |
|
if not symbolic_helper._is_bool(input): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise Not " |
|
"for non-boolean input values", |
|
input, |
|
) |
|
return g.op("Not", input) |
|
|
|
|
|
@_onnx_symbolic("aten::bitwise_or") |
|
def bitwise_or(g, self, other): |
|
if not symbolic_helper._is_bool(self): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise OR " |
|
"for non-boolean input values. self: ", |
|
self, |
|
) |
|
if not symbolic_helper._is_bool(other): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise OR " |
|
"for non-boolean input values. other: ", |
|
other, |
|
) |
|
return g.op("Or", self, other) |
|
|
|
|
|
def wrap_logical_op_with_cast_to(to_type): |
|
def decorator(fn): |
|
@functools.wraps(fn) |
|
def wrap_with_cast(g, input, other): |
|
to_cast_func = globals()[f"_cast_{to_type}"] |
|
return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) |
|
|
|
return wrap_with_cast |
|
|
|
return decorator |
|
|
|
|
|
def wrap_logical_op_with_negation(func: Callable) -> Callable: |
|
@functools.wraps(func) |
|
def wrap_with_not(g, input, other): |
|
return g.op("Not", func(g, input, other)) |
|
|
|
return wrap_with_not |
|
|
|
|
|
@_onnx_symbolic("aten::__not_") |
|
def __not_(g: jit_utils.GraphContext, self): |
|
if not symbolic_helper._is_bool(self): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise Not " |
|
"for non-boolean input values", |
|
self, |
|
) |
|
return g.op("Not", self) |
|
|
|
|
|
@_onnx_symbolic("aten::eq") |
|
@symbolic_helper.quantized_args(True, True) |
|
def eq(g: jit_utils.GraphContext, self, other): |
|
if isinstance(self.type(), _C.DeviceObjType) and isinstance( |
|
other.type(), _C.DeviceObjType |
|
): |
|
|
|
|
|
return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) |
|
self_node = self.node() |
|
other_node = other.node() |
|
if self_node.kind() == other_node.kind() == "onnx::Constant": |
|
if self_node.kindOf("value") == other_node.kindOf("value") == "s": |
|
|
|
|
|
|
|
return g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
self_node.s("value") == other_node.s("value"), |
|
dtype=torch.bool, |
|
), |
|
) |
|
|
|
return g.op("Equal", self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::ne") |
|
@symbolic_helper.quantized_args(True, True) |
|
@wrap_logical_op_with_negation |
|
def ne(g: jit_utils.GraphContext, self, other): |
|
return eq(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::gt") |
|
@symbolic_helper.quantized_args(True, True) |
|
def gt(g: jit_utils.GraphContext, input, other): |
|
return _gt_impl(g, input, other) |
|
|
|
|
|
def _gt_impl(g: jit_utils.GraphContext, input, other): |
|
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): |
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) |
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) |
|
return g.op("Greater", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::lt") |
|
@symbolic_helper.quantized_args(True, True) |
|
def lt(g: jit_utils.GraphContext, input, other): |
|
return _lt_impl(g, input, other) |
|
|
|
|
|
def _lt_impl(g: jit_utils.GraphContext, input, other): |
|
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): |
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) |
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) |
|
return g.op("Less", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::ge") |
|
@symbolic_helper.quantized_args(True, True) |
|
@wrap_logical_op_with_negation |
|
def ge(g: jit_utils.GraphContext, input, other): |
|
return _lt_impl(g, input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::le") |
|
@symbolic_helper.quantized_args(True, True) |
|
@wrap_logical_op_with_negation |
|
def le(g: jit_utils.GraphContext, input, other): |
|
return _gt_impl(g, input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::__and_") |
|
def __and_(g: jit_utils.GraphContext, input, other): |
|
if not symbolic_helper._is_bool(input): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise AND " |
|
"for non-boolean input values", |
|
input, |
|
) |
|
if not symbolic_helper._is_bool(other): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise AND " |
|
"for non-boolean input values", |
|
other, |
|
) |
|
return g.op("And", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::__or_") |
|
def __or_(g: jit_utils.GraphContext, input, other): |
|
if not symbolic_helper._is_bool(input): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise OR " |
|
"for non-boolean input values", |
|
input, |
|
) |
|
if not symbolic_helper._is_bool(other): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise OR " |
|
"for non-boolean input values", |
|
other, |
|
) |
|
return g.op("Or", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::__xor_") |
|
def __xor_(g: jit_utils.GraphContext, input, other): |
|
if not symbolic_helper._is_bool(input): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise XOR " |
|
"for non-boolean input values", |
|
input, |
|
) |
|
if not symbolic_helper._is_bool(other): |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting bitwise XOR " |
|
"for non-boolean input values", |
|
other, |
|
) |
|
return g.op("Xor", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::logical_and") |
|
@wrap_logical_op_with_cast_to("Bool") |
|
def logical_and(g: jit_utils.GraphContext, input, other): |
|
return g.op("And", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::logical_or") |
|
@wrap_logical_op_with_cast_to("Bool") |
|
def logical_or(g: jit_utils.GraphContext, input, other): |
|
return g.op("Or", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::logical_xor") |
|
@wrap_logical_op_with_cast_to("Bool") |
|
def logical_xor(g: jit_utils.GraphContext, input, other): |
|
return g.op("Xor", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::logical_not") |
|
def logical_not(g: jit_utils.GraphContext, input): |
|
return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) |
|
|
|
|
|
@_onnx_symbolic("aten::__rshift_") |
|
def __rshift_(g: jit_utils.GraphContext, self, other): |
|
|
|
|
|
self_scalar_type = _type_utils.JitScalarType.from_value(self) |
|
if ( |
|
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) |
|
!= self_scalar_type |
|
): |
|
other = g.op( |
|
"Cast", |
|
other, |
|
to_i=self_scalar_type.onnx_type(), |
|
) |
|
|
|
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) |
|
|
|
if not symbolic_helper._is_fp(self): |
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
two_pow = g.op("Pow", two, other) |
|
two_pow = g.op( |
|
"Cast", |
|
two_pow, |
|
to_i=self_scalar_type.onnx_type(), |
|
) |
|
rshift = g.op("Div", self, two_pow) |
|
return rshift |
|
|
|
|
|
@_onnx_symbolic("aten::__lshift_") |
|
def __lshift_(g: jit_utils.GraphContext, self, other): |
|
|
|
|
|
self_scalar_type = _type_utils.JitScalarType.from_value(self) |
|
if ( |
|
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) |
|
!= self_scalar_type |
|
): |
|
other = g.op( |
|
"Cast", |
|
other, |
|
to_i=self_scalar_type.onnx_type(), |
|
) |
|
|
|
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) |
|
|
|
if not symbolic_helper._is_fp(self): |
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
two_pow = g.op("Pow", two, other) |
|
two_pow = g.op( |
|
"Cast", |
|
two_pow, |
|
to_i=self_scalar_type.onnx_type(), |
|
) |
|
lshift = g.op("Mul", self, two_pow) |
|
return lshift |
|
|
|
|
|
@_onnx_symbolic("aten::where") |
|
@symbolic_helper.parse_args("v", "v", "v", "i") |
|
def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): |
|
|
|
if not symbolic_helper._is_bool(condition): |
|
condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) |
|
if self is None: |
|
condition = nonzero(g, condition) |
|
return symbolic_helper._unbind_helper( |
|
g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs |
|
) |
|
return g.op("Where", condition, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::log_softmax") |
|
@symbolic_helper.parse_args("v", "i", "none") |
|
def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): |
|
|
|
|
|
|
|
input_dim = symbolic_helper._get_tensor_rank(input) |
|
if input_dim is None: |
|
return symbolic_helper._unimplemented( |
|
"dim", |
|
"ONNX and PyTorch use different strategies to split the input. " |
|
"Input rank must be known at export time.", |
|
) |
|
if dim < 0: |
|
dim = input_dim + dim |
|
is_transpose_required = input_dim != dim + 1 |
|
|
|
if is_transpose_required: |
|
axes = list(range(input_dim)) |
|
axes[dim], axes[-1] = axes[-1], axes[dim] |
|
input = g.op("Transpose", input, perm_i=axes) |
|
dim = input_dim - 1 |
|
return_op = g.op("LogSoftmax", input, axis_i=dim) |
|
if dtype and dtype.node().kind() != "prim::Constant": |
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
return_op = g.op( |
|
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() |
|
) |
|
if is_transpose_required: |
|
return_op = g.op("Transpose", return_op, perm_i=axes) |
|
return return_op |
|
|
|
|
|
@_onnx_symbolic("aten::_log_softmax") |
|
@symbolic_helper.parse_args("v", "i", "i") |
|
def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): |
|
if ( |
|
half_to_float |
|
and _type_utils.JitScalarType.from_value( |
|
input, _type_utils.JitScalarType.UNDEFINED |
|
) |
|
== _type_utils.JitScalarType.HALF |
|
): |
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
return log_softmax(g, input, dim) |
|
|
|
|
|
@_onnx_symbolic("aten::_convolution") |
|
@symbolic_helper.parse_args( |
|
"v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" |
|
) |
|
def _convolution( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
transposed, |
|
output_padding, |
|
groups, |
|
benchmark, |
|
deterministic, |
|
cudnn_enabled, |
|
allow_tf32=None, |
|
): |
|
weight_size = symbolic_helper._get_tensor_sizes(weight) |
|
try: |
|
kernel_shape = weight_size[2:] |
|
except Exception: |
|
|
|
|
|
kernel_shape = None |
|
|
|
if kernel_shape is None or any(i is None for i in kernel_shape): |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of convolution for kernel of unknown shape.", |
|
input, |
|
) |
|
|
|
args = [input, weight] |
|
|
|
if ( |
|
not symbolic_helper._is_none(bias) |
|
and symbolic_helper._get_tensor_rank(bias) == 1 |
|
): |
|
args.append(bias) |
|
|
|
kwargs = { |
|
"kernel_shape_i": weight_size[2:], |
|
"strides_i": stride, |
|
|
|
|
|
"pads_i": padding + padding, |
|
"dilations_i": dilation, |
|
"group_i": groups, |
|
} |
|
|
|
if any(o != 0 for o in output_padding): |
|
|
|
|
|
|
|
assert transposed |
|
assert len(stride) == len(output_padding) |
|
kwargs["output_padding_i"] = output_padding |
|
|
|
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) |
|
|
|
if ( |
|
not symbolic_helper._is_none(bias) |
|
and symbolic_helper._get_tensor_rank(bias) != 1 |
|
): |
|
return g.op("Add", n, bias) |
|
else: |
|
return n |
|
|
|
|
|
@_onnx_symbolic("aten::_convolution_mode") |
|
@symbolic_helper.parse_args( |
|
"v", |
|
"v", |
|
"v", |
|
"is", |
|
"s", |
|
"is", |
|
"i", |
|
) |
|
def _convolution_mode( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
): |
|
weight_size = symbolic_helper._get_tensor_sizes(weight) |
|
try: |
|
kernel_shape = weight_size[2:] |
|
except Exception: |
|
|
|
|
|
kernel_shape = None |
|
|
|
if kernel_shape is None or any(i is None for i in kernel_shape): |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of convolution for kernel of unknown shape.", |
|
input, |
|
) |
|
|
|
args = [input, weight] |
|
|
|
if ( |
|
not symbolic_helper._is_none(bias) |
|
and symbolic_helper._get_tensor_rank(bias) == 1 |
|
): |
|
args.append(bias) |
|
|
|
if padding == "valid": |
|
padding = "VALID" |
|
elif padding == "same": |
|
padding = "SAME_UPPER" |
|
kwargs = { |
|
"kernel_shape_i": weight_size[2:], |
|
"strides_i": stride, |
|
"auto_pad_s": padding, |
|
"dilations_i": dilation, |
|
"group_i": groups, |
|
} |
|
|
|
n = g.op("Conv", *args, **kwargs) |
|
|
|
if ( |
|
not symbolic_helper._is_none(bias) |
|
and symbolic_helper._get_tensor_rank(bias) != 1 |
|
): |
|
return g.op("Add", n, bias) |
|
else: |
|
return n |
|
|
|
|
|
@_onnx_symbolic("aten::convolution") |
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") |
|
def convolution( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
transposed, |
|
output_padding, |
|
groups, |
|
): |
|
return _convolution( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
transposed, |
|
output_padding, |
|
groups, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::conv1d") |
|
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") |
|
def conv1d( |
|
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups |
|
): |
|
str_padding = symbolic_helper._parse_arg(padding, "s") |
|
if str_padding in ["valid", "same"]: |
|
return _convolution_mode( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
str_padding, |
|
dilation, |
|
groups, |
|
) |
|
else: |
|
padding = symbolic_helper._parse_arg(padding, "is") |
|
return _convolution( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
False, |
|
(), |
|
groups, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::conv2d") |
|
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") |
|
def conv2d( |
|
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups |
|
): |
|
str_padding = symbolic_helper._parse_arg(padding, "s") |
|
if str_padding in ["valid", "same"]: |
|
return _convolution_mode( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
str_padding, |
|
dilation, |
|
groups, |
|
) |
|
else: |
|
padding = symbolic_helper._parse_arg(padding, "is") |
|
return _convolution( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
False, |
|
(), |
|
groups, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::conv3d") |
|
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") |
|
def conv3d( |
|
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups |
|
): |
|
str_padding = symbolic_helper._parse_arg(padding, "s") |
|
if str_padding in ["valid", "same"]: |
|
return _convolution_mode( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
str_padding, |
|
dilation, |
|
groups, |
|
) |
|
else: |
|
padding = symbolic_helper._parse_arg(padding, "is") |
|
return _convolution( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
False, |
|
(), |
|
groups, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::conv_transpose1d") |
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") |
|
def conv_transpose1d( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
output_padding, |
|
groups, |
|
dilation, |
|
): |
|
return _convolution( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
True, |
|
output_padding, |
|
groups, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::conv_transpose2d") |
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") |
|
def conv_transpose2d( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
output_padding, |
|
groups, |
|
dilation, |
|
): |
|
return _convolution( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
True, |
|
output_padding, |
|
groups, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::conv_transpose3d") |
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") |
|
def conv_transpose3d( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
output_padding, |
|
groups, |
|
dilation, |
|
): |
|
return _convolution( |
|
g, |
|
input, |
|
weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
True, |
|
output_padding, |
|
groups, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::batch_norm") |
|
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") |
|
def batch_norm( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
running_mean, |
|
running_var, |
|
training, |
|
momentum, |
|
eps, |
|
cudnn_enabled, |
|
): |
|
symbolic_helper.check_training_mode(training, "batch_norm") |
|
|
|
if ( |
|
torch.is_autocast_enabled() |
|
and not symbolic_helper.args_have_same_dtype( |
|
[input, weight, bias, running_mean, running_var] |
|
) |
|
and GLOBALS.export_onnx_opset_version < 15 |
|
): |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"BatchNormalization", |
|
9, |
|
15, |
|
"All input tensors must have the same `dtype`." |
|
" Turn off Autocast or export using opset version 15.", |
|
input, |
|
) |
|
|
|
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( |
|
g, input, weight, bias, running_mean, running_var |
|
) |
|
out = g.op( |
|
"BatchNormalization", |
|
input, |
|
weight, |
|
bias, |
|
running_mean, |
|
running_var, |
|
epsilon_f=eps, |
|
momentum_f=1 - momentum, |
|
outputs=1 if not training else 5, |
|
) |
|
if not training: |
|
return out |
|
else: |
|
res, new_running_mean, new_running_var, saved_mean, saved_var = out |
|
new_running_mean.setType(running_mean.type()) |
|
new_running_var.setType(running_var.type()) |
|
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) |
|
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) |
|
return res |
|
|
|
|
|
@_onnx_symbolic("aten::native_layer_norm") |
|
@symbolic_helper.quantized_args(True, False, False, False) |
|
@symbolic_helper.parse_args("v", "is", "v", "v", "f") |
|
def native_layer_norm( |
|
g: jit_utils.GraphContext, |
|
input: _C.Value, |
|
normalized_shape: Sequence[int], |
|
weight: _C.Value, |
|
bias: _C.Value, |
|
eps: float, |
|
) -> tuple[_C.Value, _C.Value, _C.Value]: |
|
axes = [-i for i in range(len(normalized_shape), 0, -1)] |
|
|
|
two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) |
|
eps_cst = symbolic_helper._generate_wrapped_number(g, eps) |
|
|
|
if g.opset < 18: |
|
mean = g.op("ReduceMean", input, axes_i=axes) |
|
else: |
|
mean = g.op( |
|
"ReduceMean", |
|
input, |
|
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), |
|
) |
|
|
|
numerator = sub(g, input, mean) |
|
|
|
|
|
is_type_half = ( |
|
_type_utils.JitScalarType.from_value(numerator) |
|
== _type_utils.JitScalarType.HALF |
|
) |
|
if is_type_half: |
|
eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) |
|
numerator = g.op( |
|
"Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() |
|
) |
|
|
|
|
|
if g.opset < 18: |
|
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) |
|
else: |
|
variance = g.op( |
|
"ReduceMean", |
|
pow(g, numerator, two_cst), |
|
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), |
|
) |
|
|
|
denominator = sqrt(g, g.op("Add", variance, eps_cst)) |
|
normalized = g.op("Div", numerator, denominator) |
|
|
|
|
|
if is_type_half: |
|
input_dtype = _type_utils.JitScalarType.from_value(input) |
|
normalized = g.op( |
|
"Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() |
|
) |
|
|
|
if not (weight is None or symbolic_helper._is_none(weight)): |
|
normalized = mul(g, normalized, weight) |
|
if not (bias is None or symbolic_helper._is_none(bias)): |
|
normalized = add(g, normalized, bias) |
|
|
|
|
|
|
|
|
|
if is_type_half: |
|
denominator = g.op( |
|
"Cast", |
|
denominator, |
|
to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), |
|
) |
|
rdenominator = g.op("Reciprocal", denominator) |
|
else: |
|
rdenominator = reciprocal(g, denominator) |
|
|
|
return normalized, mean, rdenominator |
|
|
|
|
|
@_onnx_symbolic("aten::layer_norm") |
|
@symbolic_helper.quantized_args(True, False, False, False) |
|
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") |
|
def layer_norm( |
|
g: jit_utils.GraphContext, |
|
input: _C.Value, |
|
normalized_shape: Sequence[int], |
|
weight: _C.Value, |
|
bias: _C.Value, |
|
eps: float, |
|
cudnn_enable: bool, |
|
) -> _C.Value: |
|
normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) |
|
return normalized |
|
|
|
|
|
@_onnx_symbolic("aten::instance_norm") |
|
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") |
|
def instance_norm( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
running_mean, |
|
running_var, |
|
use_input_stats: bool, |
|
momentum: Number, |
|
eps: Number, |
|
cudnn_enabled: bool, |
|
): |
|
symbolic_helper.check_training_mode(use_input_stats, "instance_norm") |
|
channel_size = symbolic_helper._get_tensor_dim_size(input, 1) |
|
if weight is None or symbolic_helper._is_none(weight): |
|
if channel_size is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of instance_norm for unknown channel size.", |
|
input, |
|
) |
|
weight_value = torch.tensor( |
|
[1.0] * channel_size, |
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
|
) |
|
weight = g.op("Constant", value_t=weight_value) |
|
if bias is None or symbolic_helper._is_none(bias): |
|
if channel_size is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of instance_norm for unknown channel size.", |
|
input, |
|
) |
|
bias_value = torch.tensor( |
|
[0.0] * channel_size, |
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
|
) |
|
bias = g.op("Constant", value_t=bias_value) |
|
if ( |
|
running_mean is None |
|
or symbolic_helper._is_none(running_mean) |
|
or running_var is None |
|
or symbolic_helper._is_none(running_var) |
|
): |
|
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) |
|
else: |
|
input_size = symbolic_helper._get_tensor_sizes(input) |
|
|
|
|
|
|
|
input_size_reshape = input_size.copy() |
|
n = input_size[0] |
|
if n is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of instance_norm training for unknown " |
|
"batch size.", |
|
input, |
|
) |
|
c = input_size[1] |
|
input_size_reshape[0] = 1 |
|
input_size_reshape[1] = n * c |
|
weight_ = repeat( |
|
g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) |
|
) |
|
bias_ = repeat( |
|
g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) |
|
) |
|
running_mean_ = repeat( |
|
g, |
|
running_mean, |
|
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), |
|
) |
|
running_var_ = repeat( |
|
g, |
|
running_var, |
|
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), |
|
) |
|
input_reshaped = g.op( |
|
"Reshape", |
|
input, |
|
g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), |
|
) |
|
out = batch_norm( |
|
g, |
|
input_reshaped, |
|
weight_, |
|
bias_, |
|
running_mean_, |
|
running_var_, |
|
use_input_stats, |
|
momentum, |
|
eps, |
|
cudnn_enabled, |
|
) |
|
return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) |
|
|
|
|
|
@_onnx_symbolic("aten::unfold") |
|
@symbolic_helper.parse_args("v", "i", "i", "i") |
|
def unfold(g: jit_utils.GraphContext, input, dimension, size, step): |
|
sizes = symbolic_helper._get_tensor_sizes(input) |
|
|
|
try: |
|
sizedim = sizes[dimension] |
|
except Exception: |
|
|
|
|
|
sizedim = None |
|
if sizedim is not None: |
|
low_indices = range(0, sizedim, step) |
|
hi_indices = range(size, sizedim + 1, step) |
|
stack = [ |
|
symbolic_helper._slice_helper( |
|
g, input, axes=[dimension], starts=[low], ends=[hi] |
|
) |
|
for low, hi in zip(low_indices, hi_indices) |
|
] |
|
ndim = len(sizes) |
|
perm = list(range(0, ndim)) |
|
perm.append(perm.pop(dimension)) |
|
unsqueeze = [ |
|
symbolic_helper._unsqueeze_helper( |
|
g, g.op("Transpose", t, perm_i=perm), [dimension] |
|
) |
|
for t in stack |
|
] |
|
return g.op("Concat", *unsqueeze, axis_i=dimension) |
|
else: |
|
return symbolic_helper._unimplemented( |
|
"Unfold", "input size not accessible", input |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::elu") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "t", "t", "t") |
|
def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): |
|
if scale and scale != 1.0: |
|
return symbolic_helper._unimplemented( |
|
"scale", "does not support scale in Elu", scale |
|
) |
|
if input_scale and input_scale != 1.0: |
|
return symbolic_helper._unimplemented( |
|
"input_scale", "does not support input_scale in Elu", input_scale |
|
) |
|
|
|
return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) |
|
|
|
|
|
@_onnx_symbolic("aten::selu") |
|
@symbolic_helper.quantized_args(True) |
|
def selu(g: jit_utils.GraphContext, input): |
|
return g.op("Selu", input) |
|
|
|
|
|
@_onnx_symbolic("aten::index_select") |
|
@symbolic_helper.parse_args("v", "i", "v") |
|
def index_select(g: jit_utils.GraphContext, self, dim, index): |
|
|
|
|
|
|
|
return symbolic_helper._select_helper(g, self, dim, index) |
|
|
|
|
|
@_onnx_symbolic("aten::index_put") |
|
def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): |
|
if symbolic_helper._is_packed_list(indices_list_value): |
|
indices_list = symbolic_helper._unpack_list(indices_list_value) |
|
else: |
|
indices_list = [indices_list_value] |
|
|
|
accumulate = symbolic_helper._parse_arg(accumulate, "b") |
|
|
|
if len(indices_list) == 0: |
|
if accumulate: |
|
return add(g, self, values) |
|
return values |
|
symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) |
|
|
|
|
|
@_onnx_symbolic("aten::index_fill") |
|
def index_fill(g: jit_utils.GraphContext, self, dim, index, value): |
|
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( |
|
g, self, dim, index |
|
) |
|
value = symbolic_helper._maybe_get_scalar(value) |
|
value = symbolic_helper._if_scalar_type_as(value, self) |
|
expanded_value = expand(g, value, expanded_index_shape, None) |
|
|
|
return scatter(g, self, dim, expanded_index, expanded_value) |
|
|
|
|
|
@_onnx_symbolic("aten::index_copy") |
|
def index_copy(g: jit_utils.GraphContext, self, dim, index, source): |
|
_expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( |
|
g, self, dim, index |
|
) |
|
return scatter(g, self, dim, expanded_index, source) |
|
|
|
|
|
@_onnx_symbolic("aten::bucketize") |
|
@symbolic_helper.parse_args("v", "v", "b", "b") |
|
def bucketize( |
|
g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False |
|
): |
|
out_type = _C_onnx.TensorProtoDataType.INT64 |
|
if out_int32: |
|
out_type = _C_onnx.TensorProtoDataType.INT32 |
|
|
|
|
|
new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) |
|
|
|
|
|
tensor_rank = symbolic_helper._get_tensor_rank(self) |
|
assert tensor_rank is not None |
|
unsqueeze_axes = list(range(1, tensor_rank + 1)) |
|
expanded_boundaries = expand( |
|
g, |
|
symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), |
|
new_shape, |
|
None, |
|
) |
|
|
|
|
|
|
|
|
|
if right: |
|
cond = ge(g, self, expanded_boundaries) |
|
else: |
|
cond = gt(g, self, expanded_boundaries) |
|
cond_out = g.op("Cast", cond, to_i=out_type) |
|
|
|
|
|
|
|
return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) |
|
|
|
|
|
@_onnx_symbolic("aten::type_as") |
|
def type_as(g: jit_utils.GraphContext, self, other): |
|
self_dtype = symbolic_helper._try_get_scalar_type(self) |
|
other_dtype = symbolic_helper._try_get_scalar_type(other) |
|
if self_dtype == other_dtype and self_dtype is not None: |
|
return self |
|
if other_dtype is not None: |
|
return g.op( |
|
"Cast", |
|
self, |
|
to_i=other_dtype.onnx_type(), |
|
) |
|
|
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of type_as for tensor " |
|
"of unknown dtype. Please check if the dtype of the " |
|
"parameter passed to the type_as function is correct.", |
|
other, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::cosine_similarity") |
|
@symbolic_helper.parse_args("v", "v", "i", "f") |
|
def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): |
|
cross = symbolic_helper._reducesum_helper( |
|
g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 |
|
) |
|
x1_l2 = symbolic_helper._reducesum_helper( |
|
g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 |
|
) |
|
x2_l2 = symbolic_helper._reducesum_helper( |
|
g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 |
|
) |
|
div_tens = max( |
|
g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) |
|
) |
|
return div(g, cross, div_tens) |
|
|
|
|
|
@_onnx_symbolic("aten::pairwise_distance") |
|
def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): |
|
if not symbolic_helper._is_value(eps): |
|
eps = g.op("Constant", value_t=torch.tensor([eps])) |
|
inv_p = div( |
|
g, |
|
g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), |
|
add(g, p, eps), |
|
) |
|
summation = symbolic_helper._reducesum_helper( |
|
g, |
|
pow(g, sub(g, input1, input2), p), |
|
axes_i=[-1], |
|
keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), |
|
) |
|
return pow(g, summation, inv_p) |
|
|
|
|
|
@_onnx_symbolic("aten::clone") |
|
|
|
def clone(g: jit_utils.GraphContext, input, unused_memory_format): |
|
return input |
|
|
|
|
|
@_onnx_symbolic("aten::abs") |
|
def abs(g: jit_utils.GraphContext, self): |
|
return g.op("Abs", self) |
|
|
|
|
|
@_onnx_symbolic("aten::log") |
|
def log(g: jit_utils.GraphContext, self): |
|
return g.op("Log", self) |
|
|
|
|
|
@_onnx_symbolic("aten::log1p") |
|
def log1p(g: jit_utils.GraphContext, self): |
|
return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) |
|
|
|
|
|
@_onnx_symbolic("aten::log10") |
|
def log10(g: jit_utils.GraphContext, self): |
|
_ln10 = 2.30258509299404568401 |
|
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) |
|
|
|
|
|
@_onnx_symbolic("aten::pow") |
|
def pow(g: jit_utils.GraphContext, self, exponent): |
|
f_dtype = _type_utils.JitScalarType.from_value(self) |
|
if not symbolic_helper._is_fp(self): |
|
f_dtype = _type_utils.JitScalarType.FLOAT |
|
self = g.op("Cast", self, to_i=f_dtype.onnx_type()) |
|
if not symbolic_helper._is_fp(exponent): |
|
exponent = g.op( |
|
"Cast", |
|
exponent, |
|
to_i=f_dtype.onnx_type(), |
|
) |
|
pow = g.op("Pow", self, exponent) |
|
return pow |
|
|
|
|
|
@_onnx_symbolic("aten::clamp") |
|
def clamp(g: jit_utils.GraphContext, self, min, max): |
|
|
|
|
|
if symbolic_helper._is_none(min): |
|
return clamp_max(g, self, max) |
|
elif symbolic_helper._is_none(max): |
|
return clamp_min(g, self, min) |
|
else: |
|
if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, |
|
"Clip", |
|
self, |
|
min_f=symbolic_helper._parse_arg(min, "f"), |
|
max_f=symbolic_helper._parse_arg(max, "f"), |
|
opset_before=12, |
|
) |
|
else: |
|
return clamp_max(g, clamp_min(g, self, min), max) |
|
|
|
|
|
@_onnx_symbolic("aten::clamp_min") |
|
@symbolic_helper.parse_args("v", "v") |
|
def clamp_min(g: jit_utils.GraphContext, self, min): |
|
if symbolic_helper._is_constant(min): |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 |
|
) |
|
else: |
|
dtype = _type_utils.JitScalarType.from_value(self) |
|
min = g.op("Cast", min, to_i=dtype.onnx_type()) |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Max", self, min, opset_before=12 |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::clamp_max") |
|
@symbolic_helper.parse_args("v", "v") |
|
def clamp_max(g: jit_utils.GraphContext, self, max): |
|
if symbolic_helper._is_constant(max): |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 |
|
) |
|
else: |
|
dtype = _type_utils.JitScalarType.from_value(self) |
|
max = g.op("Cast", max, to_i=dtype.onnx_type()) |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Min", self, max, opset_before=12 |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::max") |
|
|
|
|
|
|
|
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): |
|
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) |
|
|
|
|
|
@_onnx_symbolic("aten::maximum") |
|
@symbolic_helper.quantized_args(True, True) |
|
def maximum(g: jit_utils.GraphContext, input, other): |
|
return max(g, input, dim_or_y=other) |
|
|
|
|
|
@_onnx_symbolic("aten::min") |
|
|
|
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): |
|
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) |
|
|
|
|
|
@_onnx_symbolic("aten::minimum") |
|
@symbolic_helper.quantized_args(True, True) |
|
def minimum(g: jit_utils.GraphContext, input, other): |
|
return min(g, input, dim_or_y=other) |
|
|
|
|
|
@_onnx_symbolic("aten::amax") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "is", "i") |
|
def amax(g: jit_utils.GraphContext, self, dim, keepdim): |
|
return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) |
|
|
|
|
|
@_onnx_symbolic("aten::amin") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "is", "i") |
|
def amin(g: jit_utils.GraphContext, self, dim, keepdim): |
|
return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) |
|
|
|
|
|
@_onnx_symbolic("aten::aminmax") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "v", "i") |
|
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): |
|
reduce_kwargs = {"keepdims_i": keepdim} |
|
if not symbolic_helper._is_none(dim): |
|
dim = symbolic_helper._get_const(dim, "i", "dim") |
|
reduce_kwargs["axes_i"] = [dim] |
|
|
|
return g.op("ReduceMin", self, **reduce_kwargs), g.op( |
|
"ReduceMax", self, **reduce_kwargs |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::exp") |
|
def exp(g: jit_utils.GraphContext, self): |
|
return g.op("Exp", self) |
|
|
|
|
|
@_onnx_symbolic("aten::dropout_") |
|
@_onnx_symbolic("aten::dropout") |
|
@symbolic_helper.parse_args("v", "f", "i") |
|
def dropout(g: jit_utils.GraphContext, input, p, train): |
|
symbolic_helper.check_training_mode(train, "dropout") |
|
|
|
if not train: |
|
return input |
|
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) |
|
return r |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::alpha_dropout_", |
|
decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::feature_alpha_dropout_", |
|
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::feature_dropout_", |
|
decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::feature_alpha_dropout", |
|
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::alpha_dropout", |
|
decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::feature_dropout", |
|
decorate=[symbolic_helper._apply_params("aten::feature_dropout")], |
|
) |
|
def _unsupported_dropout(name: str): |
|
@symbolic_helper.parse_args("v", "none", "b") |
|
def feature_dropout(g, input, p, train): |
|
|
|
if train: |
|
return symbolic_helper._unimplemented(name, "training mode", input) |
|
return input |
|
|
|
return feature_dropout |
|
|
|
|
|
@_onnx_symbolic("aten::norm") |
|
@symbolic_helper.parse_args("v", "t", "is", "i", "v") |
|
def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): |
|
if p == 1: |
|
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") |
|
elif p == 2: |
|
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") |
|
else: |
|
raise errors.SymbolicValueError( |
|
"ONNX export only p-norms with p of 1 or 2", self |
|
) |
|
result = f(g, self, dim=dim, keepdim=keepdim) |
|
if dtype is not None: |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
return result |
|
|
|
|
|
@_onnx_symbolic("aten::conv_tbc") |
|
@symbolic_helper.parse_args("v", "v", "v", "i") |
|
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): |
|
|
|
|
|
|
|
|
|
|
|
input = g.op("Transpose", input, perm_i=[1, 2, 0]) |
|
weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) |
|
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) |
|
return g.op("Transpose", conv, perm_i=[2, 0, 1]) |
|
|
|
|
|
@_onnx_symbolic("aten::_unique") |
|
@symbolic_helper.parse_args("v", "i", "i") |
|
def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): |
|
return symbolic_helper._onnx_unsupported("_unique", input) |
|
|
|
|
|
@_onnx_symbolic("aten::_unique2") |
|
@symbolic_helper.parse_args("v", "i", "i", "i") |
|
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): |
|
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Byte") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Char") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Short") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Int") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Long") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Half") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Float") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Double") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) |
|
|
|
|
|
@_onnx_symbolic("aten::_cast_Bool") |
|
@deprecated("Avoid using this function and create a Cast node instead") |
|
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): |
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) |
|
|
|
|
|
@_onnx_symbolic("aten::empty") |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
|
def empty( |
|
g: jit_utils.GraphContext, |
|
sizes, |
|
dtype, |
|
layout, |
|
device, |
|
pin_memory=False, |
|
memory_format=None, |
|
): |
|
return zeros(g, sizes, dtype, layout, device, pin_memory) |
|
|
|
|
|
@_onnx_symbolic("aten::empty_like") |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
|
def empty_like( |
|
g: jit_utils.GraphContext, |
|
input, |
|
dtype=None, |
|
layout=None, |
|
device=None, |
|
pin_memory=False, |
|
memory_format=None, |
|
): |
|
return zeros_like(g, input, dtype, layout, device, pin_memory) |
|
|
|
|
|
@_onnx_symbolic("aten::new_empty") |
|
def new_empty( |
|
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False |
|
): |
|
self_dtype = symbolic_helper._try_get_scalar_type(self) |
|
if symbolic_helper._is_none(dtype) and self_dtype is not None: |
|
dtype = self_dtype |
|
return empty(g, sizes, dtype, layout, device, pin_memory) |
|
|
|
|
|
@_onnx_symbolic("aten::scalar_tensor") |
|
def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if dtype is None: |
|
dtype = _type_utils.JitScalarType.FLOAT |
|
scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
return scalar |
|
|
|
|
|
@_onnx_symbolic("aten::tensor") |
|
def tensor( |
|
g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False |
|
): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if symbolic_helper._is_packed_list(data): |
|
if dtype is None: |
|
dtype = _type_utils.JitScalarType.from_value( |
|
symbolic_helper._unpack_list(data)[0] |
|
) |
|
input_list = [] |
|
for t in symbolic_helper._unpack_list(data): |
|
shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) |
|
t = symbolic_helper._reshape_helper(g, t, shape_reference) |
|
t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
input_list.append(t) |
|
return g.op("Concat", *input_list, axis_i=0) |
|
else: |
|
if dtype is None: |
|
dtype = _type_utils.JitScalarType.from_value(data) |
|
if symbolic_helper._is_list(data) and ( |
|
symbolic_helper._is_tensor_list(data) |
|
or symbolic_helper._is_scalar_list(data) |
|
): |
|
data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) |
|
return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
|
|
|
|
@_onnx_symbolic("aten::as_tensor") |
|
def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): |
|
return tensor(g, data, dtype, device) |
|
|
|
|
|
@_onnx_symbolic("aten::zeros") |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v") |
|
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): |
|
|
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.FLOAT |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
sizes_ = symbolic_helper._maybe_get_const(sizes, "is") |
|
if isinstance(sizes_, list) and len(sizes_) == 0: |
|
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) |
|
return g.op( |
|
"ConstantOfShape", |
|
sizes, |
|
value_t=torch.tensor([0], dtype=scalar_type.dtype()), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::zeros_like") |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
|
def zeros_like( |
|
g: jit_utils.GraphContext, |
|
input, |
|
dtype=None, |
|
layout=None, |
|
device=None, |
|
pin_memory=False, |
|
memory_format=None, |
|
): |
|
shape = g.op("Shape", input) |
|
if symbolic_helper._is_none(dtype): |
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
input, _type_utils.JitScalarType.FLOAT |
|
) |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
return g.op( |
|
"ConstantOfShape", |
|
shape, |
|
value_t=torch.tensor([0], dtype=scalar_type.dtype()), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::new_zeros") |
|
def new_zeros( |
|
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False |
|
): |
|
self_dtype = symbolic_helper._try_get_scalar_type(self) |
|
|
|
if symbolic_helper._is_none(dtype) and self_dtype is not None: |
|
dtype = self_dtype |
|
return zeros(g, sizes, dtype, layout, device, pin_memory) |
|
|
|
|
|
@_onnx_symbolic("aten::zero") |
|
def zero(g: jit_utils.GraphContext, self): |
|
self_dtype = symbolic_helper._try_get_scalar_type(self) |
|
return zeros_like(g, self, self_dtype) |
|
|
|
|
|
@_onnx_symbolic("aten::ones") |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v") |
|
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.FLOAT |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
sizes_ = symbolic_helper._maybe_get_const(sizes, "is") |
|
if isinstance(sizes_, list) and len(sizes_) == 0: |
|
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) |
|
return g.op( |
|
"ConstantOfShape", |
|
sizes, |
|
value_t=torch.tensor([1], dtype=scalar_type.dtype()), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::ones_like") |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
|
def ones_like( |
|
g: jit_utils.GraphContext, |
|
input, |
|
dtype=None, |
|
layout=None, |
|
device=None, |
|
pin_memory=False, |
|
memory_format=None, |
|
): |
|
shape = g.op("Shape", input) |
|
if symbolic_helper._is_none(dtype): |
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
input, _type_utils.JitScalarType.FLOAT |
|
) |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
return g.op( |
|
"ConstantOfShape", |
|
shape, |
|
value_t=torch.tensor([1], dtype=scalar_type.dtype()), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::new_ones") |
|
def new_ones( |
|
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False |
|
): |
|
self_dtype = symbolic_helper._try_get_scalar_type(self) |
|
if symbolic_helper._is_none(dtype) and self_dtype is not None: |
|
dtype = self_dtype |
|
return ones(g, sizes, dtype, layout, device, pin_memory) |
|
|
|
|
|
@_onnx_symbolic("aten::full") |
|
def full( |
|
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False |
|
): |
|
const_value = symbolic_helper._maybe_get_const(value, "t") |
|
if symbolic_helper._is_value(const_value): |
|
dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype |
|
tmp = zeros(g, sizes, dtype, layout, device) |
|
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) |
|
else: |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.FLOAT |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
sizes_ = symbolic_helper._maybe_get_const(sizes, "is") |
|
if isinstance(sizes_, list) and len(sizes_) == 0: |
|
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) |
|
return g.op( |
|
"ConstantOfShape", |
|
sizes, |
|
value_t=const_value.view(1).to(scalar_type.dtype()), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::full_like") |
|
def full_like( |
|
g: jit_utils.GraphContext, |
|
input, |
|
fill_value, |
|
dtype=None, |
|
layout=None, |
|
device=None, |
|
pin_memory=False, |
|
memory_format=None, |
|
): |
|
fill_value = symbolic_helper._maybe_get_const(fill_value, "f") |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
input, _type_utils.JitScalarType.FLOAT |
|
) |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
if symbolic_helper._is_value(fill_value): |
|
tmp = zeros_like(g, input, dtype, layout, device) |
|
fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) |
|
return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) |
|
else: |
|
shape = g.op("Shape", input) |
|
return g.op( |
|
"ConstantOfShape", |
|
shape, |
|
value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::new_full") |
|
def new_full( |
|
g: jit_utils.GraphContext, |
|
self, |
|
size, |
|
fill_value, |
|
dtype, |
|
layout, |
|
device, |
|
pin_memory=False, |
|
): |
|
self_dtype = symbolic_helper._try_get_scalar_type(self) |
|
if symbolic_helper._is_none(dtype) and self_dtype is not None: |
|
dtype = self_dtype |
|
return full(g, size, fill_value, dtype, layout, device, pin_memory) |
|
|
|
|
|
@_onnx_symbolic("aten::eye") |
|
def eye(g: jit_utils.GraphContext, *args): |
|
if len(args) == 5: |
|
|
|
n, dtype, layout, device, _pin_memory = args |
|
dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) |
|
shape = g.op("Concat", dim_size, dim_size, axis_i=0) |
|
tensor = zeros(g, shape, dtype, layout, device) |
|
return g.op("EyeLike", tensor) |
|
if len(args) == 6: |
|
|
|
n, m, dtype, layout, device, _pin_memory = args |
|
shape = g.op( |
|
"Concat", |
|
symbolic_helper._unsqueeze_helper(g, n, [0]), |
|
symbolic_helper._unsqueeze_helper(g, m, [0]), |
|
axis_i=0, |
|
) |
|
tensor = zeros(g, shape, dtype, layout, device) |
|
return g.op("EyeLike", tensor) |
|
|
|
return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") |
|
|
|
|
|
@_onnx_symbolic("aten::slice") |
|
def slice(g: jit_utils.GraphContext, self, *args): |
|
if len(args) == 4: |
|
|
|
dim, start, end, step = args |
|
step = symbolic_helper._parse_arg(step, "i") |
|
if step != 1: |
|
raise errors.SymbolicValueError("step!=1 is currently not supported", self) |
|
is_start_none = start.node().kind() == "prim::Constant" and isinstance( |
|
start.type(), _C.NoneType |
|
) |
|
is_end_none = end.node().kind() == "prim::Constant" and isinstance( |
|
end.type(), _C.NoneType |
|
) |
|
is_start_onnx_const = start.node().kind() == "onnx::Constant" |
|
is_end_onnx_const = end.node().kind() == "onnx::Constant" |
|
if ( |
|
((not is_start_none) and (not is_start_onnx_const)) |
|
or ((not is_end_none) and (not is_end_onnx_const)) |
|
or dim.node().kind() != "onnx::Constant" |
|
): |
|
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " |
|
"is a deprecated experimental op. Please use statically allocated " |
|
"variables or export to a higher opset version.", |
|
self, |
|
) |
|
else: |
|
start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) |
|
end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) |
|
dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) |
|
return g.op( |
|
"DynamicSlice", |
|
self, |
|
start_unsqueezed, |
|
end_unsqueezed, |
|
dim_unsqueezed, |
|
) |
|
else: |
|
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") |
|
end = ( |
|
_constants.INT64_MAX |
|
if is_end_none |
|
else symbolic_helper._parse_arg(end, "i") |
|
) |
|
dim = symbolic_helper._parse_arg(dim, "i") |
|
return symbolic_helper._slice_helper( |
|
g, self, axes=[dim], starts=[start], ends=[end] |
|
) |
|
elif len(args) == 3: |
|
|
|
start, end, step = args |
|
dim = 0 |
|
is_start_none = start.node().kind() == "prim::Constant" and isinstance( |
|
start.type(), _C.NoneType |
|
) |
|
is_end_none = end.node().kind() == "prim::Constant" and isinstance( |
|
end.type(), _C.NoneType |
|
) |
|
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") |
|
end = ( |
|
_constants.INT64_MAX |
|
if is_end_none |
|
else symbolic_helper._parse_arg(end, "i") |
|
) |
|
return symbolic_helper._slice_helper( |
|
g, self, axes=[dim], starts=[start], ends=[end] |
|
) |
|
|
|
return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") |
|
|
|
|
|
@_onnx_symbolic("aten::hardtanh") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "f", "f") |
|
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::hardswish") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v") |
|
def hardswish(g: jit_utils.GraphContext, self): |
|
hs = hardsigmoid(g, self) |
|
return g.op("Mul", self, hs) |
|
|
|
|
|
@_onnx_symbolic("aten::hardsigmoid") |
|
|
|
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) |
|
@symbolic_helper.parse_args("v") |
|
def hardsigmoid(g: jit_utils.GraphContext, self): |
|
|
|
|
|
return g.op("HardSigmoid", self, alpha_f=1 / 6) |
|
|
|
|
|
@_onnx_symbolic("aten::tanhshrink") |
|
@symbolic_helper.parse_args("v") |
|
def tanhshrink(g: jit_utils.GraphContext, self): |
|
return g.op("Sub", self, tanh(g, self)) |
|
|
|
|
|
@_onnx_symbolic("aten::hardshrink") |
|
@symbolic_helper.parse_args("v", "f") |
|
def hardshrink(g: jit_utils.GraphContext, self, lambd): |
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
self, _type_utils.JitScalarType.FLOAT |
|
) |
|
lambd_op = g.op( |
|
"Constant", |
|
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), |
|
) |
|
cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) |
|
return g.op( |
|
"Where", |
|
cond, |
|
self, |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor(0, dtype=scalar_type.dtype()), |
|
), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::softshrink") |
|
@symbolic_helper.parse_args("v", "f") |
|
def softshrink(g: jit_utils.GraphContext, self, lambd): |
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
self, _type_utils.JitScalarType.FLOAT |
|
) |
|
lambd_op = g.op( |
|
"Constant", |
|
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), |
|
) |
|
gt_cond = gt(g, self, lambd_op) |
|
gt_out = g.op( |
|
"Where", |
|
gt_cond, |
|
sub(g, self, lambd_op), |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor(0, dtype=scalar_type.dtype()), |
|
), |
|
) |
|
lt_cond = lt(g, self, neg(g, lambd_op)) |
|
lt_out = g.op( |
|
"Where", |
|
lt_cond, |
|
add(g, self, lambd_op), |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor(0, dtype=scalar_type.dtype()), |
|
), |
|
) |
|
return add(g, gt_out, lt_out) |
|
|
|
|
|
@_onnx_symbolic("aten::alias") |
|
def alias(g: jit_utils.GraphContext, self): |
|
return self |
|
|
|
|
|
@_onnx_symbolic("aten::unsqueeze") |
|
@symbolic_helper.parse_args("v", "i") |
|
def unsqueeze(g: jit_utils.GraphContext, self, dim): |
|
"""Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" |
|
|
|
if dim < 0: |
|
rank = symbolic_helper._get_tensor_rank(self) |
|
if rank is not None: |
|
warnings.warn( |
|
"ONNX export unsqueeze with negative axis " |
|
+ str(dim) |
|
+ " might cause the onnx model to be incorrect. " |
|
+ "Negative axis is not supported in ONNX. " |
|
+ "Axis is converted to " |
|
+ str(dim + rank + 1) |
|
+ " based on input shape at export time. " |
|
+ "Passing an tensor of different rank in execution will be incorrect." |
|
) |
|
dim = dim + rank + 1 |
|
else: |
|
return symbolic_helper._unimplemented( |
|
"unsqueeze", "negative axis with unknown input rank", self |
|
) |
|
|
|
return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) |
|
|
|
|
|
@_onnx_symbolic("aten::sort") |
|
|
|
@symbolic_helper.parse_args("v", "i", "i", "none") |
|
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): |
|
if out is not None: |
|
symbolic_helper._unimplemented( |
|
"Sort", "Out parameter is not supported for sort", self |
|
) |
|
self_sizes = symbolic_helper._get_tensor_sizes(self) |
|
try: |
|
dim_size = self_sizes[dim] |
|
except Exception: |
|
|
|
|
|
dim_size = None |
|
|
|
if dim_size is None: |
|
return symbolic_helper._unimplemented("Sort", "input size not accessible", self) |
|
|
|
return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) |
|
|
|
|
|
@_onnx_symbolic("aten::numel") |
|
def numel(g: jit_utils.GraphContext, self): |
|
return symbolic_helper._numel_helper(g, self) |
|
|
|
|
|
@_onnx_symbolic("aten::topk") |
|
|
|
@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") |
|
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): |
|
if out is not None: |
|
symbolic_helper._unimplemented( |
|
"TopK", "Out parameter is not supported for topk", self |
|
) |
|
if not largest: |
|
symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) |
|
|
|
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) |
|
|
|
|
|
@_onnx_symbolic("prim::convert_element_type") |
|
def convert_element_type(g: jit_utils.GraphContext, self, *args): |
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
|
|
|
|
@_onnx_symbolic("aten::to") |
|
def to(g: jit_utils.GraphContext, self, *args): |
|
def is_aten_to_device_only(args): |
|
if len(args) == 4: |
|
|
|
return ( |
|
args[0].node().kind() == "prim::device" |
|
or args[0].type().isSubtypeOf(_C.ListType.ofInts()) |
|
or isinstance(args[0].type(), _C.DeviceObjType) |
|
) |
|
elif len(args) == 5: |
|
|
|
|
|
dtype = symbolic_helper._get_const(args[1], "i", "dtype") |
|
return dtype is None |
|
elif len(args) in (6, 7): |
|
|
|
|
|
|
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
|
return dtype is None |
|
return False |
|
|
|
|
|
if is_aten_to_device_only(args): |
|
return self |
|
|
|
if len(args) == 4: |
|
|
|
|
|
|
|
dtype = args[0] |
|
if ( |
|
symbolic_helper._is_value(args[0]) |
|
and args[0].node().kind() == "onnx::Constant" |
|
): |
|
tval = symbolic_helper._node_get(args[0].node(), "value") |
|
if isinstance(tval, torch.Tensor): |
|
if len(tval.shape) == 0: |
|
tval = tval.item() |
|
dtype = int(tval) |
|
else: |
|
dtype = tval |
|
|
|
if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): |
|
|
|
dtype = _type_utils.JitScalarType.from_value(args[0]) |
|
return g.op( |
|
"Cast", |
|
self, |
|
to_i=dtype.onnx_type(), |
|
) |
|
else: |
|
|
|
|
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
elif len(args) == 5: |
|
|
|
dtype = symbolic_helper._get_const(args[1], "i", "dtype") |
|
|
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
elif len(args) == 6: |
|
|
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
|
|
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
elif len(args) == 7: |
|
|
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
|
|
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
|
|
|
return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) |
|
|
|
|
|
@_onnx_symbolic("aten::repeat") |
|
def repeat(g: jit_utils.GraphContext, self, repeats): |
|
dtype = _type_utils.JitScalarType.INT64 |
|
shape_ = ones_like(g, repeats, dtype) |
|
self = g.op("Expand", self, shape_) |
|
return g.op("Tile", self, repeats) |
|
|
|
|
|
@_onnx_symbolic("aten::repeat_interleave") |
|
def repeat_interleave( |
|
g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None |
|
): |
|
repeats_dim = symbolic_helper._get_tensor_rank(repeats) |
|
repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) |
|
input_sizes = symbolic_helper._get_tensor_sizes(self) |
|
if repeats_dim is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", |
|
self, |
|
) |
|
if repeats_sizes is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of repeat_interleave for unknown repeats size.", |
|
self, |
|
) |
|
if input_sizes is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of repeat_interleave for unknown input size.", |
|
self, |
|
) |
|
|
|
|
|
|
|
if symbolic_helper._is_none(dim): |
|
self = symbolic_helper._reshape_helper( |
|
g, self, g.op("Constant", value_t=torch.tensor([-1])) |
|
) |
|
dim = torch.tensor(0, dtype=torch.int64) |
|
else: |
|
dim = symbolic_helper._maybe_get_scalar(dim) |
|
|
|
|
|
if dim < 0: |
|
dim += len(input_sizes) |
|
|
|
input_sizes_temp = input_sizes.copy() |
|
for idx, input_size in enumerate(input_sizes): |
|
if input_size is None: |
|
input_sizes[idx], input_sizes_temp[idx] = 0, -1 |
|
|
|
|
|
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): |
|
if input_sizes[dim] == 0: |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"repeat_interleave", |
|
9, |
|
13, |
|
"Unsupported along dimension with unknown input size", |
|
self, |
|
) |
|
return symbolic_helper._repeat_interleave_single_value_repeat_helper( |
|
g, self, repeats, dim |
|
) |
|
|
|
|
|
elif repeats_dim == 1: |
|
if input_sizes[dim] == 0: |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"repeat_interleave", |
|
9, |
|
13, |
|
"Unsupported along dimension with unknown input size", |
|
self, |
|
) |
|
if repeats_sizes[0] is None: |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"repeat_interleave", |
|
9, |
|
13, |
|
"Unsupported for cases with dynamic repeats", |
|
self, |
|
) |
|
assert repeats_sizes[0] == input_sizes[dim], ( |
|
"repeats must have the same size as input along dim" |
|
) |
|
reps = repeats_sizes[0] |
|
else: |
|
raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) |
|
|
|
final_splits = [] |
|
r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) |
|
i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) |
|
input_sizes[dim], input_sizes_temp[dim] = -1, 1 |
|
for idx, r_split in enumerate(r_splits): |
|
i_split = unsqueeze(g, i_splits[idx], dim + 1) |
|
r_concat = [ |
|
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), |
|
r_split, |
|
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), |
|
] |
|
r_concat = g.op("Concat", *r_concat, axis_i=0) |
|
i_split = expand(g, i_split, r_concat, None) |
|
i_split = symbolic_helper._reshape_helper( |
|
g, |
|
i_split, |
|
g.op("Constant", value_t=torch.LongTensor(input_sizes)), |
|
allowzero=0, |
|
) |
|
final_splits.append(i_split) |
|
return g.op("Concat", *final_splits, axis_i=dim) |
|
|
|
|
|
@_onnx_symbolic("aten::pixel_shuffle") |
|
@symbolic_helper.parse_args("v", "i") |
|
def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): |
|
dims = symbolic_helper._get_tensor_sizes(self) |
|
if len(dims) != 4: |
|
return symbolic_helper._unimplemented( |
|
"pixel_shuffle", "only support 4d input", self |
|
) |
|
if any(i is None for i in dims[1:]): |
|
after_view = symbolic_helper._reshape_helper( |
|
g, |
|
symbolic_helper._unsqueeze_helper(g, self, [2, 3]), |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), |
|
), |
|
allowzero=0, |
|
) |
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) |
|
|
|
reshape_h = symbolic_helper._reshape_helper( |
|
g, |
|
after_transpose, |
|
g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), |
|
allowzero=0, |
|
) |
|
reshape_w = symbolic_helper._reshape_helper( |
|
g, |
|
reshape_h, |
|
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), |
|
allowzero=0, |
|
) |
|
return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) |
|
else: |
|
output_channel = dims[1] // upscale_factor // upscale_factor |
|
after_view = symbolic_helper._reshape_helper( |
|
g, |
|
self, |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
[ |
|
-1, |
|
output_channel, |
|
upscale_factor, |
|
upscale_factor, |
|
dims[2], |
|
dims[3], |
|
] |
|
), |
|
), |
|
allowzero=0, |
|
) |
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) |
|
return symbolic_helper._reshape_helper( |
|
g, |
|
after_transpose, |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
[ |
|
-1, |
|
output_channel, |
|
dims[2] * upscale_factor, |
|
dims[3] * upscale_factor, |
|
] |
|
), |
|
), |
|
allowzero=0, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::pixel_unshuffle") |
|
@symbolic_helper.parse_args("v", "i") |
|
def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): |
|
dims = symbolic_helper._get_tensor_sizes(self) |
|
if len(dims) != 4: |
|
return symbolic_helper._unimplemented( |
|
"pixel_shuffle", "only support 4d input", self |
|
) |
|
if any(i is None for i in dims[1:]): |
|
|
|
reshape_h = symbolic_helper._reshape_helper( |
|
g, |
|
symbolic_helper._unsqueeze_helper(g, self, [3]), |
|
g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), |
|
allowzero=0, |
|
) |
|
reshape_w = symbolic_helper._reshape_helper( |
|
g, |
|
reshape_h, |
|
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), |
|
allowzero=0, |
|
) |
|
after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) |
|
final_reshape = symbolic_helper._reshape_helper( |
|
g, |
|
after_transpose, |
|
g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), |
|
allowzero=0, |
|
) |
|
return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) |
|
else: |
|
output_channel = dims[1] * downscale_factor * downscale_factor |
|
after_view = symbolic_helper._reshape_helper( |
|
g, |
|
self, |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
[ |
|
-1, |
|
dims[1], |
|
dims[2] // downscale_factor, |
|
downscale_factor, |
|
dims[3] // downscale_factor, |
|
downscale_factor, |
|
] |
|
), |
|
), |
|
allowzero=0, |
|
) |
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) |
|
return symbolic_helper._reshape_helper( |
|
g, |
|
after_transpose, |
|
g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
[ |
|
-1, |
|
output_channel, |
|
dims[2] // downscale_factor, |
|
dims[3] // downscale_factor, |
|
] |
|
), |
|
), |
|
allowzero=0, |
|
) |
|
|
|
|
|
def _generic_rnn( |
|
g: jit_utils.GraphContext, |
|
variant, |
|
input, |
|
initial_states, |
|
all_weights, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_first=None, |
|
batch_sizes=None, |
|
): |
|
warnings.warn( |
|
"Exporting a model to ONNX with a batch_size other than 1, " |
|
+ "with a variable length with " |
|
+ variant |
|
+ " can cause an error " |
|
+ "when running the ONNX model with a different batch size. " |
|
+ "Make sure to save the model with a batch size of 1, " |
|
+ "or define the initial states (h0/c0) as inputs of the model. " |
|
) |
|
|
|
onnxActivations = [ |
|
"Relu", |
|
"Tanh", |
|
"Sigmoid", |
|
"Affine", |
|
"LeakyRelu", |
|
"ThresholdedRelu", |
|
"ScaledTanh", |
|
"HardSigmoid", |
|
"Elu", |
|
"Softsign", |
|
"Softplus", |
|
] |
|
variantToOnnxActivationMap = dict( |
|
zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) |
|
) |
|
weights_per_layer = 4 if has_biases else 2 |
|
|
|
if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( |
|
1 + bidirectional |
|
): |
|
return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) |
|
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) |
|
layer_weights = [ |
|
all_weights[i : i + weights_per_layer] |
|
for i in range(0, len(all_weights), weights_per_layer) |
|
] |
|
if batch_first: |
|
|
|
input = g.op("Transpose", input, perm_i=[1, 0, 2]) |
|
if dropout and train: |
|
return symbolic_helper._unimplemented( |
|
"RNN/GRU/LSTM", "dropout in training mode", input |
|
) |
|
|
|
if variant.startswith("RNN"): |
|
nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] |
|
variant = "RNN" |
|
|
|
w_hh = all_weights[1] |
|
hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) |
|
if hidden_size is None: |
|
return symbolic_helper._unimplemented( |
|
"RNN/GRU/LSTM", "unknown hidden size", input |
|
) |
|
|
|
unidirectional = not bidirectional |
|
|
|
prev_output = input |
|
|
|
h_outs = [] |
|
if variant == "RNN" or variant == "GRU": |
|
h0 = initial_states |
|
elif variant == "LSTM": |
|
h0, c0 = initial_states |
|
c_outs = [] |
|
|
|
sequence_lens = unused(g) if batch_sizes is None else batch_sizes |
|
|
|
if variant == "GRU": |
|
|
|
|
|
reform_permutation = [(1, 2), (0, 1), (2, 3)] |
|
elif variant == "LSTM": |
|
|
|
|
|
reform_permutation = [(0, 1), (3, 4), (1, 3)] |
|
|
|
def reform_weights(g, w, n, intervals): |
|
slices = [ |
|
symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) |
|
for x, y in intervals |
|
] |
|
return g.op("Concat", *slices, axis_i=0) |
|
|
|
def transform_weights_no_bias(layer_index): |
|
weights = layer_weights[layer_index] |
|
if variant == "RNN": |
|
weight_ih, weight_hh = weights |
|
elif variant == "GRU" or variant == "LSTM": |
|
weight_ih, weight_hh = ( |
|
reform_weights(g, w, hidden_size, reform_permutation) for w in weights |
|
) |
|
return tuple( |
|
symbolic_helper._unsqueeze_helper(g, x, [0]) |
|
for x in (weight_ih, weight_hh) |
|
) |
|
|
|
def transform_weights(layer_index): |
|
weights = layer_weights[layer_index] |
|
if variant == "RNN": |
|
weight_ih, weight_hh, bias_ih, bias_hh = weights |
|
elif variant == "GRU" or variant == "LSTM": |
|
weight_ih, weight_hh, bias_ih, bias_hh = ( |
|
reform_weights(g, w, hidden_size, reform_permutation) for w in weights |
|
) |
|
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) |
|
return tuple( |
|
symbolic_helper._unsqueeze_helper(g, x, [0]) |
|
for x in (weight_ih, weight_hh, bias_concat) |
|
) |
|
|
|
def retrieve_state(x, start, end): |
|
return ( |
|
x |
|
if num_layers == 1 |
|
else symbolic_helper._slice_helper( |
|
g, x, axes=[0], starts=[start], ends=[end] |
|
) |
|
) |
|
|
|
for i in range(num_layers): |
|
if unidirectional: |
|
if weights_per_layer == 4: |
|
weight_ih, weight_hh, bias_concat = transform_weights(i) |
|
else: |
|
weight_ih, weight_hh = transform_weights_no_bias(i) |
|
bias_concat = unused(g) |
|
|
|
state_indices = i, i + 1 |
|
else: |
|
if weights_per_layer == 4: |
|
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) |
|
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) |
|
bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) |
|
else: |
|
weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) |
|
weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) |
|
bias_concat = unused(g) |
|
|
|
weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) |
|
weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) |
|
|
|
state_indices = 2 * i, 2 * i + 2 |
|
|
|
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] |
|
|
|
inputs.append(retrieve_state(h0, *state_indices)) |
|
if variant == "LSTM": |
|
inputs.append(retrieve_state(c0, *state_indices)) |
|
|
|
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} |
|
if variant == "RNN": |
|
if bidirectional: |
|
activation = [nonlinearity, nonlinearity] |
|
else: |
|
activation = [nonlinearity] |
|
|
|
prev_output, h_out = g.op( |
|
"RNN", |
|
*inputs, |
|
outputs=2, |
|
hidden_size_i=hidden_size, |
|
activations_s=activation, |
|
**extra_kwargs, |
|
) |
|
elif variant == "GRU": |
|
prev_output, h_out = g.op( |
|
"GRU", |
|
*inputs, |
|
outputs=2, |
|
hidden_size_i=hidden_size, |
|
linear_before_reset_i=1, |
|
**extra_kwargs, |
|
) |
|
elif variant == "LSTM": |
|
prev_output, h_out, c_out = g.op( |
|
"LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs |
|
) |
|
|
|
if bidirectional: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) |
|
prev_output = symbolic_helper._reshape_helper( |
|
g, |
|
prev_output, |
|
g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), |
|
allowzero=0, |
|
) |
|
else: |
|
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) |
|
|
|
h_outs.append(h_out) |
|
if variant == "LSTM": |
|
c_outs.append(c_out) |
|
if batch_first: |
|
|
|
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) |
|
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) |
|
if variant == "RNN" or variant == "GRU": |
|
return prev_output, h_outs |
|
elif variant == "LSTM": |
|
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) |
|
return prev_output, h_outs, c_outs |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") |
|
def _lstm_full( |
|
g: jit_utils.GraphContext, |
|
input, |
|
hidden_v, |
|
weight_v, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_first, |
|
): |
|
hidden, weight = ( |
|
symbolic_helper._unpack_list(hidden_v), |
|
symbolic_helper._unpack_list(weight_v), |
|
) |
|
return _generic_rnn( |
|
g, |
|
"LSTM", |
|
input, |
|
hidden, |
|
weight, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_first, |
|
) |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") |
|
def _lstm_packed( |
|
g: jit_utils.GraphContext, |
|
input, |
|
batch_sizes, |
|
hidden_v, |
|
weight_v, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
): |
|
hidden, weight = ( |
|
symbolic_helper._unpack_list(hidden_v), |
|
symbolic_helper._unpack_list(weight_v), |
|
) |
|
return _generic_rnn( |
|
g, |
|
"LSTM", |
|
input, |
|
hidden, |
|
weight, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_sizes=batch_sizes, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::lstm") |
|
def lstm(g: jit_utils.GraphContext, *args): |
|
if symbolic_helper._is_tensor_list(args[3]): |
|
return _lstm_packed(g, *args) |
|
else: |
|
return _lstm_full(g, *args) |
|
|
|
|
|
@_onnx_symbolic("aten::lstm_cell") |
|
def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): |
|
input = symbolic_helper._unsqueeze_helper(g, self, [0]) |
|
hidden = symbolic_helper._unpack_list(hidden) |
|
hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] |
|
weight = ( |
|
(w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) |
|
) |
|
has_biases = True if symbolic_helper._is_tensor(b_ih) else False |
|
_, h_outs, c_outs = _generic_rnn( |
|
g, |
|
"LSTM", |
|
input, |
|
hidden, |
|
weight, |
|
has_biases, |
|
num_layers=1, |
|
dropout=0, |
|
train=0, |
|
bidirectional=False, |
|
batch_first=False, |
|
) |
|
return symbolic_helper._squeeze_helper( |
|
g, h_outs, [0] |
|
), symbolic_helper._squeeze_helper(g, c_outs, [0]) |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] |
|
) |
|
@_onnx_symbolic( |
|
"aten::rnn_tanh", |
|
decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::rnn_relu", |
|
decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], |
|
) |
|
def _one_hidden_rnn(kind: str): |
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") |
|
def _rnn_full( |
|
g, |
|
input, |
|
hidden, |
|
weight_v, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_first, |
|
): |
|
weight = symbolic_helper._unpack_list(weight_v) |
|
return _generic_rnn( |
|
g, |
|
kind, |
|
input, |
|
hidden, |
|
weight, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_first, |
|
) |
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") |
|
def _rnn_packed( |
|
g, |
|
input, |
|
batch_sizes, |
|
hidden, |
|
weight_v, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
): |
|
weight = symbolic_helper._unpack_list(weight_v) |
|
return _generic_rnn( |
|
g, |
|
kind, |
|
input, |
|
hidden, |
|
weight, |
|
has_biases, |
|
num_layers, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_sizes=batch_sizes, |
|
) |
|
|
|
def symbolic(g, *args): |
|
if symbolic_helper._is_tensor_list(args[3]): |
|
return _rnn_packed(g, *args) |
|
else: |
|
return _rnn_full(g, *args) |
|
|
|
return symbolic |
|
|
|
|
|
@_onnx_symbolic("aten::_dim_arange") |
|
@symbolic_helper.parse_args("v", "i") |
|
def _dim_arange(g: jit_utils.GraphContext, like, dim): |
|
like_shape = g.op("Shape", like) |
|
stop = g.op( |
|
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 |
|
) |
|
|
|
return arange(g, stop, 4, None, None, None) |
|
|
|
|
|
@_onnx_symbolic("aten::detach") |
|
def detach(g: jit_utils.GraphContext, input): |
|
|
|
return input |
|
|
|
|
|
@_onnx_symbolic("aten::contiguous") |
|
@symbolic_helper.parse_args("v", "i") |
|
def contiguous(g: jit_utils.GraphContext, input, memory_format): |
|
if memory_format > 2: |
|
raise errors.SymbolicValueError( |
|
"onnx memory_format support is not implemented", input |
|
) |
|
return input |
|
|
|
|
|
@_onnx_symbolic("aten::_pack_padded_sequence") |
|
@symbolic_helper.parse_args("v", "v", "i") |
|
def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): |
|
|
|
|
|
|
|
if batch_first: |
|
input = g.op("Transpose", input, perm_i=[1, 0, 2]) |
|
if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): |
|
raise errors.SymbolicValueError( |
|
"'lengths' must be a Tensor for ONNX export", input |
|
) |
|
|
|
|
|
|
|
if ( |
|
_type_utils.JitScalarType.from_value( |
|
lengths, _type_utils.JitScalarType.UNDEFINED |
|
) |
|
!= _type_utils.JitScalarType.INT |
|
): |
|
lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) |
|
return g.op("prim::PackPadded", input, lengths, outputs=2) |
|
|
|
|
|
@_onnx_symbolic("aten::_pad_packed_sequence") |
|
@symbolic_helper.parse_args("v", "v", "i", "t", "v") |
|
def _pad_packed_sequence( |
|
g: jit_utils.GraphContext, |
|
data, |
|
batch_sizes, |
|
batch_first, |
|
padding_value, |
|
total_length, |
|
): |
|
|
|
|
|
|
|
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) |
|
if batch_first: |
|
data = g.op("Transpose", data, perm_i=[1, 0, 2]) |
|
return data, lengths |
|
|
|
|
|
@_onnx_symbolic("aten::randint") |
|
def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
low_i = symbolic_helper._get_const(low, "i", "low") |
|
high_i = symbolic_helper._get_const(high, "i", "high") |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.INT64 |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
if low_i is None: |
|
raise symbolic_helper._onnx_unsupported("randint", low) |
|
if high_i is None: |
|
raise symbolic_helper._onnx_unsupported("randint", high) |
|
|
|
shape = symbolic_helper._maybe_get_const(shapes, "is") |
|
if symbolic_helper._is_value(shape): |
|
shape_const = g.op( |
|
"ConstantOfShape", |
|
shapes, |
|
value_t=torch.tensor([0], dtype=torch.float), |
|
) |
|
randn = g.op( |
|
"RandomUniformLike", |
|
shape_const, |
|
low_f=low_i, |
|
high_f=high_i, |
|
) |
|
else: |
|
randn = g.op( |
|
"RandomUniform", |
|
shape_i=shape, |
|
low_f=low_i, |
|
high_f=high_i, |
|
) |
|
|
|
|
|
int_dtype = _type_utils.JitScalarType.INT64 |
|
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) |
|
if int_dtype != scalar_type: |
|
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) |
|
return randint |
|
|
|
|
|
@_onnx_symbolic("aten::randint_like") |
|
def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
low_i = symbolic_helper._get_const(low, "i", "low") |
|
high_i = symbolic_helper._get_const(high, "i", "high") |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.INT64 |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
if low_i is None: |
|
raise symbolic_helper._onnx_unsupported("randint", low) |
|
if high_i is None: |
|
raise symbolic_helper._onnx_unsupported("randint", high) |
|
|
|
randn = g.op( |
|
"RandomUniformLike", |
|
self, |
|
low_f=low_i, |
|
high_f=high_i, |
|
) |
|
|
|
|
|
int_dtype = _type_utils.JitScalarType.INT64 |
|
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) |
|
if int_dtype != scalar_type: |
|
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) |
|
return randint |
|
|
|
|
|
@_onnx_symbolic("aten::randn") |
|
def randn(g: jit_utils.GraphContext, shapes, dtype, *options): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.FLOAT |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
shape = symbolic_helper._maybe_get_const(shapes, "is") |
|
if symbolic_helper._is_value(shape): |
|
shape_const = g.op( |
|
"ConstantOfShape", |
|
shapes, |
|
value_t=torch.tensor([0], dtype=torch.float), |
|
) |
|
return g.op( |
|
"RandomNormalLike", |
|
shape_const, |
|
dtype_i=scalar_type.onnx_type(), |
|
) |
|
return g.op( |
|
"RandomNormal", |
|
shape_i=shape, |
|
dtype_i=scalar_type.onnx_type(), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::rand") |
|
def rand(g: jit_utils.GraphContext, shapes, dtype, *options): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.FLOAT |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
shape = symbolic_helper._maybe_get_const(shapes, "is") |
|
if symbolic_helper._is_value(shape): |
|
shape_const = g.op( |
|
"ConstantOfShape", |
|
shapes, |
|
value_t=torch.tensor([0], dtype=torch.float), |
|
) |
|
return g.op( |
|
"RandomUniformLike", |
|
shape_const, |
|
dtype_i=scalar_type.onnx_type(), |
|
) |
|
return g.op( |
|
"RandomUniform", |
|
shape_i=shape, |
|
dtype_i=scalar_type.onnx_type(), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::randn_like") |
|
def randn_like( |
|
g: jit_utils.GraphContext, |
|
self, |
|
dtype, |
|
layout=None, |
|
device=None, |
|
pin_memory=False, |
|
memory_format=None, |
|
): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if dtype is None: |
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
self, _type_utils.JitScalarType.FLOAT |
|
) |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) |
|
|
|
|
|
@_onnx_symbolic("aten::rand_like") |
|
def rand_like( |
|
g: jit_utils.GraphContext, |
|
self, |
|
dtype, |
|
layout=None, |
|
device=None, |
|
pin_memory=False, |
|
memory_format=None, |
|
): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
if dtype is None: |
|
dtype = _type_utils.JitScalarType.from_value( |
|
self, _type_utils.JitScalarType.FLOAT |
|
) |
|
return g.op( |
|
"RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::rrelu") |
|
@symbolic_helper.parse_args("v", "f", "f", "i", "none") |
|
def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): |
|
if not training: |
|
slope = (upper + lower) / 2.0 |
|
return g.op("LeakyRelu", input, alpha_f=slope) |
|
p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) |
|
return g.op("PRelu", input, p) |
|
|
|
|
|
@_onnx_symbolic("aten::bernoulli") |
|
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): |
|
if out is not None and not symbolic_helper._is_none(out): |
|
symbolic_helper._unimplemented( |
|
"Bernoulli", "out parameter is not supported for bernoulli", input |
|
) |
|
if generator is not None and not symbolic_helper._is_none(generator): |
|
symbolic_helper._unimplemented( |
|
"Bernoulli", "generator is not supported for bernoulli", input |
|
) |
|
|
|
dtype = _type_utils.JitScalarType.from_value( |
|
input, _type_utils.JitScalarType.UNDEFINED |
|
) |
|
if dtype == _type_utils.JitScalarType.UNDEFINED: |
|
return symbolic_helper._unimplemented( |
|
"Bernoulli", "input dtype not accessible", input |
|
) |
|
|
|
rands = g.op( |
|
"RandomUniformLike", |
|
input, |
|
high_f=1.0, |
|
low_f=0.0, |
|
dtype_i=dtype.onnx_type(), |
|
) |
|
prob = p if p is not None and not symbolic_helper._is_none(p) else input |
|
output = g.op("Less", rands, prob) |
|
return g.op("Cast", output, to_i=dtype.onnx_type()) |
|
|
|
|
|
@_onnx_symbolic("aten::log_sigmoid") |
|
@symbolic_helper.parse_args("v") |
|
def log_sigmoid(g: jit_utils.GraphContext, input): |
|
p = g.op("Sigmoid", input) |
|
return g.op("Log", p) |
|
|
|
|
|
@_onnx_symbolic("aten::erf") |
|
@symbolic_helper.parse_args("v") |
|
def erf(g: jit_utils.GraphContext, input): |
|
return g.op("Erf", input) |
|
|
|
|
|
@_onnx_symbolic("aten::flatten") |
|
@symbolic_helper.quantized_args(True, False, False) |
|
@symbolic_helper.parse_args("v", "i", "i") |
|
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): |
|
dim = symbolic_helper._get_tensor_rank(input) |
|
if dim is None: |
|
return symbolic_helper._unimplemented( |
|
"dim", |
|
"ONNX and PyTorch use different strategies to split the input. " |
|
"Input rank must be known at export time.", |
|
input, |
|
) |
|
|
|
if dim == 0: |
|
return symbolic_helper._reshape_helper(g, input, [1]) |
|
if dim == 1: |
|
return g.op("Identity", input) |
|
|
|
if end_dim < 0: |
|
end_dim = dim + end_dim |
|
|
|
if start_dim == 1 and end_dim == dim - 1: |
|
return g.op("Flatten", input, axis_i=start_dim) |
|
if start_dim == 0 and end_dim == dim - 2: |
|
return g.op("Flatten", input, axis_i=end_dim + 1) |
|
|
|
return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) |
|
|
|
|
|
@_onnx_symbolic("aten::nonzero") |
|
@symbolic_helper.parse_args("v") |
|
def nonzero(g: jit_utils.GraphContext, input): |
|
"""Emitted from `torch.nonzero(x, as_tuple=False)`""" |
|
return t(g, g.op("NonZero", input)) |
|
|
|
|
|
@_onnx_symbolic("aten::nonzero_numpy") |
|
|
|
def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): |
|
return unbind(g, nonzero(g, input), 1, _outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::isnan") |
|
@symbolic_helper.parse_args("v") |
|
def isnan(g: jit_utils.GraphContext, input): |
|
output = g.op("IsNaN", input) |
|
return output |
|
|
|
|
|
@_onnx_symbolic("aten::any") |
|
def _any(g: jit_utils.GraphContext, *args): |
|
|
|
if len(args) == 1: |
|
input = args[0] |
|
dim, keepdim = None, 0 |
|
|
|
else: |
|
input, dim, keepdim = args |
|
|
|
dim = symbolic_helper._parse_arg(dim, "t") |
|
dim = [int(d) for d in dim.view(-1)] |
|
keepdim = symbolic_helper._parse_arg(keepdim, "i") |
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) |
|
input_sum = symbolic_helper._reducesum_helper( |
|
g, input, axes_i=dim, keepdims_i=keepdim |
|
) |
|
return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) |
|
|
|
|
|
@_onnx_symbolic("aten::all") |
|
def _all(g: jit_utils.GraphContext, *args): |
|
input = g.op("Not", args[0]) |
|
|
|
if len(args) == 1: |
|
return g.op("Not", _any(g, input)) |
|
|
|
else: |
|
return g.op("Not", _any(g, input, args[1], args[2])) |
|
|
|
|
|
@_onnx_symbolic("aten::narrow") |
|
@symbolic_helper.parse_args("v", "i", "i", "i") |
|
def narrow(g: jit_utils.GraphContext, input, dim, start, length): |
|
return symbolic_helper._slice_helper( |
|
g, input, axes=[dim], starts=[start], ends=[start + length] |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::argmax") |
|
@symbolic_helper.parse_args("v", "v", "b") |
|
def argmax( |
|
g: jit_utils.GraphContext, |
|
input: torch._C.Value, |
|
dim: torch._C.Value, |
|
keepdim: bool, |
|
): |
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") |
|
|
|
|
|
@_onnx_symbolic("aten::argmin") |
|
@symbolic_helper.parse_args("v", "v", "b") |
|
def argmin( |
|
g: jit_utils.GraphContext, |
|
input: torch._C.Value, |
|
dim: torch._C.Value, |
|
keepdim: bool, |
|
): |
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") |
|
|
|
|
|
@_onnx_symbolic("aten::scatter") |
|
@symbolic_helper.parse_args("v", "i", "v", "v") |
|
def scatter(g: jit_utils.GraphContext, self, dim, index, src): |
|
src_type = _type_utils.JitScalarType.from_value( |
|
src, _type_utils.JitScalarType.UNDEFINED |
|
) |
|
src = symbolic_helper._maybe_get_scalar(src) |
|
if symbolic_helper._is_value(src): |
|
return g.op("Scatter", self, index, src, axis_i=dim) |
|
else: |
|
|
|
|
|
self_scalar_type = _type_utils.JitScalarType.from_value(self) |
|
if self_scalar_type != src_type: |
|
src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) |
|
return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) |
|
|
|
|
|
@_onnx_symbolic("aten::scatter_add") |
|
@symbolic_helper.parse_args("v", "i", "v", "v") |
|
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): |
|
scalar_type = symbolic_helper._try_get_scalar_type(self) |
|
if scalar_type is None: |
|
return symbolic_helper._unimplemented( |
|
"scatter_add", "input dtype not accessible", self |
|
) |
|
sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) |
|
if sizes: |
|
to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) |
|
else: |
|
to_add = zeros_like(g, self, scalar_type) |
|
to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) |
|
return add(g, self, to_add) |
|
|
|
|
|
@_onnx_symbolic("aten::log2") |
|
def log2(g: jit_utils.GraphContext, self): |
|
_ln2 = 0.693147180559945309 |
|
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) |
|
|
|
|
|
@_onnx_symbolic("aten::is_floating_point") |
|
def is_floating_point(g: jit_utils.GraphContext, self): |
|
if symbolic_helper._is_fp(self): |
|
return g.op("Constant", value_t=torch.BoolTensor([1])) |
|
return g.op("Constant", value_t=torch.BoolTensor([0])) |
|
|
|
|
|
@_onnx_symbolic("aten::__is_") |
|
def __is_(g: jit_utils.GraphContext, self, other): |
|
if symbolic_helper._is_none(other): |
|
if symbolic_helper._is_none(self): |
|
return g.op("Constant", value_t=torch.BoolTensor([1])) |
|
return g.op("Constant", value_t=torch.BoolTensor([0])) |
|
return eq(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::__isnot_") |
|
@wrap_logical_op_with_negation |
|
def __isnot_(g: jit_utils.GraphContext, self, other): |
|
return __is_(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::one_hot") |
|
def one_hot(g: jit_utils.GraphContext, self, num_classes): |
|
values = g.op("Constant", value_t=torch.LongTensor([0, 1])) |
|
|
|
if _type_utils.JitScalarType.from_value( |
|
num_classes, _type_utils.JitScalarType.UNDEFINED |
|
) in { |
|
_type_utils.JitScalarType.UINT8, |
|
_type_utils.JitScalarType.INT8, |
|
_type_utils.JitScalarType.INT, |
|
_type_utils.JitScalarType.INT16, |
|
}: |
|
num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) |
|
return g.op("OneHot", self, num_classes, values, axis_i=-1) |
|
|
|
|
|
@_onnx_symbolic("aten::gather") |
|
@symbolic_helper.parse_args("v", "i", "v", "v") |
|
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): |
|
if symbolic_helper._maybe_get_const(sparse_grad, "i"): |
|
return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) |
|
|
|
|
|
scalar_type = _type_utils.JitScalarType.from_value(self) |
|
values = g.op("Constant", value_t=torch.LongTensor([0, 1])) |
|
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) |
|
index = g.op( |
|
"Cast", |
|
g.op("OneHot", index, depth, values, axis_i=dim), |
|
to_i=scalar_type.onnx_type(), |
|
) |
|
mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) |
|
return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) |
|
|
|
|
|
@symbolic_helper.parse_args("v", "is", "i", "i") |
|
def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): |
|
return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) |
|
|
|
|
|
@_onnx_symbolic("aten::std") |
|
def std(g: jit_utils.GraphContext, input, *args): |
|
var, _ = var_mean(g, input, *args) |
|
return g.op("Sqrt", var) |
|
|
|
|
|
@_onnx_symbolic("aten::var") |
|
def var(g: jit_utils.GraphContext, input, *args): |
|
var, _ = var_mean(g, input, *args) |
|
return var |
|
|
|
|
|
@_onnx_symbolic("aten::var_mean") |
|
def var_mean(g: jit_utils.GraphContext, input, *args): |
|
if len(args) == 1: |
|
return _var_mean(g, input, None, args[0], None) |
|
else: |
|
return _var_mean(g, input, *args) |
|
|
|
|
|
@_onnx_symbolic("aten::std_mean") |
|
def std_mean(g: jit_utils.GraphContext, input, *args): |
|
var, mean = var_mean(g, input, *args) |
|
return g.op("Sqrt", var), mean |
|
|
|
|
|
@_onnx_symbolic("aten::logsumexp") |
|
@symbolic_helper.parse_args("v", "is", "i") |
|
def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): |
|
return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) |
|
|
|
|
|
@_onnx_symbolic("aten::arange") |
|
def arange(g: jit_utils.GraphContext, *args): |
|
def _get_arange_dtype(dtype): |
|
dtype = symbolic_helper._maybe_get_const(dtype, "i") |
|
return dtype |
|
|
|
def _float_step_convert(range_tensor): |
|
if symbolic_helper._is_fp(range_tensor): |
|
range_tensor = g.op( |
|
"Cast", |
|
g.op("Ceil", range_tensor), |
|
to_i=_type_utils.JitScalarType.INT64.onnx_type(), |
|
) |
|
return range_tensor |
|
|
|
if len(args) == 2 or len(args) == 5: |
|
if len(args) == 2: |
|
|
|
dtype = None |
|
else: |
|
|
|
dtype = _get_arange_dtype(args[1]) |
|
dtype, end, start, step = symbolic_helper._arange_cast_helper( |
|
g, end=args[0], dtype=dtype |
|
) |
|
end = symbolic_helper._unsqueeze_helper(g, end, [0]) |
|
range_tensor = _float_step_convert(end) |
|
arange_tensor = symbolic_helper._squeeze_helper( |
|
g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] |
|
) |
|
return g.op( |
|
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
|
) |
|
elif len(args) == 4 or len(args) == 7: |
|
if len(args) == 4: |
|
|
|
dtype = None |
|
else: |
|
|
|
dtype = _get_arange_dtype(args[3]) |
|
dtype, end, start, step = symbolic_helper._arange_cast_helper( |
|
g, start=args[0], end=args[1], step=args[2], dtype=dtype |
|
) |
|
step = symbolic_helper._unsqueeze_helper(g, step, [0]) |
|
end = symbolic_helper._unsqueeze_helper(g, end, [0]) |
|
start = symbolic_helper._unsqueeze_helper(g, start, [0]) |
|
range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) |
|
arange_tensor = symbolic_helper._squeeze_helper( |
|
g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] |
|
) |
|
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) |
|
return g.op( |
|
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
|
) |
|
elif len(args) == 6: |
|
|
|
dtype = _get_arange_dtype(args[2]) |
|
dtype, end, start, step = symbolic_helper._arange_cast_helper( |
|
g, start=args[0], end=args[1], dtype=dtype |
|
) |
|
end = symbolic_helper._unsqueeze_helper(g, end, [0]) |
|
start = symbolic_helper._unsqueeze_helper(g, start, [0]) |
|
range_tensor = _float_step_convert(g.op("Sub", end, start)) |
|
arange_tensor = g.op( |
|
"Add", |
|
symbolic_helper._squeeze_helper( |
|
g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] |
|
), |
|
start, |
|
) |
|
return g.op( |
|
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
|
) |
|
|
|
return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") |
|
|
|
|
|
@_onnx_symbolic("aten::linspace") |
|
def linspace( |
|
g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory |
|
): |
|
range_tensor = symbolic_helper._arange_helper(g, steps, None) |
|
step = div( |
|
g, |
|
sub(g, end, start), |
|
sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), |
|
) |
|
return add(g, mul(g, range_tensor, step), start) |
|
|
|
|
|
@_onnx_symbolic("aten::lift") |
|
def lift(g: jit_utils.GraphContext, self): |
|
|
|
return self |
|
|
|
|
|
@_onnx_symbolic("aten::masked_fill") |
|
def masked_fill(g: jit_utils.GraphContext, self, mask, value): |
|
"""Implement the masked_fill functionality available for a pytorch tensor in ONNX. |
|
|
|
Fills elements of the input tensor with `value` where `mask` is True. |
|
""" |
|
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) |
|
value = symbolic_helper._maybe_get_scalar(value) |
|
return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) |
|
|
|
|
|
@_onnx_symbolic("aten::masked_fill_") |
|
def masked_fill_(g: jit_utils.GraphContext, self, mask, value): |
|
return masked_fill(g, self, mask, value) |
|
|
|
|
|
@_onnx_symbolic("aten::index") |
|
def index(g: jit_utils.GraphContext, self, index): |
|
if symbolic_helper._is_packed_list(index): |
|
indices = symbolic_helper._unpack_list(index) |
|
else: |
|
indices = [index] |
|
|
|
def try_mask_to_index(index): |
|
if not symbolic_helper._is_none(index) and ( |
|
_type_utils.JitScalarType.from_value( |
|
index, _type_utils.JitScalarType.UNDEFINED |
|
) |
|
== _type_utils.JitScalarType.UINT8 |
|
or symbolic_helper._is_bool(index) |
|
): |
|
if g.opset < 9: |
|
raise errors.SymbolicValueError( |
|
"Exporting masked indices are only supported after ONNX opset 9.", |
|
self, |
|
) |
|
warnings.warn( |
|
"Exporting aten::index operator with indices of type Byte. " |
|
"Only 1-D indices are supported. In any other case, " |
|
"this will produce an incorrect ONNX graph." |
|
) |
|
index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) |
|
return index |
|
|
|
indices = [try_mask_to_index(idx) for idx in indices] |
|
if len(indices) == 1: |
|
return symbolic_helper._select_helper( |
|
g, self, 0, indices[0], apply_reshape=False |
|
) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adv_idx_indices = [ |
|
i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) |
|
] |
|
|
|
if len(adv_idx_indices) == 0: |
|
return self |
|
elif len(adv_idx_indices) == 1: |
|
return index_select( |
|
g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] |
|
) |
|
else: |
|
rank = symbolic_helper._get_tensor_rank(self) |
|
if rank is None: |
|
return symbolic_helper._unimplemented( |
|
"aten::index", |
|
"operator of advanced indexing on tensor of unknown rank. ", |
|
self, |
|
) |
|
|
|
|
|
warnings.warn( |
|
"Exporting aten::index operator of advanced indexing in opset " |
|
f"{GLOBALS.export_onnx_opset_version}" |
|
" is achieved by combination of multiple ONNX operators, " |
|
"including Reshape, Transpose, Concat, and Gather. " |
|
"If indices include negative values, the exported graph will produce incorrect results." |
|
) |
|
adv_idx_count = len(adv_idx_indices) |
|
shape_tensor = _shape_as_tensor(g, self) |
|
dim_tensor_list = [ |
|
g.op( |
|
"Gather", |
|
shape_tensor, |
|
g.op("Constant", value_t=torch.LongTensor([dim])), |
|
axis_i=0, |
|
) |
|
for dim in range(rank) |
|
] |
|
|
|
self = g.op( |
|
"Transpose", |
|
self, |
|
perm_i=adv_idx_indices |
|
+ [i for i in range(rank) if i not in adv_idx_indices], |
|
) |
|
self = g.op("Flatten", self, axis_i=adv_idx_count) |
|
|
|
|
|
cum_adv_index = indices[adv_idx_indices[-1]] |
|
multiplier = dim_tensor_list[adv_idx_indices[-1]] |
|
for i in range(adv_idx_count - 2, -1, -1): |
|
adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) |
|
cum_adv_index = g.op("Add", cum_adv_index, adv_index) |
|
multiplier = g.op( |
|
"Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] |
|
) |
|
|
|
|
|
self = index_select(g, self, 0, cum_adv_index) |
|
|
|
cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) |
|
|
|
|
|
|
|
if adv_idx_indices == list( |
|
range(adv_idx_indices[0], adv_idx_indices[-1] + 1) |
|
): |
|
|
|
folded_adv_idx_shape_list = [ |
|
g.op("Constant", value_t=torch.LongTensor([-1])) |
|
] + [ |
|
dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices |
|
] |
|
folded_adv_idx_shape = g.op( |
|
"Concat", *folded_adv_idx_shape_list, axis_i=0 |
|
) |
|
self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) |
|
|
|
|
|
adv_idx_permute = ( |
|
list(range(1, adv_idx_indices[0] + 1)) |
|
+ [0] |
|
+ list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) |
|
) |
|
self = g.op("Transpose", self, perm_i=adv_idx_permute) |
|
|
|
|
|
final_shape_list = ( |
|
[dim_tensor_list[i] for i in range(adv_idx_indices[0])] |
|
+ [cum_adv_index_shape_tensor] |
|
+ [ |
|
dim_tensor_list[i] |
|
for i in range(adv_idx_indices[0], rank) |
|
if i not in adv_idx_indices |
|
] |
|
) |
|
final_shape = g.op("Concat", *final_shape_list, axis_i=0) |
|
else: |
|
final_shape = g.op( |
|
"Concat", |
|
cum_adv_index_shape_tensor, |
|
*[ |
|
dim_tensor_list[i] |
|
for i in range(rank) |
|
if i not in adv_idx_indices |
|
], |
|
axis_i=0, |
|
) |
|
|
|
return symbolic_helper._reshape_helper(g, self, final_shape) |
|
|
|
|
|
@_onnx_symbolic("aten::linalg_norm") |
|
@symbolic_helper.parse_args("v", "v", "is", "b", "v") |
|
def linalg_norm( |
|
g: jit_utils.GraphContext, |
|
self: torch._C.Value, |
|
ord: torch._C.Value, |
|
dim: Sequence[int] | None, |
|
keepdim: bool, |
|
dtype: torch._C.Value, |
|
): |
|
|
|
ord_value = None |
|
if dim is None: |
|
if symbolic_helper._is_none(ord): |
|
self = symbolic_helper._reshape_helper(g, self, [-1]) |
|
ord = g.op("Constant", value_t=torch.LongTensor([2])) |
|
self_dim = symbolic_helper._get_tensor_rank(self) |
|
if self_dim is None: |
|
return symbolic_helper._unimplemented( |
|
"dim", "Input rank must be known at export time.", self |
|
) |
|
if self_dim == 1: |
|
ord_value = symbolic_helper._parse_arg(ord, "f") |
|
else: |
|
dim = [0, 1] |
|
else: |
|
if len(dim) == 1: |
|
if symbolic_helper._is_none(ord): |
|
ord = g.op("Constant", value_t=torch.LongTensor([2])) |
|
ord_value = symbolic_helper._parse_arg(ord, "f") |
|
if ord_value: |
|
return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) |
|
return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) |
|
|
|
|
|
@_onnx_symbolic("aten::linalg_vector_norm") |
|
@symbolic_helper.parse_args("v", "f", "is", "b", "v") |
|
def linalg_vector_norm( |
|
g: jit_utils.GraphContext, |
|
self: torch._C.Value, |
|
ord: float, |
|
dim: Sequence[int] | None, |
|
keepdim: bool, |
|
dtype: torch._C.Value, |
|
): |
|
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) |
|
|
|
|
|
@_onnx_symbolic("aten::linalg_matrix_norm") |
|
@symbolic_helper.parse_args("v", "v", "is", "b", "v") |
|
def linalg_matrix_norm( |
|
g: jit_utils.GraphContext, |
|
self: torch._C.Value, |
|
ord: torch._C.Value, |
|
dim: list[int], |
|
keepdim: bool, |
|
dtype: torch._C.Value, |
|
): |
|
|
|
ord_value = symbolic_helper._parse_arg(ord, "s") |
|
if ord_value == "fro": |
|
return frobenius_norm(g, self, dim, keepdim) |
|
elif ord_value == "nuc": |
|
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) |
|
else: |
|
ord_value = symbolic_helper._parse_arg(ord, "f") |
|
if ord_value is None: |
|
return frobenius_norm(g, self, dim, keepdim) |
|
if ord_value == 2 or ord_value == -2: |
|
|
|
|
|
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) |
|
|
|
self_dim = symbolic_helper._get_tensor_rank(self) |
|
if self_dim is None: |
|
return symbolic_helper._unimplemented( |
|
"linalg.matrix_norm", "Input rank must be known at export time.", self |
|
) |
|
|
|
|
|
if dim[0] < 0: |
|
dim[0] += self_dim |
|
if dim[1] < 0: |
|
dim[1] += self_dim |
|
|
|
if ord_value == math.inf or ord_value == -math.inf: |
|
dim[0], dim[1] = dim[1], dim[0] |
|
if dim[1] > dim[0] and not keepdim: |
|
dim[1] -= 1 |
|
sum = symbolic_helper._reducesum_helper( |
|
g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim |
|
) |
|
if ord_value > 0: |
|
result, _indices = max( |
|
g, |
|
sum, |
|
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), |
|
keepdim=keepdim, |
|
) |
|
else: |
|
result, _indices = min( |
|
g, |
|
sum, |
|
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), |
|
keepdim=keepdim, |
|
) |
|
return result |
|
|
|
|
|
@_onnx_symbolic("aten::linalg_cross") |
|
@symbolic_helper.parse_args("v", "v", "i") |
|
def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): |
|
return cross(g, input, other, dim) |
|
|
|
|
|
@_onnx_symbolic("aten::frobenius_norm") |
|
@symbolic_helper.parse_args("v", "is", "b") |
|
def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): |
|
sqr = g.op("Mul", self, self) |
|
sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) |
|
return g.op("Sqrt", sumsqr) |
|
|
|
|
|
@_onnx_symbolic("aten::multinomial") |
|
@symbolic_helper.parse_args("v", "i", "b", "v") |
|
def multinomial( |
|
g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None |
|
): |
|
if generator is not None and not symbolic_helper._is_none(generator): |
|
symbolic_helper._unimplemented( |
|
"Multinomial", "generator is not supported for multinomial", input |
|
) |
|
if not replacement and num_samples > 1: |
|
symbolic_helper._unimplemented( |
|
"Multinomial", |
|
"replacement=False when num_samples > 1 is not supported for multinomial", |
|
input, |
|
) |
|
|
|
log_input = log(g, input) |
|
return g.op( |
|
"Multinomial", |
|
log_input, |
|
dtype_i=_C_onnx.TensorProtoDataType.INT64, |
|
sample_size_i=num_samples, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::baddbmm") |
|
def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): |
|
scalar_type = _type_utils.JitScalarType.from_value(self) |
|
batch_mul = matmul(g, batch1, batch2) |
|
mul_a = mul( |
|
g, |
|
batch_mul, |
|
g.op("Cast", alpha, to_i=scalar_type.onnx_type()), |
|
) |
|
mul_b = mul( |
|
g, |
|
self, |
|
g.op("Cast", beta, to_i=scalar_type.onnx_type()), |
|
) |
|
return add(g, mul_a, mul_b) |
|
|
|
|
|
@_onnx_symbolic("aten::meshgrid") |
|
@symbolic_helper.parse_args("v", "s") |
|
def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None): |
|
if indexing is None: |
|
indexing = "ij" |
|
elif indexing not in {"ij", "xy"}: |
|
raise errors.SymbolicValueError( |
|
f"Unsupported indexing: {indexing}", tensor_list |
|
) |
|
unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) |
|
if indexing == "xy": |
|
unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] |
|
tensors = [ |
|
symbolic_helper._reshape_helper( |
|
g, t, g.op("Constant", value_t=torch.LongTensor([-1])) |
|
) |
|
for t in unpacked_tensor_list |
|
] |
|
tensors_shape = [g.op("Shape", t) for t in tensors] |
|
out_shape = g.op("Concat", *tensors_shape, axis_i=0) |
|
out = [] |
|
for i, t in enumerate(tensors): |
|
shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( |
|
tensors |
|
) |
|
shape_i[i] = tensors_shape[i] |
|
t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) |
|
out.append(g.op("Expand", t_reshaped, out_shape)) |
|
if indexing == "xy": |
|
out[0], out[1] = out[1], out[0] |
|
return g.op("prim::ListConstruct", *out) |
|
|
|
|
|
@_onnx_symbolic("aten::remainder") |
|
def remainder(g: jit_utils.GraphContext, input, other): |
|
div = _floor_divide(g, input, other) |
|
quo = g.op("Mul", div, other) |
|
return g.op("Sub", input, quo) |
|
|
|
|
|
@_onnx_symbolic("aten::gelu") |
|
@symbolic_helper.parse_args("v", "s") |
|
def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): |
|
if approximate == "tanh": |
|
kBeta = math.sqrt(2 / math.pi) |
|
kKappa = 0.044715 |
|
|
|
beta = torch.tensor(kBeta, dtype=torch.double) |
|
kappa = torch.tensor(kKappa, dtype=torch.double) |
|
one = torch.tensor(1.0, dtype=torch.double) |
|
half = torch.tensor(0.5, dtype=torch.double) |
|
|
|
self_cube = mul(g, self, mul(g, self, self)) |
|
inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) |
|
return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) |
|
else: |
|
_sqrt2 = 1.4142135623730951 |
|
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) |
|
erf_plusone = add( |
|
g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) |
|
) |
|
return mul( |
|
g, |
|
mul(g, self, erf_plusone), |
|
g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::group_norm") |
|
@symbolic_helper.quantized_args(True, False, False, False) |
|
@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") |
|
def group_norm( |
|
g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled |
|
): |
|
channel_size = symbolic_helper._get_tensor_dim_size(input, 1) |
|
if channel_size is not None: |
|
assert channel_size % num_groups == 0 |
|
input_rank = symbolic_helper._get_tensor_rank(input) |
|
if input_rank is None: |
|
return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) |
|
|
|
shape = [0, num_groups, -1] |
|
input_reshaped = symbolic_helper._reshape_helper( |
|
g, input, g.op("Constant", value_t=torch.LongTensor(shape)) |
|
) |
|
|
|
|
|
|
|
|
|
weight_ = g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
[1.0] * num_groups, |
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
|
), |
|
) |
|
bias_ = g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
[0.0] * num_groups, |
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
|
), |
|
) |
|
|
|
norm_reshaped = g.op( |
|
"InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps |
|
) |
|
norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) |
|
|
|
if weight is None or weight.node().mustBeNone(): |
|
weight_value = torch.tensor( |
|
[1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() |
|
) |
|
weight = g.op("Constant", value_t=weight_value) |
|
if bias is None or bias.node().mustBeNone(): |
|
bias_value = torch.tensor( |
|
[0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() |
|
) |
|
bias = g.op("Constant", value_t=bias_value) |
|
|
|
|
|
axes = list(range(1, input_rank - 1)) |
|
return add( |
|
g, |
|
mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), |
|
symbolic_helper._unsqueeze_helper(g, bias, axes), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::_weight_norm") |
|
@symbolic_helper.parse_args("v", "v", "i") |
|
def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): |
|
rank = symbolic_helper._get_tensor_rank(weight_v) |
|
if rank is not None: |
|
|
|
|
|
|
|
|
|
|
|
axes = list(range(rank)) |
|
if dim is not None: |
|
if dim < -1: |
|
dim += rank |
|
if dim != -1: |
|
axes.remove(dim) |
|
norm_v = norm(g, weight_v, 2, axes, 1) |
|
div = g.op("Div", weight_v, norm_v) |
|
return g.op("Mul", div, weight_g) |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", |
|
weight_v, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::dim") |
|
def dim(g: jit_utils.GraphContext, self): |
|
"""Implement the dim functionality available for a pytorch tensor in ONNX""" |
|
|
|
shape = g.op("Shape", self) |
|
return g.op("Size", shape) |
|
|
|
|
|
@_onnx_symbolic("aten::__contains_") |
|
def __contains_(g: jit_utils.GraphContext, self, element): |
|
unpacked_list = symbolic_helper._unpack_list(self) |
|
if all( |
|
symbolic_helper._is_constant(x) for x in unpacked_list |
|
) and symbolic_helper._is_constant(element): |
|
return g.op( |
|
"Constant", |
|
value_t=torch.tensor( |
|
symbolic_helper._node_get(element.node(), "value") |
|
in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) |
|
), |
|
) |
|
|
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of __contains__ for non-constant list or element.", |
|
self, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::__getitem_") |
|
def __getitem_(g: jit_utils.GraphContext, self, i): |
|
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) |
|
|
|
|
|
@_onnx_symbolic("aten::item") |
|
def item(g: jit_utils.GraphContext, self): |
|
return self |
|
|
|
|
|
@_onnx_symbolic("aten::take") |
|
def take(g: jit_utils.GraphContext, self, index): |
|
self_flattened = symbolic_helper._reshape_helper( |
|
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) |
|
) |
|
out = index_select(g, self_flattened, 0, index) |
|
out = reshape_as(g, out, index) |
|
return out |
|
|
|
|
|
def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): |
|
diff_ = sub(g, target, input) |
|
exp_ = exp(g, target) |
|
output = mul(g, exp_, diff_) |
|
return output |
|
|
|
|
|
def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): |
|
log_ = log(g, target) |
|
diff_ = sub(g, log_, input) |
|
output_pos = mul(g, target, diff_) |
|
zeros_ = zeros_like(g, output_pos) |
|
mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) |
|
output = where(g, mask_, output_pos, zeros_) |
|
return output |
|
|
|
|
|
@_onnx_symbolic("aten::kl_div") |
|
@symbolic_helper.parse_args("v", "v", "i", "b") |
|
def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): |
|
if log_target: |
|
output = _kl_div_log_target_impl(g, input, target) |
|
else: |
|
output = _kl_div_non_log_target_impl(g, input, target) |
|
|
|
if reduction == 0: |
|
return output |
|
elif reduction == 1: |
|
return g.op("ReduceMean", output, keepdims_i=0) |
|
elif reduction == 2: |
|
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) |
|
else: |
|
return symbolic_helper._onnx_unsupported( |
|
"kl_div with reduction other than none, mean, or sum.", input |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::mse_loss") |
|
@symbolic_helper.parse_args("v", "v", "i") |
|
def mse_loss(g: jit_utils.GraphContext, input, target, reduction): |
|
output = mul(g, sub(g, input, target), sub(g, input, target)) |
|
if reduction == 0: |
|
return output |
|
elif reduction == 1: |
|
return g.op("ReduceMean", output, keepdims_i=0) |
|
elif reduction == 2: |
|
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) |
|
else: |
|
return symbolic_helper._onnx_unsupported( |
|
"mse_loss with reduction other than none, mean, or sum.", input |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::as_strided") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "v", "is", "i") |
|
def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): |
|
sizes = symbolic_helper._maybe_get_const(sizes, "is") |
|
rank = len(strides) |
|
self_1d = symbolic_helper._reshape_helper( |
|
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) |
|
) |
|
ind: torch.Tensor | None |
|
if not symbolic_helper._is_value(sizes): |
|
ind = torch.tensor([0], dtype=torch.long) |
|
for i, (size, stride) in enumerate(zip(sizes, strides)): |
|
r_size = [1] * rank |
|
r_size[i] = -1 |
|
ind = ind + torch.arange(size).view(r_size) * stride |
|
if offset: |
|
ind = ind + offset |
|
return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) |
|
else: |
|
ind = None |
|
for i, stride in enumerate(strides): |
|
r_size = [1] * rank |
|
r_size[i] = -1 |
|
size = select( |
|
g, |
|
sizes, |
|
g.op("Constant", value_t=torch.tensor([0])), |
|
g.op("Constant", value_t=torch.tensor(i)), |
|
) |
|
tmp_ind = symbolic_helper._reshape_helper( |
|
g, |
|
arange(g, size, 4, None, None, None), |
|
g.op("Constant", value_t=torch.tensor(r_size)), |
|
) |
|
tmp_ind = g.op( |
|
"Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) |
|
) |
|
if ind is None: |
|
ind = tmp_ind |
|
else: |
|
ind = g.op("Add", ind, tmp_ind) |
|
if offset: |
|
ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) |
|
return g.op("Gather", self_1d, ind) |
|
|
|
|
|
@_onnx_symbolic("aten::__derive_index") |
|
def __derive_index(g: jit_utils.GraphContext, index, start, step): |
|
return g.op("Add", start, g.op("Mul", index, step)) |
|
|
|
|
|
@_onnx_symbolic("aten::__range_length") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __range_length(g: jit_utils.GraphContext, lo, hi, step): |
|
sub = g.op("Sub", hi, lo) |
|
div = g.op("Ceil", true_divide(g, sub, step)) |
|
return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) |
|
|
|
|
|
@_onnx_symbolic("aten::linear") |
|
def linear(g: jit_utils.GraphContext, input, weight, bias): |
|
rank = symbolic_helper._get_tensor_rank(input) |
|
weight = t(g, weight) |
|
if rank == 2 and not bias.node().mustBeNone(): |
|
alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
|
beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
|
output = addmm(g, bias, input, weight, alpha, beta) |
|
else: |
|
output = matmul(g, input, weight) |
|
if not bias.node().mustBeNone(): |
|
output = add(g, bias, output) |
|
|
|
return output |
|
|
|
|
|
@_onnx_symbolic("aten::hann_window") |
|
@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") |
|
def hann_window( |
|
g: jit_utils.GraphContext, |
|
window_length, |
|
periodic=True, |
|
dtype: int | None = None, |
|
layout=None, |
|
device=None, |
|
pin_memory=None, |
|
requires_grad=False, |
|
): |
|
if dtype is None: |
|
dtype_ = torch.get_default_dtype() |
|
if not dtype_ or not dtype_.is_floating_point: |
|
dtype_ = torch.float |
|
scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) |
|
else: |
|
scalar_type = _type_utils.JitScalarType(dtype) |
|
|
|
n_array = arange(g, window_length, 4, None, None, None) |
|
output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
output = mul( |
|
g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output |
|
) |
|
|
|
if periodic is False: |
|
window_length = sub( |
|
g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) |
|
) |
|
output = div(g, output, window_length) |
|
output = g.op( |
|
"Cast", |
|
square(g, sin(g, output)), |
|
to_i=scalar_type.onnx_type(), |
|
) |
|
|
|
return output |
|
|
|
|
|
@_onnx_symbolic("aten::mv") |
|
def mv(g: jit_utils.GraphContext, self, vec): |
|
return matmul(g, self, vec) |
|
|
|
|
|
@_onnx_symbolic("aten::dot") |
|
def dot(g: jit_utils.GraphContext, self, other): |
|
return matmul(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::movedim") |
|
@symbolic_helper.parse_args("v", "t", "t") |
|
def movedim(g: jit_utils.GraphContext, self, source, destination): |
|
|
|
source = source.view(-1) |
|
destination = destination.view(-1) |
|
|
|
assert source.size() == destination.size() |
|
|
|
if (source == destination).all(): |
|
return self |
|
|
|
self_rank = symbolic_helper._get_tensor_rank(self) |
|
assert self_rank is not None |
|
|
|
perm = list(range(self_rank)) |
|
|
|
src_dims = perm.copy() |
|
dst_dims = perm.copy() |
|
|
|
for src, dst in zip(source.tolist(), destination.tolist()): |
|
perm[dst] = src |
|
src_dims[src] = -1 |
|
dst_dims[dst] = -1 |
|
|
|
src_dims = [dim for dim in src_dims if dim != -1] |
|
dst_dims = [dim for dim in dst_dims if dim != -1] |
|
|
|
for src, dst in zip(src_dims, dst_dims): |
|
perm[dst] = src |
|
|
|
return g.op("Transpose", self, perm_i=perm) |
|
|
|
|
|
@_onnx_symbolic("aten::fill") |
|
@symbolic_helper.parse_args("v", "v") |
|
def fill(g: jit_utils.GraphContext, self, value): |
|
scalar_type = _type_utils.JitScalarType.from_value( |
|
self, _type_utils.JitScalarType.FLOAT |
|
) |
|
return full_like(g, self, value, scalar_type) |
|
|
|
|
|
@_onnx_symbolic("aten::index_add") |
|
def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): |
|
warnings.warn( |
|
"Warning: ONNX export does not support duplicated values in 'index' field, " |
|
+ "this will cause the ONNX model to be incorrect." |
|
) |
|
|
|
|
|
|
|
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: |
|
return symbolic_helper._unimplemented("index_add", "alpha != 1", self) |
|
|
|
dim = symbolic_helper._maybe_get_const(dim, "i") |
|
if dim is None: |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting 'index_add_()' function with " |
|
"unknown 'dim' value.", |
|
self, |
|
) |
|
|
|
self_dim_rank = symbolic_helper._get_tensor_rank(self) |
|
other_dim_rank = symbolic_helper._get_tensor_rank(other) |
|
|
|
if self_dim_rank is None or other_dim_rank is None: |
|
raise errors.SymbolicValueError( |
|
"ONNX export does NOT support exporting 'index_add_()' function while " |
|
"the rank of self tensor or tensor to be added is unknown.", |
|
self, |
|
) |
|
|
|
if other_dim_rank != self_dim_rank: |
|
delta = self_dim_rank - other_dim_rank |
|
for i in range(delta): |
|
other = symbolic_helper._unsqueeze_helper( |
|
g, other, [symbolic_helper._get_tensor_rank(other)] |
|
) |
|
|
|
other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) |
|
self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
|
|
if (other_dim_size is not None) and (self_dim_size is not None): |
|
if other_dim_size > self_dim_size: |
|
raise errors.SymbolicValueError( |
|
"ONNX export does not support exporting 'index_add_()' function with " |
|
"duplicated values in 'index' parameter yet.", |
|
self, |
|
) |
|
|
|
|
|
|
|
new_shape_axes = list(range(self_dim_rank)) |
|
new_shape_starts = [0 for i in range(self_dim_rank)] |
|
new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] |
|
|
|
new_shape = symbolic_helper._slice_helper( |
|
g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends |
|
) |
|
other = expand_as(g, other, new_shape) |
|
|
|
for i in range(dim): |
|
index = symbolic_helper._unsqueeze_helper(g, index, [0]) |
|
|
|
for i in range(self_dim_rank - dim - 1): |
|
index = symbolic_helper._unsqueeze_helper( |
|
g, index, [symbolic_helper._get_tensor_rank(index)] |
|
) |
|
|
|
return scatter_add(g, self, dim, expand_as(g, index, other), other) |
|
|
|
|
|
@_onnx_symbolic("aten::roll") |
|
@symbolic_helper.parse_args("v", "is", "is") |
|
def roll(g: jit_utils.GraphContext, self, shifts, dims): |
|
assert len(shifts) == len(dims) |
|
|
|
result = self |
|
for i in range(len(shifts)): |
|
shapes = [] |
|
shape = symbolic_helper._slice_helper( |
|
g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] |
|
) |
|
shapes.append(shape) |
|
shape = symbolic_helper._slice_helper( |
|
g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] |
|
) |
|
shapes.append(shape) |
|
result = g.op("Concat", *shapes, axis_i=dims[i]) |
|
|
|
return result |
|
|
|
|
|
@_onnx_symbolic("aten::cross") |
|
@symbolic_helper.parse_args("v", "v", "i") |
|
def cross(g: jit_utils.GraphContext, input, other, dim=None): |
|
dim = symbolic_helper._get_dim_for_cross(input, dim) |
|
|
|
|
|
|
|
|
|
roll_x_1 = roll(g, input, [2], [dim]) |
|
roll_y_1 = roll(g, other, [1], [dim]) |
|
|
|
|
|
roll_x_2 = roll(g, input, [1], [dim]) |
|
roll_y_2 = roll(g, other, [2], [dim]) |
|
|
|
|
|
return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) |
|
|
|
|
|
@_onnx_symbolic("aten::cdist") |
|
def cdist( |
|
g: jit_utils.GraphContext, |
|
x1, |
|
x2, |
|
p=2.0, |
|
compute_mode="use_mm_for_euclid_dist_if_necessary", |
|
): |
|
|
|
|
|
|
|
|
|
row_size_x1 = symbolic_helper._get_tensor_dim_size(x1, -2) |
|
row_size_x2 = symbolic_helper._get_tensor_dim_size(x2, -2) |
|
assert row_size_x1 is not None |
|
assert row_size_x2 is not None |
|
p_float = symbolic_helper._parse_arg(p, "f") |
|
compute_mode = symbolic_helper._parse_arg(compute_mode, "i") |
|
if p_float == 2.0 and ( |
|
compute_mode == 1 |
|
or (compute_mode is None and row_size_x1 >= 25 and row_size_x2 >= 25) |
|
): |
|
return _euclidean_dist(g, x1, x2) |
|
rank = symbolic_helper._get_tensor_rank(x1) |
|
assert rank is not None |
|
broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) |
|
broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) |
|
return pairwise_distance( |
|
g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False |
|
) |
|
|
|
|
|
def _euclidean_dist(g: jit_utils.GraphContext, x1, x2): |
|
|
|
|
|
|
|
rank = symbolic_helper._get_tensor_rank(x1) |
|
assert rank is not None |
|
x1_norm = symbolic_helper._reducesum_helper( |
|
g, |
|
pow(g, x1, symbolic_helper._generate_wrapped_number(g, 2.0)), |
|
axes_i=[-1], |
|
keepdims_i=True, |
|
) |
|
x1_pad = ones_like(g, x1_norm) |
|
x2_norm = symbolic_helper._reducesum_helper( |
|
g, |
|
pow(g, x2, symbolic_helper._generate_wrapped_number(g, 2.0)), |
|
axes_i=[-1], |
|
keepdims_i=True, |
|
) |
|
x2_pad = ones_like(g, x2_norm) |
|
x1_ = g.op( |
|
"Concat", |
|
*[ |
|
mul(g, symbolic_helper._generate_wrapped_number(g, -2.0), x1), |
|
x1_norm, |
|
x1_pad, |
|
], |
|
axis_i=-1, |
|
) |
|
x2_ = g.op("Concat", *[x2, x2_pad, x2_norm], axis_i=-1) |
|
result = matmul(g, x1_, transpose(g, x2_, -2, -1)) |
|
dtype = _type_utils.JitScalarType.from_value(result) |
|
min = g.op( |
|
"Cast", symbolic_helper._generate_wrapped_number(g, 0.0), to_i=dtype.onnx_type() |
|
) |
|
result = symbolic_helper._op_with_optional_float_cast( |
|
g, "Max", result, min, opset_before=12 |
|
) |
|
result = sqrt(g, result) |
|
return result |
|
|
|
|
|
@_onnx_symbolic("aten::lerp") |
|
def lerp(g: jit_utils.GraphContext, self, end, weight): |
|
|
|
|
|
diff = g.op("Sub", end, self) |
|
return where( |
|
g, |
|
g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), |
|
g.op("Add", self, g.op("Mul", weight, diff)), |
|
g.op( |
|
"Sub", |
|
end, |
|
g.op( |
|
"Mul", |
|
diff, |
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), |
|
), |
|
), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::broadcast_tensors") |
|
def broadcast_tensors(g: jit_utils.GraphContext, self): |
|
all_tensors = symbolic_helper._unpack_list(self) |
|
t_with_final_shape = zeros_like(g, all_tensors[0]) |
|
|
|
|
|
|
|
for t in all_tensors: |
|
t_with_final_shape = add(g, t_with_final_shape, t) |
|
|
|
t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] |
|
return g.op("prim::ListConstruct", *t_list) |
|
|
|
|
|
@_onnx_symbolic("aten::is_pinned") |
|
def is_pinned(g: jit_utils.GraphContext, self, device=None): |
|
|
|
return None |
|
|
|
|
|
@_onnx_symbolic("prim::ConstantSplit") |
|
def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): |
|
size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
if size is None: |
|
return symbolic_helper._unimplemented( |
|
"prim::ConstantSplit", "unknown dimension size", self |
|
) |
|
splits = [split_size] * (size // split_size) |
|
leftover = size % split_size |
|
if leftover: |
|
splits.append(leftover) |
|
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("prim::ConstantChunk") |
|
def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): |
|
dim_size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
if dim_size is None: |
|
return symbolic_helper._unimplemented( |
|
"prim::ConstantChunk", "unknown dimension size", self |
|
) |
|
split_size = (dim_size + chunks - 1) // chunks |
|
return prim_constant_split(g, self, split_size, dim) |
|
|
|
|
|
@_onnx_symbolic("prim::shape") |
|
def prim_shape(g: jit_utils.GraphContext, self): |
|
return g.op("Shape", self) |
|
|
|
|
|
@_onnx_symbolic("prim::max") |
|
def prim_max(g: jit_utils.GraphContext, self, other): |
|
return symbolic_helper._op_with_optional_float_cast( |
|
g, "Max", self, other, opset_before=12 |
|
) |
|
|
|
|
|
@_onnx_symbolic("prim::min") |
|
def prim_min(g: jit_utils.GraphContext, self, other=None): |
|
if not other: |
|
if symbolic_helper._is_packed_list(self): |
|
self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) |
|
return min(g, self) |
|
return min(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("prim::data") |
|
def prim_data(g: jit_utils.GraphContext, self): |
|
return self |
|
|
|
|
|
@_onnx_symbolic("prim::layout") |
|
def prim_layout(g: jit_utils.GraphContext, self): |
|
|
|
|
|
return g.op("Constant", value_t=torch.tensor(0)) |
|
|
|
|
|
@_onnx_symbolic("prim::ListConstruct") |
|
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): |
|
return None |
|
|
|
|
|
@_onnx_symbolic("prim::ListUnpack") |
|
def prim_list_unpack( |
|
g: jit_utils.GraphContext, *inputs, **kwargs |
|
) -> list[_C.Value] | None: |
|
if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": |
|
|
|
|
|
return symbolic_helper._unpack_list(inputs[0]) |
|
|
|
return None |
|
|
|
|
|
@_onnx_symbolic("prim::TupleConstruct") |
|
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): |
|
return None |
|
|
|
|
|
@_onnx_symbolic("prim::Uninitialized") |
|
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("prim::unchecked_cast") |
|
def prim_unchecked_cast(g: jit_utils.GraphContext, self): |
|
return self |
|
|
|
|
|
@_onnx_symbolic("prim::dtype") |
|
def prim_dtype(g: jit_utils.GraphContext, self): |
|
scalar_type = symbolic_helper._try_get_scalar_type(self) |
|
if scalar_type is None: |
|
scalar_type = _type_utils.JitScalarType.FLOAT |
|
|
|
return g.op("Constant", value_t=torch.tensor(scalar_type)) |
|
|
|
|
|
@_onnx_symbolic("prim::tolist") |
|
def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): |
|
"""tolist is currently supported only for 1D input tensors. |
|
|
|
dim_val and elem_ty_val represent dimension and type annotations |
|
that need to match dimension and type of the input tensor. |
|
""" |
|
dim = symbolic_helper._maybe_get_const(dim_val, "i") |
|
if dim > 1: |
|
return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) |
|
return input |
|
|
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("prim::device") |
|
def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: |
|
output_type = g.original_node.output().type() |
|
if isinstance(output_type, _C.DeviceObjType): |
|
return None |
|
|
|
return symbolic_helper._unimplemented( |
|
"prim::device", |
|
f"output type should be 'DeviceObjType', not '{output_type.kind()}'", |
|
g.original_node.output(), |
|
) |
|
|
|
|
|
@_onnx_symbolic("prim::Loop") |
|
def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: |
|
node = g.original_node |
|
env = g.env |
|
values_in_env = g.values_in_env |
|
params_dict = g.params_dict |
|
|
|
operator_export_type = GLOBALS.operator_export_type |
|
opset_version = GLOBALS.export_onnx_opset_version |
|
|
|
old_blocks = tuple(node.blocks()) |
|
_new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( |
|
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) |
|
) |
|
|
|
for old_block, new_block_context in zip(old_blocks, new_block_contexts): |
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, b_in in enumerate(old_block.inputs()): |
|
if i == 0 and i < len(inputs): |
|
b_in.setType(inputs[i].type()) |
|
|
|
|
|
|
|
if ( |
|
i > 0 |
|
and (i + 1) < len(inputs) |
|
and not isinstance(b_in.type(), _C.OptionalType) |
|
): |
|
b_in.setType(inputs[i + 1].type()) |
|
torch._C._jit_pass_onnx_block( |
|
old_block, |
|
new_block_context.block, |
|
operator_export_type, |
|
env, |
|
values_in_env, |
|
False, |
|
) |
|
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( |
|
new_node, opset_version |
|
) |
|
|
|
if GLOBALS.onnx_shape_inference: |
|
torch._C._jit_pass_onnx_node_shape_type_inference( |
|
new_node, params_dict, opset_version |
|
) |
|
return fixed_outputs |
|
|
|
|
|
@_onnx_symbolic("prim::If") |
|
def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: |
|
n = g.original_node |
|
block = g.block |
|
env = g.env |
|
values_in_env = g.values_in_env |
|
params_dict = g.params_dict |
|
|
|
operator_export_type = GLOBALS.operator_export_type |
|
opset_version = GLOBALS.export_onnx_opset_version |
|
|
|
static_if = inputs[0].node().kind() == "onnx::Constant" |
|
if static_if: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() |
|
const_value = ( |
|
all(input_flag) if isinstance(input_flag, list) else bool(input_flag) |
|
) |
|
block_idx = 0 if const_value else 1 |
|
current_b = list(n.blocks())[block_idx] |
|
env = torch._C._jit_pass_onnx_block( |
|
current_b, |
|
block, |
|
operator_export_type, |
|
env, |
|
values_in_env, |
|
True, |
|
) |
|
if_output_list = list(n.outputs()) |
|
current_b_list = list(current_b.outputs()) |
|
|
|
final_b_list = [] |
|
for idx in range(len(if_output_list)): |
|
if current_b_list[idx] not in env: |
|
raise errors.SymbolicValueError( |
|
f"The sub block ATen output {current_b_list[idx]} is not in env.", |
|
current_b_list[idx], |
|
) |
|
onnx_b = env[current_b_list[idx]] |
|
final_b_list.append(onnx_b) |
|
return final_b_list |
|
else: |
|
old_blocks = tuple(n.blocks()) |
|
_new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( |
|
g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) |
|
) |
|
|
|
for old_block, new_block_context in zip(old_blocks, new_block_contexts): |
|
torch._C._jit_pass_onnx_block( |
|
old_block, |
|
new_block_context.block, |
|
operator_export_type, |
|
env, |
|
values_in_env, |
|
False, |
|
) |
|
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( |
|
new_node, opset_version |
|
) |
|
|
|
if GLOBALS.onnx_shape_inference: |
|
torch._C._jit_pass_onnx_node_shape_type_inference( |
|
new_node, params_dict, opset_version |
|
) |
|
return fixed_outputs |
|
|
|
|
|
@_onnx_symbolic("prim::Constant") |
|
def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): |
|
node = g.original_node |
|
|
|
if node.mustBeNone(): |
|
return None |
|
|
|
|
|
|
|
if isinstance(node.output().type(), _C.DeviceObjType): |
|
return None |
|
if node.kindOf("value") == "t": |
|
return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) |
|
if node.kindOf("value") == "s": |
|
return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) |
|
if node.output().type().isSubtypeOf( |
|
_C.ListType.ofInts() |
|
) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): |
|
return g.op( |
|
"Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) |
|
) |
|
if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): |
|
str_constants = [ |
|
g.op("Constant", value_s=s) |
|
for s in symbolic_helper._node_get(node, "value") |
|
] |
|
return g.op("prim::ListConstruct", *str_constants) |
|
|
|
raise errors.SymbolicValueError( |
|
f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " |
|
f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", |
|
node.output(), |
|
) |
|
|
|
|
|
@_onnx_symbolic("prim::type") |
|
def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): |
|
if device_value.node().kind() == "prim::device": |
|
device = jit_utils.get_device_from_value(device_value.node().input()) |
|
if device is not None: |
|
return g.op("Constant", value_s=str(device)) |
|
|
|
return symbolic_helper._unimplemented( |
|
"prim::type", |
|
"Device type cannot be statically determined.", |
|
device_value, |
|
) |
|
|
|
|
|
@_onnx_symbolic("onnx::Placeholder") |
|
def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): |
|
node = g.original_node |
|
block = g.block |
|
env = g.env |
|
values_in_env = g.values_in_env |
|
|
|
return torch._C._jit_onnx_convert_pattern_from_subblock( |
|
block, node, env, values_in_env |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::resolve_conj") |
|
@_onnx_symbolic("aten::resolve_neg") |
|
def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): |
|
|
|
|
|
|
|
|
|
|
|
return input |
|
|
|
|
|
@_onnx_symbolic("aten::_conj") |
|
@_onnx_symbolic("aten::conj_physical") |
|
def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): |
|
|
|
|
|
|
|
|
|
|
|
if symbolic_helper.is_complex_value(input): |
|
|
|
return symbolic_helper._onnx_unsupported( |
|
"aten::_conj, aten::conj_physical", |
|
input, |
|
) |
|
|
|
|
|
return noop_complex_operators(g, input) |
|
|
|
|
|
@_onnx_symbolic("aten::logit") |
|
def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): |
|
one = g.op("Constant", value_t=torch.tensor(1.0)) |
|
|
|
if not symbolic_helper._is_none(eps): |
|
eps = g.op( |
|
"Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() |
|
) |
|
one_sub_eps = g.op("Sub", one, eps) |
|
self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) |
|
temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) |
|
|
|
temporary_self_less_eps = g.op("Less", temporary_self, eps) |
|
z = g.op("Where", temporary_self_less_eps, eps, temporary_self) |
|
else: |
|
z = self |
|
|
|
sub = g.op("Sub", one, z) |
|
div = g.op("Div", z, sub) |
|
return g.op("Log", div) |
|
|