|
|
|
"""This file exports ONNX ops for opset 15. |
|
|
|
Note [ONNX operators that are added/updated in opset 15] |
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set |
|
New operators: |
|
Bernoulli |
|
CastLike |
|
Optional |
|
OptionalGetElement |
|
OptionalHasElement |
|
|
|
Updated operators: |
|
BatchNormalization https://github.com/onnx/onnx/pull/3545 |
|
Backwards compatible |
|
TODO: test coverage for mixed types inputs. |
|
Pow https://github.com/onnx/onnx/pull/3412 |
|
Backwards compatible |
|
TODO: bfloat16 support. |
|
Shape https://github.com/onnx/onnx/pull/3580 |
|
Backwards compatible |
|
TODO: optional start/end attribute. |
|
""" |
|
|
|
|
|
|
|
|
|
import functools |
|
|
|
import torch |
|
from torch import _C |
|
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) |
|
|
|
|
|
@_onnx_symbolic("aten::__is_") |
|
def aten__is_(g: jit_utils.GraphContext, self, other): |
|
if symbolic_helper._is_none(other): |
|
if isinstance(self.type(), _C.OptionalType): |
|
none = g.op("OptionalHasElement", self) |
|
return g.op("Not", none) |
|
else: |
|
return g.op("Constant", value_t=torch.BoolTensor([0])) |
|
return opset9.eq(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::__isnot_") |
|
@opset9.wrap_logical_op_with_negation |
|
def aten__isnot_(g: jit_utils.GraphContext, self, other): |
|
return aten__is_(g, self, other) |
|
|
|
|
|
@_onnx_symbolic("aten::bernoulli") |
|
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): |
|
if out is not None and not symbolic_helper._is_none(out): |
|
symbolic_helper._unimplemented( |
|
"Bernoulli", "out parameter is not supported for bernoulli", input |
|
) |
|
if generator is not None and not symbolic_helper._is_none(generator): |
|
symbolic_helper._unimplemented( |
|
"Bernoulli", "generator is not supported for bernoulli", input |
|
) |
|
if p is None or symbolic_helper._is_none(p): |
|
return g.op("Bernoulli", input) |
|
return opset9.bernoulli(g, input, p, generator, out) |
|
|
|
|
|
@_onnx_symbolic("prim::unchecked_cast") |
|
def prim_unchecked_cast(g: jit_utils.GraphContext, self): |
|
|
|
|
|
|
|
if isinstance(self.type(), _C.OptionalType): |
|
return g.op("OptionalGetElement", self) |
|
|
|
return self |
|
|