|
|
|
|
|
|
|
|
|
|
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
import functools |
|
import threading |
|
from typing import Any, Dict, Generator, Optional, Tuple |
|
import weakref |
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
import torch.utils.checkpoint as torch_checkpoint |
|
|
|
from fairscale.internal.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors |
|
|
|
from .checkpoint_utils import patch_batchnorm |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class ThreadLocalCheckpointingState(threading.local): |
|
is_checkpointing: bool = False |
|
is_recomputing: bool = False |
|
is_checkpointing_disabled: bool = False |
|
|
|
|
|
thread_local = ThreadLocalCheckpointingState() |
|
|
|
|
|
@contextmanager |
|
def disable_checkpointing() -> Generator[None, None, None]: |
|
"""Makes :func:`is_checkpointing_disabled` return :data:`True` within a context.""" |
|
orig = thread_local.is_checkpointing_disabled |
|
thread_local.is_checkpointing_disabled = True |
|
try: |
|
yield |
|
finally: |
|
thread_local.is_checkpointing_disabled = orig |
|
|
|
|
|
@contextmanager |
|
def enable_checkpointing() -> Generator[None, None, None]: |
|
"""Makes :func:`is_checkpointing` return :data:`True` within a context.""" |
|
orig = thread_local.is_checkpointing |
|
thread_local.is_checkpointing = True |
|
try: |
|
yield |
|
finally: |
|
thread_local.is_checkpointing = orig |
|
|
|
|
|
@contextmanager |
|
def enable_recomputing() -> Generator[None, None, None]: |
|
"""Makes :func:`is_recomputing` return :data:`True` within a context.""" |
|
orig = thread_local.is_recomputing |
|
thread_local.is_recomputing = True |
|
try: |
|
yield |
|
finally: |
|
thread_local.is_recomputing = orig |
|
|
|
|
|
def is_checkpointing() -> bool: |
|
"""Whether the current forward propagation is under checkpointing. |
|
|
|
Returns: |
|
bool: :data:`True` if it's under checkpointing. |
|
|
|
""" |
|
return thread_local.is_checkpointing |
|
|
|
|
|
def is_recomputing() -> bool: |
|
"""Whether the current forward propagation is under checkpoint |
|
recomputation. Use this to prevent duplicated side-effects at forward |
|
propagation:: |
|
|
|
class Counter(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.counter = 0 |
|
|
|
def forward(self, input): |
|
if not is_recomputing(): |
|
self.counter += 1 |
|
return input |
|
|
|
Returns: |
|
bool: :data:`True` if it's under checkpoint recomputation. |
|
""" |
|
return thread_local.is_recomputing |
|
|
|
|
|
def checkpoint_wrapper( |
|
module: nn.Module, |
|
offload_to_cpu: bool = False, |
|
) -> nn.Module: |
|
""" |
|
A friendlier wrapper for performing activation checkpointing. |
|
|
|
Compared to the PyTorch version, this version: |
|
|
|
- wraps an nn.Module, so that all subsequent calls will use checkpointing |
|
- handles keyword arguments in the forward |
|
- handles non-Tensor outputs from the forward |
|
- supports offloading activations to CPU |
|
|
|
Usage:: |
|
|
|
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) |
|
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) |
|
|
|
To understand the benefits of checkpointing and the `offload_to_cpu` flag, |
|
let's divide activations into 2 types: inner activations and outer |
|
activations w.r.t. the checkpointed modules. The inner ones are saved |
|
by activation checkpointing, the outer ones are saved by offload_to_cpu. |
|
|
|
In terms of GPU memory savings: |
|
|
|
- When inner ones are large in size and outer ones are small, |
|
checkpointing helps a lot, offload_to_cpu may help a little. |
|
- When inner ones are small and outer ones are large, |
|
checkpointing helps little, offload_to_cpu helps a lot. |
|
- When both inner and outer are large, both help and the |
|
benefit is additive. |
|
|
|
..Note:: |
|
|
|
The first and last layers are not likely to benefit from the `offload_to_cpu` flag |
|
because (1) there are typically other references to the first layer's input, so |
|
the GPU memory won't be freed; (2) the input to the last layer is immediately |
|
used by the backward pass and won't result in memory savings. |
|
|
|
Args: |
|
module (nn.Module): |
|
The module to be wrapped |
|
offload_to_cpu (bool): |
|
Whether to offload activations to CPU. |
|
|
|
Returns: |
|
(nn.Module): |
|
Wrapped module |
|
""" |
|
|
|
patch_batchnorm(module) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
module.forward = functools.partial( |
|
_checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu |
|
) |
|
return module |
|
|
|
|
|
def _checkpointed_forward( |
|
original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any |
|
) -> Any: |
|
module = weak_self() |
|
|
|
|
|
if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled: |
|
return original_forward(module, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
args = (module,) + args |
|
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) |
|
parent_ctx_dict: Dict[str, Any] = { |
|
"offload": offload_to_cpu, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = CheckpointFunction.apply( |
|
torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args |
|
) |
|
output_requires_grad = parent_ctx_dict["output_requires_grad"] |
|
if not isinstance(output, torch.Tensor): |
|
|
|
|
|
|
|
output = [x.detach() if not output_requires_grad else x for x in output] |
|
|
|
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] |
|
if packed_non_tensor_outputs: |
|
output = unpack_non_tensors(output, packed_non_tensor_outputs) |
|
|
|
else: |
|
|
|
|
|
|
|
if not output_requires_grad: |
|
output = output.detach() |
|
|
|
return output |
|
|
|
|
|
def get_rng_state() -> Dict[str, Any]: |
|
state = {"torch_rng_state": torch.get_rng_state()} |
|
if torch.cuda.is_available(): |
|
state["cuda_rng_state"] = torch.cuda.get_rng_state() |
|
return state |
|
|
|
|
|
def set_rng_state(state: Dict[str, Any]) -> None: |
|
torch.set_rng_state(state["torch_rng_state"]) |
|
if torch.cuda.is_available(): |
|
torch.cuda.set_rng_state(state["cuda_rng_state"]) |
|
|
|
|
|
def is_autocast_enabled() -> bool: |
|
"""Similar to torch.is_autocast_enabled, but compatible with torch 1.5.1""" |
|
if hasattr(torch, "is_autocast_enabled"): |
|
return torch.is_autocast_enabled() |
|
return False |
|
|
|
|
|
@contextmanager |
|
def autocast(enabled: bool) -> Generator: |
|
"""Similar to torch.cuda.amp.autocast, but compatible with torch 1.5.1""" |
|
if enabled: |
|
with torch.cuda.amp.autocast(enabled): |
|
yield |
|
else: |
|
yield |
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
"""Similar to the torch version, but support non-Tensor outputs. |
|
|
|
The caller is expected to provide a dict (*parent_ctx_dict*) that will hold |
|
the non-Tensor outputs. These should be combined with the Tensor *outputs* |
|
by calling :func:`unpack_non_tensors`. |
|
""" |
|
|
|
@staticmethod |
|
def forward( |
|
ctx: Any, |
|
dummy_tensor_requires_grad: torch.Tensor, |
|
run_function: Any, |
|
parent_ctx_dict: Dict[str, Any], |
|
kwarg_keys: Tuple[str, ...], |
|
*args: Any, |
|
**kwargs: Any |
|
) -> Any: |
|
torch_checkpoint.check_backward_validity(args) |
|
|
|
ctx.run_function = run_function |
|
ctx.kwarg_keys = kwarg_keys |
|
ctx.fwd_rng_state = get_rng_state() |
|
ctx.had_autocast_in_fwd = is_autocast_enabled() |
|
|
|
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) |
|
if parent_ctx_dict["offload"]: |
|
ctx.fwd_device = tuple(x.device for x in tensor_inputs) |
|
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) |
|
tensor_inputs = tuple(x.to("cpu", non_blocking=True) for x in tensor_inputs) |
|
else: |
|
ctx.fwd_device, ctx.grad_requirements = None, None |
|
|
|
ctx.save_for_backward(*tensor_inputs) |
|
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs |
|
|
|
with torch.no_grad(), enable_checkpointing(): |
|
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) |
|
outputs = run_function(*unpacked_args, **unpacked_kwargs) |
|
the_module = unpacked_args[0] |
|
|
|
|
|
|
|
|
|
parameters = list(the_module.parameters()) |
|
|
|
|
|
|
|
|
|
if hasattr(the_module, "_unflattened_param_views"): |
|
parameters += the_module._unflattened_param_views |
|
|
|
output_requires_grad = any(param.requires_grad for param in parameters) or any( |
|
x.requires_grad for x in tensor_inputs |
|
) |
|
parent_ctx_dict["output_requires_grad"] = output_requires_grad |
|
|
|
if not isinstance(outputs, torch.Tensor): |
|
|
|
|
|
|
|
outputs, packed_non_tensor_outputs = split_non_tensors(outputs) |
|
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs |
|
|
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: |
|
if not torch.autograd._is_checkpoint_valid(): |
|
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") |
|
|
|
tensor_inputs: Tuple = ctx.saved_tensors |
|
tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs) |
|
if ctx.fwd_device is not None: |
|
tensor_inputs = tuple(t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs)) |
|
for i, need_grad in enumerate(ctx.grad_requirements): |
|
tensor_inputs[i].requires_grad = need_grad |
|
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) |
|
|
|
|
|
bwd_rng_state = get_rng_state() |
|
|
|
|
|
set_rng_state(ctx.fwd_rng_state) |
|
|
|
with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd): |
|
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) |
|
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) |
|
tensor_outputs, _ = split_non_tensors(outputs) |
|
|
|
|
|
set_rng_state(bwd_rng_state) |
|
|
|
|
|
outputs_with_grad = [] |
|
args_with_grad = [] |
|
for i in range(len(tensor_outputs)): |
|
if tensor_outputs[i].requires_grad: |
|
outputs_with_grad.append(tensor_outputs[i]) |
|
args_with_grad.append(args[i]) |
|
|
|
if len(outputs_with_grad) == 0: |
|
raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") |
|
|
|
torch.autograd.backward(outputs_with_grad, args_with_grad) |
|
|
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) |
|
|
|
return (None, None, None, None) + grads |
|
|