from __future__ import annotations import torchgen.api.meta as meta import torchgen.api.structured as structured from torchgen.api.types import kernel_signature from torchgen.context import with_native_function_and_index from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup from torchgen.utils import mapMaybe def torch_api_key_word_prefix(bankend_index: BackendIndex) -> str: if bankend_index.external: return "" # Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structrued # kernels. Regarding these produced structured kernels, they should be visible for the Intel GPU ATen # library. Therefore, we need to add "TORCH_XPU_API" prefix to these structured kernels, # rather than "TORCH_API". Because the semantic of "TORCH_API" is "hidden" for out-of-tree backends. # For other in-tree backends like cpu and cuda, they still use "TORCH_API" prefix with "visible" semantic. device_torch_api_key_word_mapping = { "XPU": "TORCH_XPU_API", } return ( device_torch_api_key_word_mapping.get( bankend_index.dispatch_key.name, "TORCH_API" ) + " " ) @with_native_function_and_index def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None: sig = kernel_signature(f, backend_index) metadata = backend_index.get_kernel(f) if metadata is None: return None if "legacy::" in metadata.kernel: return None else: prefix = "static" if backend_index.external else "TORCH_API" return f"{prefix} {sig.decl(name=metadata.kernel)};" @with_native_function_and_index def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]: meta_name = meta.name(g) out_args = structured.impl_arguments(g) metadata = backend_index.get_kernel(g) if metadata is None: return [] prefix = torch_api_key_word_prefix(backend_index) return [ f"""\ struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ void impl({", ".join(a.decl() for a in out_args)}); }}; """ ] # Generates NativeFunctions.h, a list of forward declarations of all # actual kernel definitions we keep in aten/src/ATen/native/ @with_native_function_and_index def compute_native_function_declaration( g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex ) -> list[str]: metadata = backend_index.get_kernel(g) if isinstance(g, NativeFunctionsGroup): if metadata is not None and metadata.structured: if backend_index.external: # Structured hasn't been tested with external backends yet. raise AssertionError( "Structured external backend functions are not implemented yet." ) else: return gen_structured(g, backend_index) else: return list( mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions()) ) else: x = gen_unstructured(g, backend_index) return [] if x is None else [x]