|
|
|
|
|
|
|
|
|
|
|
from typing import Set |
|
import torch |
|
|
|
from deepspeed.accelerator import get_accelerator |
|
from deepspeed.runtime.zero.offload_config import OffloadStateTypeEnum |
|
|
|
|
|
def _make_offload_state_key(key): |
|
return f"{key}_offload_buffer" |
|
|
|
|
|
def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False): |
|
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" |
|
|
|
def move_key(state, key): |
|
offload_buf_key = _make_offload_state_key(key) |
|
if offload_buf_key not in state: |
|
state[offload_buf_key] = torch.empty_like(state[key], device=device) |
|
if pin_memory: |
|
state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key]) |
|
state[offload_buf_key].copy_(state[key], non_blocking=non_blocking) |
|
state[key].data = state[offload_buf_key] |
|
|
|
for _, state in optimizer.state.items(): |
|
if "exp_avg" in state: |
|
move_key(state, "exp_avg") |
|
if "exp_avg_sq" in state: |
|
move_key(state, "exp_avg_sq") |
|
|
|
|
|
def reload_adam_states(optimizer, device, non_blocking: bool = False): |
|
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" |
|
|
|
def move_back_key(state, key): |
|
state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking) |
|
|
|
for _, state in optimizer.state.items(): |
|
if "exp_avg" in state: |
|
move_back_key(state, "exp_avg") |
|
if "exp_avg_sq" in state: |
|
move_back_key(state, "exp_avg_sq") |
|
|
|
|
|
def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]: |
|
"""Retrieve the devices of the specified state of the model. |
|
|
|
Args: |
|
model (DeepSpeedEngine): The model whose device allocations are to be checked. |
|
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved. |
|
|
|
Returns: |
|
Set[torch.device]: A set of devices of the specified state. |
|
|
|
""" |
|
if state == OffloadStateTypeEnum.hp_params: |
|
return set(model.optimizer.get_hp_param_device(p) for p in model.parameters()) |
|
elif state == OffloadStateTypeEnum.lp_params: |
|
return set(p.ds_tensor.device for p in model.parameters()) |
|
elif state == OffloadStateTypeEnum.lp_grads: |
|
return {model.optimizer.grad_partitions_flat_buffer.device} |
|
elif state == OffloadStateTypeEnum.optim_states: |
|
return set(model.optimizer.get_hp_param_device(p, "exp_avg") for p in model.parameters()) | \ |
|
set(model.optimizer.get_hp_param_device(p, "exp_avg_sq") for p in model.parameters()) |
|
elif state == OffloadStateTypeEnum.contiguous_grad_buffer: |
|
if model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer == None: |
|
return {} |
|
return {model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer.device} |
|
|