|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
import json |
|
import logging |
|
import math |
|
import shutil |
|
import time |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from omegaconf import OmegaConf |
|
import wandb |
|
import torch |
|
from torch.optim import AdamW |
|
from lightning.pytorch.utilities import CombinedLoader |
|
|
|
from transformers import AutoTokenizer, AutoConfig |
|
from accelerate import Accelerator |
|
from accelerate.logging import get_logger |
|
from accelerate.utils import DistributedType, set_seed |
|
|
|
from training.data import Text2ImageDataset |
|
from training.utils import get_config, flatten_omega_conf, image_transform |
|
from training.imagenet_dataset import ImageNetDataset |
|
from parquet import RefinedWebDataset |
|
|
|
from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig |
|
from training.prompting_utils import UniversalPrompting |
|
from models.lr_schedulers import get_scheduler |
|
from models.logging import set_verbosity_info, set_verbosity_error |
|
|
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
SYSTEM_PROMPT_LEN = 28 |
|
|
|
from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter |
|
|
|
try: |
|
import apex |
|
|
|
is_apex_available = True |
|
except ImportError: |
|
is_apex_available = False |
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
|
def get_vq_model_class(model_type): |
|
if model_type == "magvitv2": |
|
return MAGVITv2 |
|
elif model_type == "vq16": |
|
return VQ_16 |
|
else: |
|
raise ValueError(f"model_type {model_type} not supported.") |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
config = get_config() |
|
|
|
|
|
if config.training.enable_tf32: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
|
|
config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") |
|
accelerator = Accelerator( |
|
gradient_accumulation_steps=config.training.gradient_accumulation_steps, |
|
mixed_precision=config.training.mixed_precision, |
|
log_with="wandb", |
|
project_dir=config.experiment.logging_dir, |
|
split_batches=True, |
|
) |
|
|
|
total_batch_size_per_gpu = (config.training.batch_size_t2i |
|
+ config.training.batch_size_lm |
|
+ config.training.batch_size_mmu) |
|
total_batch_size = ( |
|
(config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu) |
|
* accelerator.num_processes * config.training.gradient_accumulation_steps |
|
) |
|
|
|
if accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( |
|
total_batch_size_per_gpu |
|
) |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger.info(accelerator.state, main_process_only=False) |
|
if accelerator.is_local_main_process: |
|
set_verbosity_info() |
|
else: |
|
set_verbosity_error() |
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
resume_wandb_run = config.wandb.resume |
|
run_id = config.wandb.get("run_id", None) |
|
if run_id is None: |
|
resume_wandb_run = False |
|
run_id = wandb.util.generate_id() |
|
config.wandb.run_id = run_id |
|
|
|
wandb_init_kwargs = dict( |
|
name=config.experiment.name, |
|
id=run_id, |
|
resume=resume_wandb_run, |
|
entity=config.wandb.get("entity", None), |
|
config_exclude_keys=[], |
|
) |
|
wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} |
|
wandb_config.pop("experiment.resume_from_checkpoint") |
|
|
|
accelerator.init_trackers( |
|
config.experiment.project, |
|
config=wandb_config, |
|
init_kwargs={"wandb": wandb_init_kwargs}, |
|
) |
|
|
|
if accelerator.is_main_process: |
|
os.makedirs(config.experiment.output_dir, exist_ok=True) |
|
config_path = Path(config.experiment.output_dir) / "config.yaml" |
|
logging.info(f"Saving config to {config_path}") |
|
OmegaConf.save(config, config_path) |
|
|
|
|
|
if config.training.seed is not None: |
|
set_seed(config.training.seed) |
|
|
|
|
|
|
|
|
|
logger.info("Loading models and optimizer") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") |
|
|
|
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, |
|
special_tokens=( |
|
"<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", |
|
"<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>" |
|
), |
|
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) |
|
|
|
print('special tokens : \n', uni_prompting.sptids_dict) |
|
|
|
|
|
vq_model = get_vq_model_class(config.model.vq_model.type) |
|
if config.model.vq_model.get("pretrained_model_path", None): |
|
vq_model = vq_model().to(accelerator.device) |
|
state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] |
|
vq_model.load_state_dict(state_dict) |
|
else: |
|
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) |
|
vq_model.eval() |
|
vq_model.requires_grad_(False) |
|
|
|
|
|
base_config = AutoConfig.from_pretrained(config.model.mmada.pretrained_model_path).to_dict() |
|
mmada_config_dict = {k: v for k, v in config.model.mmada.items()} |
|
merged_config = {**base_config, **mmada_config_dict} |
|
mmada_config = MMadaConfig(**merged_config) |
|
model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16, config=mmada_config) |
|
model.resize_token_embeddings(mmada_config.new_vocab_size) |
|
model.config.embedding_size = model.config.vocab_size |
|
model = model.to(accelerator.device) |
|
|
|
mask_id = model.config.mask_token_id |
|
|
|
|
|
|
|
|
|
optimizer_config = config.optimizer.params |
|
|
|
|
|
no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in model.named_parameters() if |
|
p.requires_grad and not any(nd in n for nd in no_decay)], |
|
"weight_decay": optimizer_config.weight_decay, |
|
}, |
|
{ |
|
"params": [p for n, p in model.named_parameters() if |
|
p.requires_grad and any(nd in n for nd in no_decay)], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
|
|
optimizer_type = config.optimizer.name |
|
if optimizer_type == "adamw": |
|
optimizer = AdamW( |
|
optimizer_grouped_parameters, |
|
lr=optimizer_config.learning_rate, |
|
betas=(optimizer_config.beta1, optimizer_config.beta2), |
|
weight_decay=optimizer_config.weight_decay, |
|
eps=optimizer_config.epsilon, |
|
) |
|
else: |
|
raise ValueError(f"Optimizer {optimizer_type} not supported") |
|
|
|
|
|
if config.get("mask_schedule", None) is not None: |
|
schedule = config.mask_schedule.schedule |
|
args = config.mask_schedule.get("params", {}) |
|
mask_schedule = get_mask_schedule(schedule, **args) |
|
else: |
|
mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) |
|
|
|
lr_scheduler = get_scheduler( |
|
config.lr_scheduler.scheduler, |
|
optimizer=optimizer, |
|
num_training_steps=config.training.max_train_steps, |
|
num_warmup_steps=config.lr_scheduler.params.warmup_steps, |
|
min_lr_scale=config.lr_scheduler.params.min_lr_scale |
|
) |
|
|
|
|
|
|
|
|
|
logger.info("Creating dataloaders and lr_scheduler") |
|
|
|
total_batch_size_t2i_without_accum = config.training.batch_size_t2i * accelerator.num_processes |
|
total_batch_size_t2i = ( |
|
config.training.batch_size_t2i * accelerator.num_processes * config.training.gradient_accumulation_steps |
|
) |
|
|
|
|
|
|
|
|
|
|
|
preproc_config = config.dataset.preprocessing |
|
dataset_config = config.dataset.params |
|
|
|
|
|
if config.dataset.gen_type == "t2i": |
|
dataset = Text2ImageDataset( |
|
train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, |
|
tokenizer=None, |
|
max_seq_length=preproc_config.max_seq_length, |
|
num_train_examples=config.experiment.max_train_examples_t2i, |
|
per_gpu_batch_size=config.training.batch_size_t2i, |
|
global_batch_size=total_batch_size_t2i_without_accum, |
|
num_workers=dataset_config.num_workers, |
|
resolution=preproc_config.resolution, |
|
shuffle_buffer_size=dataset_config.shuffle_buffer_size, |
|
pin_memory=dataset_config.pin_memory, |
|
persistent_workers=dataset_config.persistent_workers, |
|
external_caption_path=dataset_config.external_caption_path, |
|
external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, |
|
external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, |
|
external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, |
|
) |
|
train_dataloader_t2i = dataset.train_dataloader |
|
num_update_steps_per_epoch = math.ceil( |
|
train_dataloader_t2i.num_batches / config.training.gradient_accumulation_steps) |
|
num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) |
|
|
|
elif config.dataset.gen_type == "t2i_parquet": |
|
|
|
num_update_steps_per_epoch = math.ceil(config.experiment.max_train_examples_t2i / total_batch_size_t2i) |
|
num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) |
|
|
|
train_dataloader_t2i = create_imagetext_dataloader( |
|
train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, |
|
batch_size=config.training.batch_size_t2i, |
|
image_size=preproc_config.resolution, |
|
num_workers=dataset_config.num_workers, |
|
num_readers=32, |
|
predefined_steps=num_update_steps_per_epoch, |
|
drop_last=True, |
|
shuffle=True, |
|
shuffle_buffer_size=dataset_config.shuffle_buffer_size |
|
) |
|
|
|
elif config.dataset.gen_type == "imagenet1k": |
|
dataset_imagenet = ImageNetDataset( |
|
dataset_config.train_t2i_shards_path_or_url, |
|
image_size=preproc_config.resolution, |
|
) |
|
|
|
print('process index : ', |
|
accelerator.process_index, ', ', accelerator.num_processes, |
|
"Length: ", len(dataset_imagenet)) |
|
|
|
if accelerator.num_processes > 1: |
|
sampler = DistributedSampler(dataset_imagenet, |
|
num_replicas=accelerator.num_processes, |
|
rank=accelerator.process_index, |
|
shuffle=True, |
|
) |
|
shuffle = False |
|
else: |
|
sampler = None |
|
shuffle = True |
|
|
|
train_dataloader_t2i = DataLoader(dataset_imagenet, batch_size=config.training.batch_size_t2i, |
|
sampler=sampler, collate_fn=dataset_imagenet.collate_fn, |
|
shuffle=shuffle, num_workers=dataset_config.num_workers) |
|
num_update_steps_per_epoch = math.ceil(len(dataset_imagenet) / total_batch_size_t2i) |
|
num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) |
|
|
|
else: |
|
raise ValueError(f"Unsupported dataset type {config.dataset.type}") |
|
|
|
total_batch_size_mmu_without_accum = config.training.batch_size_mmu * accelerator.num_processes |
|
|
|
if config.dataset.und_type == "captioning": |
|
dataset_mmu = Text2ImageDataset( |
|
train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, |
|
tokenizer=None, |
|
max_seq_length=preproc_config.max_seq_length, |
|
num_train_examples=config.experiment.max_train_examples_mmu, |
|
per_gpu_batch_size=config.training.batch_size_mmu, |
|
global_batch_size=total_batch_size_mmu_without_accum, |
|
num_workers=dataset_config.num_workers, |
|
resolution=preproc_config.resolution, |
|
shuffle_buffer_size=dataset_config.shuffle_buffer_size, |
|
pin_memory=dataset_config.pin_memory, |
|
persistent_workers=dataset_config.persistent_workers, |
|
external_caption_path=dataset_config.external_caption_path, |
|
external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, |
|
external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, |
|
external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, |
|
is_captioning=True, |
|
add_caption_prompt=dataset_config.add_caption_prompt, |
|
) |
|
train_dataloader_mmu = dataset_mmu.train_dataloader |
|
|
|
elif config.dataset.und_type == "captioning_parquet": |
|
train_dataloader_mmu = create_imagetext_dataloader( |
|
train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, |
|
batch_size=config.training.batch_size_mmu, |
|
image_size=preproc_config.resolution, |
|
num_workers=dataset_config.num_workers, |
|
num_readers=32, |
|
predefined_steps=num_update_steps_per_epoch, |
|
drop_last=True, |
|
shuffle=True, |
|
shuffle_buffer_size=dataset_config.shuffle_buffer_size, |
|
is_captioning=True |
|
) |
|
|
|
else: |
|
raise NotImplementedError(f"Unsupported dataset type {config.dataset.und_type}") |
|
|
|
|
|
dataset_lm = RefinedWebDataset(data_path=dataset_config.train_lm_shards_path_or_url, |
|
rank=accelerator.process_index, |
|
world_size=accelerator.num_processes, |
|
num_workers=dataset_config.num_workers) |
|
|
|
train_dataloader_lm = torch.utils.data.DataLoader(dataset_lm, batch_size=config.training.batch_size_lm, |
|
sampler=None, collate_fn=dataset_lm.collate_fn, |
|
num_workers=dataset_config.num_workers) |
|
|
|
|
|
iterables = { |
|
"t2i_flow": train_dataloader_t2i, |
|
"lm_flow": train_dataloader_lm, |
|
"mmu_flow": train_dataloader_mmu, |
|
} |
|
|
|
combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) |
|
|
|
|
|
|
|
|
|
global_step = 0 |
|
first_epoch = 0 |
|
|
|
if config.experiment.resume_from_checkpoint: |
|
dirs = os.listdir(config.experiment.output_dir) |
|
logger.info(f"dirs: {dirs}") |
|
dirs = [d for d in dirs if d.startswith("checkpoint")] |
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
|
path = dirs[-1] if len(dirs) > 0 else None |
|
logger.info(f"path: {path}") |
|
if path is not None: |
|
path = os.path.join(config.experiment.output_dir, path) |
|
logger.info(f"Resuming from checkpoint: {path}") |
|
global_step = int(os.path.basename(path).split("-")[1]) |
|
first_epoch = global_step // num_update_steps_per_epoch |
|
if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): |
|
state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") |
|
model.load_state_dict(state_dict, strict=True) |
|
del state_dict |
|
elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): |
|
from safetensors.torch import load_file |
|
from transformers.modeling_utils import load_sharded_checkpoint |
|
load_sharded_checkpoint(model, f'{path}/unwrapped_model/') |
|
|
|
elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): |
|
from transformers.modeling_utils import load_sharded_checkpoint |
|
load_sharded_checkpoint( |
|
model, |
|
f'{path}/unwrapped_model/', |
|
|
|
|
|
) |
|
else: |
|
raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") |
|
else: |
|
logger.info("Not resuming from checkpoint") |
|
|
|
|
|
|
|
|
|
logger.info("Preparing model, optimizer and dataloaders") |
|
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) |
|
|
|
vq_model.to(device=accelerator.device) |
|
|
|
mask_dtype = model.get_input_embeddings().weight.dtype |
|
|
|
|
|
|
|
|
|
logger.info("***** Running training *****") |
|
logger.info(f" Num training steps = {config.training.max_train_steps}") |
|
logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") |
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
|
logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") |
|
|
|
@torch.no_grad() |
|
def prepare_inputs_and_labels( |
|
pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], |
|
texts: Union[str, str], |
|
min_masking_rate: float = 0.0, |
|
is_train: bool = True, |
|
): |
|
|
|
image_tokens = vq_model.get_code(pixel_values_or_image_ids) |
|
image_tokens = image_tokens + len(uni_prompting.text_tokenizer) |
|
|
|
input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( |
|
image_tokens, |
|
mask_id, |
|
config, |
|
mask_schedule=mask_schedule, |
|
is_train=is_train, |
|
) |
|
input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') |
|
return input_ids, labels, mask_prob, image_tokens, masks |
|
|
|
@torch.no_grad() |
|
def prepare_inputs_and_labels_for_text( |
|
texts: Union[str, str], max_seq_len, eps=1e-3 |
|
): |
|
|
|
|
|
input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') |
|
b, l = input_ids_lm.shape |
|
t = torch.rand(b, device=input_ids_lm.device) |
|
p_mask = (1 - eps) * t + eps |
|
p_mask = p_mask[:, None].repeat(1, l) |
|
|
|
masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask |
|
|
|
noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) |
|
masked_indices = noisy_batch == mask_id |
|
|
|
return noisy_batch, labels_lm, p_mask |
|
|
|
@torch.no_grad() |
|
def prepare_inputs_and_labels_for_mmu( |
|
input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 |
|
): |
|
b, l = input_ids_mmu.shape |
|
t = torch.rand(b, device=input_ids_mmu.device) |
|
p_mask = (1 - eps) * t + eps |
|
p_mask = p_mask[:, None].repeat(1, l) |
|
|
|
masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask |
|
|
|
noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) |
|
masked_indices = noisy_batch == mask_id |
|
noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] |
|
masked_indices = noisy_batch == mask_id |
|
|
|
prompt_masks = prompt_masks.to(torch.int64) |
|
answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) |
|
answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) |
|
|
|
return noisy_batch, labels_mmu, p_mask, answer_lengths |
|
|
|
|
|
|
|
batch_time_m = AverageMeter() |
|
data_time_m = AverageMeter() |
|
end = time.time() |
|
|
|
for epoch in range(first_epoch, num_train_epochs): |
|
model.train() |
|
for batch, batch_idx, dataloader_idx in combined_dataloader: |
|
|
|
batch_size_t2i = batch["t2i_flow"]["images"].shape[0] |
|
batch_size_lm = len(batch["lm_flow"]["input_ids"]) |
|
batch_size_mmu = batch["mmu_flow"]["images"].shape[0] |
|
|
|
|
|
|
|
|
|
pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["input_ids"] |
|
pixel_values = pixel_values.to(accelerator.device, non_blocking=True) |
|
data_time_m.update(time.time() - end) |
|
|
|
|
|
( |
|
input_ids, |
|
labels, |
|
mask_prob, |
|
image_tokens_ori, |
|
t2i_masks |
|
) = prepare_inputs_and_labels(pixel_values, texts, config.training.min_masking_rate) |
|
|
|
|
|
|
|
|
|
max_seq_len = input_ids.shape[-1] |
|
texts_lm = batch["lm_flow"]["input_ids"] |
|
( |
|
input_ids_lm, |
|
labels_lm, |
|
p_mask_lm |
|
) = prepare_inputs_and_labels_for_text(texts_lm, max_seq_len) |
|
input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) |
|
labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if "llava" in config.dataset.und_type: |
|
pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"],batch["mmu_flow"]["labels"]) |
|
pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) |
|
input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) |
|
image_tokens_mmu = vq_model.get_code(pixel_values_mmu) |
|
image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) |
|
|
|
input_ids_mmu = torch.cat([ |
|
(torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to( |
|
accelerator.device), |
|
(torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to( |
|
accelerator.device), |
|
image_tokens_mmu, |
|
(torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to( |
|
accelerator.device), |
|
input_ids_mmu, |
|
], dim=1).long() |
|
|
|
labels_mmu = torch.cat([ |
|
(torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), |
|
(torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), |
|
torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id, |
|
(torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), |
|
labels_mmu.to(accelerator.device) |
|
], dim=1).long() |
|
|
|
else: |
|
pixel_values_mmu, texts_mmu = batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"] |
|
pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) |
|
image_tokens_mmu = vq_model.get_code(pixel_values_mmu) |
|
image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) |
|
|
|
input_ids_mmu, prompt_masks, labels_mmu = uni_prompting((image_tokens_mmu, texts_mmu), 'mmu') |
|
( |
|
input_ids_mmu, |
|
labels_mmu, |
|
p_mask_mmu, |
|
answer_lengths |
|
) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu) |
|
input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) |
|
|
|
input_ids = torch.cat((input_ids, input_ids_mmu.to(input_ids.device)), dim=0) |
|
labels = torch.cat((labels, labels_mmu.to(input_ids.device)), dim=0) |
|
|
|
if global_step == 0 and epoch == 0: |
|
logger.info("Input ids: {}".format(input_ids)) |
|
logger.info("Labels: {}".format(labels)) |
|
|
|
with accelerator.accumulate(model): |
|
logits, loss_t2i, loss_lm, loss_mmu = model.forward_process( |
|
input_ids=input_ids, |
|
labels=labels, |
|
batch_size_t2i=batch_size_t2i, |
|
batch_size_lm=batch_size_lm, |
|
batch_size_mmu=batch_size_mmu, |
|
max_seq_length=config.dataset.preprocessing.max_seq_length, |
|
p_mask_lm=p_mask_lm, |
|
p_mask_mmu=p_mask_mmu, |
|
answer_lengths=answer_lengths, |
|
t2i_masks=t2i_masks |
|
) |
|
|
|
avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() |
|
avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() |
|
avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() |
|
loss = config.training.t2i_coeff * loss_t2i + \ |
|
config.training.lm_coeff * loss_lm + \ |
|
config.training.mmu_coeff * loss_mmu |
|
|
|
avg_masking_rate = accelerator.gather(mask_prob.repeat(config.training.batch_size_t2i)).mean() |
|
|
|
accelerator.backward(loss) |
|
|
|
if config.training.max_grad_norm is not None and accelerator.sync_gradients: |
|
accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) |
|
|
|
optimizer.step() |
|
lr_scheduler.step() |
|
|
|
|
|
if ( |
|
accelerator.sync_gradients |
|
and (global_step + 1) % config.experiment.log_grad_norm_every == 0 |
|
and accelerator.is_main_process |
|
): |
|
log_grad_norm(model, accelerator, global_step + 1) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
|
batch_time_m.update(time.time() - end) |
|
end = time.time() |
|
|
|
|
|
if (global_step + 1) % config.experiment.log_every == 0: |
|
samples_per_second_per_gpu = ( |
|
config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val |
|
) |
|
logs = { |
|
"step_loss_t2i": avg_loss_t2i.item(), |
|
"step_loss_mmu": avg_loss_mmu.item(), |
|
"step_loss_lm": avg_loss_lm.item(), |
|
"lr": lr_scheduler.get_last_lr()[0], |
|
"avg_masking_rate": avg_masking_rate.item(), |
|
"samples/sec/gpu": samples_per_second_per_gpu, |
|
"data_time": data_time_m.val, |
|
"batch_time": batch_time_m.val, |
|
} |
|
accelerator.log(logs, step=global_step + 1) |
|
|
|
logger.info( |
|
f"Step: {global_step + 1} " |
|
f"Loss_t2i: {avg_loss_t2i.item():0.4f} " |
|
f"Loss_mmu: {avg_loss_mmu.item():0.4f} " |
|
f"Loss_lm: {avg_loss_lm.item():0.4f} " |
|
f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " |
|
f"Batch (t): {batch_time_m.val:0.4f} " |
|
f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" |
|
) |
|
|
|
|
|
batch_time_m.reset() |
|
data_time_m.reset() |
|
|
|
|
|
if (global_step + 1) % config.experiment.save_every == 0: |
|
save_checkpoint(model, config, accelerator, global_step + 1) |
|
|
|
if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == 0) and accelerator.is_main_process: |
|
generate_images( |
|
model, |
|
vq_model, |
|
uni_prompting, |
|
accelerator, |
|
config, |
|
global_step + 1, |
|
mask_schedule=mask_schedule, |
|
) |
|
|
|
visualize_predictions( |
|
model, |
|
vq_model, |
|
uni_prompting, |
|
config, |
|
global_step + 1, |
|
input_ids, |
|
image_tokens_ori, |
|
batch["t2i_flow"]["images"], |
|
texts, |
|
logits, |
|
accelerator |
|
) |
|
|
|
understanding_images( |
|
model, |
|
vq_model, |
|
uni_prompting, |
|
accelerator, |
|
config, |
|
global_step + 1, |
|
) |
|
|
|
global_step += 1 |
|
|
|
if global_step >= config.training.max_train_steps: |
|
break |
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
save_checkpoint(model, config, accelerator, global_step) |
|
|
|
|
|
if accelerator.is_main_process: |
|
model = accelerator.unwrap_model(model) |
|
model.save_pretrained(config.experiment.output_dir, safe_serialization=True) |
|
|
|
accelerator.end_training() |
|
|
|
|
|
@torch.no_grad() |
|
def visualize_predictions( |
|
model, |
|
vq_model, |
|
uni_prompting, |
|
config, |
|
global_step, |
|
input_ids, |
|
image_tokens_ori, |
|
ori_images, |
|
texts, |
|
logits, |
|
accelerator |
|
): |
|
logger.info("Visualizing predictions...") |
|
model.eval() |
|
|
|
recons_images = vq_model.decode_code(image_tokens_ori - len(uni_prompting.text_tokenizer)) |
|
recons_images = torch.clamp((recons_images + 1.0) / 2.0, min=0.0, max=1.0) |
|
recons_images *= 255.0 |
|
recons_images = recons_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
|
|
|
images = torch.clamp((ori_images + 1.0) / 2.0, min=0.0, max=1.0) |
|
images *= 255.0 |
|
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
|
predictions = logits[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:, len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens: len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens + config.model.mmada.codebook_size] |
|
|
|
predictions = predictions.argmax(axis=-1) |
|
mask_token_id = accelerator.unwrap_model(model).config.mask_token_id - len(uni_prompting.text_tokenizer) |
|
input_ids = input_ids[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:] - len(uni_prompting.text_tokenizer) |
|
mask_ratio = list((torch.where(input_ids == mask_token_id, 1, 0).sum( |
|
dim=-1) / config.model.mmada.num_vq_tokens).cpu().numpy()) |
|
predicted_images = torch.where(input_ids == mask_token_id, predictions, input_ids) |
|
predicted_images = vq_model.decode_code(predicted_images) |
|
predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0) |
|
predicted_images *= 255.0 |
|
predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
|
predicted_images = np.concatenate((images, recons_images, predicted_images), 2) |
|
pil_images = [Image.fromarray(image) for image in predicted_images] |
|
|
|
|
|
wandb_images = [wandb.Image(image, caption=f'mask ratio: {r:0.2f} \n caption: {texts[i]}') for i, (image, r) in |
|
enumerate(zip(pil_images, mask_ratio))] |
|
wandb.log({"Original images v.s. Reconstructed images v.s. Predicted images": wandb_images}, step=global_step) |
|
|
|
model.train() |
|
|
|
|
|
@torch.no_grad() |
|
def generate_images( |
|
model, |
|
vq_model, |
|
uni_prompting, |
|
accelerator, |
|
config, |
|
global_step, |
|
mask_schedule, |
|
): |
|
logger.info("Generating images...") |
|
model.eval() |
|
|
|
|
|
with open(config.dataset.params.validation_prompts_file, "r") as f: |
|
validation_prompts = f.read().splitlines() |
|
|
|
|
|
mask_dtype = model.get_input_embeddings().weight.dtype |
|
mask_token_id = accelerator.unwrap_model(model).config.mask_token_id |
|
image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, |
|
device=accelerator.device) * mask_token_id |
|
input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') |
|
if config.training.guidance_scale > 0: |
|
uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') |
|
else: |
|
uncond_input_ids = None |
|
uncond_attention_mask = None |
|
if accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
else: |
|
weight_dtype = torch.float32 |
|
|
|
with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): |
|
|
|
gen_token_ids = accelerator.unwrap_model(model).t2i_generate( |
|
input_ids=input_ids, |
|
uncond_input_ids=uncond_input_ids, |
|
attention_mask=attention_mask, |
|
uncond_attention_mask=uncond_attention_mask, |
|
guidance_scale=config.training.guidance_scale, |
|
temperature=config.training.get("generation_temperature", 1.0), |
|
timesteps=config.training.generation_timesteps, |
|
noise_schedule=mask_schedule, |
|
noise_type=config.training.get("noise_type", "mask"), |
|
predict_all_tokens=config.training.get("predict_all_tokens", False), |
|
seq_len=config.model.mmada.num_vq_tokens, |
|
uni_prompting=uni_prompting, |
|
config=config, |
|
) |
|
|
|
|
|
gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) |
|
images = vq_model.decode_code(gen_token_ids) |
|
|
|
model.train() |
|
|
|
if config.training.get("pre_encode", False): |
|
del vq_model |
|
|
|
|
|
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) |
|
images *= 255.0 |
|
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
|
|
wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)] |
|
wandb.log({"Generated images": wandb_images}, step=global_step) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def understanding_images( |
|
model, |
|
vq_model, |
|
uni_prompting, |
|
accelerator, |
|
config, |
|
global_step, |
|
): |
|
logger.info("Understanding images...") |
|
model.eval() |
|
|
|
file_list = os.listdir(config.dataset.params.mmu_image_root) |
|
file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] |
|
responses = ['' for i in range(len(file_list))] |
|
images = [] |
|
|
|
device = accelerator.device |
|
|
|
if accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
else: |
|
weight_dtype = torch.float32 |
|
|
|
for i, file_name in enumerate(file_list): |
|
image_path = os.path.join(config.dataset.params.mmu_image_root, file_name) |
|
image_ori = Image.open(image_path).convert("RGB") |
|
image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) |
|
image = image.unsqueeze(0) |
|
images.append(image) |
|
image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) |
|
batch_size = 1 |
|
|
|
input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." +'<eot_id><|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] |
|
input_ids = torch.tensor(input_ids).to(device) |
|
|
|
input_ids = torch.cat([ |
|
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), |
|
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), |
|
image_tokens, |
|
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), |
|
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), |
|
input_ids |
|
], dim=1).long() |
|
with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): |
|
output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids) |
|
|
|
|
|
text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) |
|
responses[i] += text[0] |
|
model.train() |
|
images = torch.cat(images, dim=0) |
|
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) |
|
images *= 255.0 |
|
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
|
|
wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] |
|
wandb.log({"Understanding images": wandb_images}, step=global_step) |
|
|
|
|
|
def save_checkpoint(model, config, accelerator, global_step): |
|
output_dir = config.experiment.output_dir |
|
checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) |
|
|
|
|
|
if accelerator.is_main_process and checkpoints_total_limit is not None: |
|
checkpoints = os.listdir(output_dir) |
|
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] |
|
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) |
|
|
|
|
|
if len(checkpoints) >= checkpoints_total_limit: |
|
num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 |
|
removing_checkpoints = checkpoints[0:num_to_remove] |
|
|
|
logger.info( |
|
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
|
) |
|
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") |
|
|
|
for removing_checkpoint in removing_checkpoints: |
|
removing_checkpoint = os.path.join(output_dir, removing_checkpoint) |
|
shutil.rmtree(removing_checkpoint) |
|
|
|
save_path = Path(output_dir) / f"checkpoint-{global_step}" |
|
|
|
|
|
|
|
state_dict = accelerator.get_state_dict(model) |
|
if accelerator.is_main_process: |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.save_pretrained( |
|
save_path / "unwrapped_model", |
|
save_function=accelerator.save, |
|
state_dict=state_dict, |
|
safe_serialization=True |
|
) |
|
json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) |
|
logger.info(f"Saved state to {save_path}") |
|
|
|
|
|
def log_grad_norm(model, accelerator, global_step): |
|
for name, param in model.named_parameters(): |
|
if param.grad is not None: |
|
grads = param.grad.detach().data |
|
grad_norm = (grads.norm(p=2) / grads.numel()).item() |
|
accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|