File size: 4,374 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# mypy: allow-untyped-defs
import torch
from torch import Tensor
aten = torch.ops.aten
import inspect
import warnings
from typing import Callable, Optional, TypeVar
from typing_extensions import ParamSpec
from torch.types import Number
decomposition_table: dict[str, torch.jit.ScriptFunction] = {}
function_name_set: set[str] = set()
_T = TypeVar("_T")
_P = ParamSpec("_P")
def check_decomposition_has_type_annotations(f):
inspect_empty = inspect._empty # type: ignore[attr-defined]
sig = inspect.signature(f)
for param in sig.parameters.values():
assert (
param.annotation != inspect_empty
), f"No signature on param {param.name} for function {f.name}"
assert (
sig.return_annotation != inspect_empty
), f"No return annotation for function {f.name}"
def signatures_match(decomposition_sig, torch_op_sig):
decomp_params = decomposition_sig.parameters
op_params = torch_op_sig.parameters
if len(decomp_params) != len(op_params):
return False
for decomp_param, op_param in zip(decomp_params.values(), op_params.values()):
# can't check full equality yet because not all fields are correcly deduced
# in the torch_op_sig - like default value
# can't check 'kind' bc
# kwarg-only values with defaults not yet supported in TS
inspect_empty = inspect._empty # type: ignore[attr-defined]
for field in ["name", "annotation"]:
if field == "name" and decomp_param.name == "self":
warnings.warn("PyTorch uses 'input' instead of 'self' on public api")
if getattr(decomp_param, field) != getattr(op_param, field):
return False
decomp_default = decomp_param.default
op_default = op_param.default
# default value not always correctly inferred as being present on torch schema,
# but if specified on both they should be equal
if decomp_default != inspect_empty and op_default != inspect_empty:
if decomp_default != op_default:
return False
return decomposition_sig.return_annotation == torch_op_sig.return_annotation
def register_decomposition(
aten_op: torch._ops.OpOverload,
registry: Optional[dict[str, torch.jit.ScriptFunction]] = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
nonlocal registry
if registry is None:
registry = decomposition_table
assert isinstance(aten_op, torch._ops.OpOverload)
# Need unique name for jit function serialization
assert (
f.__name__ not in function_name_set
), f"Duplicated function name {f.__name__}"
function_name_set.add(f.__name__)
scripted_func = torch.jit.script(f)
torch._C._jit_pass_inline(scripted_func.graph)
for _ in range(2):
torch._C._jit_pass_peephole(scripted_func.graph)
torch._C._jit_pass_constant_propagation(scripted_func.graph)
registry[str(aten_op._schema)] = scripted_func
return f
return decomposition_decorator
# TODO: replace torch.sigmoid -> aten.sigmoid
@register_decomposition(aten.var.correction)
def var_decomposition(
input: Tensor,
dim: Optional[list[int]] = None,
correction: Optional[Number] = None,
keepdim: bool = False,
) -> Tensor:
if dim is None:
dim_i: list[int] = []
dim = dim_i
if isinstance(dim, (tuple, list)) and len(dim) == 0:
n = input.numel()
else:
n = 1
for dim_i in dim: # type: ignore[assignment]
n *= input.shape[dim_i] # type: ignore[call-overload]
mean = aten.mean(input, dim, True)
sub = input - mean
sq = sub * sub
sum = aten.sum(sq, dim, keepdim)
if correction is None:
denom = float(n - 1)
else:
if isinstance(correction, int):
denom = float(n - correction)
elif isinstance(correction, float):
denom = float(n) - correction
else:
raise RuntimeError("correction must be int or float")
return sum / max(0, denom)
@register_decomposition(aten.var.default)
def var(input: Tensor, unbiased: bool = True) -> Tensor:
return var_decomposition(input, correction=(1 if unbiased else 0))
|