Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of https://github.com/pytorch/torchtune. | |
import warnings | |
import psutil | |
import torch | |
from torch import nn | |
from torch.autograd.graph import saved_tensors_hooks | |
class OffloadActivations(saved_tensors_hooks): | |
""" | |
Context manager under which activation tensors created in the forward pass will be offloaded. | |
Enable the memory efficiency technique of activation offloading, where activations bigger than `min_offload_size` | |
bytes will be offloaded to CPU in the forward and brought back in the backward. This is in contrast to maintaining | |
the activation on GPU VRAM throughout the program. | |
This manager contains the option of using one additional CUDA stream to handle the communication between CUDA and | |
CPU, which is intended to overlap with the default computation stream to improve runtime. We designed | |
synchronization with a few heuristics for optimizing the tradeoff between runtime vs memory usage. | |
Args: | |
use_pin_memory (`bool`, *optional*, defaults to `True`): | |
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to | |
be moved back onto GPU more quickly but is a limited resource. | |
use_streams (`bool`, *optional*, defaults to `True`): | |
Whether to use streams for performance optimization where the communications get overlapped with the | |
computation. Requires a torch build after torch-2.5.0. | |
min_offload_size (`int`, *optional*, defaults to `1024`): | |
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we | |
do not want to waste bandwidth and resources moving it to CPU and back. | |
max_fwd_stash_size (`int`, *optional*, defaults to `5`): | |
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during | |
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow | |
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping | |
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing | |
runtime. | |
Raises: | |
ValueError: if `max_fwd_stash_size` is not at least `1`. | |
Example: | |
>>> with OffloadActivations(): | |
>>> outputs = model(inputs, labels=labels) | |
>>> loss = outputs.loss | |
>>> loss.backward() | |
""" | |
def __init__( | |
self, | |
use_pin_memory: bool = True, | |
use_streams: bool = True, | |
min_offload_size: int = 1024, | |
max_fwd_stash_size: int = 5, | |
) -> None: | |
self.use_streams = use_streams | |
self.min_tensor_size_bytes = min_offload_size # we don't want to bother with small tensors | |
self.tracker = {} # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where | |
self.tensor_id = 0 | |
self.is_first_forward_call = True | |
self.is_first_backward_call = True | |
self.is_first_forward_pass = True | |
# Managing cpu memory | |
self.use_pin_memory = use_pin_memory | |
self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory | |
self.accelerator_type = ( | |
torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" | |
) | |
# NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead | |
self.s0 = ( | |
torch.xpu.current_stream() if self.accelerator_type == "xpu" else torch.cuda.default_stream() | |
) # comp stream | |
# For streaming | |
if self.use_streams: | |
self.s1 = torch.Stream() if self.accelerator_type == "xpu" else torch.cuda.Stream() # comms stream | |
self.fwd_stash = {} # tensor_id => (activation, ev1) | |
if max_fwd_stash_size < 1: | |
raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}") | |
self.max_fwd_stash_size = max_fwd_stash_size | |
self.bwd_tensor_stash = {} # tensor_id => activation | |
self.bwd_ev_stash = {} # tensor_id => ev0 | |
self.curr_graph_id = None | |
self.curr_autograd_node = None | |
# -------- platform util functions -------- # | |
def verify_sufficient_virtual_memory(): | |
curr_pct = get_cpu_ram_pct() | |
if curr_pct > self.virtual_memory_safe_pct: | |
warnings.warn(f"{curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used") | |
def get_cpu_ram_pct() -> float: | |
# get the percentage of memory used by the system | |
return psutil.virtual_memory().percent | |
def get_tensor_id() -> int: | |
# create a unique id for each tensor we are managing | |
self.tensor_id += 1 | |
return self.tensor_id | |
def get_num_bytes_tensor(x: torch.Tensor) -> int: | |
# get the number of bytes in a tensor, for memory management purposes | |
return x.element_size() * x.nelement() # x.element_size() * x._base_storage().nbytes() | |
# -------- core pack / unpack work -------- # | |
def pack_tensor(activation: torch.Tensor) -> int: | |
# activations are passed in during forward pass - from here we take over and return a unique id | |
if self.is_first_forward_call: | |
if len(self.tracker) != 0: | |
raise ValueError("Backward pass should have cleared tracker of all tensors") | |
# set training phase trackers | |
self.is_first_forward_call = False | |
self.is_first_backward_call = True | |
# query for basic tensor info | |
num_bytes = get_num_bytes_tensor(activation) | |
tensor_id = get_tensor_id() | |
# only offload hefty bois if they're activations on CUDA (our heuristic | |
# for that is to check if they're not params or buffers)! | |
if ( | |
activation.device.type in ["cuda", "xpu"] | |
and num_bytes >= self.min_tensor_size_bytes | |
and ( | |
not isinstance(activation, torch.nn.Parameter) | |
and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)) | |
) | |
): | |
if self.use_streams: | |
# First, sync back and dereference previously offloaded tensors | |
# as the offloading should be done sufficiently long ago. | |
for id in list(self.fwd_stash.keys()): | |
if id <= tensor_id - self.max_fwd_stash_size: | |
_, ev = self.fwd_stash[id] | |
self.s0.wait_event(ev) | |
del self.fwd_stash[id] | |
else: | |
break | |
# Sync in, offload, and add an event to sync back later | |
self.s1.wait_stream(self.s0) | |
stream = self.s1 if self.use_streams else self.s0 | |
with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream): | |
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") | |
cpu_tensor.copy_(activation, non_blocking=True) | |
self.tracker[tensor_id] = ( | |
cpu_tensor, | |
True, # True = (in future) modified | |
) | |
if self.use_streams: | |
event = self.s1.record_event() | |
# Stash to keep activation alive til s1 is done | |
self.fwd_stash[tensor_id] = (activation, event) | |
else: | |
self.tracker[tensor_id] = ( | |
activation, | |
False, | |
) # False = not modified, tensor is as is | |
return tensor_id | |
def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: | |
# backward pass - we are called with the tensor_id, which | |
# we will use to retrieve the saved/offloaded tensor | |
if self.is_first_backward_call: | |
if self.is_first_forward_pass: | |
self.is_first_forward_pass = False | |
if self.use_pin_memory: | |
verify_sufficient_virtual_memory() | |
self.is_first_backward_call = False | |
self.is_first_forward_call = True | |
if unpack_tensor_id not in self.tracker: | |
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}") | |
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] | |
if modified: | |
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) | |
maybe_accelerator_tensor = accelerator_tensor | |
# clear tensor from tracking | |
del self.tracker[unpack_tensor_id] | |
return maybe_accelerator_tensor | |
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: | |
# backward pass - we are called with the tensor_id, which | |
# we will use to retrieve the saved/offloaded tensor | |
if self.is_first_backward_call: | |
self.curr_graph_id = torch._C._current_graph_task_id() | |
def wait_and_del_remaining_references() -> None: | |
for id in list(self.bwd_tensor_stash.keys()): | |
event = self.bwd_ev_stash[id] | |
self.s1.wait_event(event) | |
del self.bwd_tensor_stash[id] | |
# Register a callback to the end of autograd to clean everything up | |
torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references) | |
if self.is_first_forward_pass: | |
self.is_first_forward_pass = False | |
if self.use_pin_memory: | |
verify_sufficient_virtual_memory() | |
self.is_first_backward_call = False | |
self.is_first_forward_call = True | |
if unpack_tensor_id not in self.tracker: | |
raise ValueError(f"untracked tensor with id {unpack_tensor_id}") | |
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] | |
if modified: | |
# Get data on the current autograd node | |
graph_id = torch._C._current_graph_task_id() | |
node = torch._C._current_autograd_node() | |
prev_node_ids = [] | |
# If we're on a new node, mark prev node's tensors to be freed later | |
if graph_id == self.curr_graph_id and self.curr_autograd_node != node: | |
self.curr_autograd_node = node | |
prev_node_ids = list(self.bwd_tensor_stash.keys()) | |
brought_back_from_cpu = True | |
if unpack_tensor_id in self.fwd_stash: | |
maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0] | |
brought_back_from_cpu = False | |
else: | |
# Kick off the process to bring tensors back | |
with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1): | |
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) | |
maybe_accelerator_tensor = accelerator_tensor | |
# Tell comp stream to wait for the info to be loaded before executing | |
self.s0.wait_stream(self.s1) | |
# Stash the tensor to keep memory alive until compute stream is complete | |
self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor | |
# Note: [Track views of the unpacked] | |
# Why do we get the use count of the unpacked tensor here? We want an | |
# initial count to compare to later, during the post-hook of the | |
# backward node, when we need to decide whether we're allowed to free | |
# the tensor yet. In what obscure cases must we delay freeing the | |
# tensor (and thus call record_stream)? | |
# 1. Any of the outputs of the backward node is a view of the unpacked | |
# tensor. | |
# 2. In the case that this unpacked tensor will be used in a | |
# checkpointed region, if one of the recomputed saved tensors ends | |
# up as a view of the unpacked tensor. | |
# 3. The user abuses the system somehow and manually relies on the | |
# unpacked tensor to exist after the backward node has executed. | |
storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata) | |
def hook(outputs, inputs): | |
# create events for the current node inputs/outputs if they were streamed in | |
if brought_back_from_cpu: | |
# See Note: [Track views of the unpacked] | |
# IF any of the outputs is a view of the tensor, OR if a view of | |
# the tensor has been saved as a part of checkpoint's recompute | |
# process, OR the user has abusedly incurred a reference on the | |
# unpacked tensor, THEN the tensor might be used later and we | |
# cannot presume to delete it after only the current node is | |
# done! So we use our frenemy, record_stream, to ensure the | |
# Tensor stays unmessed with until it's done getting used in the | |
# compute stream (s0 here). Note that the con here is we introduce | |
# non-deterministic (thus higher) memory usage, but this case | |
# should not happen often. | |
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] | |
if torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) > storage_refcount: | |
unpacked_tensor.record_stream(self.s0) | |
del self.bwd_tensor_stash[unpack_tensor_id] | |
else: | |
event = self.s0.record_event() | |
self.bwd_ev_stash[unpack_tensor_id] = event | |
# if there are still things in the fwd_stash, get rid of them as we're in bwd now | |
for id in list(self.fwd_stash.keys()): | |
_, ev = self.fwd_stash[id] | |
self.s0.wait_event(ev) | |
del self.fwd_stash[id] | |
# wait on prev node's events and del those | |
for id in prev_node_ids: | |
event = self.bwd_ev_stash[id] | |
self.s1.wait_event(event) | |
del self.bwd_tensor_stash[id] | |
return outputs | |
node.register_hook(hook) | |
# clear tensor from tracking | |
del self.tracker[unpack_tensor_id] | |
return maybe_accelerator_tensor | |
unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream | |
super().__init__(pack_tensor, unpack_tensor) | |
class NoOpManager(saved_tensors_hooks): | |
""" | |
A `saved_tensors_hook` manager used to disable any other `saved_tensors_hook` manager applied before. This relies | |
on the behavior that only the most recently registered `saved_tensors_hook` will run. | |
One example usage is to opt a local region of code out of activations offloading, which is usually applied globally | |
to best track state. | |
""" | |
def __init__(self) -> None: | |
def noop(tensor): | |
return tensor | |
super().__init__(noop, noop) | |
def get_act_offloading_ctx_manager( | |
model: nn.Module, | |
use_pin_memory: bool = True, | |
use_streams: bool = True, | |
min_offload_size: int = 1024, | |
max_fwd_stash_size: int = 5, | |
warn_if_no_head: bool = True, | |
) -> OffloadActivations: | |
""" | |
Returns the activation offloading context manager for the model. All but the last output Linear in every step will | |
be offloaded. | |
If activation offloading is enabled, we return the OffloadActivations context manager. | |
If activation offloading is disabled, we return a NoOpManager context manager. | |
Args: | |
model (`nn.Module`): | |
Model to wrap with the activation offloading context manager. | |
use_pin_memory (`bool`, *optional*, defaults to `True`): | |
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to | |
be moved back onto GPU more quickly but is a limited resource. | |
use_streams (`bool`, *optional*, defaults to `True`): | |
Whether to use streams for performance optimization where the communications get overlapped with the | |
computation. Requires a torch build after torch-2.5.0. | |
min_offload_size (`int`, *optional*, defaults to `1024`): | |
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we | |
do not want to waste bandwidth and resources moving it to CPU and back. | |
max_fwd_stash_size (`int`, *optional*, defaults to `5`): | |
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during | |
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow | |
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping | |
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing | |
runtime. | |
warn_if_no_head (`bool`, *optional*, defaults to `True`): | |
Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output | |
head is detected. | |
Returns: | |
`contextlib.ContextDecorator`: | |
Activation offloading context manager for the model. | |
""" | |
activations_handling_ctx = OffloadActivations( | |
use_pin_memory=use_pin_memory, | |
use_streams=use_streams, | |
min_offload_size=min_offload_size, | |
max_fwd_stash_size=max_fwd_stash_size, | |
) | |
# Below is our hack to disable offloading the last output Linear in every | |
# step, as the cost for offloading the activation and then soon after bringing | |
# it back is expensive. | |
output_head_detected = False | |
noop_ctx = NoOpManager() | |
# Try to get the actual model if it's wrapped | |
unwrapped_model = model | |
if hasattr(unwrapped_model, "module"): | |
unwrapped_model = unwrapped_model.module | |
# check for PEFT models | |
if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model, "peft_config"): | |
unwrapped_model = unwrapped_model.base_model | |
# Check for different types of output heads | |
if hasattr(unwrapped_model, "output"): | |
if isinstance(unwrapped_model.output, nn.Module): | |
unwrapped_model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
unwrapped_model.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
output_head_detected = True | |
elif hasattr(unwrapped_model.output, "linear") and isinstance(unwrapped_model.output.linear, nn.Module): | |
unwrapped_model.output.linear.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
unwrapped_model.output.linear.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
output_head_detected = True | |
# Check for HuggingFace model output heads | |
elif hasattr(unwrapped_model, "lm_head"): | |
unwrapped_model.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
unwrapped_model.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
output_head_detected = True | |
# Check for decoder-based models | |
elif hasattr(unwrapped_model, "decoder"): | |
decoder = unwrapped_model.decoder | |
if hasattr(decoder, "output"): | |
decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
decoder.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
output_head_detected = True | |
# Some models have lm_head in the decoder | |
elif hasattr(decoder, "lm_head"): | |
decoder.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
decoder.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
output_head_detected = True | |
# Check for transformer models with final layer norm | |
elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(unwrapped_model, "ln_f"): | |
final_norm = getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f | |
final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
final_norm.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
output_head_detected = True | |
# Check for models with head module | |
elif hasattr(unwrapped_model, "head") and isinstance(unwrapped_model.head, nn.Module): | |
unwrapped_model.head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
unwrapped_model.head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
output_head_detected = True | |
if not output_head_detected and warn_if_no_head: | |
warnings.warn( | |
"During activation offloading, no output head was detected. If your model has an output head, it will be " | |
"offloaded. This usually greatly slows training, given the large vocabulary size. To change this " | |
"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by " | |
"passing `warn_if_no_head=False`." | |
) | |
# Disable offloading for any Liger modules | |
for name, module in unwrapped_model.named_modules(): | |
if "liger" in name.lower(): | |
module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | |
module.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) | |
return activations_handling_ctx | |