|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from .constants import (BASE_OPTIMIZER_STATE, GROUP_PADDINGS, OPTIMIZER_STATE_DICT, PARTITION_COUNT) |
|
|
|
from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state) |
|
|
|
from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) |
|
|
|
GROUP_STATE_KEY = 'state' |
|
|
|
|
|
class ZeROCheckpoint(object): |
|
|
|
def __init__(self, dir): |
|
basic_folder_validation(dir) |
|
self.dir = dir |
|
self.file_list = get_zero_files(dir) |
|
self.num_files = len(self.file_list) |
|
assert self.num_files > 0, f'No ZeRO files found in {dir}' |
|
|
|
self.src_3d = get_model_3d_descriptor(dir) |
|
self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree, |
|
tp_degree=self.src_3d.tp_degree, |
|
dp_degree=self.src_3d.dp_degree) |
|
self._3d_file_map = self.src_3d.reshape(self.target_3d) |
|
|
|
def get_src_world_size(self): |
|
return self.src_3d.world_size() |
|
|
|
def get_src_tp_degree(self): |
|
return self.src_3d.tp_degree |
|
|
|
def get_src_pp_degree(self): |
|
return self.src_3d.pp_degree |
|
|
|
def get_src_dp_degree(self): |
|
return self.src_3d.dp_degree |
|
|
|
def get_file_indices_for_rank(self, pp_index, tp_index, dp_index): |
|
assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}' |
|
dp_2d_map = self._3d_file_map[dp_index] |
|
return dp_2d_map.get_data(pp_index, tp_index) |
|
|
|
def get_files_for_rank(self, pp_index, tp_index, dp_index): |
|
file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index) |
|
return [self.file_list[idx] for idx in file_idx_list] |
|
|
|
def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], strip_tensor_paddings=True): |
|
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) |
|
merged_sd = None |
|
for state_file in state_file_list: |
|
sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False) |
|
for key in keys_to_ignore: |
|
sd.pop(key, None) |
|
|
|
if strip_tensor_paddings: |
|
self._strip_tensor_paddings(sd) |
|
|
|
if merged_sd is None: |
|
merged_sd = sd |
|
else: |
|
merged_sd = merge_state(merged_sd, sd) |
|
|
|
self._update_partition_count(merged_sd) |
|
if strip_tensor_paddings: |
|
self._clear_group_paddings(merged_sd) |
|
|
|
return merged_sd |
|
|
|
def print_3d_index_map(self, tag=None): |
|
if tag: |
|
print(f'3D index map: {tag}') |
|
for dp_index, _2d_map in enumerate(self._3d_file_map): |
|
_2d_map.print_data(f'dp = {dp_index}') |
|
|
|
def print_3d_file_map(self, tag=None): |
|
if tag: |
|
print(f'3D file map: {tag}') |
|
for dp_index, _2d_map in enumerate(self._3d_file_map): |
|
for pp_index in _2d_map.pp_degree: |
|
for tp_index in _2d_map.tp_degree: |
|
file_index_list = _2d_map.get_data(pp_index, tp_index) |
|
file_list = [self.file_list[idx] for idx in file_index_list] |
|
print(f'{pp_index}, {tp_index}, {dp_index} => {file_list}') |
|
|
|
def reshape(self, target_3d_desc: model_3d_desc): |
|
self.target_3d = target_3d_desc |
|
self._3d_file_map = self.src_3d.reshape(self.target_3d) |
|
|
|
def _strip_tensor_paddings(self, sd): |
|
param_group_states = self._get_param_group_states(sd) |
|
if param_group_states is None: |
|
return |
|
|
|
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) |
|
if group_paddings is None: |
|
return |
|
|
|
for key, group_state in param_group_states.items(): |
|
if group_paddings[key] == 0: |
|
continue |
|
for state_name, state_value in group_state.items(): |
|
if state_name != "step" and torch.is_tensor(state_value): |
|
raw_length = state_value.numel() - group_paddings[key] |
|
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone() |
|
else: |
|
group_state[state_name] = state_value |
|
|
|
def _clear_group_paddings(self, sd): |
|
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) |
|
if group_paddings: |
|
num_groups = len(group_paddings) |
|
sd[OPTIMIZER_STATE_DICT][GROUP_PADDINGS] = [0] * num_groups |
|
|
|
def _get_optimizer_state(self, sd, state_key): |
|
optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None) |
|
if optimizer_state is None: |
|
return None |
|
|
|
return optimizer_state.get(state_key, None) |
|
|
|
def _get_param_group_states(self, sd): |
|
optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None) |
|
if optimizer_state is None: |
|
return None |
|
|
|
base_optimizer_state = optimizer_state.get(BASE_OPTIMIZER_STATE, None) |
|
if base_optimizer_state is None: |
|
return None |
|
|
|
return base_optimizer_state.get(GROUP_STATE_KEY, None) |
|
|
|
def _update_partition_count(self, sd): |
|
partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT) |
|
if partition_counts: |
|
num_groups = len(partition_counts) |
|
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree] * num_groups |
|
|