|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from deepspeed.utils.torch import required_torch_version |
|
|
|
backward_inputs = [] |
|
|
|
enabled_patched_func = False |
|
original_grad_fn = None |
|
base_meta = type(torch.autograd.Function) |
|
|
|
if required_torch_version(min_version=2.7): |
|
|
|
class FunctionMeta(base_meta): |
|
|
|
def __new__(cls, name, bases, dct): |
|
if name == "CompiledFunction": |
|
original_backward_impl = dct.get("_backward_impl") |
|
|
|
def wrapped_backward_impl(ctx, all_args): |
|
assert original_backward_impl is not None |
|
|
|
if enabled_patched_func: |
|
backward_inputs.append(all_args) |
|
wrapped_backward_impl.owner_class.compiled_bw = None |
|
|
|
return original_backward_impl(ctx, all_args) |
|
|
|
wrapped_backward_impl.owner_class = None |
|
dct["_backward_impl"] = staticmethod(wrapped_backward_impl) |
|
new_class = super().__new__(cls, name, bases, dct) |
|
wrapped_backward_impl.owner_class = new_class |
|
|
|
return new_class |
|
|
|
return super().__new__(cls, name, bases, dct) |
|
|
|
elif required_torch_version(min_version=2.6): |
|
|
|
class FunctionMeta(base_meta): |
|
|
|
def __new__(cls, name, bases, dct): |
|
if name == "CompiledFunction": |
|
original_backward_prologue = dct.get("_backward_prologue") |
|
|
|
def wrapped_backward_prologue(ctx, *grad_outputs): |
|
assert original_backward_prologue is not None |
|
|
|
all_args = original_backward_prologue(ctx, *grad_outputs) |
|
if enabled_patched_func: |
|
backward_inputs.append(all_args) |
|
wrapped_backward_prologue.owner_class.compiled_bw = None |
|
|
|
return all_args |
|
|
|
wrapped_backward_prologue.owner_class = None |
|
dct["_backward_prologue"] = staticmethod(wrapped_backward_prologue) |
|
new_class = super().__new__(cls, name, bases, dct) |
|
wrapped_backward_prologue.owner_class = new_class |
|
|
|
return new_class |
|
|
|
return super().__new__(cls, name, bases, dct) |
|
|
|
|
|
def patch_compiled_func(): |
|
|
|
global enabled_patched_func |
|
enabled_patched_func = True |
|
|
|
class PatchedFunction(torch.autograd.Function, metaclass=FunctionMeta): |
|
pass |
|
|
|
global original_grad_fn |
|
original_grad_fn = torch.autograd.Function |
|
torch.autograd.Function = PatchedFunction |
|
|
|
return backward_inputs |
|
|
|
|
|
def unpatch_compiled_func(): |
|
global enabled_patched_func |
|
enabled_patched_func = False |
|
|
|
global original_grad_fn |
|
torch.autograd.Function = original_grad_fn |
|
|
|
|
|
def get_backward_inputs(): |
|
return backward_inputs |
|
|