|
from __future__ import annotations |
|
|
|
from collections import defaultdict, namedtuple |
|
from typing import Any |
|
|
|
import yaml |
|
|
|
from torchgen.executorch.model import ETKernelIndex, ETKernelKey |
|
from torchgen.gen import LineLoader, parse_native_yaml |
|
from torchgen.model import ( |
|
BackendMetadata, |
|
DispatchKey, |
|
FunctionSchema, |
|
NativeFunction, |
|
OperatorName, |
|
) |
|
from torchgen.utils import NamespaceHelper |
|
|
|
|
|
|
|
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"]) |
|
|
|
|
|
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] |
|
|
|
|
|
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]: |
|
"""Given a loaded yaml representing kernel assignment information, extract the |
|
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) |
|
|
|
Args: |
|
ei: Dict keys {kernels, type_alias, dim_order_alias} |
|
See ETKernelKey for description of arguments |
|
""" |
|
e = ei.copy() |
|
if (kernels := e.pop("kernels", None)) is None: |
|
return {} |
|
|
|
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) |
|
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) |
|
dim_order_alias.pop("__line__", None) |
|
|
|
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {} |
|
|
|
for entry in kernels: |
|
arg_meta = entry.get("arg_meta") |
|
if arg_meta is not None: |
|
arg_meta.pop("__line__") |
|
|
|
kernel_name = entry.get("kernel_name") |
|
namespace_helper = NamespaceHelper.from_namespaced_entity( |
|
kernel_name, max_level=3 |
|
) |
|
kernel_namespace = namespace_helper.get_cpp_namespace(default="at") |
|
backend_metadata = BackendMetadata( |
|
kernel=namespace_helper.entity_name, |
|
structured=False, |
|
cpp_namespace=(kernel_namespace + "::native"), |
|
) |
|
|
|
kernel_keys = ( |
|
[ETKernelKey((), default=True)] |
|
if arg_meta is None |
|
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) |
|
) |
|
|
|
for kernel_key in kernel_keys: |
|
assert kernel_key not in kernel_mapping, ( |
|
"Duplicate kernel key: " + str(kernel_key) + " " + str(e) |
|
) |
|
kernel_mapping[kernel_key] = backend_metadata |
|
|
|
return kernel_mapping |
|
|
|
|
|
def parse_et_yaml_struct(es: object) -> ETKernelIndex: |
|
"""Given a loaded yaml representing a list of operators, for each op extract the mapping |
|
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance |
|
that should be used by the kernel key). |
|
""" |
|
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {} |
|
for ei in es: |
|
e = ei.copy() |
|
|
|
funcs = e.pop("func") |
|
assert isinstance(funcs, str), f"not a str: {funcs}" |
|
namespace_helper = NamespaceHelper.from_namespaced_entity( |
|
namespaced_entity=funcs, max_level=1 |
|
) |
|
opname = FunctionSchema.parse(namespace_helper.entity_name).name |
|
|
|
assert opname not in indices, f"Duplicate func found in yaml: {opname} already" |
|
|
|
if len(index := parse_from_yaml(e)) != 0: |
|
indices[opname] = index |
|
|
|
return ETKernelIndex(indices) |
|
|
|
|
|
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]: |
|
"""Given a loaded yaml representing a list of operators, extract the |
|
kernel key related fields indexed by the operator name. |
|
""" |
|
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict) |
|
for ei in es: |
|
funcs = ei.get("func") |
|
assert isinstance(funcs, str), f"not a str: {funcs}" |
|
namespace_helper = NamespaceHelper.from_namespaced_entity( |
|
namespaced_entity=funcs, max_level=1 |
|
) |
|
opname = FunctionSchema.parse(namespace_helper.entity_name).name |
|
|
|
for field in ET_FIELDS: |
|
if (value := ei.get(field)) is not None: |
|
fields[opname][field] = value |
|
|
|
return fields |
|
|
|
|
|
def parse_et_yaml( |
|
path: str, |
|
tags_yaml_path: str, |
|
ignore_keys: set[DispatchKey] | None = None, |
|
skip_native_fns_gen: bool = False, |
|
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]: |
|
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict |
|
of fields to persist from native_functions.yaml to functions.yaml |
|
""" |
|
with open(path) as f: |
|
es = yaml.load(f, Loader=LineLoader) |
|
|
|
et_kernel = extract_kernel_fields(es) |
|
|
|
|
|
strip_et_fields(es) |
|
|
|
native_yaml = parse_native_yaml( |
|
path, |
|
tags_yaml_path, |
|
ignore_keys, |
|
skip_native_fns_gen=skip_native_fns_gen, |
|
loaded_yaml=es, |
|
) |
|
return native_yaml.native_functions, et_kernel |
|
|
|
|
|
def strip_et_fields(es: object) -> None: |
|
"""Given a loaded yaml representing a list of operators, |
|
remove ET specific fields from every entries for BC compatibility |
|
""" |
|
for entry in es: |
|
for field in ET_FIELDS: |
|
entry.pop(field, None) |
|
|