File size: 7,663 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# Represents all kernels used by an Executorch model.
# It maintains a dict[OperatorName, dict[ETKernelKey, BackendMetadata]] structure.
from __future__ import annotations
import itertools
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from enum import IntEnum
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
)
from torchgen.utils import assert_never
KERNEL_KEY_VERSION = 1
# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
class ScalarType(IntEnum):
Byte = 0
Char = 1
Short = 2
Int = 3
Long = 4
Float = 6
Double = 7
Bool = 11
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
@dataclass(frozen=True)
class ETKernelKeyOpArgMeta:
arg_name: str
dtype: str
# The order of the dimensions if entry is a Tensor
dim_order: tuple[int, ...]
def to_native_string(self) -> str:
dtype_str = ScalarType[self.dtype].value
dim_str = str(self.dim_order)[1:-1].replace(" ", "")
return f"{dtype_str};{dim_str}"
@dataclass(frozen=True)
class ETKernelKey:
# Field undefined is default = True
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
# Indicator for this kernel being used as a catch all
default: bool = False
version: int = KERNEL_KEY_VERSION
@staticmethod
def gen_from_yaml(
args: dict[str, tuple[str, str]],
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
dim_order_alias_map: dict[str, list[int]],
) -> list[ETKernelKey]:
"""Generate ETKernelKeys from arg kernel specs
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
type_alias_map (actualizing each potential type permutation as a KernelKey)
Args:
args: Mapping from argument name to kernel specs
Kernel specs are a tuple of (dtype, dim_order).
Currently tuple entries must be aliased via the alias map arguments
type_alias_map: Mapping from type alias to potential type enums
i.e { T0 : [Double, Int] } means T0 can be either Double or Int
Used for lookup by args
dim_order_alias_map: Mapping from alias to a list of dimension orders
Used for lookup by args
"""
# Cast to dim order to int
dim_order_alias_map = {
k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
}
kernel_keys = []
# Get all used Dtype Alias
dtype_alias_used = set()
for type_alias, dim_order in args.values():
# Enforce usage of alias initially
# TODO: Support inlined arguments
assert type_alias in type_alias_map, "Undefined type alias: " + str(
type_alias
)
assert dim_order in dim_order_alias_map, (
f"Undefined dim_order alias: {dim_order}"
)
dtype_alias_used.add(type_alias)
# Generate all permutations of dtype alias values
alias_dtypes = [
[(alias, dtype) for dtype in type_alias_map[alias]]
for alias in dtype_alias_used
]
alias_permutations = [
dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
]
# Using each alias value permutation, generate kernel keys
op_arg_cache = {}
for permutation in alias_permutations:
arg_list = []
for arg_name, arg_spec in args.items():
dtype = permutation[arg_spec[0]]
dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
if (
cache_key := (arg_name, dtype, tuple(dim_order))
) not in op_arg_cache:
op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
arg_list.append(op_arg_cache[cache_key])
kernel_keys.append(ETKernelKey(tuple(arg_list)))
return kernel_keys
def to_native_string(self) -> str:
if self.default:
return "default"
return (
"v"
+ str(KERNEL_KEY_VERSION)
+ "/"
+ "|".join([arg.to_native_string() for arg in self.arg_meta])
)
@dataclass(frozen=True)
class ETKernelIndex:
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
m = self.get_kernels(g)
return m is not None
def get_kernels(
self, g: NativeFunction | NativeFunctionsGroup
) -> dict[ETKernelKey, BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
f = g.functional
else:
assert_never(g)
if f.func.name not in self.index:
return {}
return self.index[f.func.name]
@staticmethod
def grow_from_backend_indices(
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> None:
for dk in backend_indices:
index = backend_indices[dk]
for op, backend_metadata in index.items():
if op in kernel_index:
kernel_index[op][ETKernelKey(default=True)] = backend_metadata
else:
kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
@staticmethod
def from_backend_indices(
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> ETKernelIndex:
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = (
defaultdict(dict)
)
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
return ETKernelIndex(kernel_index)
def grow(
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> ETKernelIndex:
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
return self
def _to_backend_index(self) -> BackendIndex:
"""
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
"""
index: dict[OperatorName, BackendMetadata] = {}
for op in self.index:
kernel_dict = self.index[op]
assert len(kernel_dict.values()) == 1, (
f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
)
index[op] = kernel_dict.get(
ETKernelKey(default=True),
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
)
return BackendIndex(
dispatch_key=DispatchKey.CPU,
use_out_as_primary=False,
device_guard=False,
external=False,
index=index,
)
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
@staticmethod
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
combined = defaultdict(dict, index_a.index.copy())
for op, entry in index_b.index.items():
for key, metadata in entry.items():
combined[op][key] = metadata
return ETKernelIndex(combined)
|