File size: 7,135 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# mypy: allow-untyped-defs
import re

import torch._C as C


"""
PythonDispatcher class is a thin python-binding to C++ dispatcher and it
is designed to show how dispatcher precompute works. In particular,
it shows for a certain op `foo`, what the computed dispatch table looks
like after user register their kernels to certains dispatch keys.

In the real C++ dispatcher we support many dispatch keys for different
functionalities. For simplicity PythonDispatcher only supports dispatch
keys for a single example of each use case. These use cases are listed below:

- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
    autograd kernel in pytorch core library.
    E.g. CPU, CUDA
- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
    inference kernels, but they share the same autograd kernel specified in AutogradOther.
    E.g. FPGA, SparseCsrCPU
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
    kernel defined in pytorch core library. Backend owner is responsible for registering both
    inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
    E.g. XLA, XPU, MPS
- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc.
    Kernels registered to this key MUST work for inference for all backends.
- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther.
    Kernels registered to this key MUST work for autograd for all backends.
- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd
    Kernels registered to this key MUST work for both inference + autograd for all backends.

Note we only allow registrations to alias keys inside pytorch core library. E.g
you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd
kernel from torch-xla extension, instead you should upstream the kernel into
pytorch/pytorch repo so that it's available for all backends and continuously
tested even without the extension.

Usage:
  dispatcher = PythonDispatcher()
  dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"])
  print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend.
  # For more debugging information
  # print(dispatcher.keys())
  # print(dispatcher.registrations())
  # print(dispatcher.rawRegistrations())
  # print(dispatcher.rawDispatchTable())
PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table.
This file only provides the simplified API for developers, relevant test code is located in
test/test_dispatch.py
"""


class PythonDispatcher:
    namespace = "__test__"
    name = "foo"
    # fmt: off
    runtime_keys = [
        "CPU", "AutogradCPU",
        "FPGA", "AutogradOther",
        "XLA", "AutogradXLA",
        "Lazy", "AutogradLazy",
    ]
    # fmt: on
    alias_keys = [
        "CompositeExplicitAutograd",
        "Autograd",
        "CompositeImplicitAutograd",
    ]
    supported_keys = runtime_keys + alias_keys

    def __init__(self) -> None:
        C._dispatch_check_invariants(self.name)  # type: ignore[attr-defined]
        self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
        self.ref.def_("foo(Tensor x) -> Tensor")

    """
    Returns a list of dispatch keys supported by PythonDispatcher.
    You can register kernels to these keys.
    """

    def keys(self):
        return self.supported_keys

    """
    Register kernels to the target dispatchKeys.
    dispatchKeys(list[str]): a list of dispatch keys that you want to register
      your own kernel. Note that you don't need to write the kernel yourself in
      this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is
      automatically generated and registered.
    """

    def register(self, dispatchKeys):
        # Overriden is not supported and triggers a warning in C++ dispatcher.
        if len(set(dispatchKeys)) != len(dispatchKeys):
            raise RuntimeError(
                f"Overriden is not allowed but found duplicates in {dispatchKeys}."
            )
        # We currently forbid this in codegen instead of C++ dispatcher.
        if (
            "CompositeImplicitAutograd" in dispatchKeys
            and "CompositeExplicitAutograd" in dispatchKeys
        ):
            raise RuntimeError(
                "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed."
            )
        for key in dispatchKeys:
            if key not in self.supported_keys:
                raise RuntimeError(
                    f"{key} is not supported, please select a dispatch key in {self.supported_keys}."
                )
            self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key)

    """
    Helper function to format (key, kernel).
    """

    def _format_line(self, key, kernel):
        return f"{key:<15} {kernel}\n"

    """
    Helper function to print a table header.
    """

    def _format_header(self, header):
        s = f"""
{header}
"""
        s += self._format_line("key", "kernel")
        s += "---------------------------\n"
        return s

    """
    Returns raw output of all registration info for debugging only.
    Use registrations() for a simplified version.
    """

    def rawRegistrations(self):
        return C._dispatch_dump(f"{self.namespace}::{self.name}")  # type: ignore[attr-defined]

    """
    Returns raw output of computed dispatch table for debugging only.
    Use dispatchTable() for a simplified version.
    """

    def rawDispatchTable(self):
        return C._dispatch_dump_table(f"{self.namespace}::{self.name}")  # type: ignore[attr-defined]

    """
    Returns a table(str) including all the registrations from users.
    Note this includes registrations to both runtime keys and alias keys.
    """

    def registrations(self):
        output = self._format_header("Registered Kernels")
        state = self.rawRegistrations()
        state_entries = state.split("\n")
        for line in state_entries:
            first = line.split(":")[0]
            if any(first.startswith(k) for k in self.supported_keys):
                kernel = line.split("::")[0].split(" ")[1]
                output += self._format_line(first, kernel)
        return output

    """
    Returns the computed dispatch table(str). Note this only include
    runtime keys, registrations to alias keys have been decoded to their
    mapped runtime keys.
    """

    def dispatchTable(self):
        output = self._format_header("Computed Dispatch Table")
        table = self.rawDispatchTable()
        table_entries = table.split("\n")
        regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
        for line in table_entries:
            k = line.split(":")[0]
            if k in self.runtime_keys:
                entry = regex.sub("[", line)
                output += self._format_line(k, entry.split(": ")[1])
        return output