|
|
|
|
|
from __future__ import annotations |
|
|
|
import functools |
|
import sys |
|
|
|
import torch |
|
from torch._C import _onnx as _C_onnx |
|
from torch.onnx import ( |
|
_type_utils, |
|
errors, |
|
symbolic_helper, |
|
symbolic_opset9 as opset9, |
|
utils, |
|
) |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
"argmax", |
|
"argmin", |
|
"binary_cross_entropy_with_logits", |
|
"celu", |
|
"cross_entropy_loss", |
|
"dropout", |
|
"einsum", |
|
"ge", |
|
"le", |
|
"native_dropout", |
|
"nll_loss", |
|
"nll_loss2d", |
|
"nll_loss_nd", |
|
"outer", |
|
"pow", |
|
"tensordot", |
|
"unfold", |
|
] |
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) |
|
|
|
|
|
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): |
|
if not tensors: |
|
raise RuntimeError("Einsum inputs are empty.") |
|
|
|
if symbolic_helper._is_bool(tensors[0]): |
|
tensors = [ |
|
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) |
|
for tensor in tensors |
|
] |
|
return g.op( |
|
"Cast", |
|
g.op("Einsum", *tensors, equation_s=equation), |
|
to_i=_C_onnx.TensorProtoDataType.BOOL, |
|
) |
|
else: |
|
return g.op("Einsum", *tensors, equation_s=equation) |
|
|
|
|
|
@_onnx_symbolic("aten::einsum") |
|
@symbolic_helper.parse_args("s", "v", "is") |
|
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): |
|
tensors = symbolic_helper._unpack_list(tensor_list) |
|
return _einsum_helper(g, equation, tensors) |
|
|
|
|
|
@_onnx_symbolic("aten::outer") |
|
@symbolic_helper.parse_args("v", "v") |
|
def outer(g: jit_utils.GraphContext, input, other): |
|
|
|
if _type_utils.JitScalarType.from_value( |
|
other, _type_utils.JitScalarType.UNDEFINED |
|
) != _type_utils.JitScalarType.from_value(input): |
|
other = g.op( |
|
"Cast", |
|
other, |
|
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), |
|
) |
|
return _einsum_helper(g, "i,j->ij", [input, other]) |
|
|
|
|
|
def _dropout_returns_masked_input_and_mask( |
|
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool |
|
) -> tuple[torch._C.Value, torch._C.Value | None]: |
|
symbolic_helper.check_training_mode(train, "dropout") |
|
|
|
|
|
if not train: |
|
return input, None |
|
p = g.op("Constant", value_t=torch.tensor(p)) |
|
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) |
|
r, mask = g.op("Dropout", input, p, t, outputs=2) |
|
return r, mask |
|
|
|
|
|
@_onnx_symbolic("aten::dropout") |
|
@symbolic_helper.parse_args("v", "f", "b") |
|
def dropout(g: jit_utils.GraphContext, input, p, train): |
|
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) |
|
return masked |
|
|
|
|
|
@_onnx_symbolic("aten::native_dropout") |
|
@symbolic_helper.parse_args("v", "f", "b") |
|
def native_dropout(g: jit_utils.GraphContext, input, p, train): |
|
return _dropout_returns_masked_input_and_mask(g, input, p, train) |
|
|
|
|
|
@_onnx_symbolic("aten::nll_loss") |
|
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): |
|
|
|
|
|
|
|
reduction = symbolic_helper._maybe_get_const(reduction, "i") |
|
reduction_vals = ["none", "mean", "sum"] |
|
reduction = reduction_vals[reduction] |
|
|
|
|
|
|
|
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") |
|
if weight.node().mustBeNone(): |
|
nllloss = g.op( |
|
"NegativeLogLikelihoodLoss", |
|
self, |
|
target, |
|
reduction_s=reduction, |
|
ignore_index_i=ignore_index, |
|
) |
|
else: |
|
nllloss = g.op( |
|
"NegativeLogLikelihoodLoss", |
|
self, |
|
target, |
|
weight, |
|
reduction_s=reduction, |
|
ignore_index_i=ignore_index, |
|
) |
|
|
|
return nllloss |
|
|
|
|
|
@_onnx_symbolic("aten::nll_loss2d") |
|
def nll_loss2d( |
|
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index |
|
): |
|
return nll_loss(g, self, target, weight, reduction, ignore_index) |
|
|
|
|
|
@_onnx_symbolic("aten::nll_loss_nd") |
|
def nll_loss_nd( |
|
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index |
|
): |
|
return nll_loss(g, self, target, weight, reduction, ignore_index) |
|
|
|
|
|
@_onnx_symbolic("aten::cross_entropy_loss") |
|
def cross_entropy_loss( |
|
g: jit_utils.GraphContext, |
|
self, |
|
target, |
|
weight, |
|
reduction, |
|
ignore_index, |
|
label_smoothing, |
|
): |
|
|
|
|
|
|
|
reduction = symbolic_helper._maybe_get_const(reduction, "i") |
|
reduction_vals = ["none", "mean", "sum"] |
|
reduction = reduction_vals[reduction] |
|
|
|
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") |
|
if label_smoothing is not None and label_smoothing > 0.0: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX does not support label_smoothing", self |
|
) |
|
|
|
|
|
|
|
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") |
|
if weight.node().mustBeNone(): |
|
celoss = g.op( |
|
"SoftmaxCrossEntropyLoss", |
|
self, |
|
target, |
|
reduction_s=reduction, |
|
ignore_index_i=ignore_index, |
|
) |
|
else: |
|
celoss = g.op( |
|
"SoftmaxCrossEntropyLoss", |
|
self, |
|
target, |
|
weight, |
|
reduction_s=reduction, |
|
ignore_index_i=ignore_index, |
|
) |
|
|
|
return celoss |
|
|
|
|
|
@_onnx_symbolic("aten::binary_cross_entropy_with_logits") |
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i") |
|
def binary_cross_entropy_with_logits( |
|
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction |
|
): |
|
p = g.op("Constant", value_t=torch.tensor([1])) |
|
sig_x = opset9.sigmoid(g, input) |
|
log_sig_x = opset9.log(g, sig_x) |
|
sub_1_x = opset9.sub(g, p, sig_x) |
|
sub_1_y = opset9.sub(g, p, target) |
|
log_1_x = opset9.log(g, sub_1_x) |
|
if pos_weight is None or symbolic_helper._is_none(pos_weight): |
|
output = opset9.neg( |
|
g, |
|
opset9.add( |
|
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) |
|
), |
|
) |
|
else: |
|
output = opset9.neg( |
|
g, |
|
opset9.add( |
|
g, |
|
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), |
|
opset9.mul(g, sub_1_y, log_1_x), |
|
), |
|
) |
|
|
|
if weight is not None and not symbolic_helper._is_none(weight): |
|
output = opset9.mul(g, weight, output) |
|
|
|
reduction = symbolic_helper._maybe_get_const(reduction, "i") |
|
if reduction == 0: |
|
return output |
|
elif reduction == 1: |
|
return g.op("ReduceMean", output, keepdims_i=0) |
|
elif reduction == 2: |
|
return g.op("ReduceSum", output, keepdims_i=0) |
|
else: |
|
return symbolic_helper._onnx_unsupported( |
|
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum", |
|
input, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::celu") |
|
def celu(g: jit_utils.GraphContext, self, alpha): |
|
alpha = symbolic_helper._maybe_get_const(alpha, "f") |
|
|
|
if ( |
|
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) |
|
== _type_utils.JitScalarType.DOUBLE |
|
): |
|
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
out = g.op("Celu", self, alpha_f=alpha) |
|
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) |
|
|
|
return g.op("Celu", self, alpha_f=alpha) |
|
|
|
|
|
@_onnx_symbolic("aten::argmax") |
|
@symbolic_helper.parse_args("v", "v", "b") |
|
def argmax( |
|
g: jit_utils.GraphContext, |
|
input: torch._C.Value, |
|
dim: torch._C.Value, |
|
keepdim: bool, |
|
): |
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") |
|
|
|
|
|
@_onnx_symbolic("aten::argmin") |
|
@symbolic_helper.parse_args("v", "v", "b") |
|
def argmin( |
|
g: jit_utils.GraphContext, |
|
input: torch._C.Value, |
|
dim: torch._C.Value, |
|
keepdim: bool, |
|
): |
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") |
|
|
|
|
|
@_onnx_symbolic("aten::pow") |
|
def pow(g: jit_utils.GraphContext, self, exponent): |
|
return g.op("Pow", self, exponent) |
|
|
|
|
|
@_onnx_symbolic("aten::ge") |
|
def ge(g: jit_utils.GraphContext, input, other): |
|
return g.op("GreaterOrEqual", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::le") |
|
def le(g: jit_utils.GraphContext, input, other): |
|
return g.op("LessOrEqual", input, other) |
|
|
|
|
|
@_onnx_symbolic("aten::unfold") |
|
@symbolic_helper.parse_args("v", "i", "v", "v") |
|
def unfold(g: jit_utils.GraphContext, input, dimension, size, step): |
|
const_size = symbolic_helper._maybe_get_const(size, "i") |
|
const_step = symbolic_helper._maybe_get_const(step, "i") |
|
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( |
|
const_step |
|
): |
|
return opset9.unfold(g, input, dimension, const_size, const_step) |
|
|
|
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) |
|
if sizedim is not None: |
|
low_start = g.op("Constant", value_t=torch.tensor(0)) |
|
low_end = g.op("Constant", value_t=torch.tensor(sizedim)) |
|
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) |
|
low_indices = g.op("Range", low_start, low_end, step) |
|
hi_indices = g.op("Range", size, hi_end, step) |
|
|
|
low_size = symbolic_helper._size_helper( |
|
g, low_indices, g.op("Constant", value_t=torch.tensor(0)) |
|
) |
|
hi_size = symbolic_helper._size_helper( |
|
g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) |
|
) |
|
|
|
ndim = symbolic_helper._get_tensor_rank(input) |
|
assert ndim is not None |
|
perm = list(range(0, ndim)) |
|
perm.append(perm.pop(dimension)) |
|
|
|
unsqueeze_list = [] |
|
loop_condition = g.op("Constant", value_t=torch.tensor(1)) |
|
loop_condition = g.op( |
|
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL |
|
) |
|
loop_len = g.op("Min", low_size, hi_size) |
|
|
|
loop, (loop_context,), _ = jit_utils.add_op_with_blocks( |
|
g, "Loop", loop_len, loop_condition, n_blocks=1 |
|
) |
|
|
|
loop_block = loop_context.block |
|
block_input_iter = utils._add_input_to_block(loop_block) |
|
cond = utils._add_input_to_block(loop_block) |
|
|
|
starts = loop_context.op("Gather", low_indices, block_input_iter) |
|
ends = loop_context.op("Gather", hi_indices, block_input_iter) |
|
axes = loop_context.op("Constant", value_t=torch.tensor([2])) |
|
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) |
|
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) |
|
stack = loop_context.op("Slice", input, starts, ends, axes) |
|
|
|
unsqueeze = symbolic_helper._unsqueeze_helper( |
|
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] |
|
) |
|
unsqueeze_list.append(unsqueeze) |
|
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) |
|
|
|
cond_out = loop_context.op( |
|
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL |
|
) |
|
utils._add_output_to_block(loop_block, cond_out) |
|
utils._add_output_to_block(loop_block, concat) |
|
|
|
loop_output = loop.node().output() |
|
perm = [0, 1, 2, 3, 4] |
|
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] |
|
transpose = g.op("Transpose", loop_output, perm_i=perm) |
|
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) |
|
|
|
return squeeze |
|
|
|
return symbolic_helper._unimplemented("Unfold", "input size not accessible") |
|
|
|
|
|
@_onnx_symbolic("aten::tensordot") |
|
@symbolic_helper.parse_args("v", "v", "is", "is", "v") |
|
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): |
|
if out is not None: |
|
symbolic_helper._unimplemented( |
|
"Tensordot", "Out parameter is not supported for tensordot." |
|
) |
|
|
|
dim_count_a = symbolic_helper._get_tensor_rank(input_a) |
|
if dim_count_a is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", |
|
input_a, |
|
) |
|
|
|
dim_count_b = symbolic_helper._get_tensor_rank(input_b) |
|
if dim_count_b is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", |
|
input_b, |
|
) |
|
|
|
dims_a = [ |
|
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] |
|
for i in range(len(dims_a)) |
|
] |
|
dims_b = [ |
|
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] |
|
for i in range(len(dims_b)) |
|
] |
|
|
|
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] |
|
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] |
|
|
|
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) |
|
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) |
|
|
|
input_shape = g.op("Shape", new_input_a) |
|
left_sizes_a = symbolic_helper._slice_helper( |
|
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] |
|
) |
|
shape_sizes = [ |
|
left_sizes_a, |
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), |
|
] |
|
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) |
|
|
|
input_shape = g.op("Shape", output_a) |
|
slices = symbolic_helper._slice_helper( |
|
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] |
|
) |
|
shape_sizes = [ |
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), |
|
slices, |
|
] |
|
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) |
|
|
|
input_shape = g.op("Shape", new_input_b) |
|
left_sizes_b = symbolic_helper._slice_helper( |
|
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] |
|
) |
|
slices = symbolic_helper._slice_helper( |
|
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] |
|
) |
|
shape_sizes = [ |
|
slices, |
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), |
|
] |
|
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) |
|
|
|
input_shape = g.op("Shape", output_b) |
|
slices = symbolic_helper._slice_helper( |
|
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] |
|
) |
|
shape_sizes = [ |
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), |
|
slices, |
|
] |
|
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) |
|
|
|
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) |
|
|
|
shape_sizes = [left_sizes_a, left_sizes_b] |
|
return opset9._reshape_from_tensor(g, output, shape_sizes) |
|
|