|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
try: |
|
import torch.utils._pytree as pytree |
|
from torch._functorch.aot_autograd import create_aot_dispatcher_function |
|
from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs |
|
from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode |
|
from torch._inductor.virtualized import V |
|
from torch._inductor.scheduler import Scheduler |
|
|
|
original_create_aot_dispatcher_function = create_aot_dispatcher_function |
|
except ImportError: |
|
pass |
|
|
|
from .util import get_input_nodes |
|
from .graph_param import DSGraphParamManager |
|
|
|
|
|
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool): |
|
|
|
def wrapped_compiler(gm, fake_inputs): |
|
mod_graph = dc_compiler(gm, fake_inputs) |
|
|
|
|
|
if mod_graph is None: |
|
return None |
|
|
|
if z3_partition: |
|
|
|
|
|
patched_inputs = [] |
|
if bwd: |
|
param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph) |
|
param_names = [n.name for n in param_nodes_bw] |
|
else: |
|
param_names = graph_param_manager[graph_id].param_names |
|
input_nodes = get_input_nodes(gm.graph) |
|
|
|
for in_node, in_v in zip(input_nodes, fake_inputs): |
|
ds_param = in_node.name in param_names |
|
if ds_param: |
|
from torch._subclasses.fake_tensor import is_fake |
|
from torch._dynamo.utils import to_fake_tensor |
|
assert is_fake(in_v), f"Input {in_v} should be fake tensor" |
|
patched_inputs.append( |
|
to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode)) |
|
else: |
|
patched_inputs.append(in_v) |
|
|
|
patched_inputs = tuple(patched_inputs) |
|
else: |
|
patched_inputs = fake_inputs |
|
|
|
return original_compiler(gm, patched_inputs) |
|
|
|
return wrapped_compiler |
|
|
|
|
|
def wrap_partition_fn(partition_fn, real_inputs, param_indices): |
|
|
|
def wrapped_partition_fn(*args, **kwargs): |
|
|
|
fw_module, bw_module = partition_fn(*args, **kwargs) |
|
|
|
|
|
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices) |
|
|
|
def fix_placeholder_meta(graph): |
|
for n in graph.nodes: |
|
if n.op == "placeholder" and n.name in pm.param_names: |
|
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device) |
|
|
|
fix_placeholder_meta(fw_module.graph) |
|
fix_placeholder_meta(bw_module.graph) |
|
|
|
return fw_module, bw_module |
|
|
|
return wrapped_partition_fn |
|
|
|
|
|
def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs, |
|
param_indices, param_manager): |
|
|
|
from torch._dynamo.backends.common import AotAutograd |
|
import functools |
|
|
|
def patch_aotautograd(): |
|
|
|
if hasattr(AotAutograd, "__original_init"): |
|
AotAutograd.__init__ = AotAutograd.__original_init |
|
|
|
original_init = AotAutograd.__init__ |
|
|
|
@functools.wraps(original_init) |
|
def patched_init(self, **kwargs): |
|
kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"], |
|
make_fw_graph, |
|
z3_partition, |
|
graph_id, |
|
param_manager, |
|
bwd=False) |
|
kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"], |
|
make_bw_graph, |
|
z3_partition, |
|
graph_id, |
|
param_manager, |
|
bwd=True) |
|
kwargs["inference_compiler"] = kwargs["fw_compiler"] |
|
|
|
if z3_partition: |
|
kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices) |
|
|
|
original_init(self, **kwargs) |
|
|
|
AotAutograd.__original_init = original_init |
|
AotAutograd.__init__ = patched_init |
|
|
|
patch_aotautograd() |
|
|
|
|
|
def register_custom_ops(): |
|
|
|
def fallback_handler_no_reuse(kernel, |
|
never_reuse_input, |
|
never_reuse_output, |
|
force_free_input, |
|
add_to_fallback_set=True): |
|
if add_to_fallback_set: |
|
fallbacks.add(kernel) |
|
|
|
def handler(*args, **kwargs): |
|
|
|
def wrap_tensors(x): |
|
out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x |
|
if out is not None and never_reuse_output: |
|
V.graph.never_reuse_buffers.add(out.get_name()) |
|
return out |
|
|
|
class CustomDCKernel(FallbackKernel): |
|
|
|
def __init__(self, op, *args, **kwargs): |
|
super().__init__(op, *args, **kwargs) |
|
|
|
def add_to_never_reuse(x): |
|
if isinstance(x, IRNode): |
|
assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}" |
|
V.graph.never_reuse_buffers.add(x.get_name()) |
|
|
|
if never_reuse_input: |
|
pytree.tree_map(add_to_never_reuse, args) |
|
|
|
def get_var_name_for_arg(self, arg: str): |
|
if arg.isidentifier(): |
|
return arg |
|
|
|
import re |
|
match = re.match(r"reinterpret_tensor\((\w+),", arg) |
|
if match: |
|
return match.group(1) |
|
return None |
|
|
|
def codegen(self, wrapper): |
|
if not force_free_input: |
|
return super().codegen(wrapper) |
|
|
|
kernel = self.op_overload |
|
self.codegen_comment(wrapper) |
|
args = [*self.codegen_args(), *self.codegen_kwargs()] |
|
|
|
V.graph.wrapper_code.generate_fallback_kernel(self, args) |
|
if isinstance(self.layout, Layout): |
|
self.codegen_size_asserts(wrapper) |
|
|
|
var_name = self.get_var_name_for_arg(args[0]) |
|
if var_name: |
|
wrapper.writeline(f"{var_name} = None") |
|
|
|
self.codegen_unbacked_symbol_defs(wrapper) |
|
|
|
kernel_cls = CustomDCKernel if force_free_input else FallbackKernel |
|
return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs)) |
|
|
|
return handler |
|
|
|
def register_fallback_no_reuse(op_overload, |
|
never_reuse_input=False, |
|
never_reuse_output=False, |
|
force_free_input=False): |
|
add_needs_realized_inputs(op_overload) |
|
return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse( |
|
op_overload, |
|
never_reuse_input=never_reuse_input, |
|
never_reuse_output=never_reuse_output, |
|
force_free_input=force_free_input)) |
|
|
|
|
|
|
|
register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True) |
|
register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True) |
|
register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False) |
|
register_fallback_no_reuse(torch.ops.dc.reduce_grad.default, |
|
never_reuse_input=True, |
|
never_reuse_output=True, |
|
force_free_input=True) |
|
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True) |
|
|
|
if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched: |
|
Scheduler.is_dc_patched = True |
|
Scheduler.dead_node_elimination = lambda _: None |
|
|