File size: 5,422 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations

from collections import defaultdict, namedtuple
from typing import Any

import yaml

from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader, parse_native_yaml
from torchgen.model import (
    BackendMetadata,
    DispatchKey,
    FunctionSchema,
    NativeFunction,
    OperatorName,
)
from torchgen.utils import NamespaceHelper


# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])

# Fields in native_functions.yaml used to determine which kernels should be used
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]


def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
    """Given a loaded yaml representing kernel assignment information, extract the
    mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)

    Args:
        ei: Dict keys {kernels, type_alias, dim_order_alias}
            See ETKernelKey for description of arguments
    """
    e = ei.copy()
    if (kernels := e.pop("kernels", None)) is None:
        return {}

    type_alias: dict[str, list[str]] = e.pop("type_alias", {})  # type: ignore[assignment]
    dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {})  # type: ignore[assignment]
    dim_order_alias.pop("__line__", None)

    kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}

    for entry in kernels:  # type: ignore[attr-defined]
        arg_meta = entry.get("arg_meta")
        if arg_meta is not None:
            arg_meta.pop("__line__")

        kernel_name = entry.get("kernel_name")
        namespace_helper = NamespaceHelper.from_namespaced_entity(
            kernel_name, max_level=3
        )
        kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
        backend_metadata = BackendMetadata(
            kernel=namespace_helper.entity_name,
            structured=False,
            cpp_namespace=(kernel_namespace + "::native"),
        )

        kernel_keys = (
            [ETKernelKey((), default=True)]
            if arg_meta is None
            else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias)  # type: ignore[arg-type]
        )

        for kernel_key in kernel_keys:
            assert kernel_key not in kernel_mapping, (
                "Duplicate kernel key: " + str(kernel_key) + " " + str(e)
            )
            kernel_mapping[kernel_key] = backend_metadata

    return kernel_mapping


def parse_et_yaml_struct(es: object) -> ETKernelIndex:
    """Given a loaded yaml representing a list of operators, for each op extract the mapping
    of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
    that should be used by the kernel key).
    """
    indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
    for ei in es:  # type: ignore[attr-defined]
        e = ei.copy()

        funcs = e.pop("func")
        assert isinstance(funcs, str), f"not a str: {funcs}"
        namespace_helper = NamespaceHelper.from_namespaced_entity(
            namespaced_entity=funcs, max_level=1
        )
        opname = FunctionSchema.parse(namespace_helper.entity_name).name

        assert opname not in indices, f"Duplicate func found in yaml: {opname} already"

        if len(index := parse_from_yaml(e)) != 0:
            indices[opname] = index

    return ETKernelIndex(indices)


def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
    """Given a loaded yaml representing a list of operators, extract the
    kernel key related fields indexed by the operator name.
    """
    fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
    for ei in es:  # type: ignore[attr-defined]
        funcs = ei.get("func")
        assert isinstance(funcs, str), f"not a str: {funcs}"
        namespace_helper = NamespaceHelper.from_namespaced_entity(
            namespaced_entity=funcs, max_level=1
        )
        opname = FunctionSchema.parse(namespace_helper.entity_name).name

        for field in ET_FIELDS:
            if (value := ei.get(field)) is not None:
                fields[opname][field] = value

    return fields


def parse_et_yaml(
    path: str,
    tags_yaml_path: str,
    ignore_keys: set[DispatchKey] | None = None,
    skip_native_fns_gen: bool = False,
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
    """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
    of fields to persist from native_functions.yaml to functions.yaml
    """
    with open(path) as f:
        es = yaml.load(f, Loader=LineLoader)

    et_kernel = extract_kernel_fields(es)

    # Remove ET specific fields from entries for BC compatibility
    strip_et_fields(es)

    native_yaml = parse_native_yaml(
        path,
        tags_yaml_path,
        ignore_keys,
        skip_native_fns_gen=skip_native_fns_gen,
        loaded_yaml=es,
    )
    return native_yaml.native_functions, et_kernel


def strip_et_fields(es: object) -> None:
    """Given a loaded yaml representing a list of operators,
    remove ET specific fields from every entries for BC compatibility
    """
    for entry in es:  # type: ignore[attr-defined]
        for field in ET_FIELDS:
            entry.pop(field, None)