|
|
|
|
|
"""This file exports ONNX ops for opset 14. |
|
|
|
Note [ONNX operators that are added/updated in opset 14] |
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
New operators: |
|
HardSwish, Trilu |
|
|
|
Updated operators: |
|
Reshape |
|
Add, Sub, Mul, Div |
|
GRU, LSTM, RNN |
|
BatchNorm, Cumsum, Relu |
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import functools |
|
|
|
import torch |
|
from torch.onnx import _constants, _type_utils, symbolic_helper |
|
from torch.onnx._globals import GLOBALS |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
__all__ = [ |
|
"hardswish", |
|
"tril", |
|
"triu", |
|
"reshape", |
|
"batch_norm", |
|
"quantized_hardswish", |
|
"scaled_dot_product_attention", |
|
] |
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) |
|
|
|
|
|
@_onnx_symbolic("aten::hardswish") |
|
@symbolic_helper.parse_args("v") |
|
def hardswish(g: jit_utils.GraphContext, self): |
|
return g.op("HardSwish", self) |
|
|
|
|
|
@_onnx_symbolic("aten::tril") |
|
def tril(g: jit_utils.GraphContext, self, diagonal, out=None): |
|
return g.op("Trilu", self, diagonal, upper_i=0) |
|
|
|
|
|
@_onnx_symbolic("aten::triu") |
|
def triu(g: jit_utils.GraphContext, self, diagonal, out=None): |
|
return g.op("Trilu", self, diagonal, upper_i=1) |
|
|
|
|
|
@_onnx_symbolic("aten::reshape") |
|
@symbolic_helper.quantized_args(True) |
|
@symbolic_helper.parse_args("v", "v") |
|
def reshape(g: jit_utils.GraphContext, self, shape): |
|
|
|
|
|
return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) |
|
|
|
|
|
@_onnx_symbolic("aten::batch_norm") |
|
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") |
|
def batch_norm( |
|
g: jit_utils.GraphContext, |
|
input, |
|
weight, |
|
bias, |
|
running_mean, |
|
running_var, |
|
training, |
|
momentum, |
|
eps, |
|
cudnn_enabled, |
|
): |
|
if ( |
|
torch.is_autocast_enabled() |
|
and not symbolic_helper.args_have_same_dtype( |
|
[input, weight, bias, running_mean, running_var] |
|
) |
|
and GLOBALS.export_onnx_opset_version < 15 |
|
): |
|
return symbolic_helper._onnx_opset_unsupported_detailed( |
|
"BatchNormalization", |
|
14, |
|
15, |
|
"All input tensors must have the same `dtype`." |
|
" Turn off Autocast or export using opset version 15.", |
|
input, |
|
) |
|
|
|
symbolic_helper.check_training_mode(training, "batch_norm") |
|
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( |
|
g, input, weight, bias, running_mean, running_var |
|
) |
|
out = g.op( |
|
"BatchNormalization", |
|
input, |
|
weight, |
|
bias, |
|
running_mean, |
|
running_var, |
|
epsilon_f=eps, |
|
momentum_f=1 - momentum, |
|
training_mode_i=0 if not training else 1, |
|
outputs=1 if not training else 3, |
|
) |
|
if not training: |
|
return out |
|
else: |
|
res, new_running_mean, new_running_var = out |
|
new_running_mean.setType(running_mean.type()) |
|
new_running_var.setType(running_var.type()) |
|
return res |
|
|
|
|
|
@_onnx_symbolic("quantized::hardswish") |
|
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
|
|
output = hardswish(g, x) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("aten::scaled_dot_product_attention") |
|
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") |
|
def scaled_dot_product_attention( |
|
g: jit_utils.GraphContext, |
|
query: torch._C.Value, |
|
key: torch._C.Value, |
|
value: torch._C.Value, |
|
attn_mask: torch._C.Value | None = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: torch._C.Value | None = None, |
|
enable_gqa: bool = False, |
|
): |
|
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( |
|
"is_causal and attn_mask cannot be set at the same time" |
|
) |
|
assert not enable_gqa, ( |
|
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True" |
|
) |
|
|
|
if symbolic_helper._is_none(scale): |
|
scale = _attention_scale(g, query) |
|
|
|
if is_causal: |
|
attn_mask = _causal_attention_mask(g, query, key) |
|
|
|
|
|
|
|
|
|
key_shape_builtin = symbolic_helper._get_tensor_rank(key) |
|
key_transposed_axes = list(range(key_shape_builtin)) |
|
key_transposed_axes[-1], key_transposed_axes[-2] = ( |
|
key_transposed_axes[-2], |
|
key_transposed_axes[-1], |
|
) |
|
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) |
|
|
|
|
|
|
|
query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) |
|
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) |
|
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) |
|
|
|
if symbolic_helper._is_none(attn_mask): |
|
mul_qk_add = mul_qk |
|
elif ( |
|
_type_utils.JitScalarType.from_value(attn_mask) |
|
== _type_utils.JitScalarType.BOOL |
|
): |
|
|
|
const_zero = g.op("Constant", value_t=torch.tensor([0.0])) |
|
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) |
|
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) |
|
mul_qk_add = g.op("Add", mul_qk, attn_mask) |
|
elif _type_utils.JitScalarType.from_value(attn_mask) in ( |
|
_type_utils.JitScalarType.FLOAT, |
|
_type_utils.JitScalarType.HALF, |
|
_type_utils.JitScalarType.BFLOAT16, |
|
): |
|
mul_qk_add = g.op("Add", mul_qk, attn_mask) |
|
else: |
|
raise ValueError( |
|
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" |
|
) |
|
|
|
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
|
|
|
if dropout_p != 0: |
|
attn_weight = g.op( |
|
"Dropout", |
|
attn_weight, |
|
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), |
|
) |
|
|
|
return g.op("MatMul", attn_weight, value) |
|
|
|
|
|
def _attention_scale( |
|
g: jit_utils.GraphContext, query: torch._C.Value |
|
) -> torch._C.Value: |
|
"""Calculate the scale factor for the attention result. |
|
|
|
Args: |
|
query: Tensor of shape [..., L, E] |
|
|
|
Returns: |
|
Scalar scale factor := 1 / math.sqrt(query.size(-1)) |
|
""" |
|
query_shape = g.op("Shape", query) |
|
query_shape_last = g.op( |
|
"Slice", |
|
query_shape, |
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), |
|
g.op( |
|
"Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) |
|
), |
|
) |
|
embedding_size = g.op( |
|
"Cast", |
|
query_shape_last, |
|
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), |
|
) |
|
const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) |
|
scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) |
|
|
|
scale = g.op( |
|
"Cast", |
|
scale, |
|
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), |
|
) |
|
return scale |
|
|
|
|
|
def _causal_attention_mask( |
|
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value |
|
) -> torch._C.Value: |
|
"""Create a causal mask for the given query and key tensors. |
|
|
|
Equivalent to:: |
|
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
|
attn_mask = torch.zeros(L, S, dtype=torch.float) |
|
attn_mask = attn_mask.masked_fill(not mask, -float("inf")) |
|
|
|
Args: |
|
query: Tensor of shape [..., L, E] |
|
key: Tensor of shape [..., S, E] |
|
|
|
Returns: |
|
Tensor of shape [L, S] |
|
""" |
|
|
|
query_shape = g.op("Shape", query) |
|
key_shape = g.op("Shape", key) |
|
|
|
last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) |
|
second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) |
|
target_length = g.op("Slice", query_shape, second_last_idx, last_idx) |
|
source_length = g.op("Slice", key_shape, second_last_idx, last_idx) |
|
|
|
size = g.op("Concat", target_length, source_length, axis_i=0) |
|
const_one = g.op("Constant", value_t=torch.tensor([1.0])) |
|
attn_mask = g.op("Expand", const_one, size) |
|
|
|
attn_mask = g.op("Trilu", attn_mask, upper_i=0) |
|
|
|
const_zero = g.op("Constant", value_t=torch.tensor([0.0])) |
|
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) |
|
attn_mask = g.op( |
|
"Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero |
|
) |
|
return attn_mask |
|
|