|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import torch |
|
import types |
|
from typing import List, Tuple, Union |
|
from dataclasses import dataclass |
|
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE) |
|
|
|
|
|
@dataclass |
|
class SubparamShape: |
|
patterns: List[str] |
|
shape: Tuple[Union[Tuple[int], int]] |
|
partition_dim: int |
|
|
|
|
|
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): |
|
hp_mapping = self._hp_mapping |
|
hp_mapping.optim_fragment = {} |
|
|
|
hp_keys = [] |
|
for file in os.listdir(folder): |
|
|
|
pattern = r'(.+).pt' |
|
match = re.search(pattern, file) |
|
if match: |
|
hp_keys.append(match.group(1)) |
|
|
|
step = None |
|
for key in hp_keys: |
|
ckpt_file = os.path.join(folder, f"{key}.pt") |
|
ckpt_dict = torch.load(ckpt_file, weights_only=False) |
|
|
|
if key == "step": |
|
step = ckpt_dict |
|
continue |
|
|
|
full_hp_param = ckpt_dict[PARAM] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if full_hp_param.shape == self.shape: |
|
tp_rank = 0 |
|
tp_world_size = 1 |
|
|
|
|
|
|
|
|
|
|
|
is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False) |
|
if is_vocab_tensor: |
|
|
|
|
|
padded_target_vocab_size = self.shape[0] * tp_world_size |
|
assert padded_target_vocab_size >= full_hp_param.shape[0], \ |
|
f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}' |
|
if padded_target_vocab_size > full_hp_param.shape[0]: |
|
padding_size = padded_target_vocab_size - full_hp_param.shape[0] |
|
full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0) |
|
|
|
full_param_numel = full_hp_param.numel() |
|
tp_slice_numel = self.numel() |
|
|
|
|
|
|
|
|
|
assert full_param_numel == tp_world_size * tp_slice_numel, \ |
|
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' |
|
|
|
|
|
|
|
|
|
sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None) |
|
|
|
|
|
|
|
chunk_dim = ckpt_dict.get(CAT_DIM, 0) |
|
n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1) |
|
if sub_param_shape: |
|
partition_dim = sub_param_shape.partition_dim |
|
sub_dim_sizes = sub_param_shape.shape[partition_dim] |
|
if not isinstance(sub_dim_sizes, tuple): |
|
sub_dim_sizes = (sub_dim_sizes, ) |
|
|
|
partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape] |
|
full_hp_param = full_hp_param.view(partition_shape) |
|
|
|
offset = 0 |
|
merged_chunks = [] |
|
for sub_dim_size in sub_dim_sizes: |
|
sub_params_tp_slice = full_hp_param.narrow(partition_dim, |
|
offset, sub_dim_size).chunk(tp_world_size, |
|
dim=partition_dim)[tp_rank] |
|
merged_chunks.append(sub_params_tp_slice) |
|
offset += sub_dim_size |
|
tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim) |
|
|
|
elif n_sub_params > 1: |
|
sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim) |
|
sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params] |
|
tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim) |
|
else: |
|
|
|
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] |
|
|
|
tp_hp_slice = tp_hp_slice.flatten() |
|
|
|
lp_frag_address = hp_mapping.lp_fragment_address |
|
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel) |
|
|
|
|
|
|
|
|
|
|
|
if key == FP32_WEIGHT_KEY: |
|
dst_tensor = hp_mapping.get_hp_fragment() |
|
assert dst_tensor.numel() == lp_frag_address.numel, \ |
|
f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' |
|
dst_tensor.data.copy_(tp_hp_fragment.data) |
|
else: |
|
assert tp_hp_fragment.numel() == lp_frag_address.numel, \ |
|
f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}' |
|
|
|
hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach() |
|
|
|
return step |
|
|
|
|
|
def enable_universal_checkpoint(param_list): |
|
for param in param_list: |
|
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param) |
|
|