jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Set
import torch
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.offload_config import OffloadStateTypeEnum
def _make_offload_state_key(key):
return f"{key}_offload_buffer"
def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""
def move_key(state, key):
offload_buf_key = _make_offload_state_key(key)
if offload_buf_key not in state:
state[offload_buf_key] = torch.empty_like(state[key], device=device)
if pin_memory:
state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key])
state[offload_buf_key].copy_(state[key], non_blocking=non_blocking)
state[key].data = state[offload_buf_key]
for _, state in optimizer.state.items():
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")
def reload_adam_states(optimizer, device, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""
def move_back_key(state, key):
state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking)
for _, state in optimizer.state.items():
if "exp_avg" in state:
move_back_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_back_key(state, "exp_avg_sq")
def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
"""Retrieve the devices of the specified state of the model.
Args:
model (DeepSpeedEngine): The model whose device allocations are to be checked.
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.
Returns:
Set[torch.device]: A set of devices of the specified state.
"""
if state == OffloadStateTypeEnum.hp_params:
return set(model.optimizer.get_hp_param_device(p) for p in model.parameters())
elif state == OffloadStateTypeEnum.lp_params:
return set(p.ds_tensor.device for p in model.parameters())
elif state == OffloadStateTypeEnum.lp_grads:
return {model.optimizer.grad_partitions_flat_buffer.device}
elif state == OffloadStateTypeEnum.optim_states:
return set(model.optimizer.get_hp_param_device(p, "exp_avg") for p in model.parameters()) | \
set(model.optimizer.get_hp_param_device(p, "exp_avg_sq") for p in model.parameters())
elif state == OffloadStateTypeEnum.contiguous_grad_buffer:
if model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer == None:
return {}
return {model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer.device}