|
import logging |
|
import os |
|
from argparse import ArgumentParser |
|
from datetime import timedelta |
|
from pathlib import Path |
|
|
|
import pandas as pd |
|
import tensordict as td |
|
import torch |
|
import torch.distributed as distributed |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from tqdm import tqdm |
|
|
|
from mmaudio.data.data_setup import error_avoidance_collate |
|
from mmaudio.data.extraction.vgg_sound import VGGSound |
|
from mmaudio.model.utils.features_utils import FeaturesUtils |
|
from mmaudio.utils.dist_utils import local_rank, world_size |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
SAMPLING_RATE = 16000 |
|
DURATION_SEC = 8.0 |
|
NUM_SAMPLES = 128000 |
|
vae_path = './ext_weights/v1-16.pth' |
|
bigvgan_path = './ext_weights/best_netG.pt' |
|
mode = '16k' |
|
|
|
|
|
""" |
|
NOTE: 352800 (8*44100) is not divisible by (STFT hop size * VAE downsampling ratio) which is 1024. |
|
353280 is the next integer divisible by 1024. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
synchformer_ckpt = './ext_weights/synchformer_state_dict.pth' |
|
|
|
|
|
BATCH_SIZE = 16 |
|
NUM_WORKERS = 16 |
|
|
|
log = logging.getLogger() |
|
log.setLevel(logging.INFO) |
|
|
|
|
|
data_cfg = { |
|
'example': { |
|
'root': './training/example_videos', |
|
'subset_name': './training/example_video.tsv', |
|
'normalize_audio': True, |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
def distributed_setup(): |
|
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=1)) |
|
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') |
|
return local_rank, world_size |
|
|
|
|
|
def setup_dataset(split: str): |
|
dataset = VGGSound( |
|
data_cfg[split]['root'], |
|
tsv_path=data_cfg[split]['subset_name'], |
|
sample_rate=SAMPLING_RATE, |
|
duration_sec=DURATION_SEC, |
|
audio_samples=NUM_SAMPLES, |
|
normalize_audio=data_cfg[split]['normalize_audio'], |
|
) |
|
sampler = DistributedSampler(dataset, rank=local_rank, shuffle=False) |
|
loader = DataLoader(dataset, |
|
batch_size=BATCH_SIZE, |
|
num_workers=NUM_WORKERS, |
|
sampler=sampler, |
|
drop_last=False, |
|
collate_fn=error_avoidance_collate) |
|
|
|
return dataset, loader |
|
|
|
|
|
@torch.inference_mode() |
|
def extract(): |
|
|
|
distributed_setup() |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('--latent_dir', |
|
type=Path, |
|
default='./training/example_output/video-latents') |
|
parser.add_argument('--output_dir', type=Path, default='./training/example_output/memmap') |
|
args = parser.parse_args() |
|
|
|
latent_dir = args.latent_dir |
|
output_dir = args.output_dir |
|
|
|
|
|
torch.cuda.set_device(local_rank) |
|
feature_extractor = FeaturesUtils(tod_vae_ckpt=vae_path, |
|
enable_conditions=True, |
|
bigvgan_vocoder_ckpt=bigvgan_path, |
|
synchformer_ckpt=synchformer_ckpt, |
|
mode=mode).eval().cuda() |
|
|
|
for split in data_cfg.keys(): |
|
print(f'Extracting latents for the {split} split') |
|
this_latent_dir = latent_dir / split |
|
this_latent_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
dataset, loader = setup_dataset(split) |
|
log.info(f'Number of samples: {len(dataset)}') |
|
log.info(f'Number of batches: {len(loader)}') |
|
|
|
for curr_iter, data in enumerate(tqdm(loader)): |
|
output = { |
|
'id': data['id'], |
|
'caption': data['caption'], |
|
} |
|
|
|
audio = data['audio'].cuda() |
|
dist = feature_extractor.encode_audio(audio) |
|
output['mean'] = dist.mean.detach().cpu().transpose(1, 2) |
|
output['std'] = dist.std.detach().cpu().transpose(1, 2) |
|
|
|
clip_video = data['clip_video'].cuda() |
|
clip_features = feature_extractor.encode_video_with_clip(clip_video) |
|
output['clip_features'] = clip_features.detach().cpu() |
|
|
|
sync_video = data['sync_video'].cuda() |
|
sync_features = feature_extractor.encode_video_with_sync(sync_video) |
|
output['sync_features'] = sync_features.detach().cpu() |
|
|
|
caption = data['caption'] |
|
text_features = feature_extractor.encode_text(caption) |
|
output['text_features'] = text_features.detach().cpu() |
|
|
|
torch.save(output, this_latent_dir / f'r{local_rank}_{curr_iter}.pth') |
|
|
|
distributed.barrier() |
|
|
|
|
|
if local_rank == 0: |
|
print('Extraction done. Combining the results.') |
|
|
|
used_id = set() |
|
list_of_ids_and_labels = [] |
|
output_data = { |
|
'mean': [], |
|
'std': [], |
|
'clip_features': [], |
|
'sync_features': [], |
|
'text_features': [], |
|
} |
|
|
|
for t in tqdm(sorted(os.listdir(this_latent_dir))): |
|
data = torch.load(this_latent_dir / t, weights_only=True) |
|
bs = len(data['id']) |
|
|
|
for bi in range(bs): |
|
this_id = data['id'][bi] |
|
this_caption = data['caption'][bi] |
|
if this_id in used_id: |
|
print('Duplicate id:', this_id) |
|
continue |
|
|
|
list_of_ids_and_labels.append({'id': this_id, 'label': this_caption}) |
|
used_id.add(this_id) |
|
output_data['mean'].append(data['mean'][bi]) |
|
output_data['std'].append(data['std'][bi]) |
|
output_data['clip_features'].append(data['clip_features'][bi]) |
|
output_data['sync_features'].append(data['sync_features'][bi]) |
|
output_data['text_features'].append(data['text_features'][bi]) |
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
output_df = pd.DataFrame(list_of_ids_and_labels) |
|
output_df.to_csv(output_dir / f'vgg-{split}.tsv', sep='\t', index=False) |
|
|
|
print(f'Output: {len(output_df)}') |
|
|
|
output_data = {k: torch.stack(v) for k, v in output_data.items()} |
|
td.TensorDict(output_data).memmap_(output_dir / f'vgg-{split}') |
|
|
|
|
|
if __name__ == '__main__': |
|
extract() |
|
distributed.destroy_process_group() |
|
|