|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from itertools import chain |
|
import argparse |
|
import glob |
|
import itertools |
|
import math |
|
from concurrent.futures import ProcessPoolExecutor |
|
import os |
|
import re |
|
import shutil |
|
import torch |
|
import tqdm |
|
|
|
|
|
from deepspeed.checkpoint import DeepSpeedCheckpoint |
|
from deepspeed.checkpoint import ( |
|
OPTIMIZER_STATE_DICT, |
|
ZERO_STAGE, |
|
BASE_OPTIMIZER_STATE, |
|
SINGLE_PARTITION_OF_FP32_GROUPS, |
|
PARAM_GROUPS, |
|
PARAM_SLICE_MAPPINGS, |
|
PARAM_SHAPES, |
|
PARAM, |
|
CAT_DIM, |
|
PARAM_N_SUB_PARAMS, |
|
SUB_PARAM_SHAPE, |
|
VOCAB_TENSOR, |
|
UNIVERSAL_CHECKPOINT_INFO, |
|
UNIVERSAL_CHECKPOINT_VERSION_KEY, |
|
UNIVERSAL_CHECKPOINT_VERSION_VALUE, |
|
VOCABULARY_PARAMETER_PATTERNS, |
|
PIPELINE_REPLICATED_PARAMETER_PATTERNS, |
|
TP_REPLICATED_PARAMETER_PATTERNS, |
|
PARAMETER_TO_AVERAGE_PATTERNS, |
|
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, |
|
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, |
|
PARAMETER_WITH_SUB_PARAMS, |
|
SubparamShape, |
|
) |
|
|
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder') |
|
parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder') |
|
parser.add_argument('--num_extract_workers', |
|
default=4, |
|
type=int, |
|
help='How many parallel processes to extract zero shards') |
|
parser.add_argument( |
|
'--num_merge_workers', |
|
default=2, |
|
type=int, |
|
help= |
|
'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))' |
|
) |
|
parser.add_argument('--keep_temp_folder', |
|
action='store_true', |
|
help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.') |
|
parser.add_argument('--no_strict', |
|
dest='strict', |
|
action='store_false', |
|
help='Do not perform validity checks on converted checkpoint.') |
|
parser.add_argument('--inject_missing_state', |
|
action='store_true', |
|
help='Inject missing checkpoint state into the checkpoint if it is absent.') |
|
args = parser.parse_args() |
|
print(f'args = {args}') |
|
return args |
|
|
|
|
|
def atoi(text): |
|
return int(text) if text.isdigit() else text |
|
|
|
|
|
def natural_keys(text): |
|
''' |
|
alist.sort(key=natural_keys) sorts in human order |
|
http://nedbatchelder.com/blog/200712/human_sorting.html |
|
(See Toothy's implementation in the comments) |
|
''' |
|
return [atoi(c) for c in re.split(r'(\d+)', text)] |
|
|
|
|
|
def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): |
|
path_list = [] |
|
iter_folder = f'iter_{iteration:07d}' |
|
for i in range(0, tp_degree): |
|
path_list.append([]) |
|
for j in range(0, pp_degree): |
|
rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' |
|
ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') |
|
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) |
|
|
|
return path_list |
|
|
|
|
|
def _save_checkpoint(file_path, chkpt_sd): |
|
dir, _ = os.path.split(file_path) |
|
os.makedirs(dir, exist_ok=True) |
|
torch.save(chkpt_sd, file_path) |
|
|
|
|
|
def extract_zero_shards(dir, ds_checkpoint, indices_3D): |
|
pp_index, tp_index, dp_index = indices_3D |
|
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) |
|
|
|
|
|
|
|
optim_sd = sd[OPTIMIZER_STATE_DICT] |
|
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] |
|
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) |
|
pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, []) |
|
|
|
|
|
|
|
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] |
|
|
|
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] |
|
param_groups_cnt = len(state_groups) |
|
|
|
for param_group_id in range(param_groups_cnt): |
|
|
|
flat_state = dict( |
|
exp_avg=state_groups[param_group_id]["exp_avg"], |
|
exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], |
|
fp32=fp32_groups[param_group_id], |
|
) |
|
|
|
if "step" in state_groups[param_group_id]: |
|
flat_state["step"] = state_groups[param_group_id]["step"] |
|
|
|
for name, fragment_mapping in param_slice_mappings[param_group_id].items(): |
|
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params): |
|
|
|
continue |
|
|
|
|
|
for state_key in flat_state.keys(): |
|
dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name, |
|
fragment_mapping.start, fragment_mapping.numel) |
|
|
|
|
|
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): |
|
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False) |
|
|
|
flat_state = dict( |
|
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], |
|
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"], |
|
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0], |
|
) |
|
|
|
offset = 0 |
|
for name, shape in param_shapes.items(): |
|
unpartitioned_numel = shape.numel() |
|
partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree) |
|
padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel)) |
|
for state_key in flat_state.keys(): |
|
dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, |
|
padding_free_numel) |
|
offset += partitioned_numel |
|
|
|
|
|
cnt = 0 |
|
|
|
|
|
def dp_index_to_str(dp_index): |
|
return f"{dp_index:0>2d}" |
|
|
|
|
|
def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): |
|
|
|
global cnt |
|
|
|
param_base_path = os.path.join(dir, param_name, str(tp_index)) |
|
os.makedirs(param_base_path, exist_ok=True) |
|
|
|
cnt += 1 |
|
|
|
path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") |
|
|
|
|
|
|
|
|
|
if state_name != "step" and torch.is_tensor(state_flat_tensor): |
|
state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() |
|
_save_checkpoint(path, state_flat_tensor) |
|
|
|
|
|
def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None): |
|
slices = [] |
|
for tp_index in range(tp_degree): |
|
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") |
|
paths = glob.glob(f"{prefix_path}.*") |
|
|
|
if len(paths) == 0: |
|
continue |
|
|
|
pattern = re.compile(f"{prefix_path}\\.([0-9]+)") |
|
dp_indices = set() |
|
for p in paths: |
|
m = pattern.match(p) |
|
if m: |
|
dp_indices.add(int(m.group(1))) |
|
else: |
|
raise ValueError(f"Cannot parse dp_rank from {p}") |
|
|
|
paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] |
|
shards = [torch.load(p, weights_only=False) for p in paths] |
|
|
|
if state == "step": |
|
assert all(v == shards[0] for v in shards), "All shards must have the same step value" |
|
slice = shards[0] |
|
else: |
|
if slice_shape is None: |
|
slice = torch.cat(shards, dim=0) |
|
else: |
|
slice = torch.cat(shards, dim=0).reshape(slice_shape) |
|
|
|
slices.append(slice) |
|
return slices |
|
|
|
|
|
def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): |
|
|
|
name, shape = name_and_shape |
|
slice_base_path = os.path.join(slice_dir, name) |
|
param_base_path = os.path.join(dir, name) |
|
|
|
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) |
|
replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, []) |
|
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, []) |
|
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, []) |
|
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, []) |
|
parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, []) |
|
parameter_with_sub_params = universal_checkpoint_info.get(PARAMETER_WITH_SUB_PARAMS, []) |
|
|
|
unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism + |
|
vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0) |
|
unmatched_patterns.update(chain.from_iterable(SubparamShape(**s).patterns for s in parameter_with_sub_params)) |
|
|
|
def get_matched_pattern(patterns_, name_): |
|
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)] |
|
assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}' |
|
if matched_: |
|
pattern_ = matched_[0] |
|
unmatched_patterns.discard(pattern_) |
|
return pattern_ |
|
return None |
|
|
|
def get_matched_sub_params_pattern(name_): |
|
for subparam_shape_dict in parameter_with_sub_params: |
|
subparam_shape = SubparamShape(**subparam_shape_dict) |
|
for pattern_ in subparam_shape.patterns: |
|
if re.match(pattern_, name_): |
|
unmatched_patterns.discard(pattern_) |
|
return subparam_shape |
|
return None |
|
|
|
matched_sub_params_shape = get_matched_sub_params_pattern(name) |
|
|
|
step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape) |
|
if step_merged: |
|
_save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0]) |
|
|
|
for state in ("fp32", "exp_avg", "exp_avg_sq"): |
|
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) |
|
final_path = os.path.join(param_base_path, f"{state}.pt") |
|
|
|
|
|
|
|
ckpt_dict = {} |
|
if get_matched_pattern(replicated_parameters, name): |
|
if len(slices) > 1: |
|
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]]) |
|
param = slices[0] |
|
|
|
elif get_matched_pattern(parameters_to_average, name): |
|
param = sum(slices) / len(slices) |
|
|
|
elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name): |
|
cat_dim = 0 |
|
chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices] |
|
merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim) |
|
merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim) |
|
param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim) |
|
ckpt_dict[CAT_DIM] = cat_dim |
|
ckpt_dict[PARAM_N_SUB_PARAMS] = 2 |
|
elif matched_sub_params_shape: |
|
merged_chunks = [] |
|
partition_dim = matched_sub_params_shape.partition_dim |
|
|
|
sub_dim_sizes = matched_sub_params_shape.shape[partition_dim] |
|
if not isinstance(sub_dim_sizes, tuple): |
|
sub_dim_sizes = (sub_dim_sizes, ) |
|
|
|
partition_shape = [sum(d) if isinstance(d, tuple) else d for d in matched_sub_params_shape.shape] |
|
partition_shape = [d // tp_degree if i == partition_dim else d for i, d in enumerate(partition_shape)] |
|
slices = [s.view(partition_shape) for s in slices] |
|
|
|
offset = 0 |
|
for sub_dim_size in sub_dim_sizes: |
|
part_sub_dim_size = sub_dim_size // tp_degree |
|
merged_chunks.append( |
|
torch.cat([s.narrow(partition_dim, offset, part_sub_dim_size) for s in slices], dim=partition_dim)) |
|
offset += part_sub_dim_size |
|
param = torch.cat(merged_chunks, dim=partition_dim) |
|
ckpt_dict[SUB_PARAM_SHAPE] = matched_sub_params_shape |
|
else: |
|
cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0 |
|
|
|
param = torch.cat(slices, dim=cat_dim) |
|
ckpt_dict[CAT_DIM] = cat_dim |
|
|
|
if get_matched_pattern(vocabulary_parameters, name): |
|
|
|
|
|
original_vocab_size = universal_checkpoint_info['original_vocab_size'] |
|
param = param[:original_vocab_size, :] |
|
ckpt_dict[VOCAB_TENSOR] = True |
|
|
|
|
|
|
|
ckpt_dict[PARAM] = param |
|
_save_checkpoint(final_path, ckpt_dict) |
|
|
|
return unmatched_patterns |
|
|
|
|
|
def merge_zero3_slices(dp_degree, dir, slice_dir, name): |
|
slice_base_path = os.path.join(slice_dir, name) |
|
param_base_path = os.path.join(dir, name) |
|
|
|
for state in ("fp32", "exp_avg", "exp_avg_sq"): |
|
slices = _merge_zero_shards(slice_base_path, state, 1) |
|
final_path = os.path.join(param_base_path, f"{state}.pt") |
|
_save_checkpoint(final_path, slices[0]) |
|
|
|
|
|
def _do_parallel_work(do_work, work_chunks, num_workers): |
|
results = [] |
|
if num_workers > 1: |
|
with ProcessPoolExecutor(max_workers=num_workers) as executor: |
|
future_list = [executor.submit(do_work, work) for work in work_chunks] |
|
for f in tqdm.tqdm(future_list): |
|
results.append(f.result()) |
|
else: |
|
|
|
|
|
for work in tqdm.tqdm(work_chunks): |
|
results.append(do_work(work)) |
|
return results |
|
|
|
|
|
def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): |
|
_3d_range_list = list( |
|
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), |
|
range(ds_checkpoint.dp_degree))) |
|
|
|
|
|
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) |
|
_do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) |
|
|
|
|
|
def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir): |
|
do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir) |
|
_do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) |
|
|
|
|
|
def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): |
|
zero_output_folder = os.path.join(args.output_folder, "zero") |
|
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) |
|
unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers) |
|
|
|
|
|
|
|
sets = [set(lst) for lst in unmatched_patterns_lists] |
|
unmatched_patterns = list(set.intersection(*sets)) |
|
if args.strict: |
|
assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices' |
|
elif unmatched_patterns: |
|
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices') |
|
|
|
|
|
def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir): |
|
zero_output_folder = os.path.join(args.output_folder, "zero") |
|
do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir) |
|
_do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers) |
|
|
|
|
|
def _zero_partitioned_param_info(unpartitioned_numel, world_size): |
|
remainder = unpartitioned_numel % world_size |
|
padding_numel = (world_size - remainder) if remainder else 0 |
|
partitioned_numel = math.ceil(unpartitioned_numel / world_size) |
|
return partitioned_numel, padding_numel |
|
|
|
|
|
def _parse_model_states_stage3(files): |
|
return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES] |
|
|
|
|
|
def _save_optimizer_state(args, ds_checkpoint): |
|
sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS] |
|
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0) |
|
|
|
optim_sd = sd[OPTIMIZER_STATE_DICT] |
|
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states} |
|
output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS] |
|
zero_output_folder = os.path.join(args.output_folder, "zero") |
|
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt") |
|
_save_checkpoint(output_file_path, output_sd) |
|
|
|
|
|
def _save_optimizer_state_stage3(args, optim_files): |
|
sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) |
|
output_sd = sd[OPTIMIZER_STATE_DICT] |
|
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS] |
|
zero_output_folder = os.path.join(args.output_folder, "zero") |
|
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt") |
|
_save_checkpoint(output_file_path, output_sd) |
|
|
|
|
|
def _get_optim_files(checkpoint_dir): |
|
return _get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") |
|
|
|
|
|
def _get_model_state_files(checkpoint_dir): |
|
return _get_checkpoint_files(checkpoint_dir, "*_model_states.pt") |
|
|
|
|
|
def _get_checkpoint_files(checkpoint_dir, glob_pattern): |
|
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) |
|
|
|
if len(ckpt_files) == 0: |
|
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") |
|
|
|
return ckpt_files |
|
|
|
|
|
def _get_zero_stage(optim_files): |
|
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) |
|
optimizer_state = state_dict[OPTIMIZER_STATE_DICT] |
|
zero_stage = optimizer_state.get(ZERO_STAGE, 1) |
|
return zero_stage |
|
|
|
|
|
def _inject_missing_state(ds_checkpoint): |
|
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: |
|
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) |
|
if UNIVERSAL_CHECKPOINT_INFO not in sd: |
|
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} |
|
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][ |
|
UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE |
|
|
|
|
|
def _check_for_required_state(ds_checkpoint): |
|
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) |
|
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.' |
|
|
|
|
|
def main(args): |
|
print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint') |
|
|
|
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}') |
|
|
|
optim_files = _get_optim_files(args.input_folder) |
|
zero_stage = _get_zero_stage(optim_files) |
|
|
|
if zero_stage <= 2: |
|
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) |
|
if args.inject_missing_state: |
|
_inject_missing_state(ds_checkpoint) |
|
else: |
|
_check_for_required_state(ds_checkpoint) |
|
|
|
iteration = ds_checkpoint.get_iteration() |
|
|
|
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, |
|
ds_checkpoint.pp_degree) |
|
|
|
slice_shapes = [] |
|
for mp_rank_file in ds_checkpoint.mp_rank_files: |
|
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False) |
|
slice_shapes += mp_sd[PARAM_SHAPES] |
|
|
|
|
|
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items()) |
|
temp_dir = os.path.join(args.output_folder, 'tmp') |
|
|
|
print('*** 1. Extracting ZeRO fragments') |
|
_extract_zero_shard_files(args, ds_checkpoint, temp_dir) |
|
|
|
print('*** 2. Merging slices .....') |
|
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) |
|
|
|
print('*** 3. Saving common optimizer states') |
|
_save_optimizer_state(args, ds_checkpoint) |
|
|
|
if not args.keep_temp_folder: |
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
for f in glob.glob(os.path.join(args.input_folder, 'mp*')): |
|
shutil.copy2(f, args.output_folder) |
|
|
|
else: |
|
model_files = _get_model_state_files(args.input_folder) |
|
param_shapes = _parse_model_states_stage3(model_files) |
|
param_shapes = {k: v for d in param_shapes for k, v in d.items()} |
|
dp_degree = len(model_files) |
|
|
|
temp_dir = os.path.join(args.output_folder, 'tmp') |
|
|
|
print('*** 1. Extracting ZeRO fragments') |
|
_extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir) |
|
|
|
print('*** 2. Merging slices .....') |
|
_merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir) |
|
|
|
print('*** 3. Saving common optimizer states') |
|
_save_optimizer_state_stage3(args, optim_files) |
|
|
|
if not args.keep_temp_folder: |
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')): |
|
shutil.copy2(f, args.output_folder) |
|
|
|
|
|
checkpoint_root_folder, step_folder = os.path.split(args.output_folder) |
|
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') |
|
with open(latest_file, "w") as f: |
|
f.write(step_folder) |
|
|
|
print('*** Done!') |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_arguments() |
|
main(args) |
|
|