File size: 2,679 Bytes
229047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py
# The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema.

from typing import Optional, Callable, Iterable, Union

from torch.library import custom_op, CustomOpDef
from torch._library.triton import set_wrap_triton_enabled


def triton_op(
    name: str,
    fn: Optional[Callable] = None,
    /,
    *,
    mutates_args: Union[str, Iterable[str]],
    schema: Optional[str] = None,
    # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False,
    # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator
    # and so inductor can't trace inside.
    allow_decomposition=True,
) -> Callable:
    def dec(fn: Callable[..., object]) -> CustomOpDef:
        def backend_fn(*args, **kwargs):  # type: ignore[no-untyped-def]
            # Optimization: we're passing regular Tensors into the triton kernel, so
            # no need to go through HOP dispatch
            with set_wrap_triton_enabled(False):
                return fn(*args, **kwargs)

        result = custom_op(
            name,
            backend_fn,
            mutates_args=mutates_args,
            # This is the only difference with the PyTorch implementation
            schema=schema,
        )
        from torch._subclasses.functional_tensor import FunctionalTensorMode

        # We require that the user pass us a function that is make_fx traceable,
        # so we can just register it as the Fake/meta kernel.
        result.register_fake(fn)

        if allow_decomposition:
            # We decompose the operator when FunctionalTensorMode is active.
            # The goal is to decompose the operator in AOTDispatcher.
            # - With torch.compile, this means that the backend (usually Inductor)
            #   can see a call to the triton kernel(s) and so it can directly optimize
            #   them by inlining them into the lowering process.
            def functional_decomp(  # type: ignore[no-untyped-def]
                mode, op, types, args, kwargs
            ):
                from torch.export._trace import custom_triton_ops_decomposition_disabled

                if custom_triton_ops_decomposition_disabled():
                    return mode.__torch_dispatch__(op, types, args, kwargs)
                else:
                    with mode:
                        return fn(*args, **kwargs)

            result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)

        return result

    if fn is None:
        return dec
    else:
        return dec(fn)