File size: 11,687 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
import contextlib
import threading
from collections.abc import Generator, Iterable
from typing import Any, Callable, Optional, Union
from torch.utils._exposed_in import exposed_in
from .custom_ops import custom_op, CustomOpDef
from .infer_schema import infer_schema
@exposed_in("torch.library")
def triton_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
schema: Optional[str] = None,
) -> Callable:
"""Create a custom operator whose implementation is backed by 1+ triton kernels.
This is a more structured way of using triton kernels with PyTorch.
Prefer using triton kernels with no ``torch.library`` custom operator wrappers
(like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because
that is simpler;
only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you
want to create an operator that behaves like PyTorch built-in operators.
For example, you may use a ``torch.library`` wrapper API to define the
behavior of the triton kernel when passed a tensor subclass or under
a TorchDispatchMode.
Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op`
when the implementation
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
custom operators as opaque (:func:`torch.compile` and
:func:`torch.export.export` will never trace into them), but ``triton_op``
makes the implementation visible to these subsystems, allowing them
to optimize the triton kernel(s).
Note that ``fn`` must only consist of calls to PyTorch-understood
operators and triton kernels. Any triton kernels called inside ``fn``
must be wrapped in a call to :func:`torch.library.wrap_triton`.
Args:
name (str): A name for the custom op that looks like "{namespace}::{name}",
e.g. "mylib::my_linear". The name is used as the op's stable identifier
in PyTorch subsystems (e.g. torch.export, FX graphs).
To avoid name collisions, please use your project name as the namespace;
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
it pessimistically assumes that all inputs to the operator are being mutated.
schema (None | str): A schema string for the operator. If None
(recommended) we'll infer a schema for the operator from its type
annotations. We recommend letting us infer a schema unless you
have a specific reason not to.
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch.library import triton_op, wrap_triton
>>>
>>> import triton
>>> from triton import language as tl
>>>
>>> @triton.jit
>>> def add_kernel(
>>> in_ptr0,
>>> in_ptr1,
>>> out_ptr,
>>> n_elements,
>>> BLOCK_SIZE: "tl.constexpr",
>>> ):
>>> pid = tl.program_id(axis=0)
>>> block_start = pid * BLOCK_SIZE
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>> mask = offsets < n_elements
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
>>> output = x + y
>>> tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> @triton_op("mylib::add", mutates_args={})
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
>>> output = torch.empty_like(x)
>>> n_elements = output.numel()
>>>
>>> def grid(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> # NB: we need to wrap the triton kernel in a call to wrap_triton
>>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>> return output
>>>
>>> @torch.compile
>>> def f(x, y):
>>> return add(x, y)
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>>
>>> z = f(x, y)
>>> assert torch.allclose(z, x + y)
"""
def dec(fn: Callable[..., object]) -> CustomOpDef:
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
# Optimization: we're passing regular Tensors into the triton kernel, so
# no need to go through HOP dispatch
with set_wrap_triton_enabled(False):
return fn(*args, **kwargs)
result = custom_op(
name,
backend_fn,
mutates_args=mutates_args,
schema=infer_schema(fn, mutates_args=mutates_args),
)
from .._subclasses.functional_tensor import FunctionalTensorMode
# We require that the user pass us a function that is make_fx traceable,
# so we can just register it as the Fake/meta kernel.
result.register_fake(fn)
# We decompose the operator when FunctionalTensorMode is active.
# The goal is to decompose the operator in AOTDispatcher.
# - With torch.compile, this means that the backend (usually Inductor)
# can see a call to the triton kernel(s) and so it can directly optimize
# them by inlining them into the lowering process.
def functional_decomp( # type: ignore[no-untyped-def]
mode, op, types, args, kwargs
):
# NOTE [Export custom triton op]
# For torch.export (strict and non-strict), we don't do functional decomposition.
# Instead, we preserve the custom triton ops as custom ops. This is because we want
# the exported program to be high-level and serializable. If we decompose
# the custom op to a functional hop and make it a node in exported program,
# we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
# functions and triton dtypes. This is undesireble because:
# - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
# - exported program will contain the implementation detail (e.g. triton source code) for a specific
# backend (GPU), which is probably at a wrong level of abstraction.
# - changes to triton or the serialization logic for triton arguments can be BC breaking
#
# In the short term, we expect users to have a separate aot_compile stage that compiles the exported program
# into a Cubin file on the same machine that users call export, which does autotuning and removes triton
# dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
# In the long term, we may export multiple cubins for the triton op directly
from torch.export._trace import custom_triton_ops_decomposition_disabled
if custom_triton_ops_decomposition_disabled():
return mode.__torch_dispatch__(op, types, args, kwargs)
else:
with mode:
return fn(*args, **kwargs)
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
return result
if fn is None:
return dec
else:
return dec(fn)
wrap_triton_enabled = threading.local()
wrap_triton_enabled_default = True
@contextlib.contextmanager
def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]:
"""If triton kernels annotated with @wrap_triton should dispatch via HOP
or go straight to the triton kernel execution.
We have this switch because eager-mode performance of HOP dispatch is slow
enough to matter (~1ms) and we know that wrap_triton isn't necessary in
some situations (eager-mode with regular Tensors)
"""
try:
prev = is_wrap_triton_enabled()
wrap_triton_enabled.value = enabled
yield
finally:
wrap_triton_enabled.value = prev
def is_wrap_triton_enabled() -> bool:
return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default)
def capture_triton(triton_kernel: Callable, /) -> Any:
"""This API has been renamed to wrap_triton"""
return wrap_triton(triton_kernel)
@exposed_in("torch.library")
def wrap_triton(triton_kernel: Callable, /) -> Any:
"""Allows capture of a triton kernel into a graph via make_fx or
non-strict ``torch.export``.
These technologies perform Dispatcher-based tracing (via
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
The ``wrap_triton`` API wraps a triton kernel into a callable that
can actually be traced into a graph.
Please use this API together with :func:`torch.library.triton_op`.
Examples:
>>> # xdoctest: +SKIP
>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.library import wrap_triton
>>>
>>> @triton.jit
>>> def add_kernel(
>>> in_ptr0,
>>> in_ptr1,
>>> out_ptr,
>>> n_elements,
>>> BLOCK_SIZE: "tl.constexpr",
>>> ):
>>> pid = tl.program_id(axis=0)
>>> block_start = pid * BLOCK_SIZE
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>> mask = offsets < n_elements
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
>>> output = x + y
>>> tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> def add(x, y):
>>> output = torch.empty_like(x)
>>> n_elements = output.numel()
>>>
>>> def grid_fn(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>> return output
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> # kernel_idx = 0, constant_args_idx = 0,
>>> # grid = [(1, 1, 1)], kwargs = {
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
>>> # })
>>> # return empty_like
"""
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
if not isinstance(triton_kernel, (JITFunction, Autotuner)):
raise RuntimeError(
"wrap_triton only works on functions annotated with triton.jit or triton.autotune"
)
if not is_wrap_triton_enabled():
return triton_kernel
return TraceableTritonKernelWrapper(triton_kernel, None, None)
|