|
|
|
|
|
|
|
|
|
|
|
import functools |
|
|
|
import torch |
|
import torch._C._onnx as _C_onnx |
|
from torch.onnx import ( |
|
_constants, |
|
_type_utils, |
|
errors, |
|
symbolic_helper, |
|
symbolic_opset11 as opset11, |
|
symbolic_opset9 as opset9, |
|
utils, |
|
) |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) |
|
|
|
|
|
@_onnx_symbolic("aten::softmax") |
|
@symbolic_helper.parse_args("v", "i", "none") |
|
def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): |
|
softmax = g.op("Softmax", input, axis_i=dim) |
|
if dtype and dtype.node().kind() != "prim::Constant": |
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
softmax = g.op( |
|
"Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() |
|
) |
|
|
|
return softmax |
|
|
|
|
|
@_onnx_symbolic("aten::log_softmax") |
|
@symbolic_helper.parse_args("v", "i", "none") |
|
def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): |
|
return_op = g.op("LogSoftmax", input, axis_i=dim) |
|
if dtype and dtype.node().kind() != "prim::Constant": |
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
return_op = g.op( |
|
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() |
|
) |
|
return return_op |
|
|
|
|
|
@_onnx_symbolic("aten::frobenius_norm") |
|
@symbolic_helper.parse_args("v", "v", "i") |
|
def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): |
|
dim_val = symbolic_helper._maybe_get_const(dim, "is") |
|
if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: |
|
return g.op("ReduceL2", self, keepdims_i=0) |
|
sqr = g.op("Mul", self, self) |
|
sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) |
|
return g.op("Sqrt", sumsqr) |
|
|
|
|
|
@_onnx_symbolic("aten::split") |
|
@symbolic_helper.parse_args("v", "v", "i", "i") |
|
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): |
|
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): |
|
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) |
|
if _outputs is None: |
|
return split_out |
|
|
|
if ( |
|
symbolic_helper._is_packed_list(split_size_or_sizes) |
|
and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs |
|
): |
|
split_sizes = [ |
|
symbolic_helper._unsqueeze_helper(g, v, [0]) |
|
for v in symbolic_helper._unpack_list(split_size_or_sizes) |
|
] |
|
|
|
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) |
|
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) |
|
res = [] |
|
for i in range(_outputs): |
|
end = g.op( |
|
"Add", start, split_sizes[i] |
|
) |
|
res.append(g.op("Slice", self, start, end, axis)) |
|
start = end |
|
return res |
|
return [ |
|
g.op( |
|
"SequenceAt", |
|
split_out, |
|
g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), |
|
) |
|
for i in range(_outputs) |
|
] |
|
|
|
split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") |
|
if split_val.dim() > 0: |
|
return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) |
|
split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") |
|
|
|
size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
if size is None: |
|
if _outputs is not None: |
|
size = split_size * _outputs |
|
else: |
|
raise errors.SymbolicValueError( |
|
"Unknown dimension size not supported", self |
|
) |
|
splits = [split_size] * (size // split_size) |
|
leftover = size % split_size |
|
if leftover: |
|
splits.append(leftover) |
|
splits = g.op("Constant", value_t=torch.tensor(splits)) |
|
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::split_with_sizes") |
|
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): |
|
return split(g, self, split_sizes, dim, _outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::unsafe_split") |
|
def unsafe_split( |
|
g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None |
|
): |
|
return split(g, self, split_size_or_sizes, dim, _outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::unsafe_split_with_sizes") |
|
def unsafe_split_with_sizes( |
|
g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None |
|
): |
|
return split_with_sizes(g, self, split_sizes, dim, _outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::tensor_split") |
|
@symbolic_helper.parse_args("v", "v", "i", "i") |
|
def tensor_split( |
|
g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None |
|
): |
|
axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) |
|
axis = opset11.unsqueeze(g, axis, 0) |
|
const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) |
|
|
|
if symbolic_helper._is_split_static(indices_or_sections, _outputs): |
|
split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") |
|
|
|
if split_val.dim() > 0: |
|
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) |
|
res = [] |
|
assert _outputs is not None |
|
for i in range(_outputs - 1): |
|
end = g.op( |
|
"Gather", |
|
indices_or_sections, |
|
g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), |
|
axis_i=0, |
|
) |
|
res.append(g.op("Slice", self, start, end, axis)) |
|
start = end |
|
|
|
end = symbolic_helper._size_helper(g, self, axis) |
|
res.append(g.op("Slice", self, start, end, axis)) |
|
return res |
|
|
|
split_size = symbolic_helper._get_const( |
|
indices_or_sections, "i", "indices_or_sections" |
|
) |
|
|
|
size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
if size is None: |
|
if _outputs is not None: |
|
size = split_size * _outputs |
|
else: |
|
raise errors.SymbolicValueError( |
|
"Unknown dimension size not supported", self |
|
) |
|
|
|
min_split_size = size // split_size |
|
num_splits_one_extra = size % split_size |
|
|
|
splits = num_splits_one_extra * [min_split_size + 1] |
|
leftover = (split_size - num_splits_one_extra) * [min_split_size] |
|
|
|
splits = g.op( |
|
"Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) |
|
) |
|
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) |
|
|
|
if ( |
|
symbolic_helper._is_tensor(indices_or_sections) |
|
and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 |
|
): |
|
loop_len = symbolic_helper._size_helper( |
|
g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) |
|
) |
|
loop_len = opset11.unsqueeze(g, loop_len, 0) |
|
loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) |
|
|
|
|
|
|
|
padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) |
|
indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) |
|
|
|
final_splits = g.op("SequenceEmpty") |
|
|
|
loop, (loop_context,), _ = jit_utils.add_op_with_blocks( |
|
g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 |
|
) |
|
|
|
loop_block = loop_context.block |
|
block_input_iter = utils._add_input_to_block(loop_block) |
|
cond = utils._add_input_to_block(loop_block) |
|
final_splits = utils._add_input_to_block(loop_block) |
|
|
|
start = loop_context.op( |
|
"Gather", indices_or_sections, block_input_iter, axis_i=0 |
|
) |
|
end = loop_context.op( |
|
"Gather", |
|
indices_or_sections, |
|
loop_context.op("Add", block_input_iter, const_1), |
|
axis_i=0, |
|
) |
|
|
|
slice = loop_context.op("Slice", self, start, end, axis) |
|
final_splits = loop_context.op("SequenceInsert", final_splits, slice) |
|
|
|
|
|
cond_out = loop_context.op("Identity", loop_condition) |
|
utils._add_output_to_block(loop_block, cond_out) |
|
utils._add_output_to_block(loop_block, final_splits) |
|
|
|
loop_out = loop.node().output() |
|
start = g.op( |
|
"Gather", |
|
indices_or_sections, |
|
g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), |
|
axis_i=0, |
|
) |
|
start = opset11.unsqueeze(g, start, 0) |
|
end = symbolic_helper._size_helper(g, self, axis) |
|
|
|
last_slice = g.op("Slice", self, start, end, axis) |
|
|
|
return g.op("SequenceInsert", loop_out, last_slice) |
|
|
|
else: |
|
dim_size = symbolic_helper._size_helper(g, self, axis) |
|
min_split_size = g.op("Div", dim_size, indices_or_sections) |
|
min_split_size_plus_1 = g.op( |
|
"Add", |
|
min_split_size, |
|
const_1, |
|
) |
|
num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) |
|
splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) |
|
leftover = g.op( |
|
"Tile", |
|
min_split_size, |
|
g.op( |
|
"Sub", |
|
opset11.unsqueeze(g, indices_or_sections, 0), |
|
num_splits_one_extra, |
|
), |
|
) |
|
|
|
splits = g.op("Concat", splits, leftover, axis_i=0) |
|
if _outputs is None: |
|
return g.op("SplitToSequence", self, splits, axis_i=dim) |
|
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::unbind") |
|
@symbolic_helper.parse_args("v", "i", "i") |
|
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): |
|
if _outputs is None: |
|
return g.op( |
|
"SplitToSequence", |
|
self, |
|
g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), |
|
axis_i=dim, |
|
keepdims_i=0, |
|
) |
|
|
|
splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) |
|
outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) |
|
outputs = [outputs] if _outputs == 1 else outputs |
|
squeezed_outputs = [ |
|
g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) |
|
for out in outputs |
|
] |
|
return squeezed_outputs |
|
|
|
|
|
@_onnx_symbolic("aten::nonzero_numpy") |
|
|
|
def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): |
|
return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::where") |
|
@symbolic_helper.parse_args("v", "v", "v", "i") |
|
def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): |
|
|
|
if not symbolic_helper._is_bool(condition): |
|
condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) |
|
if self is None: |
|
condition = opset9.nonzero(g, condition) |
|
return symbolic_helper._unbind_helper( |
|
g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs |
|
) |
|
return g.op("Where", condition, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::fake_quantize_per_channel_affine") |
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") |
|
def fake_quantize_per_channel_affine( |
|
g: jit_utils.GraphContext, |
|
inputs, |
|
scale, |
|
zero_point, |
|
axis, |
|
quant_min=-128, |
|
quant_max=127, |
|
): |
|
|
|
|
|
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: |
|
raise errors.SymbolicValueError( |
|
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " |
|
f"Got ({quant_min}, {quant_max})", |
|
inputs, |
|
) |
|
|
|
if quant_min == 0: |
|
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) |
|
else: |
|
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) |
|
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) |
|
if (quant_min, quant_max) == (0, 127): |
|
quantized = g.op( |
|
"Clip", |
|
quantized, |
|
opset9.unused(g), |
|
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), |
|
) |
|
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) |
|
|
|
|
|
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") |
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i") |
|
def fake_quantize_per_tensor_affine( |
|
g: jit_utils.GraphContext, |
|
inputs, |
|
scale, |
|
zero_point, |
|
quant_min=-128, |
|
quant_max=127, |
|
): |
|
|
|
|
|
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: |
|
raise errors.SymbolicValueError( |
|
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " |
|
f"Got ({quant_min}, {quant_max})", |
|
inputs, |
|
) |
|
if quant_min == 0: |
|
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) |
|
else: |
|
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) |
|
if ( |
|
_type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) |
|
!= _type_utils.JitScalarType.FLOAT |
|
): |
|
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
quantized = g.op("QuantizeLinear", inputs, scale, zero_point) |
|
if (quant_min, quant_max) == (0, 127): |
|
quantized = g.op( |
|
"Clip", |
|
quantized, |
|
opset9.unused(g), |
|
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), |
|
) |
|
return g.op("DequantizeLinear", quantized, scale, zero_point) |
|
|
|
|
|
def _reduce_op_symbolic(onnx_op_name): |
|
def symbolic(g, self, dim=None, keepdim=None): |
|
self = symbolic_helper._maybe_cast_reduce_op_input(g, self) |
|
if dim is None: |
|
|
|
return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) |
|
else: |
|
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") |
|
return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) |
|
|
|
return symbolic |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::sum", |
|
decorate=[symbolic_helper._apply_params("ReduceSum", "sum")], |
|
) |
|
def _reduce_with_dtype(onnx_op, name): |
|
symbolic = _reduce_op_symbolic(onnx_op) |
|
|
|
@symbolic_helper._overload_by_arg_count |
|
def reduce(g, *args, **kwargs): |
|
@symbolic_helper.parse_args("v", "none") |
|
def reduce_nodim(g, self, dtype): |
|
dtype_onnx = None |
|
if dtype.node().kind() == "onnx::Constant": |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() |
|
self = g.op("Cast", self, to_i=dtype_onnx) |
|
elif dtype.node().kind() != "prim::Constant": |
|
return symbolic_helper._unimplemented(name, "dtype", dtype) |
|
result = symbolic(g, self) |
|
if dtype_onnx is not None: |
|
result_dtype_onnx = _type_utils.JitScalarType.from_value( |
|
result |
|
).onnx_type() |
|
if result_dtype_onnx != dtype_onnx: |
|
result = g.op("Cast", result, to_i=dtype_onnx) |
|
return result |
|
|
|
@symbolic_helper.parse_args("v", "v", "i", "none") |
|
def reduce_dim(g, self, dim, keepdim, dtype): |
|
dtype_onnx = None |
|
if dtype.node().kind() == "onnx::Constant": |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() |
|
self = g.op("Cast", self, to_i=dtype_onnx) |
|
elif dtype.node().kind() != "prim::Constant": |
|
return symbolic_helper._unimplemented(name, "dtype", dtype) |
|
result = symbolic(g, self, dim, keepdim) |
|
if dtype_onnx is not None: |
|
result_dtype_onnx = _type_utils.JitScalarType.from_value( |
|
result |
|
).onnx_type() |
|
if result_dtype_onnx != dtype_onnx: |
|
result = g.op("Cast", result, to_i=dtype_onnx) |
|
return result |
|
|
|
return reduce_nodim, reduce_dim |
|
|
|
return reduce |
|
|
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("aten::unflatten") |
|
def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): |
|
input_dim = symbolic_helper._get_tensor_rank(input) |
|
if input_dim is None: |
|
return symbolic_helper._unimplemented( |
|
"dim", |
|
"ONNX and PyTorch use different strategies to split the input. " |
|
"Input rank must be known at export time.", |
|
) |
|
|
|
|
|
input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) |
|
dim = g.op("Add", input_dim, dim) |
|
dim = g.op("Mod", dim, input_dim) |
|
|
|
input_size = g.op("Shape", input) |
|
|
|
head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) |
|
head_end_idx = g.op( |
|
"Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) |
|
) |
|
head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) |
|
|
|
dim_plus_one = g.op( |
|
"Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) |
|
) |
|
tail_start_idx = g.op( |
|
"Reshape", |
|
dim_plus_one, |
|
g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), |
|
) |
|
tail_end_idx = g.op( |
|
"Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) |
|
) |
|
tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) |
|
|
|
final_shape = g.op( |
|
"Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 |
|
) |
|
|
|
return symbolic_helper._reshape_helper(g, input, final_shape) |
|
|
|
|
|
@_onnx_symbolic("aten::unsafe_chunk") |
|
@symbolic_helper.parse_args("v", "i", "i", "i") |
|
def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): |
|
if _outputs is None: |
|
return g.op( |
|
"SplitToSequence", |
|
self, |
|
g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), |
|
axis_i=dim, |
|
keepdims_i=0, |
|
) |
|
|
|
size = symbolic_helper._get_tensor_dim_size(self, dim) |
|
if size is None: |
|
return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") |
|
split_size = (size + chunks - 1) // chunks |
|
splits = [split_size] * (size // split_size) |
|
leftover = size % split_size |
|
if leftover: |
|
splits.append(leftover) |
|
|
|
|
|
|
|
|
|
splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) |
|
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) |
|
|
|
|
|
@_onnx_symbolic("aten::tile") |
|
def tile(g: jit_utils.GraphContext, self, dims): |
|
self_shape = g.op("Shape", self) |
|
self_rank = g.op("Size", self_shape) |
|
dims_rank = g.op("Size", dims) |
|
diff = g.op("Sub", self_rank, dims_rank) |
|
const_zero = g.op("Constant", value_t=torch.tensor([0])) |
|
|
|
|
|
dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) |
|
( |
|
if_op_greater, |
|
(if_context_greater, else_context_greater), |
|
_, |
|
) = jit_utils.add_op_with_blocks( |
|
g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 |
|
) |
|
const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) |
|
diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) |
|
exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) |
|
dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) |
|
utils._add_output_to_block(if_context_greater.block, dims_) |
|
identity_dim = else_context_greater.op("Identity", dims) |
|
utils._add_output_to_block(else_context_greater.block, identity_dim) |
|
dims_final = if_op_greater.node().output() |
|
|
|
|
|
dims_longer_than_self_shape = g.op("Less", diff, const_zero) |
|
( |
|
if_op_less, |
|
(if_context_less, else_context_less), |
|
_, |
|
) = jit_utils.add_op_with_blocks( |
|
g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 |
|
) |
|
const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) |
|
diff_1d_less = if_context_less.op( |
|
"Reshape", |
|
if_context_less.op("Abs", diff), |
|
const_one, |
|
) |
|
exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) |
|
self_final_shape = if_context_less.op( |
|
"Concat", exapnd_ones_less, self_shape, axis_i=0 |
|
) |
|
self_ = if_context_less.op("Reshape", self, self_final_shape) |
|
utils._add_output_to_block(if_context_less.block, self_) |
|
identity_self = else_context_less.op("Identity", self) |
|
utils._add_output_to_block(else_context_less.block, identity_self) |
|
self_final = if_op_less.node().output() |
|
|
|
dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) |
|
return g.op("Tile", self_final, dims_final) |
|
|
|
|
|
@_onnx_symbolic("aten::repeat_interleave") |
|
def repeat_interleave( |
|
g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None |
|
): |
|
repeats_dim = symbolic_helper._get_tensor_rank(repeats) |
|
repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) |
|
input_sizes = symbolic_helper._get_tensor_sizes(self) |
|
if repeats_dim is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", |
|
self, |
|
) |
|
if repeats_sizes is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of repeat_interleave for unknown repeats size.", |
|
self, |
|
) |
|
if input_sizes is None: |
|
raise errors.SymbolicValueError( |
|
"Unsupported: ONNX export of repeat_interleave for unknown input size.", |
|
self, |
|
) |
|
|
|
final_dim = dim |
|
|
|
|
|
if symbolic_helper._is_none(dim): |
|
self = symbolic_helper._reshape_helper( |
|
g, self, g.op("Constant", value_t=torch.tensor([-1])) |
|
) |
|
dim = torch.tensor(0, dtype=torch.int64) |
|
else: |
|
dim = symbolic_helper._maybe_get_scalar(dim) |
|
|
|
|
|
if dim < 0: |
|
dim += len(input_sizes) |
|
|
|
output_sizes = input_sizes.copy() |
|
for idx, input_size in enumerate(input_sizes): |
|
if input_size is None: |
|
output_sizes[idx], input_sizes[idx] = 0, -1 |
|
|
|
|
|
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): |
|
return symbolic_helper._repeat_interleave_single_value_repeat_helper( |
|
g, self, repeats, dim |
|
) |
|
|
|
cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None |
|
|
|
if output_sizes[dim] == 0 or cond_dynamic_repeats: |
|
reps = symbolic_helper._size_helper(g, self, dim) |
|
reps = opset11.unsqueeze(g, reps, 0) |
|
|
|
|
|
|
|
|
|
if cond_dynamic_repeats: |
|
repeat_dim = symbolic_helper._size_helper( |
|
g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) |
|
) |
|
repeat_cond = g.op( |
|
"Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) |
|
) |
|
repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
return opset9.repeat_interleave(g, self, repeats, final_dim) |
|
|
|
reps_like = g.op( |
|
"ConstantOfShape", |
|
g.op("Shape", repeats), |
|
value_t=torch.tensor([1], dtype=torch.long), |
|
) |
|
r_splits = split(g, repeats, reps_like, 0) |
|
i_splits = split(g, self, reps_like, dim) |
|
|
|
output_sizes[dim], input_sizes[dim] = -1, 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loop_condition = g.op("Constant", value_t=torch.tensor(1)) |
|
loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) |
|
loop_len = reps |
|
|
|
|
|
final_splits = g.op("SequenceEmpty") |
|
|
|
|
|
loop, (loop_context,), _ = jit_utils.add_op_with_blocks( |
|
g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 |
|
) |
|
|
|
loop_block = loop_context.block |
|
block_input_iter = utils._add_input_to_block(loop_block) |
|
cond = utils._add_input_to_block(loop_block) |
|
final_splits = utils._add_input_to_block(loop_block) |
|
|
|
r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) |
|
i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) |
|
|
|
i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) |
|
r_concat = [ |
|
loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), |
|
r_split, |
|
loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), |
|
] |
|
r_concat = loop_context.op("Concat", *r_concat, axis_i=0) |
|
i_split = opset9.expand(loop_context, i_split, r_concat, None) |
|
i_split = symbolic_helper._reshape_helper( |
|
loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) |
|
) |
|
final_splits = loop_context.op("SequenceInsert", final_splits, i_split) |
|
|
|
|
|
cond_out = loop_context.op( |
|
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL |
|
) |
|
utils._add_output_to_block(loop_block, cond_out) |
|
utils._add_output_to_block(loop_block, final_splits) |
|
|
|
loop_out = loop.node().output() |
|
loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) |
|
return loop_out |
|
|
|
|
|
@_onnx_symbolic("aten::diagonal") |
|
@symbolic_helper.parse_args("v", "i", "i", "i") |
|
def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): |
|
rank = symbolic_helper._get_tensor_rank(self) |
|
|
|
if rank is not None: |
|
dim1 = dim1 if dim1 >= 0 else dim1 + rank |
|
dim2 = dim2 if dim2 >= 0 else dim2 + rank |
|
|
|
dim1_size = opset9.size( |
|
g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) |
|
) |
|
dim2_size = opset9.size( |
|
g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) |
|
) |
|
|
|
mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) |
|
mask = opset9.zeros(g, mask_shape, None, None, None) |
|
mask = g.op("EyeLike", mask, k_i=offset) |
|
|
|
|
|
if rank is not None: |
|
axes = list(range(rank)) |
|
axes.remove(dim1) |
|
axes.remove(dim2) |
|
self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) |
|
else: |
|
return symbolic_helper._unimplemented("diagonal", "unknown input rank") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = g.op("Mul", self, mask) |
|
result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) |
|
|
|
|
|
|
|
|
|
offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) |
|
if offset >= 0: |
|
diag_size = g.op( |
|
"Max", |
|
g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), |
|
g.op("Constant", value_t=torch.LongTensor([0])), |
|
) |
|
offset = 0 |
|
else: |
|
diag_size = g.op( |
|
"Max", |
|
g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), |
|
g.op("Constant", value_t=torch.LongTensor([0])), |
|
) |
|
diag_size = g.op("Concat", diag_size, axis_i=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) |
|
select_window = g.op( |
|
"CumSum", |
|
select_window_ones_fill, |
|
g.op("Constant", value_t=torch.LongTensor([0])), |
|
) |
|
select_window = g.op( |
|
"Add", |
|
select_window, |
|
g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), |
|
) |
|
|
|
gather_shape = [ |
|
opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) |
|
for axis in list(range(rank))[:-2] |
|
] |
|
gather_shape.append(diag_size) |
|
gather_shape = g.op("Concat", *gather_shape, axis_i=0) |
|
gather_indices = opset9.zeros(g, gather_shape, 4, None, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overrun_cond = g.op( |
|
"Not", |
|
g.op( |
|
"Equal", |
|
diag_size, |
|
g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), |
|
), |
|
) |
|
|
|
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( |
|
g, "If", overrun_cond, n_blocks=2 |
|
) |
|
|
|
gather_indices_if_block = if_context.op("Add", gather_indices, select_window) |
|
gather_indices_if_block = symbolic_helper._unsqueeze_helper( |
|
if_context, gather_indices_if_block, [rank - 1] |
|
) |
|
final_non_overrun = if_context.op( |
|
"GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 |
|
) |
|
final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) |
|
utils._add_output_to_block(if_context.block, final_non_overrun) |
|
utils._add_output_to_block(else_context.block, final_overrun) |
|
return if_op |
|
|
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("quantized::linear") |
|
def quantized_linear( |
|
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.linear(g, input, weight, bias) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::linear_relu") |
|
def quantized_linear_relu( |
|
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.linear(g, input, weight, bias) |
|
output = opset9.relu(g, output) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv1d_relu") |
|
def quantized_conv1d_relu( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) |
|
output = opset9.relu(g, output) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv2d_relu") |
|
def quantized_conv2d_relu( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) |
|
output = opset9.relu(g, output) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv3d_relu") |
|
def quantized_conv3d_relu( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) |
|
output = opset9.relu(g, output) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv1d") |
|
def quantized_conv1d( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv2d") |
|
def quantized_conv2d( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv3d") |
|
def quantized_conv3d( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv_transpose1d") |
|
def quantized_conv_transpose1d( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
output_padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv_transpose2d( |
|
g, input, weight, bias, stride, padding, output_padding, groups, dilation |
|
) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv_transpose2d") |
|
def quantized_conv_transpose2d( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
output_padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv_transpose2d( |
|
g, input, weight, bias, stride, padding, output_padding, groups, dilation |
|
) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv_transpose3d") |
|
def quantized_conv_transpose3d( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
output_padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper( |
|
g, bias, input_scale, weight_scale, axis |
|
) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv_transpose3d( |
|
g, input, weight, bias, stride, padding, output_padding, groups, dilation |
|
) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|