|
import logging |
|
import math |
|
import random |
|
from datetime import timedelta |
|
from pathlib import Path |
|
|
|
import hydra |
|
import numpy as np |
|
import torch |
|
import torch.distributed as distributed |
|
from hydra import compose |
|
from hydra.core.hydra_config import HydraConfig |
|
from omegaconf import DictConfig, open_dict |
|
from torch.distributed.elastic.multiprocessing.errors import record |
|
|
|
from mmaudio.data.data_setup import setup_training_datasets, setup_val_datasets |
|
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K |
|
from mmaudio.runner import Runner |
|
from mmaudio.sample import sample |
|
from mmaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size |
|
from mmaudio.utils.logger import TensorboardLogger |
|
from mmaudio.utils.synthesize_ema import synthesize_ema |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
log = logging.getLogger() |
|
|
|
|
|
def distributed_setup(): |
|
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2)) |
|
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') |
|
return local_rank, world_size |
|
|
|
|
|
@record |
|
@hydra.main(version_base='1.3.2', config_path='config', config_name='train_config.yaml') |
|
def train(cfg: DictConfig): |
|
|
|
torch.cuda.set_device(local_rank) |
|
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark |
|
distributed_setup() |
|
num_gpus = world_size |
|
run_dir = HydraConfig.get().run.dir |
|
|
|
|
|
eval_cfg = compose('eval_config', overrides=[f'exp_id={cfg.exp_id}']) |
|
|
|
|
|
if cfg.model.endswith('16k'): |
|
seq_cfg = CONFIG_16K |
|
elif cfg.model.endswith('44k'): |
|
seq_cfg = CONFIG_44K |
|
else: |
|
raise ValueError(f'Unknown model: {cfg.model}') |
|
with open_dict(cfg): |
|
cfg.data_dim.latent_seq_len = seq_cfg.latent_seq_len |
|
cfg.data_dim.clip_seq_len = seq_cfg.clip_seq_len |
|
cfg.data_dim.sync_seq_len = seq_cfg.sync_seq_len |
|
|
|
|
|
log = TensorboardLogger(cfg.exp_id, |
|
run_dir, |
|
logging.getLogger(), |
|
is_rank0=(local_rank == 0), |
|
enable_email=cfg.enable_email and not cfg.debug) |
|
|
|
info_if_rank_zero(log, f'All configuration: {cfg}') |
|
info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') |
|
|
|
|
|
info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') |
|
|
|
|
|
torch.manual_seed(cfg.seed) |
|
np.random.seed(cfg.seed) |
|
random.seed(cfg.seed) |
|
|
|
|
|
info_if_rank_zero(log, f'Training configuration: {cfg}') |
|
cfg.batch_size //= num_gpus |
|
info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') |
|
|
|
|
|
total_iterations = cfg['num_iterations'] |
|
|
|
|
|
dataset, sampler, loader = setup_training_datasets(cfg) |
|
info_if_rank_zero(log, f'Number of training samples: {len(dataset)}') |
|
info_if_rank_zero(log, f'Number of training batches: {len(loader)}') |
|
|
|
val_dataset, val_loader, eval_loader = setup_val_datasets(cfg) |
|
info_if_rank_zero(log, f'Number of val samples: {len(val_dataset)}') |
|
val_cfg = cfg.data.ExtractedVGG_val |
|
|
|
|
|
latent_mean, latent_std = dataset.compute_latent_stats() |
|
|
|
|
|
trainer = Runner(cfg, |
|
log=log, |
|
run_path=run_dir, |
|
for_training=True, |
|
latent_mean=latent_mean, |
|
latent_std=latent_std).enter_train() |
|
eval_rng_clone = trainer.rng.graphsafe_get_state() |
|
|
|
|
|
if cfg['checkpoint'] is not None: |
|
curr_iter = trainer.load_checkpoint(cfg['checkpoint']) |
|
cfg['checkpoint'] = None |
|
info_if_rank_zero(log, 'Model checkpoint loaded!') |
|
else: |
|
|
|
checkpoint = trainer.get_latest_checkpoint_path() |
|
if checkpoint is not None: |
|
curr_iter = trainer.load_checkpoint(checkpoint) |
|
info_if_rank_zero(log, 'Latest checkpoint loaded!') |
|
else: |
|
|
|
curr_iter = 0 |
|
if cfg['weights'] is not None: |
|
info_if_rank_zero(log, 'Loading weights from the disk') |
|
trainer.load_weights(cfg['weights']) |
|
cfg['weights'] = None |
|
|
|
|
|
total_epoch = math.ceil(total_iterations / len(loader)) |
|
current_epoch = curr_iter // len(loader) |
|
info_if_rank_zero(log, f'We will approximately use {total_epoch} epochs.') |
|
|
|
|
|
try: |
|
|
|
np.random.seed(np.random.randint(2**30 - 1) + local_rank * 1000) |
|
while curr_iter < total_iterations: |
|
|
|
sampler.set_epoch(current_epoch) |
|
current_epoch += 1 |
|
log.debug(f'Current epoch: {current_epoch}') |
|
|
|
trainer.enter_train() |
|
trainer.log.data_timer.start() |
|
for data in loader: |
|
trainer.train_pass(data, curr_iter) |
|
|
|
if (curr_iter + 1) % cfg.val_interval == 0: |
|
|
|
train_rng_snapshot = trainer.rng.graphsafe_get_state() |
|
trainer.rng.graphsafe_set_state(eval_rng_clone) |
|
info_if_rank_zero(log, f'Iteration {curr_iter}: validating') |
|
for data in val_loader: |
|
trainer.validation_pass(data, curr_iter) |
|
distributed.barrier() |
|
trainer.val_integrator.finalize('val', curr_iter, ignore_timer=True) |
|
trainer.rng.graphsafe_set_state(train_rng_snapshot) |
|
|
|
if (curr_iter + 1) % cfg.eval_interval == 0: |
|
save_eval = (curr_iter + 1) % cfg.save_eval_interval == 0 |
|
train_rng_snapshot = trainer.rng.graphsafe_get_state() |
|
trainer.rng.graphsafe_set_state(eval_rng_clone) |
|
info_if_rank_zero(log, f'Iteration {curr_iter}: validating') |
|
for data in eval_loader: |
|
audio_path = trainer.inference_pass(data, |
|
curr_iter, |
|
val_cfg, |
|
save_eval=save_eval) |
|
distributed.barrier() |
|
trainer.rng.graphsafe_set_state(train_rng_snapshot) |
|
trainer.eval(audio_path, curr_iter, val_cfg) |
|
|
|
curr_iter += 1 |
|
|
|
if curr_iter >= total_iterations: |
|
break |
|
except Exception as e: |
|
log.error(f'Error occurred at iteration {curr_iter}!') |
|
log.critical(e.message if hasattr(e, 'message') else str(e)) |
|
raise |
|
finally: |
|
if not cfg.debug: |
|
trainer.save_checkpoint(curr_iter) |
|
trainer.save_weights(curr_iter) |
|
|
|
|
|
del trainer |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if local_rank == 0: |
|
log.info(f'Synthesizing EMA with sigma={cfg.ema.default_output_sigma}') |
|
ema_sigma = cfg.ema.default_output_sigma |
|
state_dict = synthesize_ema(cfg, ema_sigma, step=None) |
|
save_dir = Path(run_dir) / f'{cfg.exp_id}_ema_final.pth' |
|
torch.save(state_dict, save_dir) |
|
log.info(f'Synthesized EMA saved to {save_dir}!') |
|
distributed.barrier() |
|
|
|
log.info(f'Evaluation: {eval_cfg}') |
|
sample(eval_cfg) |
|
|
|
|
|
log.complete() |
|
distributed.barrier() |
|
distributed.destroy_process_group() |
|
|
|
|
|
if __name__ == '__main__': |
|
train() |
|
|