# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc import logging import math import importlib import os import random import sys import types from contextlib import contextmanager from functools import partial from PIL import Image import numpy as np import torch import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms as transforms import torch.nn.functional as F import torch.nn as nn from tqdm import tqdm from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.multitalk_model import WanModel, WanLayerNorm, WanRMSNorm from .modules.t5 import T5EncoderModel, T5LayerNorm, T5RelativeEmbedding from .modules.vae import WanVAE, CausalConv3d, RMS_norm, Upsample from .utils.multitalk_utils import MomentumBuffer, adaptive_projected_guidance from src.vram_management import AutoWrappedLinear, AutoWrappedModule, enable_vram_management def torch_gc(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def resize_and_centercrop(cond_image, target_size): """ Resize image or tensor to the target size without padding. """ # Get the original size if isinstance(cond_image, torch.Tensor): _, orig_h, orig_w = cond_image.shape else: orig_h, orig_w = cond_image.height, cond_image.width target_h, target_w = target_size # Calculate the scaling factor for resizing scale_h = target_h / orig_h scale_w = target_w / orig_w # Compute the final size scale = max(scale_h, scale_w) final_h = math.ceil(scale * orig_h) final_w = math.ceil(scale * orig_w) # Resize if isinstance(cond_image, torch.Tensor): if len(cond_image.shape) == 3: cond_image = cond_image[None] resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() # crop cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) cropped_tensor = cropped_tensor.squeeze(0) else: resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR) resized_image = np.array(resized_image) # tensor and crop resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous() cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) cropped_tensor = cropped_tensor[:, :, None, :, :] return cropped_tensor def timestep_transform( t, shift=5.0, num_timesteps=1000, ): t = t / num_timesteps # shift the timestep based on ratio new_t = shift * t / (1 + (shift - 1) * t) new_t = new_t * num_timesteps return new_t class MultiTalkPipeline: def __init__( self, config, checkpoint_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, t5_cpu=False, init_on_cpu=True, num_timesteps=1000, use_timestep_transform=True ): r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints device_id (`int`, *optional*, defaults to 0): Id of target GPU device rank (`int`, *optional*, defaults to 0): Process rank for distributed training t5_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for T5 model dit_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for DiT model use_usp (`bool`, *optional*, defaults to False): Enable distribution strategy of USP. t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ self.device = torch.device(f"cuda:{device_id}") self.config = config self.rank = rank self.use_usp = use_usp self.t5_cpu = t5_cpu self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype shard_fn = partial(shard_model, device_id=device_id) self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn=shard_fn if t5_fsdp else None, ) self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, checkpoint_path=os.path.join(checkpoint_dir, config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) logging.info(f"Creating WanModel from {checkpoint_dir}") self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: init_on_cpu = False if use_usp: from xfuser.core.distributed import get_sequence_parallel_world_size from .distributed.xdit_context_parallel import ( usp_dit_forward_multitalk, usp_attn_forward_multitalk, usp_crossattn_multi_forward_multitalk ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward_multitalk, block.self_attn) block.audio_cross_attn.forward = types.MethodType( usp_crossattn_multi_forward_multitalk, block.audio_cross_attn) self.model.forward = types.MethodType(usp_dit_forward_multitalk, self.model) self.sp_size = get_sequence_parallel_world_size() else: self.sp_size = 1 self.model.to(self.param_dtype) if dist.is_initialized(): dist.barrier() if dit_fsdp: self.model = shard_fn(self.model) else: if not init_on_cpu: self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt self.num_timesteps = num_timesteps self.use_timestep_transform = use_timestep_transform self.cpu_offload = False self.model_names = ["model"] self.vram_management = False def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: """ compatible with diffusers add_noise() """ timesteps = timesteps.float() / self.num_timesteps timesteps = timesteps.view(timesteps.shape + (1,) * (len(noise.shape)-1)) return (1 - timesteps) * original_samples + timesteps * noise def enable_vram_management(self, num_persistent_param_in_dit=None): dtype = next(iter(self.model.parameters())).dtype enable_vram_management( self.model, module_map={ torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, WanLayerNorm: AutoWrappedModule, WanRMSNorm: AutoWrappedModule, }, module_config=dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.param_dtype, computation_device=self.device, ), max_num_param=num_persistent_param_in_dit, overflow_module_config=dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.param_dtype, computation_device=self.device, ), ) self.enable_cpu_offload() def enable_cpu_offload(self): self.cpu_offload = True def load_models_to_device(self, loadmodel_names=[]): # only load models to device if cpu_offload is enabled if not self.cpu_offload: return # offload the unneeded models to cpu for model_name in self.model_names: if model_name not in loadmodel_names: model = getattr(self, model_name) if not isinstance(model, nn.Module): model = model.model if model is not None: if ( hasattr(model, "vram_management_enabled") and model.vram_management_enabled ): for module in model.modules(): if hasattr(module, "offload"): module.offload() else: model.cpu() # load the needed models to device for model_name in loadmodel_names: model = getattr(self, model_name) if not isinstance(model, nn.Module): model = model.model if model is not None: if ( hasattr(model, "vram_management_enabled") and model.vram_management_enabled ): for module in model.modules(): if hasattr(module, "onload"): module.onload() else: model.to(self.device) # fresh the cuda cache torch.cuda.empty_cache() def generate(self, input_data, size_buckget='multitalk-480', motion_frame=25, frame_num=81, shift=5.0, sampling_steps=40, text_guide_scale=5.0, audio_guide_scale=4.0, n_prompt="", seed=-1, offload_model=True, max_frames_num=1000, face_scale=0.05, progress=True, extra_args=None): r""" Generates video frames from input image and text prompt using diffusion process. Args: frame_num (`int`, *optional*, defaults to 81): How many frames to sample from a video. The number should be 4n+1 shift (`float`, *optional*, defaults to 5.0): Noise schedule shift parameter. Affects temporal dynamics [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. sampling_steps (`int`, *optional*, defaults to 40): Number of diffusion sampling steps. Higher values improve quality but slow generation n_prompt (`str`, *optional*, defaults to ""): Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` seed (`int`, *optional*, defaults to -1): Random seed for noise generation. If -1, use random seed offload_model (`bool`, *optional*, defaults to True): If True, offloads models to CPU during generation to save VRAM """ # init teacache if extra_args.use_teacache: self.model.teacache_init( sample_steps=sampling_steps, teacache_thresh=extra_args.teacache_thresh, model_scale=extra_args.size, ) else: self.model.disable_teacache() input_prompt = input_data['prompt'] cond_file_path = input_data['cond_image'] cond_image = Image.open(cond_file_path).convert('RGB') # decide a proper size bucket_config_module = importlib.import_module("wan.utils.multitalk_utils") if size_buckget == 'multitalk-480': bucket_config = getattr(bucket_config_module, 'ASPECT_RATIO_627') elif size_buckget == 'multitalk-720': bucket_config = getattr(bucket_config_module, 'ASPECT_RATIO_960') src_h, src_w = cond_image.height, cond_image.width ratio = src_h / src_w closest_bucket = sorted(list(bucket_config.keys()), key=lambda x: abs(float(x)-ratio))[0] target_h, target_w = bucket_config[closest_bucket][0] cond_image = resize_and_centercrop(cond_image, (target_h, target_w)) cond_image = cond_image / 255 cond_image = (cond_image - 0.5) * 2 # normalization cond_image = cond_image.to(self.device) # 1 C 1 H W # read audio embeddings audio_embedding_path_1 = input_data['cond_audio']['person1'] if len(input_data['cond_audio']) == 1: HUMAN_NUMBER = 1 audio_embedding_path_2 = None else: HUMAN_NUMBER = 2 audio_embedding_path_2 = input_data['cond_audio']['person2'] full_audio_embs = [] audio_embedding_paths = [audio_embedding_path_1, audio_embedding_path_2] for human_idx in range(HUMAN_NUMBER): audio_embedding_path = audio_embedding_paths[human_idx] if not os.path.exists(audio_embedding_path): continue full_audio_emb = torch.load(audio_embedding_path) if torch.isnan(full_audio_emb).any(): continue if full_audio_emb.shape[0] <= frame_num: continue full_audio_embs.append(full_audio_emb) assert len(full_audio_embs) == HUMAN_NUMBER, f"Aduio file not exists or length not satisfies frame nums." # preprocess text embedding if n_prompt == "": n_prompt = self.sample_neg_prompt if not self.t5_cpu: self.text_encoder.model.to(self.device) context, context_null = self.text_encoder([input_prompt, n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: context = self.text_encoder([input_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] torch_gc() # prepare params for video generation indices = (torch.arange(2 * 2 + 1) - 2) * 1 clip_length = frame_num is_first_clip = True arrive_last_frame = False cur_motion_frames_num = 1 audio_start_idx = 0 audio_end_idx = audio_start_idx + clip_length gen_video_list = [] torch_gc() # set random seed and init noise seed = seed if seed >= 0 else random.randint(0, 99999999) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True # start video generation iteratively while True: audio_embs = [] # split audio with window size for human_idx in range(HUMAN_NUMBER): center_indices = torch.arange( audio_start_idx, audio_end_idx, 1, ).unsqueeze( 1 ) + indices.unsqueeze(0) center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1) audio_emb = full_audio_embs[human_idx][center_indices][None,...].to(self.device) audio_embs.append(audio_emb) audio_embs = torch.concat(audio_embs, dim=0).to(self.param_dtype) torch_gc() h, w = cond_image.shape[-2], cond_image.shape[-1] lat_h, lat_w = h // self.vae_stride[1], w // self.vae_stride[2] max_seq_len = ((frame_num - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size noise = torch.randn( 16, (frame_num - 1) // 4 + 1, lat_h, lat_w, dtype=torch.float32, device=self.device) # get mask msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) msk[:, cur_motion_frames_num:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2).to(self.param_dtype) # B 4 T H W with torch.no_grad(): # get clip embedding self.clip.model.to(self.device) clip_context = self.clip.visual(cond_image[:, :, -1:, :, :]).to(self.param_dtype) if offload_model: self.clip.model.cpu() torch_gc() # zero padding and vae encode video_frames = torch.zeros(1, cond_image.shape[1], frame_num-cond_image.shape[2], target_h, target_w).to(self.device) padding_frames_pixels_values = torch.concat([cond_image, video_frames], dim=2) y = self.vae.encode(padding_frames_pixels_values) y = torch.stack(y).to(self.param_dtype) # B C T H W cur_motion_frames_latent_num = int(1 + (cur_motion_frames_num-1) // 4) latent_motion_frames = y[:, :, :cur_motion_frames_latent_num][0] # C T H W y = torch.concat([msk, y], dim=1) # B 4+C T H W torch_gc() # construct human mask human_masks = [] if HUMAN_NUMBER==1: background_mask = torch.ones([src_h, src_w]) human_mask1 = torch.ones([src_h, src_w]) human_mask2 = torch.ones([src_h, src_w]) human_masks = [human_mask1, human_mask2, background_mask] elif HUMAN_NUMBER==2: if 'bbox' in input_data: assert len(input_data['bbox']) == len(input_data['cond_audio']), f"The number of target bbox should be the same with cond_audio" background_mask = torch.zeros([src_h, src_w]) for _, person_bbox in input_data['bbox'].items(): x_min, y_min, x_max, y_max = person_bbox human_mask = torch.zeros([src_h, src_w]) human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1 background_mask += human_mask human_masks.append(human_mask) else: x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale)) background_mask = torch.zeros([src_h, src_w]) background_mask = torch.zeros([src_h, src_w]) human_mask1 = torch.zeros([src_h, src_w]) human_mask2 = torch.zeros([src_h, src_w]) src_w = src_w//2 lefty_min, lefty_max = int(src_w * face_scale), int(src_w * (1 - face_scale)) righty_min, righty_max = int(src_w * face_scale + src_w), int(src_w * (1 - face_scale) + src_w) human_mask1[x_min:x_max, lefty_min:lefty_max] = 1 human_mask2[x_min:x_max, righty_min:righty_max] = 1 background_mask += human_mask1 background_mask += human_mask2 human_masks = [human_mask1, human_mask2] background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1)) human_masks.append(background_mask) ref_target_masks = torch.stack(human_masks, dim=0).to(self.device) # resize and centercrop for ref_target_masks ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w)) _, _, _,lat_h, lat_w = y.shape ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(lat_h, lat_w), mode='nearest').squeeze() ref_target_masks = (ref_target_masks > 0) ref_target_masks = ref_target_masks.float().to(self.device) torch_gc() @contextmanager def noop_no_sync(): yield no_sync = getattr(self.model, 'no_sync', noop_no_sync) # evaluation mode with torch.no_grad(), no_sync(): # prepare timesteps timesteps = list(np.linspace(self.num_timesteps, 1, sampling_steps, dtype=np.float32)) timesteps.append(0.) timesteps = [torch.tensor([t], device=self.device) for t in timesteps] if self.use_timestep_transform: timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps] # sample videos latent = noise # prepare condition and uncondition configs arg_c = { 'context': [context], 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': y, 'audio': audio_embs, 'ref_target_masks': ref_target_masks } arg_null_text = { 'context': [context_null], 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': y, 'audio': audio_embs, 'ref_target_masks': ref_target_masks } arg_null = { 'context': [context_null], 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': y, 'audio': torch.zeros_like(audio_embs)[-1:], 'ref_target_masks': ref_target_masks } torch_gc() if not self.vram_management: self.model.to(self.device) else: self.load_models_to_device(["model"]) # injecting motion frames if not is_first_clip: latent_motion_frames = latent_motion_frames.to(latent.dtype).to(self.device) motion_add_noise = torch.randn_like(latent_motion_frames).contiguous() add_latent = self.add_noise(latent_motion_frames, motion_add_noise, timesteps[0]) _, T_m, _, _ = add_latent.shape latent[:, :T_m] = add_latent # infer with APG # refer https://arxiv.org/abs/2410.02416 if extra_args.use_apg: text_momentumbuffer = MomentumBuffer(extra_args.apg_momentum) audio_momentumbuffer = MomentumBuffer(extra_args.apg_momentum) progress_wrap = partial(tqdm, total=len(timesteps)-1) if progress else (lambda x: x) for i in progress_wrap(range(len(timesteps)-1)): timestep = timesteps[i] latent_model_input = [latent.to(self.device)] # inference with CFG strategy noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0] torch_gc() noise_pred_drop_text = self.model( latent_model_input, t=timestep, **arg_null_text)[0] torch_gc() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0] torch_gc() if extra_args.use_apg: # correct update direction diff_uncond_text = noise_pred_cond - noise_pred_drop_text diff_uncond_audio = noise_pred_drop_text - noise_pred_uncond noise_pred = noise_pred_cond + (text_guide_scale - 1) * adaptive_projected_guidance(diff_uncond_text, noise_pred_cond, momentum_buffer=text_momentumbuffer, norm_threshold=extra_args.apg_norm_threshold) \ + (audio_guide_scale - 1) * adaptive_projected_guidance(diff_uncond_audio, noise_pred_cond, momentum_buffer=audio_momentumbuffer, norm_threshold=extra_args.apg_norm_threshold) else: # vanilla CFG strategy noise_pred = noise_pred_uncond + text_guide_scale * ( noise_pred_cond - noise_pred_drop_text) + \ audio_guide_scale * (noise_pred_drop_text - noise_pred_uncond) noise_pred = -noise_pred # update latent dt = timesteps[i] - timesteps[i + 1] dt = dt / self.num_timesteps latent = latent + noise_pred * dt[:, None, None, None] # injecting motion frames if not is_first_clip: latent_motion_frames = latent_motion_frames.to(latent.dtype).to(self.device) motion_add_noise = torch.randn_like(latent_motion_frames).contiguous() add_latent = self.add_noise(latent_motion_frames, motion_add_noise, timesteps[i+1]) _, T_m, _, _ = add_latent.shape latent[:, :T_m] = add_latent x0 = [latent.to(self.device)] del latent_model_input, timestep if offload_model: if not self.vram_management: self.model.cpu() torch_gc() videos = self.vae.decode(x0) # cache generated samples videos = torch.stack(videos).cpu() # B C T H W if is_first_clip: gen_video_list.append(videos) else: gen_video_list.append(videos[:, :, cur_motion_frames_num:]) # decide whether is done if arrive_last_frame: break # update next condition frames is_first_clip = False cur_motion_frames_num = motion_frame cond_image = videos[:, :, -cur_motion_frames_num:].to(torch.float32).to(self.device) audio_start_idx += (frame_num - cur_motion_frames_num) audio_end_idx = audio_start_idx + clip_length # Repeat audio emb if audio_end_idx >= min(max_frames_num, len(full_audio_embs[0])): arrive_last_frame = True miss_lengths = [] source_frames = [] for human_inx in range(HUMAN_NUMBER): source_frame = len(full_audio_embs[human_inx]) source_frames.append(source_frame) if audio_end_idx >= len(full_audio_embs[human_inx]): miss_length = audio_end_idx - len(full_audio_embs[human_inx]) + 3 add_audio_emb = torch.flip(full_audio_embs[human_inx][-1*miss_length:], dims=[0]) full_audio_embs[human_inx] = torch.cat([full_audio_embs[human_inx], add_audio_emb], dim=0) miss_lengths.append(miss_length) else: miss_lengths.append(0) if max_frames_num <= frame_num: break torch_gc() if offload_model: torch.cuda.synchronize() if dist.is_initialized(): dist.barrier() gen_video_samples = torch.cat(gen_video_list, dim=2)[:, :, :int(max_frames_num)] gen_video_samples = gen_video_samples.to(torch.float32) if max_frames_num > frame_num and sum(miss_lengths) > 0: # split video frames gen_video_samples = gen_video_samples[:, :, :-1*miss_lengths[0]] if dist.is_initialized(): dist.barrier() del noise, latent torch_gc() return gen_video_samples[0] if self.rank == 0 else None