|
|
|
import contextlib |
|
import platform |
|
import uuid |
|
import warnings |
|
import weakref |
|
from collections import defaultdict |
|
from typing import * |
|
import enum |
|
from weakref import ReferenceType |
|
|
|
import torch |
|
import torch.fx.traceback as fx_traceback |
|
from torch._functorch._aot_autograd.functional_utils import is_fun |
|
from torch.utils._pytree import tree_map |
|
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode |
|
from torch.utils._python_dispatch import TorchDispatchMode |
|
|
|
__all__ = [ |
|
"checkpoint", |
|
"checkpoint_sequential", |
|
"CheckpointError", |
|
"CheckpointFunction", |
|
"check_backward_validity", |
|
"detach_variable", |
|
"get_device_states", |
|
"set_device_states", |
|
"noop_context_fn", |
|
"set_checkpoint_early_stop", |
|
"DefaultDeviceType", |
|
"set_checkpoint_debug_enabled", |
|
"CheckpointPolicy", |
|
"SelectiveCheckpointContext", |
|
"create_selective_checkpoint_contexts", |
|
"SAC_IGNORED_OPS", |
|
] |
|
|
|
_DEFAULT_DETERMINISM_MODE = "default" |
|
|
|
_checkpoint_debug_enabled: Optional[bool] = None |
|
|
|
|
|
@contextlib.contextmanager |
|
def set_checkpoint_debug_enabled(enabled: Optional[bool]): |
|
""" |
|
Context manager that sets whether checkpoint should print additional debug |
|
information when running. See the ``debug`` flag for |
|
:func:`~torch.utils.checkpoint.checkpoint` for more information. Note that |
|
when set, this context manager overrides the value of ``debug`` passed to |
|
checkpoint. To defer to the local setting, pass ``None`` to this context. |
|
|
|
Args: |
|
enabled (bool): Whether checkpoint should print debug information. |
|
Default is 'None'. |
|
""" |
|
global _checkpoint_debug_enabled |
|
try: |
|
prev = _checkpoint_debug_enabled |
|
_checkpoint_debug_enabled = enabled |
|
yield |
|
finally: |
|
_checkpoint_debug_enabled = prev |
|
|
|
|
|
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: |
|
if isinstance(inputs, tuple): |
|
out = [] |
|
for inp in inputs: |
|
if not isinstance(inp, torch.Tensor): |
|
out.append(inp) |
|
continue |
|
|
|
x = inp.detach() |
|
x.requires_grad = inp.requires_grad |
|
out.append(x) |
|
return tuple(out) |
|
else: |
|
raise RuntimeError( |
|
"Only tuple of tensors is supported. Got Unsupported input type: ", |
|
type(inputs).__name__, |
|
) |
|
|
|
|
|
def check_backward_validity(inputs: Iterable[Any]) -> None: |
|
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): |
|
warnings.warn( |
|
"None of the inputs have requires_grad=True. Gradients will be None" |
|
) |
|
|
|
|
|
def _get_device_module(device="cuda"): |
|
if device == "meta": |
|
return torch.device("meta") |
|
device_module = getattr(torch, device) |
|
return device_module |
|
|
|
|
|
class DefaultDeviceType: |
|
r""" |
|
A class that manages the default device type for checkpointing. |
|
|
|
If no non-CPU tensors are present, the default device type will |
|
be used. The default value is 'cuda'. The device type is used in |
|
the checkpointing process when determining which device states |
|
to save and restore for recomputation. |
|
""" |
|
|
|
_default_device_type = "cuda" |
|
|
|
@staticmethod |
|
def set_device_type(device: str = "cuda"): |
|
""" |
|
Set the default device type for checkpointing. |
|
|
|
Args: |
|
device (str): The device type to be set as default. Default is 'cuda'. |
|
""" |
|
DefaultDeviceType._default_device_type = device |
|
|
|
@staticmethod |
|
def get_device_type() -> str: |
|
""" |
|
Get the current default device type for checkpointing. |
|
|
|
Returns: |
|
str: The current default device type. |
|
""" |
|
return DefaultDeviceType._default_device_type |
|
|
|
|
|
def _infer_device_type(*args): |
|
device_types = [] |
|
|
|
def add_device_types(arg): |
|
nonlocal device_types |
|
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu": |
|
device_types.append(arg.device.type) |
|
tree_map(add_device_types, args) |
|
|
|
device_types_set = set(device_types) |
|
if len(device_types_set) > 1: |
|
warnings.warn( |
|
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " |
|
"Device state will only be saved for devices of a single device type, and the remaining " |
|
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, " |
|
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices " |
|
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)" |
|
f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}" |
|
) |
|
if len(device_types) == 0: |
|
return DefaultDeviceType.get_device_type() |
|
elif "cuda" in device_types_set: |
|
return "cuda" |
|
else: |
|
return device_types[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: |
|
|
|
|
|
fwd_device_ids = [] |
|
|
|
def add_device_ids(arg): |
|
nonlocal fwd_device_ids |
|
if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: |
|
fwd_device_ids.append(arg.get_device()) |
|
tree_map(add_device_ids, args) |
|
|
|
fwd_device_states = [] |
|
device_module = _get_device_module(_infer_device_type(*args)) |
|
for device_id in fwd_device_ids: |
|
with device_module.device(device_id): |
|
fwd_device_states.append(device_module.get_rng_state()) |
|
|
|
return fwd_device_ids, fwd_device_states |
|
|
|
|
|
def set_device_states(devices, states, *, device_type=None) -> None: |
|
"""Sets random number generator states for the specified devices. |
|
|
|
Args: |
|
devices: Device ids to set states for. |
|
states: States to set. |
|
device_type: ``device_type`` of the devices to set states for. Default |
|
is the device returned by a call to ``DefaultDeviceType.get_device_type()``, |
|
which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``. |
|
""" |
|
if device_type is None: |
|
device_type = DefaultDeviceType.get_device_type() |
|
if device_type == "meta": |
|
return |
|
device_module = _get_device_module(device_type) |
|
for device, state in zip(devices, states): |
|
with device_module.device(device): |
|
device_module.set_rng_state(state) |
|
|
|
|
|
def _get_autocast_kwargs(device_type="cuda"): |
|
if torch.amp.is_autocast_available(device_type): |
|
device_autocast_kwargs = { |
|
"enabled": torch.is_autocast_enabled(device_type), |
|
"dtype": torch.get_autocast_dtype(device_type), |
|
"cache_enabled": torch.is_autocast_cache_enabled(), |
|
} |
|
else: |
|
device_autocast_kwargs = None |
|
|
|
cpu_autocast_kwargs = { |
|
"enabled": torch.is_autocast_enabled('cpu'), |
|
"dtype": torch.get_autocast_dtype('cpu'), |
|
"cache_enabled": torch.is_autocast_cache_enabled(), |
|
} |
|
|
|
return device_autocast_kwargs, cpu_autocast_kwargs |
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, run_function, preserve_rng_state, *args): |
|
check_backward_validity(args) |
|
ctx.run_function = run_function |
|
ctx.preserve_rng_state = preserve_rng_state |
|
|
|
ctx.device_type = _infer_device_type(*args) |
|
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( |
|
ctx.device_type |
|
) |
|
if preserve_rng_state: |
|
ctx.fwd_cpu_state = torch.get_rng_state() |
|
|
|
|
|
|
|
|
|
ctx.had_device_in_fwd = False |
|
device_module = _get_device_module(ctx.device_type) |
|
if getattr(device_module, "_initialized", False): |
|
ctx.had_device_in_fwd = True |
|
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) |
|
|
|
|
|
|
|
ctx.inputs = [] |
|
ctx.tensor_indices = [] |
|
tensor_inputs = [] |
|
for i, arg in enumerate(args): |
|
if torch.is_tensor(arg): |
|
tensor_inputs.append(arg) |
|
ctx.tensor_indices.append(i) |
|
ctx.inputs.append(None) |
|
else: |
|
ctx.inputs.append(arg) |
|
|
|
ctx.save_for_backward(*tensor_inputs) |
|
|
|
with torch.no_grad(): |
|
outputs = run_function(*args) |
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
if not torch.autograd._is_checkpoint_valid(): |
|
raise RuntimeError( |
|
"When use_reentrant=True, torch.utils.checkpoint is incompatible" |
|
" with .grad() or passing an `inputs` parameter to .backward()." |
|
" To resolve this error, you can either set use_reentrant=False," |
|
" or call .backward() without passing the `inputs` argument." |
|
) |
|
|
|
inputs = list(ctx.inputs) |
|
tensor_indices = ctx.tensor_indices |
|
tensors = ctx.saved_tensors |
|
|
|
|
|
for i, idx in enumerate(tensor_indices): |
|
inputs[idx] = tensors[i] |
|
|
|
|
|
|
|
|
|
rng_devices = [] |
|
if ctx.preserve_rng_state and ctx.had_device_in_fwd: |
|
rng_devices = ctx.fwd_devices |
|
with torch.random.fork_rng( |
|
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type |
|
): |
|
if ctx.preserve_rng_state: |
|
torch.set_rng_state(ctx.fwd_cpu_state) |
|
if ctx.had_device_in_fwd: |
|
set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) |
|
detached_inputs = detach_variable(tuple(inputs)) |
|
|
|
device_autocast_ctx = torch.amp.autocast( |
|
device_type=ctx.device_type, **ctx.device_autocast_kwargs |
|
) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() |
|
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): |
|
outputs = ctx.run_function(*detached_inputs) |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
outputs = (outputs,) |
|
|
|
|
|
outputs_with_grad = [] |
|
args_with_grad = [] |
|
for i in range(len(outputs)): |
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: |
|
outputs_with_grad.append(outputs[i]) |
|
args_with_grad.append(args[i]) |
|
if len(outputs_with_grad) == 0: |
|
raise RuntimeError( |
|
"none of output has requires_grad=True," |
|
" this checkpoint() is not necessary" |
|
) |
|
torch.autograd.backward(outputs_with_grad, args_with_grad) |
|
grads = tuple( |
|
inp.grad if isinstance(inp, torch.Tensor) else None |
|
for inp in detached_inputs |
|
) |
|
|
|
return (None, None) + grads |
|
|
|
|
|
def noop_context_fn(): |
|
return contextlib.nullcontext(), contextlib.nullcontext() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch._disable_dynamo |
|
def checkpoint( |
|
function, |
|
*args, |
|
use_reentrant: Optional[bool] = None, |
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, |
|
determinism_check: str = _DEFAULT_DETERMINISM_MODE, |
|
debug: bool = False, |
|
**kwargs |
|
): |
|
r"""Checkpoint a model or part of the model. |
|
|
|
Activation checkpointing is a technique that trades compute for memory. |
|
Instead of keeping tensors needed for backward alive until they are used in |
|
gradient computation during backward, forward computation in checkpointed |
|
regions omits saving tensors for backward and recomputes them during the |
|
backward pass. Activation checkpointing can be applied to any part of a |
|
model. |
|
|
|
There are currently two checkpointing implementations available, determined |
|
by the :attr:`use_reentrant` parameter. It is recommended that you use |
|
``use_reentrant=False``. Please refer the note below for a discussion of |
|
their differences. |
|
|
|
.. warning:: |
|
|
|
If the :attr:`function` invocation during the backward pass differs |
|
from the forward pass, e.g., due to a global variable, the checkpointed |
|
version may not be equivalent, potentially causing an |
|
error being raised or leading to silently incorrect gradients. |
|
|
|
.. warning:: |
|
|
|
The ``use_reentrant`` parameter should be passed explicitly. In version |
|
2.4 we will raise an exception if ``use_reentrant`` is not passed. |
|
If you are using the ``use_reentrant=True`` variant, please refer to the |
|
note below for important considerations and potential limitations. |
|
|
|
.. note:: |
|
|
|
The reentrant variant of checkpoint (``use_reentrant=True``) and |
|
the non-reentrant variant of checkpoint (``use_reentrant=False``) |
|
differ in the following ways: |
|
|
|
* Non-reentrant checkpoint stops recomputation as soon as all needed |
|
intermediate activations have been recomputed. This feature is enabled |
|
by default, but can be disabled with :func:`set_checkpoint_early_stop`. |
|
Reentrant checkpoint always recomputes :attr:`function` in its |
|
entirety during the backward pass. |
|
|
|
* The reentrant variant does not record the autograd graph during the |
|
forward pass, as it runs with the forward pass under |
|
:func:`torch.no_grad`. The non-reentrant version does record the |
|
autograd graph, allowing one to perform backward on the graph within |
|
checkpointed regions. |
|
|
|
* The reentrant checkpoint only supports the |
|
:func:`torch.autograd.backward` API for the backward pass without its |
|
`inputs` argument, while the non-reentrant version supports all ways |
|
of performing the backward pass. |
|
|
|
* At least one input and output must have ``requires_grad=True`` for the |
|
reentrant variant. If this condition is unmet, the checkpointed part |
|
of the model will not have gradients. The non-reentrant version does |
|
not have this requirement. |
|
|
|
* The reentrant version does not consider tensors in nested structures |
|
(e.g., custom objects, lists, dicts, etc) as participating in |
|
autograd, while the non-reentrant version does. |
|
|
|
* The reentrant checkpoint does not support checkpointed regions with |
|
detached tensors from the computational graph, whereas the |
|
non-reentrant version does. For the reentrant variant, if the |
|
checkpointed segment contains tensors detached using ``detach()`` or |
|
with :func:`torch.no_grad`, the backward pass will raise an error. |
|
This is because ``checkpoint`` makes all the outputs require gradients |
|
and this causes issues when a tensor is defined to have no gradient in |
|
the model. To avoid this, detach the tensors outside of the |
|
``checkpoint`` function. |
|
|
|
Args: |
|
function: describes what to run in the forward pass of the model or |
|
part of the model. It should also know how to handle the inputs |
|
passed as the tuple. For example, in LSTM, if user passes |
|
``(activation, hidden)``, :attr:`function` should correctly use the |
|
first input as ``activation`` and the second input as ``hidden`` |
|
preserve_rng_state(bool, optional): Omit stashing and restoring |
|
the RNG state during each checkpoint. Note that under torch.compile, |
|
this flag doesn't take effect and we always preserve RNG state. |
|
Default: ``True`` |
|
use_reentrant(bool): |
|
specify whether to use the activation checkpoint variant that |
|
requires reentrant autograd. This parameter should be passed |
|
explicitly. In version 2.5 we will raise an exception if |
|
``use_reentrant`` is not passed. If ``use_reentrant=False``, |
|
``checkpoint`` will use an implementation that does not require |
|
reentrant autograd. This allows ``checkpoint`` to support additional |
|
functionality, such as working as expected with |
|
``torch.autograd.grad`` and support for keyword arguments input into |
|
the checkpointed function. |
|
context_fn(Callable, optional): A callable returning a tuple of two |
|
context managers. The function and its recomputation will be run |
|
under the first and second context managers respectively. |
|
This argument is only supported if ``use_reentrant=False``. |
|
determinism_check(str, optional): A string specifying the determinism |
|
check to perform. By default it is set to ``"default"`` which |
|
compares the shapes, dtypes, and devices of the recomputed tensors |
|
against those the saved tensors. To turn off this check, specify |
|
``"none"``. Currently these are the only two supported values. |
|
Please open an issue if you would like to see more determinism |
|
checks. This argument is only supported if ``use_reentrant=False``, |
|
if ``use_reentrant=True``, the determinism check is always disabled. |
|
debug(bool, optional): If ``True``, error messages will also include |
|
a trace of the operators ran during the original forward computation |
|
as well as the recomputation. This argument is only supported if |
|
``use_reentrant=False``. |
|
args: tuple containing inputs to the :attr:`function` |
|
|
|
Returns: |
|
Output of running :attr:`function` on :attr:`*args` |
|
""" |
|
if use_reentrant is None: |
|
warnings.warn( |
|
"torch.utils.checkpoint: the use_reentrant parameter should be " |
|
"passed explicitly. In version 2.5 we will raise an exception " |
|
"if use_reentrant is not passed. use_reentrant=False is " |
|
"recommended, but if you need to preserve the current default " |
|
"behavior, you can pass use_reentrant=True. Refer to docs for more " |
|
"details on the differences between the two variants.", |
|
stacklevel=2 |
|
) |
|
use_reentrant = True |
|
|
|
|
|
preserve = kwargs.pop("preserve_rng_state", True) |
|
if kwargs and use_reentrant: |
|
raise ValueError( |
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) |
|
) |
|
|
|
if use_reentrant: |
|
if context_fn is not noop_context_fn or debug is not False: |
|
raise ValueError( |
|
"Passing `context_fn` or `debug` is only supported when " |
|
"use_reentrant=False." |
|
) |
|
return CheckpointFunction.apply(function, preserve, *args) |
|
else: |
|
gen = _checkpoint_without_reentrant_generator( |
|
function, preserve, context_fn, determinism_check, debug, *args, **kwargs |
|
) |
|
|
|
next(gen) |
|
ret = function(*args, **kwargs) |
|
|
|
try: |
|
next(gen) |
|
except StopIteration: |
|
return ret |
|
|
|
|
|
def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): |
|
r"""Checkpoint a sequential model to save memory. |
|
|
|
Sequential models execute a list of modules/functions in order |
|
(sequentially). Therefore, we can divide such a model in various segments |
|
and checkpoint each segment. All segments except the last will not store |
|
the intermediate activations. The inputs of each checkpointed segment will |
|
be saved for re-running the segment in the backward pass. |
|
|
|
.. warning:: |
|
The ``use_reentrant`` parameter should be passed explicitly. In version |
|
2.4 we will raise an exception if ``use_reentrant`` is not passed. |
|
If you are using the ``use_reentrant=True` variant, please see |
|
:func:`~torch.utils.checkpoint.checkpoint` for |
|
the important considerations and limitations of this variant. It is |
|
recommended that you use ``use_reentrant=False``. |
|
|
|
.. warning: |
|
Since PyTorch 1.4, it allows only one Tensor as the input and |
|
intermediate outputs, just like :class:`torch.nn.Sequential`. |
|
|
|
Args: |
|
functions: A :class:`torch.nn.Sequential` or the list of modules or |
|
functions (comprising the model) to run sequentially. |
|
segments: Number of chunks to create in the model |
|
input: A Tensor that is input to :attr:`functions` |
|
preserve_rng_state(bool, optional): Omit stashing and restoring |
|
the RNG state during each checkpoint. |
|
Default: ``True`` |
|
use_reentrant(bool): |
|
specify whether to use the activation checkpoint variant that |
|
requires reentrant autograd. This parameter should be passed |
|
explicitly. In version 2.5 we will raise an exception if |
|
``use_reentrant`` is not passed. If ``use_reentrant=False``, |
|
``checkpoint`` will use an implementation that does not require |
|
reentrant autograd. This allows ``checkpoint`` to support additional |
|
functionality, such as working as expected with |
|
``torch.autograd.grad`` and support for keyword arguments input into |
|
the checkpointed function. |
|
|
|
Returns: |
|
Output of running :attr:`functions` sequentially on :attr:`*inputs` |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP("stub") |
|
>>> model = nn.Sequential(...) |
|
>>> input_var = checkpoint_sequential(model, chunks, input_var) |
|
""" |
|
if use_reentrant is None: |
|
warnings.warn( |
|
"torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " |
|
"parameter should be passed explicitly. " |
|
"In version 2.5 we will raise an exception if use_reentrant " |
|
"is not passed. use_reentrant=False is " |
|
"recommended, but if you need to preserve the current default " |
|
"behavior, you can pass use_reentrant=True. Refer to docs for more " |
|
"details on the differences between the two variants." |
|
) |
|
use_reentrant = True |
|
|
|
|
|
preserve = kwargs.pop("preserve_rng_state", True) |
|
if kwargs: |
|
raise ValueError( |
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) |
|
) |
|
|
|
def run_function(start, end, functions): |
|
def forward(input): |
|
for j in range(start, end + 1): |
|
input = functions[j](input) |
|
return input |
|
|
|
return forward |
|
|
|
if isinstance(functions, torch.nn.Sequential): |
|
functions = list(functions.children()) |
|
|
|
segment_size = len(functions) // segments |
|
|
|
end = -1 |
|
for start in range(0, segment_size * (segments - 1), segment_size): |
|
end = start + segment_size - 1 |
|
input = checkpoint( |
|
run_function(start, end, functions), |
|
input, |
|
use_reentrant=use_reentrant, |
|
preserve_rng_state=preserve, |
|
) |
|
return run_function(end + 1, len(functions) - 1, functions)(input) |
|
|
|
|
|
def _internal_assert(cond): |
|
if not cond: |
|
raise AssertionError( |
|
"Something went unexpectedly wrong in activation checkpoint. " |
|
"Please report this bug by filing an issue to PyTorch." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_enable_checkpoint_early_stop = True |
|
|
|
|
|
@contextlib.contextmanager |
|
def set_checkpoint_early_stop(enable: bool): |
|
"""Context manager that sets whether checkpoint should stop recomputation early. |
|
|
|
By default, non-reentrant checkpoint stops recomputation as soon as it |
|
has computed all needed Tensors. This context manager can be used to disable |
|
that feature if it is problematic for your specific application. |
|
|
|
This context manager only needs to be active when forward is run. It does |
|
not need to be active during backward. |
|
|
|
Example:: |
|
|
|
>>> # xdoctest: +SKIP(failing) |
|
>>> message = "saved tensors default hooks are disabled" |
|
>>> with set_checkpoint_early_stop(False): |
|
... # Any checkpoint under this context manager will respect this |
|
... # context manager, even if its backward is performed outside. |
|
... out = checkpoint(fn, inputs) |
|
... |
|
>>> out.backward() |
|
""" |
|
global _enable_checkpoint_early_stop |
|
try: |
|
prev = _enable_checkpoint_early_stop |
|
_enable_checkpoint_early_stop = enable |
|
yield |
|
finally: |
|
_enable_checkpoint_early_stop = prev |
|
|
|
|
|
class _Handle: |
|
pass |
|
|
|
|
|
class _Holder: |
|
def __init__(self): |
|
self.handles: Dict[int, Optional[_Handle]] = {} |
|
|
|
|
|
class _NoopSaveInputs(torch.autograd.Function): |
|
@staticmethod |
|
def forward(*args): |
|
return torch.empty((0,)) |
|
|
|
@staticmethod |
|
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: |
|
|
|
|
|
tensor_indices, tensors = zip( |
|
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] |
|
) |
|
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} |
|
|
|
args = [None if isinstance(o, torch.Tensor) else o for o in inputs] |
|
|
|
def get_args(saved_tensors): |
|
|
|
|
|
|
|
|
|
ret = [ |
|
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o |
|
for i, o in enumerate(args) |
|
] |
|
|
|
|
|
return ret[1:] |
|
|
|
ctx.get_args = get_args |
|
ctx.save_for_backward(*tensors) |
|
|
|
@staticmethod |
|
def backward(ctx, *grad_outputs): |
|
raise AssertionError("Did not expect to backward on this graph") |
|
|
|
|
|
class _CheckpointFrame: |
|
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): |
|
self.recompute_fn = recompute_fn |
|
self.input_saver = None |
|
self.weak_holders: List[ReferenceType] = [] |
|
|
|
|
|
|
|
self.recomputed: DefaultDict[ |
|
int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] |
|
] = defaultdict(weakref.WeakKeyDictionary) |
|
|
|
|
|
self.recomp_counter: DefaultDict[int, int] = defaultdict(int) |
|
self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) |
|
|
|
|
|
self.early_stop = early_stop |
|
|
|
|
|
self.metadata_fn = metadata_fn |
|
self.unpack_error_cb = unpack_error_cb |
|
self.x_metadatas = [] |
|
self.forward_completed = False |
|
self.ignore_saved_mismatch = False |
|
|
|
def check_recomputed_tensors_match(self, gid): |
|
if self.ignore_saved_mismatch: |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not len(self.weak_holders) == self.recomp_counter[gid]: |
|
|
|
|
|
|
|
|
|
|
|
raise CheckpointError( |
|
"torch.utils.checkpoint: A different number of tensors was saved " |
|
"during the original forward and recomputation.\n" |
|
f"Number of tensors saved during forward: {len(self.weak_holders)}\n" |
|
f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}" |
|
) |
|
|
|
|
|
|
|
nb_meta_different = [] |
|
for idx, weak_holder in enumerate(self.weak_holders): |
|
holder = weak_holder() |
|
if holder is None: |
|
continue |
|
|
|
|
|
|
|
|
|
_internal_assert(gid in holder.handles) |
|
|
|
|
|
_internal_assert(holder.handles[gid] is not None) |
|
|
|
_internal_assert(holder.handles[gid] in self.recomputed[gid]) |
|
|
|
x_meta = self.x_metadatas[idx] |
|
recomputed_x = self.recomputed[gid][holder.handles[gid]] |
|
if x_meta != self.metadata_fn(recomputed_x): |
|
nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) |
|
|
|
if len(nb_meta_different) > 0: |
|
mismatched_tensors = "" |
|
for idx, x_meta, recomputed_meta in nb_meta_different: |
|
mismatched_tensors += ( |
|
f"tensor at position {idx}:\n" |
|
f"saved metadata: {x_meta}\n" |
|
f"recomputed metadata: {recomputed_meta}\n" |
|
) |
|
raise CheckpointError( |
|
"torch.utils.checkpoint: Recomputed values for the following tensors " |
|
"have different metadata than during the forward pass.\n" |
|
f"{mismatched_tensors}" |
|
) |
|
|
|
|
|
_checkpoint_error_template = """ \ |
|
An error happened while unpacking tensors; dumping logs of latest computation |
|
because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. |
|
Scroll all the way down for guidance on how to navigate these logs. |
|
|
|
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ |
|
| 1. Stack traces of the operators that ran in the original forward | |
|
+------------------------------------------------------------------------------+ |
|
|
|
{forward_traces} |
|
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ |
|
| 2. Stack traces of the operators that ran during recomputation | |
|
+------------------------------------------------------------------------------+ |
|
|
|
{recompute_traces} |
|
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ |
|
| 3. Log of operators in the original forward and recomputation | |
|
+------------------------------------------------------------------------------+ |
|
(Scroll up to correlate stack traces with each operation listed below. This |
|
helps identify their source in the code.) |
|
|
|
IMPORTANT: Differences in "detach" calls between the original forward and the |
|
recomputation are expected. They are introduced by the checkpointing |
|
mechanism and can be ignored. |
|
|
|
Operations executed during the original forward: |
|
|
|
{forward_ops} |
|
|
|
Operations executed during recomputation: |
|
|
|
{recompute_ops} |
|
|
|
+------------------------------------------------------------------------------+ |
|
ERROR: Detected non-determinism while running activation checkpointing |
|
|
|
You are seeing this error because you passed `debug=True` to checkpoint and |
|
tensors to be saved during the original forward and differ between those saved |
|
during recomputation. This can happen if different operators were ran in the |
|
original forward and in the recomputation. |
|
|
|
To identify where the mismatch may be coming from, you can do the following: |
|
|
|
1) Compare the operators ran during original forward and recomputation to |
|
see where they differ. These operators are printed above in the order they |
|
were executed. |
|
|
|
2) Review the stack trace for each operator to locate its invocation source. |
|
Each operator's stack trace is printed in their execution order. |
|
|
|
Note that the logs can be quite long. Here's how they are structured: |
|
(Tip: you can Ctrl-f for these headers) |
|
|
|
1. Stack traces of the operators that ran in the original forward |
|
2. Stack traces of the operators that ran during recomputation |
|
3. Log of operators in the original forward and recomputation |
|
4. Error message <--- You are here |
|
-------------------------------------------------------------------------------- |
|
""" |
|
|
|
class CheckpointError(RuntimeError): |
|
pass |
|
|
|
|
|
def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: |
|
|
|
|
|
|
|
|
|
|
|
cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' |
|
|
|
class CaptureLogs: |
|
def __init__(self): |
|
self.logs = None |
|
self.tbs = None |
|
|
|
def get_context_manager(self): |
|
@contextlib.contextmanager |
|
def logging_mode(): |
|
with LoggingTensorMode(), \ |
|
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: |
|
self.logs, self.tbs = logs_and_tb |
|
yield logs_and_tb |
|
return logging_mode() |
|
|
|
capture_logs_fwd = CaptureLogs() |
|
capture_logs_recompute = CaptureLogs() |
|
|
|
def unpack_error_cb(e: CheckpointError): |
|
def get_str_tb(label, capture_logs): |
|
out = "" |
|
total_len = len(capture_logs.logs) |
|
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): |
|
out += f"{log} ({i + 1} of {total_len} in {label})\n\n" |
|
found_torch_dispatch = False |
|
for line in tb: |
|
|
|
is_torch_dispatch = line['name'] == '__torch_dispatch__' |
|
if not found_torch_dispatch and not is_torch_dispatch: |
|
continue |
|
elif is_torch_dispatch: |
|
found_torch_dispatch = True |
|
continue |
|
out += f"{line['filename']}:{line['line']}:{line['name']}\n" |
|
out += "\n\n" |
|
return out |
|
assert capture_logs_fwd.logs is not None |
|
assert capture_logs_recompute.logs is not None |
|
raise CheckpointError( |
|
_checkpoint_error_template.format( |
|
forward_traces=get_str_tb("original", capture_logs_fwd), |
|
recompute_traces=get_str_tb("recompute", capture_logs_recompute), |
|
forward_ops="\n".join(capture_logs_fwd.logs), |
|
recompute_ops="\n".join(capture_logs_recompute.logs) |
|
) |
|
) from e |
|
|
|
def context_fn(): |
|
return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() |
|
|
|
return context_fn, unpack_error_cb |
|
|
|
def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: |
|
|
|
return { |
|
"shape": x.shape, |
|
"dtype": x.dtype, |
|
"device": x.device |
|
} |
|
|
|
_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { |
|
_DEFAULT_DETERMINISM_MODE: _default_meta_extractor, |
|
"none": lambda _: None, |
|
} |
|
|
|
|
|
class _StopRecomputationError(Exception): |
|
pass |
|
|
|
|
|
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): |
|
def __init__(self, target_frame_ref: ReferenceType, gid: int): |
|
def pack_hook(x): |
|
x = x.detach() if x.requires_grad else x |
|
target_frame = target_frame_ref() |
|
assert target_frame is not None |
|
recomp_idx = target_frame.recomp_counter[gid] |
|
target_frame.recomp_counter[gid] += 1 |
|
|
|
if recomp_idx >= len(target_frame.weak_holders): |
|
assert not target_frame.early_stop |
|
if not target_frame.forward_completed: |
|
|
|
|
|
|
|
|
|
|
|
target_frame.ignore_saved_mismatch = True |
|
return x |
|
raise CheckpointError( |
|
"torch.utils.checkpoint: trying to save more tensors during " |
|
"recomputation than during the original forward pass." |
|
) |
|
|
|
holder = target_frame.weak_holders[recomp_idx]() |
|
|
|
|
|
|
|
if holder is not None: |
|
_internal_assert(holder.handles.get(gid, None) is None) |
|
holder.handles[gid] = _Handle() |
|
target_frame.recomputed[gid][holder.handles[gid]] = x |
|
|
|
if target_frame.early_stop and target_frame.recomp_counter[gid] == len( |
|
target_frame.weak_holders |
|
): |
|
raise _StopRecomputationError |
|
|
|
return x |
|
|
|
def unpack_hook(x): |
|
|
|
|
|
return x |
|
|
|
super().__init__(pack_hook, unpack_hook) |
|
|
|
|
|
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): |
|
def __init__(self, frame): |
|
def pack_hook(x): |
|
|
|
holder = _Holder() |
|
frame.weak_holders.append(weakref.ref(holder)) |
|
|
|
if frame.metadata_fn is not None: |
|
with torch.no_grad(): |
|
frame.x_metadatas.append(frame.metadata_fn(x)) |
|
return holder |
|
|
|
def unpack_hook(holder): |
|
gid = torch._C._current_graph_task_id() |
|
if gid == -1: |
|
|
|
gid = int(uuid.uuid4()) |
|
|
|
if not frame.is_recomputed[gid]: |
|
ctx = frame.input_saver.grad_fn |
|
args = ctx.get_args(ctx.saved_tensors) |
|
|
|
try: |
|
with _recomputation_hook( |
|
weakref.ref(frame), gid |
|
), torch.autograd.enable_grad(): |
|
frame.recompute_fn(*args) |
|
except _StopRecomputationError: |
|
pass |
|
frame.is_recomputed[gid] = True |
|
frame.check_recomputed_tensors_match(gid) |
|
|
|
_internal_assert(gid in holder.handles) |
|
|
|
if holder.handles[gid] is None: |
|
raise CheckpointError( |
|
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " |
|
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " |
|
"so only once. Otherwise please open an issue with details on your use case." |
|
) |
|
_internal_assert(holder.handles[gid] in frame.recomputed[gid]) |
|
ret = frame.recomputed[gid][holder.handles[gid]] |
|
holder.handles[gid] = None |
|
return ret |
|
|
|
if frame.unpack_error_cb is not None: |
|
def unpack_hook_with_error_cb(holder): |
|
try: |
|
return unpack_hook(holder) |
|
except CheckpointError as e: |
|
frame.unpack_error_cb(e) |
|
super().__init__(pack_hook, unpack_hook_with_error_cb) |
|
else: |
|
super().__init__(pack_hook, unpack_hook) |
|
|
|
|
|
def _is_compiling(func, args, kwargs): |
|
|
|
|
|
|
|
for arg in args: |
|
if isinstance(arg, torch.Tensor) and is_fun(arg): |
|
return True |
|
return False |
|
|
|
|
|
class _VersionWrapper: |
|
|
|
def __init__(self, val): |
|
self.val: Union[torch.Tensor, Any] = val |
|
self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None |
|
|
|
def get_val(self, allow_cache_entry_mutation): |
|
if self.version is not None and not allow_cache_entry_mutation: |
|
if self.val._version != self.version: |
|
|
|
raise RuntimeError( |
|
"Tensor cached during selective activation checkpoint has been mutated" |
|
) |
|
return self.val |
|
|
|
|
|
def _maybe_detach(x, any_ret_has_alias_info): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): |
|
with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): |
|
|
|
|
|
|
|
|
|
|
|
x = x.detach() |
|
return x |
|
|
|
|
|
class SelectiveCheckpointContext: |
|
""" |
|
Context passed to policy function during selective checkpointing. |
|
|
|
This class is used to pass relevant metadata to the policy function during |
|
selective checkpointing. The metadata includes whether the current invocation |
|
of the policy function is during recomputation or not. |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP(stub) |
|
>>> |
|
>>> def policy_fn(ctx, op, *args, **kwargs): |
|
>>> print(ctx.is_recompute) |
|
>>> |
|
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) |
|
>>> |
|
>>> out = torch.utils.checkpoint.checkpoint( |
|
>>> fn, x, y, |
|
>>> use_reentrant=False, |
|
>>> context_fn=context_fn, |
|
>>> ) |
|
""" |
|
def __init__(self, *, is_recompute): |
|
self.is_recompute = is_recompute |
|
|
|
|
|
class CheckpointPolicy(enum.Enum): |
|
""" |
|
Enum for specifying the policy for checkpointing during backpropagation. |
|
|
|
The following policies are supported: |
|
|
|
- ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward |
|
pass and will not be recomputed during the backward pass |
|
- ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the |
|
forward pass and will be recomputed during the backward pass |
|
|
|
Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden |
|
by other subsystems like `torch.compile`. |
|
|
|
.. note:: |
|
A policy function that always returns ``PREFER_RECOMPUTE`` is |
|
equivalent to vanilla checkpointing. |
|
|
|
A policy function that returns ``PREFER_SAVE`` every op is |
|
NOT equivalent to not using checkpointing. Using such a policy would |
|
save additional tensors not limited to ones that are actually needed for |
|
gradient computation. |
|
""" |
|
MUST_SAVE = 0 |
|
PREFER_SAVE = 1 |
|
MUST_RECOMPUTE = 2 |
|
PREFER_RECOMPUTE = 3 |
|
|
|
|
|
def _policy_from_bool(b): |
|
|
|
return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE |
|
|
|
|
|
SAC_IGNORED_OPS = { |
|
|
|
torch.ops.aten.detach.default, |
|
|
|
|
|
|
|
torch.ops.prim.device.default, |
|
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) |
|
|
|
|
|
class _CachingTorchDispatchMode(TorchDispatchMode): |
|
|
|
def __init__(self, policy_fn, storage): |
|
self.policy_fn = policy_fn |
|
self.storage = storage |
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
if func in SAC_IGNORED_OPS: |
|
return func(*args, **kwargs) |
|
|
|
kwargs = {} if kwargs is None else kwargs |
|
policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), |
|
func, *args, **kwargs) |
|
if isinstance(policy, bool): |
|
policy = _policy_from_bool(policy) |
|
|
|
is_compiling = _is_compiling(func, args, kwargs) |
|
|
|
if is_compiling: |
|
|
|
fx_traceback.current_meta["recompute"] = policy |
|
|
|
out = func(*args, **kwargs) |
|
|
|
any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) |
|
|
|
if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: |
|
self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) |
|
return out |
|
|
|
class _CachedTorchDispatchMode(TorchDispatchMode): |
|
|
|
def __init__(self, policy_fn, storage, allow_cache_entry_mutation): |
|
self.policy_fn = policy_fn |
|
self.storage = storage |
|
self.allow_cache_entry_mutation = allow_cache_entry_mutation |
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
if func in SAC_IGNORED_OPS: |
|
return func(*args, **kwargs) |
|
|
|
kwargs = {} if kwargs is None else kwargs |
|
policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), |
|
func, *args, **kwargs) |
|
if isinstance(policy, bool): |
|
policy = _policy_from_bool(policy) |
|
|
|
is_compiling = _is_compiling(func, args, kwargs) |
|
|
|
if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: |
|
storage = self.storage.get(func) |
|
if storage is None: |
|
raise RuntimeError(f"{func} encountered during backward, but not found in storage") |
|
if len(storage) == 0: |
|
raise RuntimeError( |
|
"Trying to backward an extra time. You are only allowed to backward once " |
|
"on any region computed under selective activation checkpoint." |
|
) |
|
out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) |
|
else: |
|
out = func(*args, **kwargs) |
|
return out |
|
|
|
|
|
def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): |
|
""" |
|
Helper to avoid recomputing certain ops during activation checkpointing. |
|
|
|
Use this with `torch.utils.checkpoint.checkpoint` to control which |
|
operations are recomputed during the backward pass. |
|
|
|
Args: |
|
policy_fn_or_list (Callable or List): |
|
- If a policy function is provided, it should accept a |
|
:class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and |
|
kwargs to the op, and return a :class:`CheckpointPolicy` enum value |
|
indicating whether the execution of the op should be recomputed or not. |
|
- If a list of operations is provided, it is equivalent to a policy |
|
returning `CheckpointPolicy.MUST_SAVE` for the specified |
|
operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other |
|
operations. |
|
allow_cache_entry_mutation (bool, optional): By default, an error is |
|
raised if any tensors cached by selective activation checkpoint are |
|
mutated in order to ensure correctness. If set to `True`, this check |
|
is disabled. |
|
Returns: |
|
A tuple of two context managers. |
|
|
|
Example: |
|
>>> # xdoctest: +REQUIRES(LINUX) |
|
>>> import functools |
|
>>> |
|
>>> x = torch.rand(10, 10, requires_grad=True) |
|
>>> y = torch.rand(10, 10, requires_grad=True) |
|
>>> |
|
>>> ops_to_save = [ |
|
>>> torch.ops.aten.mm.default, |
|
>>> ] |
|
>>> |
|
>>> def policy_fn(ctx, op, *args, **kwargs): |
|
>>> if op in ops_to_save: |
|
>>> return CheckpointPolicy.MUST_SAVE |
|
>>> else: |
|
>>> return CheckpointPolicy.PREFER_RECOMPUTE |
|
>>> |
|
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) |
|
>>> |
|
>>> # or equivalently |
|
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) |
|
>>> |
|
>>> def fn(x, y): |
|
>>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y |
|
>>> |
|
>>> out = torch.utils.checkpoint.checkpoint( |
|
>>> fn, x, y, |
|
>>> use_reentrant=False, |
|
>>> context_fn=context_fn, |
|
>>> ) |
|
""" |
|
|
|
|
|
if isinstance(policy_fn_or_list, list): |
|
for op in policy_fn_or_list: |
|
if not isinstance(op, torch._ops.OpOverload): |
|
_extra_msg = ( |
|
"Please update the OpOverloadPacket to a specific OpOverload." |
|
"For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." |
|
) if isinstance(op, torch._ops.OpOverloadPacket) else "" |
|
raise ValueError( |
|
f"Expected op in `op_list` to be an OpOverload but got: {op} " |
|
f"of type {type(op)}. {_extra_msg}" |
|
) |
|
|
|
def policy_fn(ctx, op, *args, **kwargs): |
|
if op in policy_fn_or_list: |
|
return CheckpointPolicy.MUST_SAVE |
|
else: |
|
return CheckpointPolicy.PREFER_RECOMPUTE |
|
elif callable(policy_fn_or_list): |
|
policy_fn = policy_fn_or_list |
|
else: |
|
raise TypeError("policy_fn_or_list must be either a function or a list of ops.") |
|
|
|
storage: Dict[Any, List[Any]] = defaultdict(list) |
|
return ( |
|
_CachingTorchDispatchMode(policy_fn, storage), |
|
_CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), |
|
) |
|
|
|
|
|
|
|
|
|
def _checkpoint_without_reentrant_generator( |
|
fn, |
|
preserve_rng_state=True, |
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, |
|
determinism_check: str = _DEFAULT_DETERMINISM_MODE, |
|
debug: bool = False, |
|
*args, |
|
**kwargs |
|
): |
|
"""Checkpointing without reentrant autograd. |
|
|
|
Args: |
|
fn: describes what to run in the forward pass of the model or |
|
part of the model. It should also know how to handle the inputs |
|
passed as the tuple. For example, in LSTM, if user passes |
|
``(activation, hidden)``, :attr:`function` should correctly use the |
|
first input as ``activation`` and the second input as ``hidden`` |
|
preserve_rng_state(bool, optional): Omit stashing and restoring |
|
the RNG state during each checkpoint. |
|
Default: ``True`` |
|
context_fn(Callable, optional): A callable returning a tuple of two |
|
context managers. The function and its recomputation will be run |
|
under the first and second context managers respectively. |
|
determinism_check(str, optional): A string specifying the determinism |
|
check to perform. By default it is set to ``"default"`` which |
|
compares the shapes, dtypes, and devices of the recomputed tensors |
|
against those the saved tensors. To turn off this check, specify |
|
``"none"``. Currently these are the only two supported values. |
|
Please open an issue if you would like to see more determinism |
|
checks. |
|
debug(bool, optional): If ``True``, error messages will also include |
|
a trace of the operators ran during the original forward computation |
|
as well as the recomputation. |
|
*args: Arguments to pass in to the given ``function``. |
|
**kwargs: Keyword arguments to pass into the given ``function``. |
|
""" |
|
unpack_error_cb = None |
|
|
|
if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: |
|
if context_fn != noop_context_fn: |
|
raise ValueError( |
|
"debug=True is incompatible with non-default context_fn" |
|
) |
|
context_fn, unpack_error_cb = _get_debug_context_and_cb() |
|
|
|
if determinism_check in _allowed_determinism_checks_to_fns: |
|
metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] |
|
else: |
|
raise ValueError( |
|
f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " |
|
f"but got {determinism_check}" |
|
) |
|
|
|
device_type = _infer_device_type(*args) |
|
device_module = _get_device_module(device_type) |
|
forward_context, recompute_context = context_fn() |
|
if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: |
|
assert ( |
|
isinstance(forward_context, TorchDispatchMode) and |
|
isinstance(recompute_context, TorchDispatchMode) |
|
), \ |
|
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \ |
|
"must generate a tuple of two `TorchDispatchMode`s." |
|
|
|
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type) |
|
|
|
if preserve_rng_state: |
|
fwd_cpu_state = torch.get_rng_state() |
|
|
|
|
|
|
|
|
|
|
|
had_device_in_fwd = False |
|
if getattr(device_module, "_initialized", False): |
|
had_device_in_fwd = True |
|
fwd_devices, fwd_device_states = get_device_states(*args) |
|
|
|
def recompute_fn(*inputs): |
|
kwargs, *args = inputs |
|
|
|
|
|
rng_devices = [] |
|
if preserve_rng_state and had_device_in_fwd: |
|
rng_devices = fwd_devices |
|
with torch.random.fork_rng( |
|
devices=rng_devices, enabled=preserve_rng_state, device_type=device_type |
|
): |
|
if preserve_rng_state: |
|
torch.set_rng_state(fwd_cpu_state) |
|
if had_device_in_fwd: |
|
set_device_states(fwd_devices, fwd_device_states, device_type=device_type) |
|
|
|
device_autocast_ctx = torch.amp.autocast( |
|
device_type=device_type, **device_autocast_kwargs |
|
) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() |
|
with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: |
|
fn(*args, **kwargs) |
|
|
|
new_frame = _CheckpointFrame( |
|
recompute_fn, |
|
_enable_checkpoint_early_stop, |
|
unpack_error_cb, |
|
metadata_fn |
|
) |
|
dummy = torch.empty((0,), requires_grad=True) |
|
new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) |
|
|
|
|
|
if new_frame.input_saver.grad_fn is None: |
|
yield |
|
return |
|
|
|
with _checkpoint_hook(new_frame), forward_context: |
|
yield |
|
new_frame.forward_completed = True |
|
|
|
if getattr(device_module, "_initialized", False) and \ |
|
preserve_rng_state and not had_device_in_fwd: |
|
|
|
|
|
raise RuntimeError( |
|
"PyTorch's device state was initialized in the forward pass " |
|
"of a Checkpoint, which is not allowed. Please open an issue " |
|
"if you need this feature." |
|
) |
|
|
|
return |
|
|