jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
try:
import torch.utils._pytree as pytree
from torch._functorch.aot_autograd import create_aot_dispatcher_function
from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
from torch._inductor.virtualized import V
from torch._inductor.scheduler import Scheduler
original_create_aot_dispatcher_function = create_aot_dispatcher_function
except ImportError:
pass
from .util import get_input_nodes
from .graph_param import DSGraphParamManager
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
def wrapped_compiler(gm, fake_inputs):
mod_graph = dc_compiler(gm, fake_inputs)
# For symint case
if mod_graph is None:
return None
if z3_partition:
# Inductor validates input size estimated by the first trace, where ds tensor is materialized.
# We need to patch the input tensors to avoid the validation error.
patched_inputs = []
if bwd:
param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph)
param_names = [n.name for n in param_nodes_bw]
else:
param_names = graph_param_manager[graph_id].param_names
input_nodes = get_input_nodes(gm.graph)
for in_node, in_v in zip(input_nodes, fake_inputs):
ds_param = in_node.name in param_names
if ds_param:
from torch._subclasses.fake_tensor import is_fake
from torch._dynamo.utils import to_fake_tensor
assert is_fake(in_v), f"Input {in_v} should be fake tensor"
patched_inputs.append(
to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode))
else:
patched_inputs.append(in_v)
patched_inputs = tuple(patched_inputs)
else:
patched_inputs = fake_inputs
return original_compiler(gm, patched_inputs)
return wrapped_compiler
def wrap_partition_fn(partition_fn, real_inputs, param_indices):
def wrapped_partition_fn(*args, **kwargs):
fw_module, bw_module = partition_fn(*args, **kwargs)
# get parameter names
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)
def fix_placeholder_meta(graph):
for n in graph.nodes:
if n.op == "placeholder" and n.name in pm.param_names:
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)
fix_placeholder_meta(fw_module.graph)
fix_placeholder_meta(bw_module.graph)
return fw_module, bw_module
return wrapped_partition_fn
def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
param_indices, param_manager):
from torch._dynamo.backends.common import AotAutograd
import functools
def patch_aotautograd():
# Unpatch if it was already patched
if hasattr(AotAutograd, "__original_init"):
AotAutograd.__init__ = AotAutograd.__original_init
original_init = AotAutograd.__init__
@functools.wraps(original_init)
def patched_init(self, **kwargs):
kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"],
make_fw_graph,
z3_partition,
graph_id,
param_manager,
bwd=False)
kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"],
make_bw_graph,
z3_partition,
graph_id,
param_manager,
bwd=True)
kwargs["inference_compiler"] = kwargs["fw_compiler"]
if z3_partition:
kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)
original_init(self, **kwargs)
AotAutograd.__original_init = original_init
AotAutograd.__init__ = patched_init
patch_aotautograd()
def register_custom_ops():
def fallback_handler_no_reuse(kernel,
never_reuse_input,
never_reuse_output,
force_free_input,
add_to_fallback_set=True):
if add_to_fallback_set:
fallbacks.add(kernel)
def handler(*args, **kwargs):
def wrap_tensors(x):
out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x
if out is not None and never_reuse_output:
V.graph.never_reuse_buffers.add(out.get_name())
return out
class CustomDCKernel(FallbackKernel):
def __init__(self, op, *args, **kwargs):
super().__init__(op, *args, **kwargs)
def add_to_never_reuse(x):
if isinstance(x, IRNode):
assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}"
V.graph.never_reuse_buffers.add(x.get_name())
if never_reuse_input:
pytree.tree_map(add_to_never_reuse, args)
def get_var_name_for_arg(self, arg: str):
if arg.isidentifier():
return arg
import re
match = re.match(r"reinterpret_tensor\((\w+),", arg)
if match:
return match.group(1)
return None
def codegen(self, wrapper):
if not force_free_input:
return super().codegen(wrapper)
kernel = self.op_overload
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs()]
V.graph.wrapper_code.generate_fallback_kernel(self, args)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
var_name = self.get_var_name_for_arg(args[0])
if var_name:
wrapper.writeline(f"{var_name} = None")
self.codegen_unbacked_symbol_defs(wrapper)
kernel_cls = CustomDCKernel if force_free_input else FallbackKernel
return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs))
return handler
def register_fallback_no_reuse(op_overload,
never_reuse_input=False,
never_reuse_output=False,
force_free_input=False):
add_needs_realized_inputs(op_overload)
return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse(
op_overload,
never_reuse_input=never_reuse_input,
never_reuse_output=never_reuse_output,
force_free_input=force_free_input))
# Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops.
# -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops.
register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True)
register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True)
register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False)
register_fallback_no_reuse(torch.ops.dc.reduce_grad.default,
never_reuse_input=True,
never_reuse_output=True,
force_free_input=True)
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)
if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
Scheduler.is_dc_patched = True
Scheduler.dead_node_elimination = lambda _: None