|
|
|
|
|
|
|
|
|
|
|
import functools |
|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.distributed.fsdp._traversal_utils as traversal_utils |
|
|
from torch.distributed.device_mesh import init_device_mesh |
|
|
from torch.distributed.fsdp import ( |
|
|
CPUOffload, |
|
|
FullyShardedDataParallel as FSDP, |
|
|
MixedPrecision, |
|
|
BackwardPrefetch, |
|
|
ShardingStrategy, |
|
|
FullStateDictConfig, |
|
|
StateDictType, |
|
|
) |
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
|
|
from safetensors.torch import load_file, save_file |
|
|
|
|
|
from modeling.bagel.modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding |
|
|
from modeling.bagel.qwen2_navit import ( |
|
|
Qwen2DecoderLayer, |
|
|
Qwen2MoEDecoderLayer, |
|
|
Qwen2MoTDecoderLayer, |
|
|
) |
|
|
from modeling.bagel.siglip_navit import SiglipEncoderLayer, SiglipVisionTransformer |
|
|
|
|
|
|
|
|
class FSDPConfig: |
|
|
def __init__( |
|
|
self, |
|
|
sharding_strategy, |
|
|
backward_prefetch, |
|
|
cpu_offload, |
|
|
num_replicate, |
|
|
num_shard=8, |
|
|
): |
|
|
self.sharding_strategy = sharding_strategy |
|
|
self.backward_prefetch = backward_prefetch |
|
|
self.cpu_offload = cpu_offload |
|
|
self.num_replicate = num_replicate |
|
|
self.num_shard = num_shard |
|
|
|
|
|
|
|
|
def fsdp_wrapper(original_model, fsdp_config, ignored_modules=[]): |
|
|
if fsdp_config.sharding_strategy == 'HYBRID_SHARD': |
|
|
device_mesh = init_device_mesh( |
|
|
"cuda", |
|
|
mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard), |
|
|
mesh_dim_names=("replicate", "shard") |
|
|
) |
|
|
else: |
|
|
device_mesh = None |
|
|
return FSDP( |
|
|
original_model, |
|
|
auto_wrap_policy=functools.partial( |
|
|
transformer_auto_wrap_policy, |
|
|
transformer_layer_cls={ |
|
|
Qwen2DecoderLayer, |
|
|
Qwen2MoEDecoderLayer, |
|
|
Qwen2MoTDecoderLayer, |
|
|
SiglipEncoderLayer, |
|
|
SiglipVisionTransformer, |
|
|
MLPconnector, |
|
|
TimestepEmbedder, |
|
|
PositionEmbedding, |
|
|
}, |
|
|
), |
|
|
ignored_modules=ignored_modules, |
|
|
mixed_precision=MixedPrecision( |
|
|
param_dtype=torch.bfloat16, |
|
|
reduce_dtype=torch.bfloat16, |
|
|
buffer_dtype=torch.bfloat16, |
|
|
), |
|
|
device_id=dist.get_rank() % torch.cuda.device_count(), |
|
|
sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy], |
|
|
backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch], |
|
|
cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload), |
|
|
device_mesh=device_mesh, |
|
|
) |
|
|
|
|
|
|
|
|
class FSDPCheckpoint: |
|
|
@staticmethod |
|
|
def fsdp_save_ckpt( |
|
|
ckpt_dir, |
|
|
train_steps, |
|
|
model, |
|
|
ema_model, |
|
|
optimizer, |
|
|
scheduler, |
|
|
data_status, |
|
|
logger, |
|
|
fsdp_config, |
|
|
): |
|
|
save_path = os.path.join(ckpt_dir, f"{train_steps:07d}") |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
logger.info(f"Saving checkpoint to {save_path}.") |
|
|
|
|
|
if ema_model is not None: |
|
|
with FSDP.state_dict_type( |
|
|
ema_model, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(rank0_only=True, offload_to_cpu=True), |
|
|
): |
|
|
ema_state_dict = ema_model.state_dict() |
|
|
if dist.get_rank() == 0: |
|
|
save_file(ema_state_dict, os.path.join(save_path, "ema.safetensors")) |
|
|
|
|
|
with FSDP.state_dict_type( |
|
|
model, |
|
|
StateDictType.FULL_STATE_DICT, |
|
|
FullStateDictConfig(rank0_only=True, offload_to_cpu=True), |
|
|
): |
|
|
model_state_dict = model.state_dict() |
|
|
if dist.get_rank() == 0: |
|
|
save_file(model_state_dict, os.path.join(save_path, "model.safetensors")) |
|
|
|
|
|
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): |
|
|
if fsdp_config.sharding_strategy == "FULL_SHARD": |
|
|
shard_index = dist.get_rank() |
|
|
total_shards = dist.get_world_size() |
|
|
elif fsdp_config.sharding_strategy == "HYBRID_SHARD": |
|
|
shard_index = dist.get_rank() % fsdp_config.num_shard |
|
|
total_shards = fsdp_config.num_shard |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
optimizer_save_path = os.path.join( |
|
|
save_path, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" |
|
|
) |
|
|
if fsdp_config.sharding_strategy == "FULL_SHARD": |
|
|
torch.save(optimizer.state_dict(), optimizer_save_path) |
|
|
elif fsdp_config.sharding_strategy == "HYBRID_SHARD": |
|
|
if dist.get_rank() < fsdp_config.num_shard: |
|
|
torch.save(optimizer.state_dict(), optimizer_save_path) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if dist.get_rank() == 0 and scheduler is not None: |
|
|
torch.save(scheduler.state_dict(), os.path.join(save_path, "scheduler.pt")) |
|
|
|
|
|
if dist.get_rank() == 0 and data_status is not None: |
|
|
torch.save(data_status, os.path.join(save_path, "data_status.pt")) |
|
|
|
|
|
dist.barrier() |
|
|
return |
|
|
|
|
|
@staticmethod |
|
|
def try_load_ckpt(resume_from, logger, model, ema_model=None, resume_from_ema=False): |
|
|
if resume_from is not None and os.path.exists(resume_from): |
|
|
logger.info(f"Loading checkpoint from {resume_from}.") |
|
|
if resume_from_ema: |
|
|
model_state_dict_path = os.path.join(resume_from, f"ema.safetensors") |
|
|
else: |
|
|
|
|
|
possible_model_files = ["model.safetensors", "model_bf16.safetensors", "pytorch_model.bin"] |
|
|
model_state_dict_path = None |
|
|
for model_file in possible_model_files: |
|
|
candidate_path = os.path.join(resume_from, model_file) |
|
|
if os.path.exists(candidate_path): |
|
|
model_state_dict_path = candidate_path |
|
|
break |
|
|
|
|
|
if model_state_dict_path is None: |
|
|
raise FileNotFoundError(f"Could not find model file in {resume_from}. Tried: {possible_model_files}") |
|
|
|
|
|
model_state_dict = load_file(model_state_dict_path, device="cpu") |
|
|
|
|
|
|
|
|
model_state_dict.pop('latent_pos_embed.pos_embed') |
|
|
model_state_dict.pop('vit_pos_embed.pos_embed') |
|
|
msg = model.load_state_dict(model_state_dict, strict=False) |
|
|
logger.info(msg) |
|
|
del model_state_dict |
|
|
|
|
|
if ema_model is not None: |
|
|
ema_state_dict_path = os.path.join(resume_from, f"ema.safetensors") |
|
|
if not os.path.exists(ema_state_dict_path): |
|
|
logger.info(f"replicaing ema model from {model_state_dict_path}.") |
|
|
ema_state_dict_path = model_state_dict_path |
|
|
ema_state_dict = load_file(ema_state_dict_path, device="cpu") |
|
|
|
|
|
|
|
|
ema_state_dict.pop('latent_pos_embed.pos_embed') |
|
|
ema_state_dict.pop('vit_pos_embed.pos_embed') |
|
|
msg = ema_model.load_state_dict(ema_state_dict, strict=False) |
|
|
logger.info(msg) |
|
|
del ema_state_dict |
|
|
else: |
|
|
logger.info(f"Training from scratch.") |
|
|
return model, ema_model |
|
|
|
|
|
@staticmethod |
|
|
def try_load_train_state(resume_from, optimizer, scheduler, fsdp_config): |
|
|
if resume_from is not None and os.path.exists(resume_from): |
|
|
if fsdp_config.sharding_strategy == "FULL_SHARD": |
|
|
shard_index = dist.get_rank() |
|
|
total_shards = dist.get_world_size() |
|
|
elif fsdp_config.sharding_strategy == "HYBRID_SHARD": |
|
|
shard_index = dist.get_rank() % fsdp_config.num_shard |
|
|
total_shards = fsdp_config.num_shard |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
optimizer_state_dict_path = os.path.join( |
|
|
resume_from, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" |
|
|
) |
|
|
optimizer_state_dict = torch.load(optimizer_state_dict_path, map_location="cpu", weights_only=True) |
|
|
optimizer.load_state_dict(optimizer_state_dict) |
|
|
del optimizer_state_dict |
|
|
|
|
|
scheduler_state_dict_path = os.path.join(resume_from, "scheduler.pt") |
|
|
scheduler_state_dict = torch.load(scheduler_state_dict_path, weights_only=True, map_location="cpu") |
|
|
scheduler.load_state_dict(scheduler_state_dict) |
|
|
del scheduler_state_dict |
|
|
|
|
|
train_steps = int(os.path.basename(os.path.normpath(resume_from))) + 1 |
|
|
""" |
|
|
data_status = [ |
|
|
{ |
|
|
dataset_name: { |
|
|
worker_id: [parquet_idx, row_group_id, row_idx], |
|
|
}, |
|
|
}, |
|
|
] |
|
|
""" |
|
|
data_status_path = os.path.join(resume_from, "data_status.pt") |
|
|
if os.path.exists(data_status_path): |
|
|
data_status = torch.load(data_status_path, weights_only=True, map_location="cpu") |
|
|
local_rank = dist.get_rank() |
|
|
if local_rank < len(data_status): |
|
|
data_status = data_status[local_rank] |
|
|
else: |
|
|
data_status = None |
|
|
else: |
|
|
data_status = None |
|
|
else: |
|
|
train_steps = 0 |
|
|
data_status = None |
|
|
return optimizer, scheduler, train_steps, data_status |
|
|
|
|
|
|
|
|
def grad_checkpoint_check_fn(module): |
|
|
module_options = ( |
|
|
Qwen2DecoderLayer, |
|
|
SiglipEncoderLayer, |
|
|
MLPconnector, |
|
|
Qwen2MoEDecoderLayer, |
|
|
Qwen2MoTDecoderLayer |
|
|
) |
|
|
return isinstance(module, module_options) |
|
|
|
|
|
|
|
|
def fsdp_ema_setup(ema_model, fsdp_config, ignored_modules=[]): |
|
|
for param in ema_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
ema_model = fsdp_wrapper(ema_model, fsdp_config, ignored_modules=ignored_modules) |
|
|
return ema_model |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def fsdp_ema_update(ema_model, model, decay=0.9999): |
|
|
ema_handles = traversal_utils._get_fsdp_handles(ema_model) |
|
|
new_handles = traversal_utils._get_fsdp_handles(model) |
|
|
assert len(ema_handles) == len(new_handles) |
|
|
ema_params = [] |
|
|
new_params = [] |
|
|
|
|
|
for ema_handle, new_handle in zip(ema_handles, new_handles): |
|
|
if ema_handle.flat_param is not None and new_handle.flat_param.requires_grad: |
|
|
ema_params.append(ema_handle.flat_param.data) |
|
|
new_params.append(new_handle.flat_param.data.to(dtype=ema_handle.flat_param.dtype)) |
|
|
|
|
|
torch._foreach_mul_(ema_params, decay) |
|
|
torch._foreach_add_(ema_params, new_params, alpha=1 - decay) |
|
|
|