|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
|
|
from deepspeed.utils import logger |
|
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states |
|
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank |
|
|
|
|
|
class DeepSpeedOptimizer(object): |
|
pass |
|
|
|
|
|
class ZeROOptimizer(DeepSpeedOptimizer): |
|
|
|
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: |
|
checkpoint_dir = os.path.join(checkpoint_dir, "zero") |
|
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") |
|
assert os.path.isfile( |
|
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' |
|
optim_sd = torch.load(optim_state_path, weights_only=False) |
|
|
|
self._load_global_state(optim_sd) |
|
|
|
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) |
|
if self.mpu is None: |
|
logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.") |
|
tp_world_size = 1 |
|
else: |
|
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ |
|
else self.mpu.get_tensor_model_parallel_world_size() |
|
|
|
for i, (param_group, |
|
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): |
|
|
|
opt_keys = set() |
|
steps = [] |
|
|
|
lp_groups = getattr(self, lp_groups_name) |
|
for lp in lp_groups[i]: |
|
if lp._hp_mapping is not None: |
|
|
|
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, |
|
tp_world_size) |
|
for key in lp._hp_mapping.get_optim_state_keys(): |
|
opt_keys.add(key) |
|
steps.append(step) |
|
|
|
hp_param = param_group['params'][0] |
|
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal" |
|
if steps[0] is not None: |
|
self.optimizer.state[hp_param]['step'] = steps[0] |
|
|
|
map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) |
|
|
|
for key, value in loaded_param_group.items(): |
|
if key == 'params': |
|
continue |
|
param_group[key] = value |
|
|