Spaces:
Running
on
L40S
Running
on
L40S
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import gc | |
import logging | |
import math | |
import importlib | |
import os | |
import random | |
import sys | |
import types | |
from contextlib import contextmanager | |
from functools import partial | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.cuda.amp as amp | |
import torch.distributed as dist | |
import torchvision.transforms as transforms | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from tqdm import tqdm | |
from .distributed.fsdp import shard_model | |
from .modules.clip import CLIPModel | |
from .modules.multitalk_model import WanModel, WanLayerNorm, WanRMSNorm | |
from .modules.t5 import T5EncoderModel, T5LayerNorm, T5RelativeEmbedding | |
from .modules.vae import WanVAE, CausalConv3d, RMS_norm, Upsample | |
from .utils.multitalk_utils import MomentumBuffer, adaptive_projected_guidance | |
from src.vram_management import AutoWrappedLinear, AutoWrappedModule, enable_vram_management | |
def torch_gc(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def resize_and_centercrop(cond_image, target_size): | |
""" | |
Resize image or tensor to the target size without padding. | |
""" | |
# Get the original size | |
if isinstance(cond_image, torch.Tensor): | |
_, orig_h, orig_w = cond_image.shape | |
else: | |
orig_h, orig_w = cond_image.height, cond_image.width | |
target_h, target_w = target_size | |
# Calculate the scaling factor for resizing | |
scale_h = target_h / orig_h | |
scale_w = target_w / orig_w | |
# Compute the final size | |
scale = max(scale_h, scale_w) | |
final_h = math.ceil(scale * orig_h) | |
final_w = math.ceil(scale * orig_w) | |
# Resize | |
if isinstance(cond_image, torch.Tensor): | |
if len(cond_image.shape) == 3: | |
cond_image = cond_image[None] | |
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() | |
# crop | |
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) | |
cropped_tensor = cropped_tensor.squeeze(0) | |
else: | |
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR) | |
resized_image = np.array(resized_image) | |
# tensor and crop | |
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous() | |
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) | |
cropped_tensor = cropped_tensor[:, :, None, :, :] | |
return cropped_tensor | |
def timestep_transform( | |
t, | |
shift=5.0, | |
num_timesteps=1000, | |
): | |
t = t / num_timesteps | |
# shift the timestep based on ratio | |
new_t = shift * t / (1 + (shift - 1) * t) | |
new_t = new_t * num_timesteps | |
return new_t | |
class MultiTalkPipeline: | |
def __init__( | |
self, | |
config, | |
checkpoint_dir, | |
device_id=0, | |
rank=0, | |
t5_fsdp=False, | |
dit_fsdp=False, | |
use_usp=False, | |
t5_cpu=False, | |
init_on_cpu=True, | |
num_timesteps=1000, | |
use_timestep_transform=True | |
): | |
r""" | |
Initializes the image-to-video generation model components. | |
Args: | |
config (EasyDict): | |
Object containing model parameters initialized from config.py | |
checkpoint_dir (`str`): | |
Path to directory containing model checkpoints | |
device_id (`int`, *optional*, defaults to 0): | |
Id of target GPU device | |
rank (`int`, *optional*, defaults to 0): | |
Process rank for distributed training | |
t5_fsdp (`bool`, *optional*, defaults to False): | |
Enable FSDP sharding for T5 model | |
dit_fsdp (`bool`, *optional*, defaults to False): | |
Enable FSDP sharding for DiT model | |
use_usp (`bool`, *optional*, defaults to False): | |
Enable distribution strategy of USP. | |
t5_cpu (`bool`, *optional*, defaults to False): | |
Whether to place T5 model on CPU. Only works without t5_fsdp. | |
init_on_cpu (`bool`, *optional*, defaults to True): | |
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. | |
""" | |
self.device = torch.device(f"cuda:{device_id}") | |
self.config = config | |
self.rank = rank | |
self.use_usp = use_usp | |
self.t5_cpu = t5_cpu | |
self.num_train_timesteps = config.num_train_timesteps | |
self.param_dtype = config.param_dtype | |
shard_fn = partial(shard_model, device_id=device_id) | |
self.text_encoder = T5EncoderModel( | |
text_len=config.text_len, | |
dtype=config.t5_dtype, | |
device=torch.device('cpu'), | |
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), | |
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), | |
shard_fn=shard_fn if t5_fsdp else None, | |
) | |
self.vae_stride = config.vae_stride | |
self.patch_size = config.patch_size | |
self.vae = WanVAE( | |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), | |
device=self.device) | |
self.clip = CLIPModel( | |
dtype=config.clip_dtype, | |
device=self.device, | |
checkpoint_path=os.path.join(checkpoint_dir, | |
config.clip_checkpoint), | |
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) | |
logging.info(f"Creating WanModel from {checkpoint_dir}") | |
self.model = WanModel.from_pretrained(checkpoint_dir) | |
self.model.eval().requires_grad_(False) | |
if t5_fsdp or dit_fsdp or use_usp: | |
init_on_cpu = False | |
if use_usp: | |
from xfuser.core.distributed import get_sequence_parallel_world_size | |
from .distributed.xdit_context_parallel import ( | |
usp_dit_forward_multitalk, | |
usp_attn_forward_multitalk, | |
usp_crossattn_multi_forward_multitalk | |
) | |
for block in self.model.blocks: | |
block.self_attn.forward = types.MethodType( | |
usp_attn_forward_multitalk, block.self_attn) | |
block.audio_cross_attn.forward = types.MethodType( | |
usp_crossattn_multi_forward_multitalk, block.audio_cross_attn) | |
self.model.forward = types.MethodType(usp_dit_forward_multitalk, self.model) | |
self.sp_size = get_sequence_parallel_world_size() | |
else: | |
self.sp_size = 1 | |
self.model.to(self.param_dtype) | |
if dist.is_initialized(): | |
dist.barrier() | |
if dit_fsdp: | |
self.model = shard_fn(self.model) | |
else: | |
if not init_on_cpu: | |
self.model.to(self.device) | |
self.sample_neg_prompt = config.sample_neg_prompt | |
self.num_timesteps = num_timesteps | |
self.use_timestep_transform = use_timestep_transform | |
self.cpu_offload = False | |
self.model_names = ["model"] | |
self.vram_management = False | |
def add_noise( | |
self, | |
original_samples: torch.FloatTensor, | |
noise: torch.FloatTensor, | |
timesteps: torch.IntTensor, | |
) -> torch.FloatTensor: | |
""" | |
compatible with diffusers add_noise() | |
""" | |
timesteps = timesteps.float() / self.num_timesteps | |
timesteps = timesteps.view(timesteps.shape + (1,) * (len(noise.shape)-1)) | |
return (1 - timesteps) * original_samples + timesteps * noise | |
def enable_vram_management(self, num_persistent_param_in_dit=None): | |
dtype = next(iter(self.model.parameters())).dtype | |
enable_vram_management( | |
self.model, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
torch.nn.Conv3d: AutoWrappedModule, | |
torch.nn.LayerNorm: AutoWrappedModule, | |
WanLayerNorm: AutoWrappedModule, | |
WanRMSNorm: AutoWrappedModule, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device=self.device, | |
computation_dtype=self.param_dtype, | |
computation_device=self.device, | |
), | |
max_num_param=num_persistent_param_in_dit, | |
overflow_module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device="cpu", | |
computation_dtype=self.param_dtype, | |
computation_device=self.device, | |
), | |
) | |
self.enable_cpu_offload() | |
def enable_cpu_offload(self): | |
self.cpu_offload = True | |
def load_models_to_device(self, loadmodel_names=[]): | |
# only load models to device if cpu_offload is enabled | |
if not self.cpu_offload: | |
return | |
# offload the unneeded models to cpu | |
for model_name in self.model_names: | |
if model_name not in loadmodel_names: | |
model = getattr(self, model_name) | |
if not isinstance(model, nn.Module): | |
model = model.model | |
if model is not None: | |
if ( | |
hasattr(model, "vram_management_enabled") | |
and model.vram_management_enabled | |
): | |
for module in model.modules(): | |
if hasattr(module, "offload"): | |
module.offload() | |
else: | |
model.cpu() | |
# load the needed models to device | |
for model_name in loadmodel_names: | |
model = getattr(self, model_name) | |
if not isinstance(model, nn.Module): | |
model = model.model | |
if model is not None: | |
if ( | |
hasattr(model, "vram_management_enabled") | |
and model.vram_management_enabled | |
): | |
for module in model.modules(): | |
if hasattr(module, "onload"): | |
module.onload() | |
else: | |
model.to(self.device) | |
# fresh the cuda cache | |
torch.cuda.empty_cache() | |
def generate(self, | |
input_data, | |
size_buckget='multitalk-480', | |
motion_frame=25, | |
frame_num=81, | |
shift=5.0, | |
sampling_steps=40, | |
text_guide_scale=5.0, | |
audio_guide_scale=4.0, | |
n_prompt="", | |
seed=-1, | |
offload_model=True, | |
max_frames_num=1000, | |
face_scale=0.05, | |
progress=True, | |
extra_args=None): | |
r""" | |
Generates video frames from input image and text prompt using diffusion process. | |
Args: | |
frame_num (`int`, *optional*, defaults to 81): | |
How many frames to sample from a video. The number should be 4n+1 | |
shift (`float`, *optional*, defaults to 5.0): | |
Noise schedule shift parameter. Affects temporal dynamics | |
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. | |
sampling_steps (`int`, *optional*, defaults to 40): | |
Number of diffusion sampling steps. Higher values improve quality but slow generation | |
n_prompt (`str`, *optional*, defaults to ""): | |
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` | |
seed (`int`, *optional*, defaults to -1): | |
Random seed for noise generation. If -1, use random seed | |
offload_model (`bool`, *optional*, defaults to True): | |
If True, offloads models to CPU during generation to save VRAM | |
""" | |
# init teacache | |
if extra_args.use_teacache: | |
self.model.teacache_init( | |
sample_steps=sampling_steps, | |
teacache_thresh=extra_args.teacache_thresh, | |
model_scale=extra_args.size, | |
) | |
else: | |
self.model.disable_teacache() | |
input_prompt = input_data['prompt'] | |
cond_file_path = input_data['cond_image'] | |
cond_image = Image.open(cond_file_path).convert('RGB') | |
# decide a proper size | |
bucket_config_module = importlib.import_module("wan.utils.multitalk_utils") | |
if size_buckget == 'multitalk-480': | |
bucket_config = getattr(bucket_config_module, 'ASPECT_RATIO_627') | |
elif size_buckget == 'multitalk-720': | |
bucket_config = getattr(bucket_config_module, 'ASPECT_RATIO_960') | |
src_h, src_w = cond_image.height, cond_image.width | |
ratio = src_h / src_w | |
closest_bucket = sorted(list(bucket_config.keys()), key=lambda x: abs(float(x)-ratio))[0] | |
target_h, target_w = bucket_config[closest_bucket][0] | |
cond_image = resize_and_centercrop(cond_image, (target_h, target_w)) | |
cond_image = cond_image / 255 | |
cond_image = (cond_image - 0.5) * 2 # normalization | |
cond_image = cond_image.to(self.device) # 1 C 1 H W | |
# read audio embeddings | |
audio_embedding_path_1 = input_data['cond_audio']['person1'] | |
if len(input_data['cond_audio']) == 1: | |
HUMAN_NUMBER = 1 | |
audio_embedding_path_2 = None | |
else: | |
HUMAN_NUMBER = 2 | |
audio_embedding_path_2 = input_data['cond_audio']['person2'] | |
full_audio_embs = [] | |
audio_embedding_paths = [audio_embedding_path_1, audio_embedding_path_2] | |
for human_idx in range(HUMAN_NUMBER): | |
audio_embedding_path = audio_embedding_paths[human_idx] | |
if not os.path.exists(audio_embedding_path): | |
continue | |
full_audio_emb = torch.load(audio_embedding_path) | |
if torch.isnan(full_audio_emb).any(): | |
continue | |
if full_audio_emb.shape[0] <= frame_num: | |
continue | |
full_audio_embs.append(full_audio_emb) | |
assert len(full_audio_embs) == HUMAN_NUMBER, f"Aduio file not exists or length not satisfies frame nums." | |
# preprocess text embedding | |
if n_prompt == "": | |
n_prompt = self.sample_neg_prompt | |
if not self.t5_cpu: | |
self.text_encoder.model.to(self.device) | |
context, context_null = self.text_encoder([input_prompt, n_prompt], self.device) | |
if offload_model: | |
self.text_encoder.model.cpu() | |
else: | |
context = self.text_encoder([input_prompt], torch.device('cpu')) | |
context_null = self.text_encoder([n_prompt], torch.device('cpu')) | |
context = [t.to(self.device) for t in context] | |
context_null = [t.to(self.device) for t in context_null] | |
torch_gc() | |
# prepare params for video generation | |
indices = (torch.arange(2 * 2 + 1) - 2) * 1 | |
clip_length = frame_num | |
is_first_clip = True | |
arrive_last_frame = False | |
cur_motion_frames_num = 1 | |
audio_start_idx = 0 | |
audio_end_idx = audio_start_idx + clip_length | |
gen_video_list = [] | |
torch_gc() | |
# set random seed and init noise | |
seed = seed if seed >= 0 else random.randint(0, 99999999) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
# start video generation iteratively | |
while True: | |
audio_embs = [] | |
# split audio with window size | |
for human_idx in range(HUMAN_NUMBER): | |
center_indices = torch.arange( | |
audio_start_idx, | |
audio_end_idx, | |
1, | |
).unsqueeze( | |
1 | |
) + indices.unsqueeze(0) | |
center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1) | |
audio_emb = full_audio_embs[human_idx][center_indices][None,...].to(self.device) | |
audio_embs.append(audio_emb) | |
audio_embs = torch.concat(audio_embs, dim=0).to(self.param_dtype) | |
torch_gc() | |
h, w = cond_image.shape[-2], cond_image.shape[-1] | |
lat_h, lat_w = h // self.vae_stride[1], w // self.vae_stride[2] | |
max_seq_len = ((frame_num - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( | |
self.patch_size[1] * self.patch_size[2]) | |
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size | |
noise = torch.randn( | |
16, (frame_num - 1) // 4 + 1, | |
lat_h, | |
lat_w, | |
dtype=torch.float32, | |
device=self.device) | |
# get mask | |
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) | |
msk[:, cur_motion_frames_num:] = 0 | |
msk = torch.concat([ | |
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] | |
], | |
dim=1) | |
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) | |
msk = msk.transpose(1, 2).to(self.param_dtype) # B 4 T H W | |
with torch.no_grad(): | |
# get clip embedding | |
self.clip.model.to(self.device) | |
clip_context = self.clip.visual(cond_image[:, :, -1:, :, :]).to(self.param_dtype) | |
if offload_model: | |
self.clip.model.cpu() | |
torch_gc() | |
# zero padding and vae encode | |
video_frames = torch.zeros(1, cond_image.shape[1], frame_num-cond_image.shape[2], target_h, target_w).to(self.device) | |
padding_frames_pixels_values = torch.concat([cond_image, video_frames], dim=2) | |
y = self.vae.encode(padding_frames_pixels_values) | |
y = torch.stack(y).to(self.param_dtype) # B C T H W | |
cur_motion_frames_latent_num = int(1 + (cur_motion_frames_num-1) // 4) | |
latent_motion_frames = y[:, :, :cur_motion_frames_latent_num][0] # C T H W | |
y = torch.concat([msk, y], dim=1) # B 4+C T H W | |
torch_gc() | |
# construct human mask | |
human_masks = [] | |
if HUMAN_NUMBER==1: | |
background_mask = torch.ones([src_h, src_w]) | |
human_mask1 = torch.ones([src_h, src_w]) | |
human_mask2 = torch.ones([src_h, src_w]) | |
human_masks = [human_mask1, human_mask2, background_mask] | |
elif HUMAN_NUMBER==2: | |
if 'bbox' in input_data: | |
assert len(input_data['bbox']) == len(input_data['cond_audio']), f"The number of target bbox should be the same with cond_audio" | |
background_mask = torch.zeros([src_h, src_w]) | |
for _, person_bbox in input_data['bbox'].items(): | |
x_min, y_min, x_max, y_max = person_bbox | |
human_mask = torch.zeros([src_h, src_w]) | |
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1 | |
background_mask += human_mask | |
human_masks.append(human_mask) | |
else: | |
x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale)) | |
background_mask = torch.zeros([src_h, src_w]) | |
background_mask = torch.zeros([src_h, src_w]) | |
human_mask1 = torch.zeros([src_h, src_w]) | |
human_mask2 = torch.zeros([src_h, src_w]) | |
src_w = src_w//2 | |
lefty_min, lefty_max = int(src_w * face_scale), int(src_w * (1 - face_scale)) | |
righty_min, righty_max = int(src_w * face_scale + src_w), int(src_w * (1 - face_scale) + src_w) | |
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1 | |
human_mask2[x_min:x_max, righty_min:righty_max] = 1 | |
background_mask += human_mask1 | |
background_mask += human_mask2 | |
human_masks = [human_mask1, human_mask2] | |
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1)) | |
human_masks.append(background_mask) | |
ref_target_masks = torch.stack(human_masks, dim=0).to(self.device) | |
# resize and centercrop for ref_target_masks | |
ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w)) | |
_, _, _,lat_h, lat_w = y.shape | |
ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(lat_h, lat_w), mode='nearest').squeeze() | |
ref_target_masks = (ref_target_masks > 0) | |
ref_target_masks = ref_target_masks.float().to(self.device) | |
torch_gc() | |
def noop_no_sync(): | |
yield | |
no_sync = getattr(self.model, 'no_sync', noop_no_sync) | |
# evaluation mode | |
with torch.no_grad(), no_sync(): | |
# prepare timesteps | |
timesteps = list(np.linspace(self.num_timesteps, 1, sampling_steps, dtype=np.float32)) | |
timesteps.append(0.) | |
timesteps = [torch.tensor([t], device=self.device) for t in timesteps] | |
if self.use_timestep_transform: | |
timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps] | |
# sample videos | |
latent = noise | |
# prepare condition and uncondition configs | |
arg_c = { | |
'context': [context], | |
'clip_fea': clip_context, | |
'seq_len': max_seq_len, | |
'y': y, | |
'audio': audio_embs, | |
'ref_target_masks': ref_target_masks | |
} | |
arg_null_text = { | |
'context': [context_null], | |
'clip_fea': clip_context, | |
'seq_len': max_seq_len, | |
'y': y, | |
'audio': audio_embs, | |
'ref_target_masks': ref_target_masks | |
} | |
arg_null = { | |
'context': [context_null], | |
'clip_fea': clip_context, | |
'seq_len': max_seq_len, | |
'y': y, | |
'audio': torch.zeros_like(audio_embs)[-1:], | |
'ref_target_masks': ref_target_masks | |
} | |
torch_gc() | |
if not self.vram_management: | |
self.model.to(self.device) | |
else: | |
self.load_models_to_device(["model"]) | |
# injecting motion frames | |
if not is_first_clip: | |
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(self.device) | |
motion_add_noise = torch.randn_like(latent_motion_frames).contiguous() | |
add_latent = self.add_noise(latent_motion_frames, motion_add_noise, timesteps[0]) | |
_, T_m, _, _ = add_latent.shape | |
latent[:, :T_m] = add_latent | |
# infer with APG | |
# refer https://arxiv.org/abs/2410.02416 | |
if extra_args.use_apg: | |
text_momentumbuffer = MomentumBuffer(extra_args.apg_momentum) | |
audio_momentumbuffer = MomentumBuffer(extra_args.apg_momentum) | |
progress_wrap = partial(tqdm, total=len(timesteps)-1) if progress else (lambda x: x) | |
for i in progress_wrap(range(len(timesteps)-1)): | |
timestep = timesteps[i] | |
latent_model_input = [latent.to(self.device)] | |
# inference with CFG strategy | |
noise_pred_cond = self.model( | |
latent_model_input, t=timestep, **arg_c)[0] | |
torch_gc() | |
noise_pred_drop_text = self.model( | |
latent_model_input, t=timestep, **arg_null_text)[0] | |
torch_gc() | |
noise_pred_uncond = self.model( | |
latent_model_input, t=timestep, **arg_null)[0] | |
torch_gc() | |
if extra_args.use_apg: | |
# correct update direction | |
diff_uncond_text = noise_pred_cond - noise_pred_drop_text | |
diff_uncond_audio = noise_pred_drop_text - noise_pred_uncond | |
noise_pred = noise_pred_cond + (text_guide_scale - 1) * adaptive_projected_guidance(diff_uncond_text, | |
noise_pred_cond, | |
momentum_buffer=text_momentumbuffer, | |
norm_threshold=extra_args.apg_norm_threshold) \ | |
+ (audio_guide_scale - 1) * adaptive_projected_guidance(diff_uncond_audio, | |
noise_pred_cond, | |
momentum_buffer=audio_momentumbuffer, | |
norm_threshold=extra_args.apg_norm_threshold) | |
else: | |
# vanilla CFG strategy | |
noise_pred = noise_pred_uncond + text_guide_scale * ( | |
noise_pred_cond - noise_pred_drop_text) + \ | |
audio_guide_scale * (noise_pred_drop_text - noise_pred_uncond) | |
noise_pred = -noise_pred | |
# update latent | |
dt = timesteps[i] - timesteps[i + 1] | |
dt = dt / self.num_timesteps | |
latent = latent + noise_pred * dt[:, None, None, None] | |
# injecting motion frames | |
if not is_first_clip: | |
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(self.device) | |
motion_add_noise = torch.randn_like(latent_motion_frames).contiguous() | |
add_latent = self.add_noise(latent_motion_frames, motion_add_noise, timesteps[i+1]) | |
_, T_m, _, _ = add_latent.shape | |
latent[:, :T_m] = add_latent | |
x0 = [latent.to(self.device)] | |
del latent_model_input, timestep | |
if offload_model: | |
if not self.vram_management: | |
self.model.cpu() | |
torch_gc() | |
videos = self.vae.decode(x0) | |
# cache generated samples | |
videos = torch.stack(videos).cpu() # B C T H W | |
if is_first_clip: | |
gen_video_list.append(videos) | |
else: | |
gen_video_list.append(videos[:, :, cur_motion_frames_num:]) | |
# decide whether is done | |
if arrive_last_frame: break | |
# update next condition frames | |
is_first_clip = False | |
cur_motion_frames_num = motion_frame | |
cond_image = videos[:, :, -cur_motion_frames_num:].to(torch.float32).to(self.device) | |
audio_start_idx += (frame_num - cur_motion_frames_num) | |
audio_end_idx = audio_start_idx + clip_length | |
# Repeat audio emb | |
if audio_end_idx >= min(max_frames_num, len(full_audio_embs[0])): | |
arrive_last_frame = True | |
miss_lengths = [] | |
source_frames = [] | |
for human_inx in range(HUMAN_NUMBER): | |
source_frame = len(full_audio_embs[human_inx]) | |
source_frames.append(source_frame) | |
if audio_end_idx >= len(full_audio_embs[human_inx]): | |
miss_length = audio_end_idx - len(full_audio_embs[human_inx]) + 3 | |
add_audio_emb = torch.flip(full_audio_embs[human_inx][-1*miss_length:], dims=[0]) | |
full_audio_embs[human_inx] = torch.cat([full_audio_embs[human_inx], add_audio_emb], dim=0) | |
miss_lengths.append(miss_length) | |
else: | |
miss_lengths.append(0) | |
if max_frames_num <= frame_num: break | |
torch_gc() | |
if offload_model: | |
torch.cuda.synchronize() | |
if dist.is_initialized(): | |
dist.barrier() | |
gen_video_samples = torch.cat(gen_video_list, dim=2)[:, :, :int(max_frames_num)] | |
gen_video_samples = gen_video_samples.to(torch.float32) | |
if max_frames_num > frame_num and sum(miss_lengths) > 0: | |
# split video frames | |
gen_video_samples = gen_video_samples[:, :, :-1*miss_lengths[0]] | |
if dist.is_initialized(): | |
dist.barrier() | |
del noise, latent | |
torch_gc() | |
return gen_video_samples[0] if self.rank == 0 else None |