File size: 1,185 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
"""This file provides a location for operators that help exporting models via onnx.
E.g. `shape_as_tensor` and `reshape_from_tensor_shape`
are to make all dynamic sizes operations traceable.
NOTE: at one point these functions were implemented differently.
Since then we have implemented these directly in ATen, so this
file is kept purely for backward-compatibility.
"""
from __future__ import annotations
__all__: list[str] = []
import torch
"""Get the shape of a tensor as a tensor.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x.
Example:
>>> x = torch.randn(2, 3)
>>> shape_as_tensor(x)
tensor([2, 3])
"""
shape_as_tensor = torch._shape_as_tensor
"""Reshape a tensor to the given shape.
This function is used to make dynamic size operations traceable when exporting models via ONNX.
This function is kept for backward-compatibility. It is implemented directly in ATen.
Parameters:
x (Tensor): the tensor to be reshaped.
shape (Tensor): the target shape.
Returns:
Tensor: the reshaped tensor.
"""
reshape_from_tensor_shape = torch._reshape_from_tensor
|