|
|
|
import contextlib |
|
|
|
import torch |
|
|
|
|
|
@contextlib.contextmanager |
|
def optimized_execution(should_optimize): |
|
"""Context manager that controls whether the JIT's executor will run optimizations before executing a function.""" |
|
stored_flag = torch._C._get_graph_executor_optimize() |
|
torch._C._set_graph_executor_optimize(should_optimize) |
|
try: |
|
yield |
|
finally: |
|
torch._C._set_graph_executor_optimize(stored_flag) |
|
|
|
|
|
@contextlib.contextmanager |
|
def fuser(name): |
|
"""Context manager that facilitates switching between backend fusers. |
|
|
|
Valid names: |
|
* ``fuser0`` - enables only legacy fuser |
|
* ``fuser1`` - enables only NNC |
|
* ``fuser2`` - enables only nvFuser |
|
* ``fuser3`` - enables oneDNN Graph |
|
""" |
|
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() |
|
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() |
|
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() |
|
old_nvfuser_state = torch._C._jit_nvfuser_enabled() |
|
old_llga_state = torch._C._jit_llga_enabled() |
|
if name == "fuser0": |
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
torch._C._jit_override_can_fuse_on_gpu(True) |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
torch._C._jit_set_nvfuser_enabled(False) |
|
torch._C._jit_set_llga_enabled(False) |
|
elif name == "fuser1": |
|
old_profiling_executor = torch._C._jit_set_profiling_executor(True) |
|
old_profiling_mode = torch._C._get_graph_executor_optimize(True) |
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
torch._C._jit_override_can_fuse_on_gpu(True) |
|
torch._C._jit_set_texpr_fuser_enabled(True) |
|
torch._C._jit_set_nvfuser_enabled(False) |
|
torch._C._jit_set_llga_enabled(False) |
|
elif name == "fuser2": |
|
torch._C._jit_override_can_fuse_on_cpu(False) |
|
torch._C._jit_override_can_fuse_on_gpu(False) |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
torch._C._jit_set_nvfuser_enabled(True) |
|
torch._C._jit_set_llga_enabled(False) |
|
elif name == "fuser3": |
|
old_profiling_executor = torch._C._jit_set_profiling_executor(True) |
|
old_profiling_mode = torch._C._get_graph_executor_optimize(True) |
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
torch._C._jit_override_can_fuse_on_gpu(False) |
|
torch._C._jit_set_texpr_fuser_enabled(True) |
|
torch._C._jit_set_nvfuser_enabled(False) |
|
torch._C._jit_set_llga_enabled(True) |
|
elif name == "none": |
|
torch._C._jit_override_can_fuse_on_cpu(False) |
|
torch._C._jit_override_can_fuse_on_gpu(False) |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
torch._C._jit_set_nvfuser_enabled(False) |
|
torch._C._jit_set_llga_enabled(False) |
|
else: |
|
raise Exception(f"unrecognized fuser option (name: {name})") |
|
try: |
|
yield |
|
finally: |
|
if name in ["fuser1", "fuser3"]: |
|
torch._C._jit_set_profiling_executor(old_profiling_executor) |
|
torch._C._get_graph_executor_optimize(old_profiling_mode) |
|
|
|
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) |
|
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) |
|
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) |
|
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) |
|
torch._C._jit_set_llga_enabled(old_llga_state) |
|
|
|
|
|
last_executed_optimized_graph = torch._C._last_executed_optimized_graph |
|
|
|
|
|
def _get_differentiable_graph_node(node, diff_node): |
|
if node.kind() == "prim::DifferentiableGraph": |
|
diff_node.append(node) |
|
else: |
|
for block in node.blocks(): |
|
for n in block.nodes(): |
|
_get_differentiable_graph_node(n, diff_node) |
|
|
|
|
|
def _graph_for(self, *args, **kwargs): |
|
return _script_method_graph_for(self, self, *args, **kwargs) |
|
|
|
|
|
def _script_method_graph_for(self, parent, *args, **kwargs): |
|
try: |
|
dbs = parent.get_debug_state() |
|
eps = list(dbs.execution_plans.values()) |
|
assert len(eps) == 1 |
|
graph = eps[0].graph.copy() |
|
|
|
|
|
fw_states = eps[0].code.differentiable_op_executor_states() |
|
diff_nodes: list[torch._C.Node] = [] |
|
for n in graph.nodes(): |
|
_get_differentiable_graph_node(n, diff_nodes) |
|
|
|
assert len(fw_states) == len(diff_nodes) |
|
|
|
for n, state in zip(diff_nodes, fw_states): |
|
fw_execution_plans = list(state.execution_plans.values()) |
|
|
|
|
|
|
|
if len(fw_execution_plans) == 1: |
|
n.g_("Subgraph", fw_execution_plans[0].graph) |
|
|
|
return graph |
|
except Exception: |
|
|
|
|
|
self(*args, **kwargs) |
|
return last_executed_optimized_graph() |
|
|
|
|
|
def set_fusion_strategy(strategy: list[tuple[str, int]]): |
|
"""Set the type and number of specializations that can occur during fusion. |
|
|
|
Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC" |
|
and depth is an integer. |
|
|
|
Behavior - static vs dynamic: |
|
In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined |
|
based on some initial profiling runs. |
|
In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple |
|
shapes are possible. |
|
|
|
In both cases, we also recompile on new striding behavior, device, or dtype. |
|
|
|
Behavior - fallback functions & depth: |
|
When an input doesn't match the format required by the specialized compiled op, it will run |
|
a fallback function. Fallback functions are recursively be compiled and specialized based |
|
on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to |
|
limit the number of specializations that can be compiled, before giving up on recompiling and |
|
falling back to a completely un-fused, un-specialized implementation. |
|
|
|
The list of (type, depth) pairs controls the type of specializations and the number of |
|
specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first |
|
two specializations will use static fusions, the following two specializations will use |
|
dynamic fusion, and any inputs that satisfy none of the 4 options will run an |
|
unfused implementation. |
|
|
|
NB: in the future, if more as more fusion backends are added there may be more granular |
|
apis for specific fusers. |
|
""" |
|
return torch._C._jit_set_fusion_strategy(strategy) |
|
|