|
from __future__ import annotations |
|
|
|
import argparse |
|
import os |
|
from collections import defaultdict |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Any, Callable, TextIO, TYPE_CHECKING |
|
|
|
import yaml |
|
|
|
|
|
from torchgen import dest |
|
from torchgen.api import cpp as aten_cpp |
|
from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType |
|
from torchgen.context import ( |
|
method_with_native_function, |
|
method_with_nested_native_function, |
|
with_native_function_and_index, |
|
) |
|
from torchgen.executorch.api import et_cpp |
|
from torchgen.executorch.api.custom_ops import ( |
|
ComputeNativeFunctionStub, |
|
gen_custom_ops_registration, |
|
) |
|
from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature |
|
from torchgen.executorch.api.unboxing import Unboxing |
|
from torchgen.executorch.model import ETKernelIndex, ETKernelKey, ETParsedYaml |
|
from torchgen.executorch.parse import ET_FIELDS, parse_et_yaml, parse_et_yaml_struct |
|
from torchgen.gen import ( |
|
get_custom_build_selector, |
|
get_native_function_declarations, |
|
get_native_function_declarations_from_ns_grouped_kernels, |
|
get_native_function_schema_registrations, |
|
LineLoader, |
|
parse_native_yaml, |
|
) |
|
from torchgen.model import ( |
|
BackendIndex, |
|
BackendMetadata, |
|
DEFAULT_KERNEL_NAMESPACE, |
|
DispatchKey, |
|
FunctionSchema, |
|
Location, |
|
NativeFunction, |
|
NativeFunctionsGroup, |
|
OperatorName, |
|
Variant, |
|
) |
|
from torchgen.utils import ( |
|
context, |
|
FileManager, |
|
make_file_manager, |
|
mapMaybe, |
|
NamespaceHelper, |
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Sequence |
|
|
|
from torchgen.selective_build.selector import SelectiveBuilder |
|
|
|
|
|
def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str: |
|
""" |
|
A wrapper function to basically get `sig.decl(include_context=True)`. |
|
For ATen kernel, the codegen has no idea about ET contextArg, so we |
|
use this wrapper to add it. |
|
""" |
|
if isinstance(sig, ExecutorchCppSignature): |
|
return sig.decl() |
|
|
|
returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type() |
|
cpp_args = [a.decl() for a in sig.arguments()] |
|
cpp_args_str = ", ".join([contextArg.decl()] + cpp_args) |
|
sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})" |
|
return sig_decl |
|
|
|
|
|
def static_dispatch( |
|
sig: CppSignature | ExecutorchCppSignature, |
|
f: NativeFunction, |
|
backend_indices: list[BackendIndex], |
|
) -> str: |
|
""" |
|
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one |
|
native function exists, error out. A simplified version of register_dispatch_key.py |
|
Arguments: |
|
sig: A CppSignature for this native function we want to use. |
|
f: NativeFunction to generate static dispatch. |
|
backend_indices: All available backends. |
|
Return: |
|
C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);" |
|
""" |
|
if len(backend_indices) == 0 or f.manual_kernel_registration: |
|
return "" |
|
|
|
backends = [b for b in backend_indices if b.has_kernel(f)] |
|
static_block = None |
|
if len(backends) == 1: |
|
backend_metadata = backends[0].get_kernel(f) |
|
if backend_metadata: |
|
args = ", ".join(a.name for a in sig.arguments()) |
|
|
|
static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});" |
|
else: |
|
static_block = f""" |
|
ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}."); |
|
""" |
|
return f""" |
|
// {f.namespace}::{f.func} |
|
TORCH_API inline {_sig_decl_wrapper(sig)} {{ |
|
{static_block} |
|
}} |
|
""" |
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class ComputeFunction: |
|
static_dispatch_backend_indices: list[BackendIndex] |
|
|
|
selector: SelectiveBuilder |
|
|
|
use_aten_lib: bool |
|
|
|
is_custom_op: Callable[[NativeFunction], bool] |
|
|
|
@method_with_native_function |
|
def __call__(self, f: NativeFunction) -> str | None: |
|
is_method_variant = False |
|
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"): |
|
return None |
|
|
|
if Variant.function not in f.variants and Variant.method in f.variants: |
|
is_method_variant = True |
|
|
|
|
|
elif not (Variant.function in f.variants and Variant.method not in f.variants): |
|
raise Exception( |
|
f"Can't handle native function {f.func} with the following variant specification {f.variants}." |
|
) |
|
|
|
sig: CppSignature | ExecutorchCppSignature = ( |
|
CppSignatureGroup.from_native_function( |
|
f, method=False, fallback_binding=f.manual_cpp_binding |
|
).most_faithful_signature() |
|
if self.use_aten_lib |
|
else ExecutorchCppSignature.from_native_function(f) |
|
) |
|
if self.use_aten_lib and not self.is_custom_op(f): |
|
comma = ", " |
|
|
|
if is_method_variant: |
|
return f""" |
|
// {f.namespace}::{f.func} |
|
TORCH_API inline {_sig_decl_wrapper(sig)} {{ |
|
return {sig.arguments()[0].name}.{sig.name()}({comma.join(e.name for e in sig.arguments()[1:])}); |
|
}} |
|
""" |
|
else: |
|
return f""" |
|
// {f.namespace}::{f.func} |
|
TORCH_API inline {_sig_decl_wrapper(sig)} {{ |
|
return at::{sig.name()}({comma.join(e.name for e in sig.arguments())}); |
|
}} |
|
""" |
|
|
|
else: |
|
return static_dispatch( |
|
sig, |
|
f, |
|
backend_indices=self.static_dispatch_backend_indices, |
|
) |
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class ComputeCodegenUnboxedKernels: |
|
selector: SelectiveBuilder |
|
|
|
use_aten_lib: bool |
|
|
|
add_exception_boundary: bool |
|
|
|
@method_with_nested_native_function |
|
def __call__( |
|
self, |
|
unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], |
|
) -> str: |
|
f: NativeFunction = unbox_kernel_entry[0] |
|
kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0] |
|
kernel_meta: BackendMetadata = unbox_kernel_entry[1][1] |
|
|
|
op_name = f"{f.namespace}::{f.func.name}" |
|
if not self.selector.is_root_operator(op_name): |
|
return "" |
|
|
|
if not isinstance(kernel_key, list): |
|
kernel_key = [kernel_key] |
|
used_kernel_keys = self.selector.et_get_selected_kernels( |
|
op_name, [k.to_native_string() for k in kernel_key] |
|
) |
|
if not used_kernel_keys: |
|
return "" |
|
sig: CppSignature | ExecutorchCppSignature |
|
argument_type_gen: Callable[..., NamedCType] |
|
return_type_gen: Callable[..., CType] |
|
if self.use_aten_lib: |
|
sig = CppSignatureGroup.from_native_function( |
|
f, method=False, fallback_binding=f.manual_cpp_binding |
|
).most_faithful_signature() |
|
argument_type_gen = aten_cpp.argumenttype_type |
|
return_type_gen = aten_cpp.returns_type |
|
arguments = sig.arguments() |
|
kernel_call = f"torch::executor::{f.namespace}::{sig.name()}" |
|
else: |
|
sig = ExecutorchCppSignature.from_native_function(f) |
|
argument_type_gen = et_cpp.argumenttype_type |
|
return_type_gen = et_cpp.returns_type |
|
arguments = sig.arguments(include_context=False) |
|
kernel_call = f"{kernel_meta.cpp_namespace}::{kernel_meta.kernel}" |
|
|
|
binding_list, code_list = Unboxing( |
|
argument_type_gen=argument_type_gen |
|
).convert_arguments(arguments) |
|
|
|
|
|
code_connector = "\n\t" |
|
arg_connector = ", " |
|
|
|
args_str = f"{arg_connector.join(e.name for e in binding_list)}" |
|
event_tracer_output_logging = "" |
|
output_ids = [] |
|
|
|
if len(f.func.returns) == 0: |
|
if len(f.func.arguments.out) == 0: |
|
raise Exception( |
|
f"Can't handle native function {f.func} with no returns and no out yet." |
|
) |
|
out = f.func.arguments.out[0] |
|
return_assignment = f"""stack[{len(binding_list)}] = &{out.name};""" |
|
ret_prefix = "" |
|
output_ids = [len(binding_list)] |
|
else: |
|
if len(f.func.arguments.out) == 0: |
|
return_assignment = ( |
|
f"""*stack[{len(binding_list)}] = EValue(result_);""" |
|
) |
|
ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = " |
|
output_ids = [len(binding_list)] |
|
else: |
|
return_assignment = "" |
|
ret_prefix = "" |
|
output_ids = [ |
|
len(binding_list) - (i + 1) |
|
for i in reversed(range(len(f.func.arguments.out))) |
|
] |
|
|
|
for output_id in output_ids: |
|
event_tracer_output_logging += ( |
|
f"internal::event_tracer_log_evalue(" |
|
f"context.internal_event_tracer(), " |
|
f"*stack[{output_id}]);\n" |
|
) |
|
|
|
exception_boundary_begin = "" |
|
exception_boundary_end = "" |
|
if self.add_exception_boundary: |
|
indent = " " * 8 |
|
exception_boundary_begin = indent + "try {" |
|
exception_boundary_end = f"""{indent}}} catch (const std::exception& ex) {{ |
|
{indent} ET_LOG(Error, "Kernel threw an exception: %s", ex.what()); |
|
{indent} context.fail(torch::executor::Error::Internal); |
|
{indent}}}""" |
|
newline = "\n " |
|
return "\n".join( |
|
[ |
|
f""" |
|
Kernel( |
|
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != "default" else ""} |
|
[]({contextArg.defn()}, EValue** stack) {{ |
|
{code_connector.join(code_list)} |
|
|
|
{exception_boundary_begin} |
|
internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_{f.func.name}"); |
|
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}"); |
|
{ret_prefix}{kernel_call}(context, {args_str}); |
|
{event_tracer_output_logging} |
|
{return_assignment} |
|
{exception_boundary_end} |
|
}} |
|
), |
|
""" |
|
for k in used_kernel_keys |
|
] |
|
) |
|
|
|
|
|
def gen_unboxing( |
|
*, |
|
native_functions: Sequence[NativeFunction], |
|
cpu_fm: FileManager, |
|
selector: SelectiveBuilder, |
|
use_aten_lib: bool, |
|
kernel_index: ETKernelIndex, |
|
manual_registration: bool, |
|
add_exception_boundary: bool = False, |
|
) -> None: |
|
|
|
def key_func( |
|
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], |
|
) -> str: |
|
return item[0].root_name + ":" + item[1][0].to_native_string() |
|
|
|
items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [ |
|
(native_function, (kernel_key, metadata)) |
|
for native_function in native_functions |
|
for kernel_key, metadata in kernel_index.get_kernels(native_function).items() |
|
] |
|
|
|
header = ["Functions.h" if use_aten_lib else "NativeFunctions.h"] |
|
filename = ( |
|
"RegisterKernels.cpp" |
|
if manual_registration |
|
else "RegisterCodegenUnboxedKernels.cpp" |
|
) |
|
cpu_fm.write_sharded( |
|
filename, |
|
items, |
|
key_fn=key_func, |
|
env_callable=lambda unbox_kernel_entry: { |
|
"unboxed_kernels": [ |
|
ComputeCodegenUnboxedKernels( |
|
selector, use_aten_lib, add_exception_boundary |
|
)(unbox_kernel_entry) |
|
], |
|
"fn_header": header |
|
if unbox_kernel_entry == items[0] |
|
else [], |
|
}, |
|
num_shards=1, |
|
sharded_keys={"unboxed_kernels", "fn_header"}, |
|
) |
|
|
|
|
|
@with_native_function_and_index |
|
def compute_native_function_declaration( |
|
g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex |
|
) -> list[str]: |
|
assert isinstance(g, NativeFunction) |
|
sig = ExecutorchCppSignature.from_native_function(f=g) |
|
metadata_list = kernel_index.get_kernels(g).values() |
|
if metadata_list is None: |
|
return [] |
|
|
|
|
|
|
|
def gen_decl(metadata: BackendMetadata, include_context: bool) -> str: |
|
return f"{sig.decl(name=metadata.kernel, include_context=include_context)};" |
|
|
|
return [ |
|
gen_decl(metadata, include_context) |
|
for include_context in [False, True] |
|
for metadata in metadata_list |
|
] |
|
|
|
|
|
def gen_functions_declarations( |
|
*, |
|
native_functions: Sequence[NativeFunction], |
|
kernel_index: ETKernelIndex, |
|
selector: SelectiveBuilder, |
|
use_aten_lib: bool, |
|
custom_ops_native_functions: Sequence[NativeFunction] | None = None, |
|
) -> str: |
|
""" |
|
Generates namespace separated C++ function API inline declaration/definitions. |
|
Native functions are grouped by namespaces and the generated code is wrapped inside |
|
namespace blocks. |
|
|
|
E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol |
|
in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when |
|
the other `custom_2::foo.out` is available. |
|
""" |
|
|
|
|
|
|
|
|
|
backend_index = kernel_index._to_backend_index() |
|
|
|
ns_grouped_functions = defaultdict(list) |
|
for native_function in native_functions: |
|
ns_grouped_functions[native_function.namespace].append(native_function) |
|
functions_declarations = "" |
|
newline = "\n" |
|
for namespace in ns_grouped_functions: |
|
ns_helper = NamespaceHelper( |
|
namespace_str=namespace, |
|
entity_name="", |
|
max_level=3, |
|
) |
|
declarations = list( |
|
mapMaybe( |
|
ComputeFunction( |
|
static_dispatch_backend_indices=[backend_index], |
|
selector=selector, |
|
use_aten_lib=use_aten_lib, |
|
is_custom_op=lambda f: custom_ops_native_functions is not None |
|
and f in custom_ops_native_functions, |
|
), |
|
ns_grouped_functions[namespace], |
|
) |
|
) |
|
functions_declarations += f""" |
|
{ns_helper.prologue} |
|
{newline.join(declarations)} |
|
{ns_helper.epilogue} |
|
""" |
|
return functions_declarations |
|
|
|
|
|
def get_ns_grouped_kernels( |
|
*, |
|
native_functions: Sequence[NativeFunction], |
|
kernel_index: ETKernelIndex, |
|
native_function_decl_gen: Callable[ |
|
[ |
|
NativeFunctionsGroup | NativeFunction, |
|
ETKernelIndex, |
|
], |
|
list[str], |
|
], |
|
) -> dict[str, list[str]]: |
|
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) |
|
for f in native_functions: |
|
native_function_namespaces = set() |
|
op_kernels = kernel_index.get_kernels(f) |
|
for backend_metadata in op_kernels.values(): |
|
if backend_metadata: |
|
namespace = backend_metadata.cpp_namespace |
|
native_function_namespaces.add(namespace) |
|
else: |
|
namespace = DEFAULT_KERNEL_NAMESPACE |
|
assert len(native_function_namespaces) <= 1, ( |
|
f"Codegen only supports one namespace per operator, got {native_function_namespaces}" |
|
) |
|
ns_grouped_kernels[namespace].extend( |
|
native_function_decl_gen(f, kernel_index) |
|
) |
|
return ns_grouped_kernels |
|
|
|
|
|
def gen_headers( |
|
*, |
|
native_functions: Sequence[NativeFunction], |
|
gen_custom_ops_header: bool, |
|
custom_ops_native_functions: Sequence[NativeFunction], |
|
selector: SelectiveBuilder, |
|
kernel_index: ETKernelIndex, |
|
cpu_fm: FileManager, |
|
use_aten_lib: bool, |
|
) -> None: |
|
"""Generate headers. |
|
|
|
Args: |
|
native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops. |
|
gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h |
|
custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops. |
|
kernel_index (ETKernelIndex): kernel collection |
|
cpu_fm (FileManager): file manager manages output stream |
|
use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types. |
|
""" |
|
aten_headers = ["#include <ATen/Functions.h>"] |
|
backend_indices = {DispatchKey.CPU: kernel_index._to_backend_index()} |
|
if gen_custom_ops_header: |
|
cpu_fm.write_with_template( |
|
"CustomOpsNativeFunctions.h", |
|
"NativeFunctions.h", |
|
lambda: { |
|
"nativeFunctions_declarations": get_native_function_declarations( |
|
grouped_native_functions=custom_ops_native_functions, |
|
backend_indices=backend_indices, |
|
native_function_decl_gen=dest.compute_native_function_declaration, |
|
), |
|
"headers": [ |
|
"#include <ATen/ATen.h>", |
|
"#include <torch/torch.h>", |
|
], |
|
}, |
|
) |
|
aten_headers.append('#include "CustomOpsNativeFunctions.h"') |
|
cpu_fm.write( |
|
"Functions.h", |
|
lambda: { |
|
"static_dispatch_extra_headers": aten_headers |
|
if use_aten_lib |
|
else ['#include "NativeFunctions.h"'], |
|
"Functions_declarations": gen_functions_declarations( |
|
native_functions=native_functions, |
|
kernel_index=kernel_index, |
|
selector=selector, |
|
use_aten_lib=use_aten_lib, |
|
custom_ops_native_functions=custom_ops_native_functions, |
|
), |
|
}, |
|
) |
|
cpu_fm.write( |
|
"RegisterKernels.h", |
|
lambda: { |
|
"generated_comment": "@" + "generated by torchgen/gen_executorch.py", |
|
}, |
|
) |
|
headers = { |
|
"headers": [ |
|
"#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.", |
|
"#include <executorch/runtime/kernel/kernel_runtime_context.h>", |
|
], |
|
} |
|
if use_aten_lib: |
|
headers["headers"].append("#include <executorch/codegen/macros.h> // TORCH_API") |
|
cpu_fm.write( |
|
"NativeFunctions.h", |
|
lambda: dict( |
|
{ |
|
"nativeFunctions_declarations": get_native_function_declarations( |
|
grouped_native_functions=native_functions, |
|
backend_indices=backend_indices, |
|
native_function_decl_gen=dest.compute_native_function_declaration, |
|
), |
|
}, |
|
**headers, |
|
), |
|
) |
|
else: |
|
ns_grouped_kernels = get_ns_grouped_kernels( |
|
native_functions=native_functions, |
|
kernel_index=kernel_index, |
|
native_function_decl_gen=compute_native_function_declaration, |
|
) |
|
cpu_fm.write( |
|
"NativeFunctions.h", |
|
lambda: dict( |
|
{ |
|
"nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( |
|
ns_grouped_kernels=ns_grouped_kernels, |
|
), |
|
}, |
|
**headers, |
|
), |
|
) |
|
|
|
|
|
def gen_custom_ops( |
|
*, |
|
native_functions: Sequence[NativeFunction], |
|
selector: SelectiveBuilder, |
|
kernel_index: ETKernelIndex, |
|
cpu_fm: FileManager, |
|
rocm: bool, |
|
) -> None: |
|
dispatch_key = DispatchKey.CPU |
|
( |
|
anonymous_definition, |
|
static_init_dispatch_registrations, |
|
) = gen_custom_ops_registration( |
|
native_functions=native_functions, |
|
selector=selector, |
|
kernel_index=kernel_index, |
|
rocm=rocm, |
|
) |
|
cpu_fm.write_with_template( |
|
f"Register{dispatch_key}CustomOps.cpp", |
|
"RegisterDispatchKeyCustomOps.cpp", |
|
lambda: { |
|
"ops_headers": '#include "CustomOpsNativeFunctions.h"', |
|
"DispatchKey": dispatch_key, |
|
"dispatch_namespace": dispatch_key.lower(), |
|
"dispatch_namespaced_definitions": "", |
|
"dispatch_anonymous_definitions": anonymous_definition, |
|
"static_init_dispatch_registrations": static_init_dispatch_registrations, |
|
}, |
|
) |
|
cpu_fm.write_with_template( |
|
f"Register{dispatch_key}Stub.cpp", |
|
"RegisterDispatchKeyCustomOps.cpp", |
|
lambda: { |
|
"ops_headers": "", |
|
"DispatchKey": dispatch_key, |
|
"dispatch_namespace": dispatch_key.lower(), |
|
"dispatch_namespaced_definitions": "", |
|
"dispatch_anonymous_definitions": list( |
|
mapMaybe(ComputeNativeFunctionStub(), native_functions) |
|
), |
|
"static_init_dispatch_registrations": static_init_dispatch_registrations, |
|
}, |
|
) |
|
|
|
( |
|
aten_schema_registrations, |
|
schema_registrations, |
|
) = get_native_function_schema_registrations( |
|
native_functions=native_functions, |
|
schema_selector=selector, |
|
) |
|
cpu_fm.write( |
|
"RegisterSchema.cpp", |
|
lambda: { |
|
"schema_registrations": schema_registrations, |
|
"aten_schema_registrations": aten_schema_registrations, |
|
}, |
|
) |
|
|
|
|
|
def translate_native_yaml( |
|
tags_yaml_path: str, |
|
aten_yaml_path: str, |
|
native_yaml_path: str | None, |
|
use_aten_lib: bool, |
|
out_file: TextIO, |
|
) -> None: |
|
"""Translates Executorch DSL dialect to use the same syntax as |
|
native_functions.yaml. The major difference is that Executorch DSL dialect |
|
supports "op" key, where it refers to the operator name in native_functions.yaml. |
|
|
|
For example, a functions.yaml may have the following entry: |
|
|
|
- op: add.out |
|
... |
|
|
|
It needs to be translated to the following: |
|
|
|
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) |
|
... |
|
|
|
We go in aten_yaml_path and find the operator schema for "add.out" and add it |
|
to the original functions.yaml. We also add required field "variants", where for |
|
Executorch it will always be "function". |
|
|
|
For ATen mode we don't have to do the translation because native_yaml_path is |
|
the same as native_functions.yaml. |
|
|
|
Args: |
|
tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. |
|
It is not optional. |
|
aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. |
|
native_yaml_path: Path to a functions.yaml file to parse. |
|
If the path does not exist in the filesystem, it is treated as an |
|
empty file. If `custom_ops_yaml_path` exists, the contents of that |
|
file are appended to the yaml input to be parsed. |
|
use_aten_lib: We use this flag to determine if we want to generate native |
|
functions. In ATen mode we should generate out= variants. |
|
out_file: The IO object that we are writing into. |
|
Returns: |
|
None |
|
""" |
|
if use_aten_lib: |
|
with open(aten_yaml_path) as aten_yaml: |
|
out_file.writelines(aten_yaml.readlines()) |
|
return |
|
|
|
native_functions, persisted_fields = parse_et_yaml( |
|
aten_yaml_path, |
|
tags_yaml_path, |
|
None, |
|
skip_native_fns_gen=False, |
|
) |
|
|
|
func_to_scoped_name: dict[FunctionSchema, str] = { |
|
f.func: f"{f.namespace}::{f.func.name}" for f in native_functions |
|
} |
|
op_to_scoped_name: dict[OperatorName, str] = { |
|
func.name: name for func, name in func_to_scoped_name.items() |
|
} |
|
|
|
schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()} |
|
kernel_persist_dict: dict[str, dict[str, Any]] = { |
|
op_to_scoped_name[op]: v for op, v in persisted_fields.items() |
|
} |
|
|
|
if ( |
|
not native_yaml_path |
|
or not os.path.exists(native_yaml_path) |
|
or os.stat(native_yaml_path).st_size == 0 |
|
): |
|
return |
|
with open(native_yaml_path) as native_yaml: |
|
native_es = yaml.load(native_yaml, Loader=LineLoader) |
|
if not native_es: |
|
return |
|
for e in native_es: |
|
assert isinstance(e.get("__line__"), int), e |
|
loc = Location(native_yaml_path, e.pop("__line__")) |
|
with context(lambda: f"in {loc}:\n "): |
|
if "variants" not in e: |
|
e["variants"] = "function" |
|
if "func" in e: |
|
continue |
|
assert isinstance(e.get("op"), str), e |
|
opname = e.pop("op") |
|
if "::" not in opname: |
|
opname = "aten::" + opname |
|
assert opname in schema_dict |
|
e["func"] = schema_dict.get(opname) |
|
|
|
|
|
if opname in kernel_persist_dict: |
|
for k, v in kernel_persist_dict[opname].items(): |
|
e[k] = v |
|
|
|
yaml.dump(native_es, out_file, width=1000) |
|
|
|
|
|
def parse_yaml( |
|
path: str | None, |
|
tags_yaml_path: str, |
|
function_filter: Callable[[NativeFunction], bool], |
|
skip_native_fns_gen: bool = False, |
|
) -> tuple[ |
|
list[NativeFunction], |
|
dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex, |
|
]: |
|
if path and os.path.exists(path) and os.stat(path).st_size > 0: |
|
with open(path) as f: |
|
es = yaml.load(f, Loader=LineLoader) |
|
|
|
|
|
kernel_index = ( |
|
parse_et_yaml_struct(es) if any("kernels" in e for e in es) else None |
|
) |
|
|
|
|
|
for entry in es: |
|
for field in ET_FIELDS: |
|
entry.pop(field, None) |
|
|
|
parsed_yaml = parse_native_yaml( |
|
path, |
|
tags_yaml_path, |
|
None, |
|
skip_native_fns_gen=skip_native_fns_gen, |
|
loaded_yaml=es, |
|
) |
|
native_functions = list(filter(function_filter, parsed_yaml.native_functions)) |
|
op_names = [f.func.name for f in native_functions] |
|
|
|
|
|
if kernel_index is not None: |
|
filtered_index = { |
|
op_name: kernel_mapping |
|
for op_name, kernel_mapping in kernel_index.index.items() |
|
if op_name in op_names |
|
} |
|
return native_functions, ETKernelIndex(index=filtered_index) |
|
|
|
|
|
def map_index( |
|
m: dict[OperatorName, BackendMetadata], |
|
) -> dict[OperatorName, BackendMetadata]: |
|
return {op: m[op] for op in m if op in op_names} |
|
|
|
backend_indices = { |
|
k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items() |
|
} |
|
|
|
return native_functions, backend_indices |
|
else: |
|
return [], {} |
|
|
|
|
|
def parse_yaml_files( |
|
tags_yaml_path: str, |
|
aten_yaml_path: str, |
|
native_yaml_path: str | None, |
|
custom_ops_yaml_path: str | None, |
|
selector: SelectiveBuilder, |
|
use_aten_lib: bool, |
|
) -> tuple[ETParsedYaml, ETParsedYaml | None]: |
|
"""Parses functions.yaml and custom_ops.yaml files. |
|
|
|
Args: |
|
tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. |
|
It is not optional. |
|
aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. |
|
native_yaml_path: Path to a functions.yaml file to parse. |
|
If the path does not exist in the filesystem, it is treated as an |
|
empty file. If `custom_ops_yaml_path` exists, the contents of that |
|
file are appended to the yaml input to be parsed. |
|
custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If |
|
the path does not exist in the filesystem, it is ignored. |
|
selector: For selective build. |
|
use_aten_lib: We use this flag to determine if we want to generate native |
|
functions. In ATen mode we should generate out= variants. |
|
Returns: |
|
A tuple with two elements: |
|
[0]: The parsed results of concatenating the contents of |
|
`native_yaml_path` and `custom_ops_yaml_path`. |
|
[1]: The parsed results of the contents of `custom_ops_yaml_path`, if |
|
present. If not present, None. |
|
""" |
|
import tempfile |
|
|
|
|
|
def function_filter(f: NativeFunction) -> bool: |
|
return selector.is_native_function_selected(f) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
translated_yaml_path = os.path.join(tmpdirname, "translated.yaml") |
|
with open(translated_yaml_path, "w") as translated: |
|
translate_native_yaml( |
|
tags_yaml_path, |
|
aten_yaml_path, |
|
native_yaml_path, |
|
use_aten_lib, |
|
translated, |
|
) |
|
|
|
translated_functions, translated_indices = parse_yaml( |
|
translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib |
|
) |
|
custom_ops_functions, custom_ops_indices = parse_yaml( |
|
custom_ops_yaml_path, tags_yaml_path, function_filter, True |
|
) |
|
|
|
|
|
if not isinstance(translated_indices, ETKernelIndex): |
|
translated_indices = ETKernelIndex.from_backend_indices(translated_indices) |
|
if not isinstance(custom_ops_indices, ETKernelIndex): |
|
custom_ops_indices = ETKernelIndex.from_backend_indices(custom_ops_indices) |
|
|
|
combined_functions = translated_functions + custom_ops_functions |
|
combined_kernel_index = ETKernelIndex.merge_indices( |
|
translated_indices, custom_ops_indices |
|
) |
|
combined_yaml = ETParsedYaml(combined_functions, combined_kernel_index) |
|
custom_ops_parsed_yaml = ETParsedYaml(custom_ops_functions, custom_ops_indices) |
|
|
|
return combined_yaml, custom_ops_parsed_yaml |
|
|
|
|
|
def main() -> None: |
|
parser = argparse.ArgumentParser(description="Generate operator source files") |
|
|
|
|
|
|
|
parser.add_argument( |
|
"-s", |
|
"--source-path", |
|
help="path to source directory for kernel templates", |
|
) |
|
parser.add_argument( |
|
"--functions-yaml-path", |
|
"--functions_yaml_path", |
|
help="path to the functions.yaml file to use. Optional, but at least " |
|
"one of --functions-yaml-path and --custom-ops-yaml-path must be " |
|
"specified.", |
|
) |
|
parser.add_argument( |
|
"--custom-ops-yaml-path", |
|
"--custom_ops_yaml_path", |
|
help="path to the custom_ops.yaml file to use. Optional, but at least " |
|
"one of --functions-yaml-path and --custom-ops-yaml-path must be " |
|
"specified.", |
|
) |
|
parser.add_argument( |
|
"--aten-yaml-path", |
|
"--aten_yaml_path", |
|
help="path to native_functions.yaml file.", |
|
) |
|
|
|
parser.add_argument( |
|
"-d", |
|
"--install-dir", |
|
"--install_dir", |
|
help="output directory", |
|
default="build/generated", |
|
) |
|
parser.add_argument( |
|
"-o", |
|
"--output-dependencies", |
|
help="output a list of dependencies into the given file and exit", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--dry-run", |
|
action="store_true", |
|
help="run without writing any files (still updates outputs)", |
|
) |
|
parser.add_argument( |
|
"--static-dispatch-backend", |
|
"--static_dispatch_backend", |
|
nargs="*", |
|
help="generate static dispatch code for the specific backend (if set)", |
|
) |
|
parser.add_argument( |
|
"--op-registration-whitelist", |
|
"--op_registration_whitelist", |
|
nargs="*", |
|
help="filter op registrations by the whitelist (if set); " |
|
"each item is `namespace`::`operator name` without overload name; " |
|
"e.g.: aten::empty aten::conv2d ...", |
|
) |
|
parser.add_argument( |
|
"--op-selection-yaml-path", |
|
"--op_selection_yaml_path", |
|
help="Provide a path to the operator selection (for custom build) YAML " |
|
"that contains the information about the set of selected operators " |
|
"and their categories (training, ...). Each operator is either a " |
|
"full operator name with overload or just a bare operator name. " |
|
"The operator names also contain the namespace prefix (e.g. aten::)", |
|
) |
|
parser.add_argument( |
|
"--tags-path", |
|
help="Path to tags.yaml. Required by yaml parsing in codegen system.", |
|
) |
|
parser.add_argument( |
|
"--rocm", |
|
action="store_true", |
|
help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", |
|
) |
|
parser.add_argument( |
|
"--use-aten-lib", |
|
"--use_aten_lib", |
|
action="store_true", |
|
help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per " |
|
"operator", |
|
) |
|
parser.add_argument( |
|
"--manual_registration", |
|
"--manual-registration", |
|
action="store_true", |
|
help="a boolean flag to indicate whether we want to manually call" |
|
"register_kernels() or rely on static init. ", |
|
) |
|
parser.add_argument( |
|
"--generate", |
|
type=str, |
|
nargs="*", |
|
choices=["headers", "sources"], |
|
default=["headers", "sources"], |
|
help="Generate only a subset of files", |
|
) |
|
parser.add_argument( |
|
"--add-exception-boundary", |
|
"--add_exception_boundary", |
|
action="store_true", |
|
help="whether to add a try/catch in the generated kernel wrapper to " |
|
"convert exceptions to clean failures.", |
|
) |
|
options = parser.parse_args() |
|
assert options.tags_path, "tags.yaml is required by codegen yaml parsing." |
|
|
|
selector = get_custom_build_selector( |
|
options.op_registration_whitelist, |
|
options.op_selection_yaml_path, |
|
) |
|
|
|
parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( |
|
aten_yaml_path=options.aten_yaml_path, |
|
tags_yaml_path=options.tags_path, |
|
native_yaml_path=options.functions_yaml_path, |
|
custom_ops_yaml_path=options.custom_ops_yaml_path, |
|
selector=selector, |
|
use_aten_lib=options.use_aten_lib, |
|
) |
|
native_functions, kernel_index = ( |
|
parsed_yaml.native_functions, |
|
parsed_yaml.kernel_index, |
|
) |
|
custom_ops_native_functions = ( |
|
custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else [] |
|
) |
|
|
|
cpu_fm = make_file_manager(options=options) |
|
|
|
if "headers" in options.generate: |
|
|
|
gen_headers( |
|
native_functions=native_functions, |
|
gen_custom_ops_header=options.custom_ops_yaml_path, |
|
custom_ops_native_functions=custom_ops_native_functions, |
|
selector=selector, |
|
kernel_index=kernel_index, |
|
cpu_fm=cpu_fm, |
|
use_aten_lib=options.use_aten_lib, |
|
) |
|
|
|
if "sources" in options.generate: |
|
gen_unboxing( |
|
native_functions=native_functions, |
|
cpu_fm=cpu_fm, |
|
selector=selector, |
|
use_aten_lib=options.use_aten_lib, |
|
kernel_index=kernel_index, |
|
manual_registration=options.manual_registration, |
|
add_exception_boundary=options.add_exception_boundary, |
|
) |
|
if custom_ops_native_functions: |
|
gen_custom_ops( |
|
native_functions=custom_ops_native_functions, |
|
selector=selector, |
|
kernel_index=kernel_index, |
|
cpu_fm=cpu_fm, |
|
rocm=options.rocm, |
|
) |
|
|
|
if options.output_dependencies: |
|
depfile_path = Path(options.output_dependencies).resolve() |
|
depfile_name = depfile_path.name |
|
depfile_stem = depfile_path.stem |
|
|
|
for fm, prefix in [ |
|
(cpu_fm, ""), |
|
]: |
|
varname = prefix + depfile_stem |
|
path = depfile_path.parent / (prefix + depfile_name) |
|
fm.write_outputs(varname, str(path)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|