Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import gc | |
import logging | |
import math | |
import os | |
import random | |
import sys | |
import types | |
import cv2 | |
from contextlib import contextmanager | |
from functools import partial | |
import torch.nn.functional as F | |
import torch | |
import torch.cuda.amp as amp | |
import torch.distributed as dist | |
from tqdm import tqdm | |
from .distributed.fsdp import shard_model | |
from .modules.model import WanModel | |
from .modules.t5 import T5EncoderModel | |
from .modules.vae import WanVAE | |
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, | |
get_sampling_sigmas, retrieve_timesteps) | |
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
import numpy as np | |
from typing import Optional, Literal | |
import itertools | |
class WanT2V: | |
def __init__( | |
self, | |
config, | |
checkpoint_dir, | |
device_id=0, | |
rank=0, | |
t5_fsdp=False, | |
dit_fsdp=False, | |
use_usp=False, | |
t5_cpu=False, | |
): | |
r""" | |
Initializes the Wan text-to-video generation model components. | |
Args: | |
config (EasyDict): | |
Object containing model parameters initialized from config.py | |
checkpoint_dir (`str`): | |
Path to directory containing model checkpoints | |
device_id (`int`, *optional*, defaults to 0): | |
Id of target GPU device | |
rank (`int`, *optional*, defaults to 0): | |
Process rank for distributed training | |
t5_fsdp (`bool`, *optional*, defaults to False): | |
Enable FSDP sharding for T5 model | |
dit_fsdp (`bool`, *optional*, defaults to False): | |
Enable FSDP sharding for DiT model | |
use_usp (`bool`, *optional*, defaults to False): | |
Enable distribution strategy of USP. | |
t5_cpu (`bool`, *optional*, defaults to False): | |
Whether to place T5 model on CPU. Only works without t5_fsdp. | |
""" | |
self.device = torch.device(f"cuda:{device_id}") | |
self.config = config | |
self.rank = rank | |
self.t5_cpu = t5_cpu | |
self.num_train_timesteps = config.num_train_timesteps | |
self.param_dtype = config.param_dtype | |
shard_fn = partial(shard_model, device_id=device_id) | |
self.text_encoder = T5EncoderModel( | |
text_len=config.text_len, | |
dtype=config.t5_dtype, | |
device=torch.device('cpu'), | |
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), | |
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), | |
shard_fn=shard_fn if t5_fsdp else None) | |
self.vae_stride = config.vae_stride | |
self.patch_size = config.patch_size | |
self.vae = WanVAE( | |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), | |
device=self.device) | |
logging.info(f"Creating WanModel from {checkpoint_dir}") | |
# condition | |
self.model = WanModel.from_pretrained(checkpoint_dir) | |
self.model.eval().requires_grad_(False) | |
if use_usp: | |
from xfuser.core.distributed import \ | |
get_sequence_parallel_world_size | |
from .distributed.xdit_context_parallel import (usp_attn_forward, | |
usp_dit_forward) | |
for block in self.model.blocks: | |
block.self_attn.forward = types.MethodType( | |
usp_attn_forward, block.self_attn) | |
self.model.forward = types.MethodType(usp_dit_forward, self.model) | |
self.sp_size = get_sequence_parallel_world_size() | |
else: | |
self.sp_size = 1 | |
if dist.is_initialized(): | |
dist.barrier() | |
if dit_fsdp: | |
self.model = shard_fn(self.model) | |
else: | |
self.model.to(self.device) | |
self.sample_neg_prompt = config.sample_neg_prompt | |
def generate(self, | |
input_prompt, | |
size=(1280, 720), | |
frame_num=81, | |
shift=5.0, | |
sample_solver='unipc', | |
sampling_steps=50, | |
guide_scale=5.0, | |
n_prompt="", | |
seed=-1, | |
offload_model=True): | |
r""" | |
Generates video frames from text prompt using diffusion process. | |
Args: | |
input_prompt (`str`): | |
Text prompt for content generation | |
size (tupele[`int`], *optional*, defaults to (1280,720)): | |
Controls video resolution, (width,height). | |
frame_num (`int`, *optional*, defaults to 81): | |
How many frames to sample from a video. The number should be 4n+1 | |
shift (`float`, *optional*, defaults to 5.0): | |
Noise schedule shift parameter. Affects temporal dynamics | |
sample_solver (`str`, *optional*, defaults to 'unipc'): | |
Solver used to sample the video. | |
sampling_steps (`int`, *optional*, defaults to 40): | |
Number of diffusion sampling steps. Higher values improve quality but slow generation | |
guide_scale (`float`, *optional*, defaults 5.0): | |
Classifier-free guidance scale. Controls prompt adherence vs. creativity | |
n_prompt (`str`, *optional*, defaults to ""): | |
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` | |
seed (`int`, *optional*, defaults to -1): | |
Random seed for noise generation. If -1, use random seed. | |
offload_model (`bool`, *optional*, defaults to True): | |
If True, offloads models to CPU during generation to save VRAM | |
Returns: | |
torch.Tensor: | |
Generated video frames tensor. Dimensions: (C, N H, W) where: | |
- C: Color channels (3 for RGB) | |
- N: Number of frames (81) | |
- H: Frame height (from size) | |
- W: Frame width from size) | |
""" | |
# preprocess | |
F = frame_num | |
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, | |
size[1] // self.vae_stride[1], | |
size[0] // self.vae_stride[2]) | |
seq_len = math.ceil((target_shape[2] * target_shape[3]) / | |
(self.patch_size[1] * self.patch_size[2]) * | |
target_shape[1] / self.sp_size) * self.sp_size | |
if n_prompt == "": | |
n_prompt = self.sample_neg_prompt | |
seed = seed if seed >= 0 else random.randint(0, sys.maxsize) | |
seed_g = torch.Generator(device=self.device) | |
seed_g.manual_seed(seed) | |
if not self.t5_cpu: | |
self.text_encoder.model.to(self.device) | |
context = self.text_encoder([input_prompt], self.device) | |
context_null = self.text_encoder([n_prompt], self.device) | |
if offload_model: | |
self.text_encoder.model.cpu() | |
else: | |
context = self.text_encoder([input_prompt], torch.device('cpu')) | |
context_null = self.text_encoder([n_prompt], torch.device('cpu')) | |
context = [t.to(self.device) for t in context] | |
context_null = [t.to(self.device) for t in context_null] | |
noise = [ | |
torch.randn( | |
target_shape[0], | |
target_shape[1], | |
target_shape[2], | |
target_shape[3], | |
dtype=torch.float32, | |
device=self.device, | |
generator=seed_g) | |
] | |
def noop_no_sync(): | |
yield | |
no_sync = getattr(self.model, 'no_sync', noop_no_sync) | |
# evaluation mode | |
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): | |
if sample_solver == 'unipc': | |
sample_scheduler = FlowUniPCMultistepScheduler( | |
num_train_timesteps=self.num_train_timesteps, | |
shift=1, | |
use_dynamic_shifting=False) | |
sample_scheduler.set_timesteps( | |
sampling_steps, device=self.device, shift=shift) | |
timesteps = sample_scheduler.timesteps | |
elif sample_solver == 'dpm++': | |
sample_scheduler = FlowDPMSolverMultistepScheduler( | |
num_train_timesteps=self.num_train_timesteps, | |
shift=1, | |
use_dynamic_shifting=False) | |
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) | |
timesteps, _ = retrieve_timesteps( | |
sample_scheduler, | |
device=self.device, | |
sigmas=sampling_sigmas) | |
else: | |
raise NotImplementedError("Unsupported solver.") | |
# sample videos | |
latents = noise | |
arg_c = {'context': context, 'seq_len': seq_len} | |
arg_null = {'context': context_null, 'seq_len': seq_len} | |
for _, t in enumerate(tqdm(timesteps)): | |
latent_model_input = latents | |
timestep = [t] | |
timestep = torch.stack(timestep) | |
self.model.to(self.device) | |
noise_pred_cond, _ = self.model( | |
latent_model_input, t=timestep, **arg_c) | |
noise_pred_uncond, _ = self.model( | |
latent_model_input, t=timestep, **arg_null) | |
noise_pred_cond, noise_pred_uncond = noise_pred_cond[0], noise_pred_uncond[0] | |
noise_pred = noise_pred_uncond + guide_scale * ( | |
noise_pred_cond - noise_pred_uncond) | |
temp_x0 = sample_scheduler.step( | |
noise_pred.unsqueeze(0), | |
t, | |
latents[0].unsqueeze(0), | |
return_dict=False, | |
generator=seed_g)[0] | |
latents = [temp_x0.squeeze(0)] | |
x0 = latents | |
if offload_model: | |
self.model.cpu() | |
if self.rank == 0: | |
videos = self.vae.decode(x0) | |
del noise, latents | |
del sample_scheduler | |
if offload_model: | |
gc.collect() | |
torch.cuda.synchronize() | |
if dist.is_initialized(): | |
dist.barrier() | |
return videos[0] if self.rank == 0 else None | |
def load_video_frames(self, video_path, size=(832, 480)): | |
r""" | |
Load video frames from the given path and preprocess them. | |
Args: | |
video_path (str): Path to the video file. | |
size (tuple[`int`], *optional*, defaults to (1280,720)): Target resolution for resizing frames. | |
Returns: | |
torch.Tensor: Tensor of video frames with shape (frame_num, C, H, W). | |
""" | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
raise ValueError(f"Cannot open video file: {video_path}") | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert BGR to RGB | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# Resize frame to target size | |
frame = cv2.resize(frame, size, interpolation=cv2.INTER_AREA) | |
# Convert to tensor and normalize to [-1, 1] | |
frame = torch.from_numpy(frame).float().permute(2, 0, 1) / 127.5 - 1.0 | |
# Convert to tensor and normailize to [0, 1] | |
# frame = torch.from_numpy(frame).float().permute(2, 0, 1) / 255 | |
frames.append(frame) | |
cap.release() | |
if not frames: | |
raise ValueError(f"No frames found in video: {video_path}") | |
# Stack frames into a single tensor | |
frames_tensor = torch.stack(frames).permute(1, 0, 2, 3).to(self.device) | |
latents = self.vae.video_encode(frames_tensor) | |
return latents # [C, F, H, W] | |
def find_subtokens_range(self, source_tokens, target_tokens): | |
""" | |
查找 target_tokens 在 source_tokens 中的位置,不包括 ' | |
返回起始索引和结束索引(类似 slice),找不到返回 None | |
""" | |
valid_len = len(source_tokens) | |
# 在 source 的有效范围内滑动窗口匹配 target | |
for i in range(valid_len - len(target_tokens) + 1): | |
if source_tokens[i:i + len(target_tokens)] == target_tokens: | |
# return (i, i + len(target_tokens)-1) | |
return list(range(*(i, i + len(target_tokens)))) | |
return None | |
def create_binary_mask( # Renamed slightly for clarity | |
self, | |
attn_map: torch.Tensor, | |
n: int, | |
pooling_mode: Literal['max', 'avg'] = 'max', | |
threshold: Optional[float] = 0.5, | |
threshold_method: Literal['fixed', 'otsu'] = 'fixed', | |
normalize_per_slice=False | |
) -> torch.Tensor: | |
# --- Preparation --- | |
original_shape = attn_map.shape | |
C, T, H, W = original_shape | |
device = attn_map.device | |
map_batched = attn_map.reshape(C * T, 1, H, W) | |
padding = (n - 1) // 2 | |
if pooling_mode == 'max': | |
smoothed_map_batched = torch.nn.functional.max_pool2d(map_batched, kernel_size=n, stride=1, padding=padding) | |
elif pooling_mode == 'avg': | |
smoothed_map_batched = torch.nn.functional.avg_pool2d(map_batched, kernel_size=n, stride=1, padding=padding) | |
smoothed_map = smoothed_map_batched.squeeze(1).view(C, T, H, W) | |
# smoothed_map has shape (C, T, H, W) | |
map_to_binarize = smoothed_map_batched # Start with the smoothed map | |
if normalize_per_slice: | |
flat_map = map_to_binarize.view(C * T, -1) | |
# Calculate min and max per slice (image in the batch C*T) | |
min_vals = torch.min(flat_map, dim=1, keepdim=True)[0] | |
max_vals = torch.max(flat_map, dim=1, keepdim=True)[0] | |
# Calculate range, handle the case where min == max (flat slice) | |
range_vals = max_vals - min_vals | |
range_vals = torch.where(range_vals == 0, | |
torch.tensor(1.0, device=device, dtype=map_to_binarize.dtype), | |
range_vals) | |
min_vals_b = min_vals.view(C * T, 1, 1, 1) | |
range_vals_b = range_vals.view(C * T, 1, 1, 1) | |
normalized_map_batched = (map_to_binarize - min_vals_b) / range_vals_b | |
map_to_binarize = torch.clamp(normalized_map_batched, 0.0, 1.0) | |
map_to_binarize = map_to_binarize.squeeze(1).view(C, T, H, W) | |
binary_mask = torch.zeros_like(map_to_binarize, dtype=torch.bool, device=device) | |
if threshold_method == 'fixed': | |
if normalize_per_slice: | |
binary_mask = map_to_binarize > torch.mean(map_to_binarize).item() | |
else: | |
binary_mask = map_to_binarize > threshold | |
return binary_mask.to(device=device, dtype=torch.float32) | |
def soften_mask_edges( | |
self, | |
binary_mask: torch.Tensor, | |
decay_factor: float = 0.1 | |
) -> torch.Tensor: | |
""" | |
Softens the edges of a binary mask by assigning values to background pixels (0) | |
based on their distance to the nearest foreground pixel (1). | |
Pixels originally equal to 1 remain 1. | |
Pixels originally equal to 0 get a value exp(-decay_factor * distance), | |
where distance is the Euclidean distance to the nearest 1. | |
This means pixels closer to the original mask edge get values closer to 1, | |
and pixels farther away get values closer to 0. | |
Args: | |
binary_mask (torch.Tensor): The input binary mask. Expected to be a 4D | |
tensor (C, T, H, W) with values 0.0 or 1.0. | |
decay_factor (float): Controls how quickly the softened value decays | |
with distance. A larger value means a faster decay | |
(sharper transition near the edge), a smaller value | |
means a slower decay (softer, more spread-out transition). | |
Defaults to 0.1. | |
Returns: | |
torch.Tensor: A mask tensor with the same dimensions as the input, | |
where original mask areas are 1.0 and background areas | |
have softened values based on distance. Output is float32. | |
Raises: | |
TypeError: If binary_mask is not a PyTorch Tensor. | |
ValueError: If binary_mask is not 4D. | |
ImportError: If scipy is not installed. | |
""" | |
from scipy.ndimage import distance_transform_edt | |
# --- Input Validation --- | |
if not isinstance(binary_mask, torch.Tensor): | |
raise TypeError("binary_mask must be a PyTorch Tensor.") | |
if binary_mask.ndim != 4: | |
raise ValueError(f"Input binary_mask must be 4D (C, T, H, W), got {binary_mask.ndim}D") | |
if not decay_factor > 0: | |
raise ValueError("decay_factor must be positive.") | |
# --- Preparation --- | |
original_shape = binary_mask.shape | |
C, T, H, W = original_shape | |
device = binary_mask.device | |
dtype = torch.float32 # Ensure output is float | |
# Create an output tensor initialized with zeros | |
softened_mask = torch.zeros_like(binary_mask, dtype=dtype, device=device) | |
# --- Process each slice (C, T) independently --- | |
for c in range(C): | |
for t in range(T): | |
# Extract the 2D slice | |
mask_slice = binary_mask[c, t, :, :] | |
# Move to CPU and convert to NumPy boolean array for SciPy | |
# distance_transform_edt expects background (0) to be True | |
# and foreground (1) to be False. | |
inverted_mask_slice_np = (mask_slice == 0).cpu().numpy() | |
# Compute Euclidean Distance Transform | |
# distance_map_np contains the distance from each True pixel (background) | |
# to the nearest False pixel (foreground/mask). | |
# Pixels that were originally part of the mask (False in inverted) will have distance 0. | |
distance_map_np = distance_transform_edt(inverted_mask_slice_np) | |
# Convert distances to softened values (0 to 1) using exponential decay | |
# exp(-k * distance). Larger distance -> smaller value. | |
# Distance 0 (original mask) -> exp(0) = 1. | |
# Add a small epsilon to distance before applying exp if you want to strictly avoid 1.0 in non-mask areas | |
# but exp(-k*dist) naturally handles this. | |
softened_values_np = np.exp(-decay_factor * distance_map_np) | |
# Convert back to PyTorch tensor and move to the original device | |
softened_values_slice = torch.from_numpy(softened_values_np).to(device=device, dtype=dtype) | |
# --- Combine original mask and softened background --- | |
# Use torch.where for clarity and efficiency: | |
# Where the original mask was 1, keep 1.0. | |
# Where the original mask was 0, use the calculated softened value. | |
final_slice = torch.where( | |
mask_slice.bool(), # Condition: True where original mask was 1 | |
torch.tensor(1.0, device=device, dtype=dtype), # Value if True | |
softened_values_slice # Value if False (use calculated softened value) | |
) | |
# Place the processed slice into the output tensor | |
softened_mask[c, t, :, :] = final_slice | |
return softened_mask | |
def generate_conflict_map( # Renamed slightly for clarity | |
self, | |
vsrc: torch.Tensor, | |
vtar: torch.Tensor, | |
normalize: bool = True, # Setting default to True as it simplifies tuning 'k' later | |
norm_type: int = 2, | |
epsilon: float = 1e-6 | |
) -> torch.Tensor: | |
""" | |
Generates a conflict map indicating the magnitude of difference between | |
two velocity fields (vsrc and vtar) and returns it as a 4D tensor. | |
The conflict at each spatial-temporal location is defined as the L-p norm | |
(default L2, Euclidean distance) of the difference vector between vsrc and | |
vtar along the channel dimension. The resulting scalar map (F, H, W) is | |
then expanded to match the input shape (C, F, H, W) by repeating the | |
scalar value across the channel dimension. | |
Args: | |
vsrc (torch.Tensor): The source velocity field tensor. Expected shape: | |
(Channels, Frames, Height, Width). | |
vtar (torch.Tensor): The target velocity field tensor. Must have the | |
same shape and device as vsrc. | |
normalize (bool, optional): If True, normalize the underlying 3D conflict | |
map to the range [0, 1] using min-max scaling | |
before expanding it to 4D. Recommended for | |
easier tuning of 'k' in downstream functions. | |
Defaults to True. | |
norm_type (int, optional): The order of the norm (p-norm). Default is 2 (L2 norm). | |
Use 1 for L1 norm (Manhattan distance), etc. | |
epsilon (float, optional): A small value added to the denominator during | |
normalization to prevent division by zero if | |
all conflict values are identical. Defaults to 1e-6. | |
Returns: | |
torch.Tensor: The conflict map tensor, expanded to 4D. | |
Shape: (Channels, Frames, Height, Width). The value is | |
uniform across the channel dimension for each (F,H,W). | |
Raises: | |
TypeError: If inputs are not PyTorch tensors. | |
ValueError: If input tensors do not have the same shape or are not 4D. | |
""" | |
# --- Input Validation --- | |
if not isinstance(vsrc, torch.Tensor) or not isinstance(vtar, torch.Tensor): | |
raise TypeError("Inputs vsrc and vtar must be PyTorch tensors.") | |
if vsrc.ndim != 4 or vtar.ndim != 4: | |
raise ValueError(f"Input tensors must be 4D (C, F, H, W), " | |
f"got shapes {vsrc.shape} and {vtar.shape}") | |
if vsrc.shape != vtar.shape: | |
raise ValueError(f"Input tensors vsrc and vtar must have the same shape, " | |
f"got {vsrc.shape} and {vtar.shape}") | |
if vsrc.device != vtar.device: | |
logging.warning(f"Input tensors are on different devices ({vsrc.device}, {vtar.device}). " | |
f"Proceeding with calculations, but ensure this is intended.") | |
# --- Conflict Calculation --- | |
vsrc_float = vsrc.float() | |
vtar_float = vtar.float() | |
num_channels = vsrc.shape[0] | |
difference = vtar_float - vsrc_float | |
# Calculate the norm along the channel dimension -> shape (F, H, W) | |
conflict_map_3d = torch.norm(difference, p=norm_type, dim=0) | |
# --- Optional Normalization (applied to the 3D map) --- | |
if normalize: | |
# Avoid normalization issues on empty tensors | |
if conflict_map_3d.numel() > 0: | |
min_val = torch.min(conflict_map_3d) | |
max_val = torch.max(conflict_map_3d) | |
denominator = max_val - min_val | |
if denominator < epsilon: | |
logging.warning(f"Conflict map values are nearly constant ({min_val.item()} to {max_val.item()}). " | |
f"Normalization results in a map of all zeros.") | |
conflict_map_3d = torch.zeros_like(conflict_map_3d) | |
else: | |
conflict_map_3d = (conflict_map_3d - min_val) / denominator | |
else: | |
logging.warning("Conflict map has zero elements, skipping normalization.") | |
# --- Expand to 4D --- | |
# Add channel dim and expand | |
conflict_map_4d = conflict_map_3d.unsqueeze(0).expand(num_channels, -1, -1, -1) | |
return conflict_map_4d | |
def compute_dynamic_source_mask( # Renamed slightly for clarity | |
self, | |
conflict_map_4d: torch.Tensor, | |
k: float, | |
function_type: str = 'exponential_squared', | |
clamp_output: bool = True | |
# warn_threshold removed as normalization is best done in the generating function | |
) -> torch.Tensor: | |
""" | |
Computes the Dynamic Source Mask M(p, t) based on a 4D conflict map input. | |
Applies a chosen function element-wise to the conflict map values to generate | |
the mask. High conflict should result in a mask value near 0, while low | |
conflict should result in a value near 1. Assumes the input conflict map | |
may have redundant values across the channel dimension. | |
Args: | |
conflict_map_4d (torch.Tensor): A tensor representing the conflict C(p, t) | |
between Vsrc and Vtar. Expected shape: | |
(Channels, Frames, Height, Width). | |
Normalizing this input (e.g., via | |
generate_conflict_map_4d with normalize=True) | |
is recommended for easier tuning of 'k'. | |
k (float): A positive sensitivity parameter. Controls how aggressively | |
the mask value decreases as conflict increases. | |
function_type (str, optional): Specifies the function f(C) used to map | |
conflict C to the mask value M element-wise. | |
Options: | |
- 'exponential_squared': M = exp(-k * C^2) | |
- 'exponential': M = exp(-k * C) | |
- 'inverse': M = 1 / (1 + k * C) | |
- 'inverse_squared': M = 1 / (1 + k * C^2) | |
Defaults to 'exponential_squared'. | |
clamp_output (bool, optional): If True, clamps the output mask values | |
strictly to the range [0.0, 1.0]. | |
Defaults to True. | |
Returns: | |
torch.Tensor: The dynamic source mask tensor M(p, t). | |
Shape: (Channels, Frames, Height, Width). | |
Raises: | |
TypeError: If conflict_map_4d is not a PyTorch tensor. | |
ValueError: If conflict_map_4d is not 4D, k is not positive, or an | |
invalid function_type is provided. | |
""" | |
# --- Input Validation --- | |
if not isinstance(conflict_map_4d, torch.Tensor): | |
raise TypeError("Input conflict_map_4d must be a PyTorch tensor.") | |
if conflict_map_4d.ndim != 4: | |
raise ValueError(f"Input conflict_map_4d must be 4D (C, F, H, W), " | |
f"got {conflict_map_4d.ndim}D shape {conflict_map_4d.shape}") | |
# num_channels = conflict_map_4d.shape[0] # Get C if needed elsewhere | |
if not isinstance(k, (int, float)) or k <= 0: | |
raise ValueError(f"Parameter k must be a positive number, got {k}") | |
supported_functions = ['exponential_squared', 'exponential', 'inverse', 'inverse_squared'] | |
if function_type not in supported_functions: | |
raise ValueError(f"Invalid function_type '{function_type}'. " | |
f"Supported types are: {supported_functions}") | |
# --- Mask Calculation (Applied element-wise on 4D tensor) --- | |
# Note: conflict_map_4d already includes the (potentially redundant) channel dim | |
conflict_map_float = conflict_map_4d.float() | |
k_float = float(k) | |
if function_type == 'exponential_squared': | |
# Applies exp(-k * C^2) element-wise | |
mask_4d = torch.exp(-k_float * torch.pow(conflict_map_float, 2)) | |
elif function_type == 'exponential': | |
# Applies exp(-k * C) element-wise | |
mask_4d = torch.exp(-k_float * conflict_map_float) | |
elif function_type == 'inverse': | |
# Applies 1 / (1 + k * C) element-wise | |
mask_4d = 1.0 / (1.0 + k_float * conflict_map_float) | |
elif function_type == 'inverse_squared': | |
# Applies 1 / (1 + k * C^2) element-wise | |
mask_4d = 1.0 / (1.0 + k_float * torch.pow(conflict_map_float, 2)) | |
else: | |
raise NotImplementedError(f"Function type {function_type} calculation not implemented.") | |
# --- Optional Clamping --- | |
if clamp_output: | |
mask_4d = torch.clamp(mask_4d, min=0.0, max=1.0) | |
# --- Return the 4D Mask --- | |
# No expansion needed, it's already 4D | |
return mask_4d | |
def edit(self, | |
target_prompt, | |
size=(832, 480), | |
frame_num=81, | |
shift=5.0, | |
sample_solver='unipc', | |
sampling_steps=50, | |
guide_scale=5.0, | |
tar_guide_scale=10.0, | |
n_prompt="", | |
seed=-1, | |
offload_model=True, | |
source_video_path=None, | |
source_prompt=None, | |
nmax_step=50, | |
nmin_step=0, | |
n_avg=5, | |
worse_avg=3, | |
omega=3, | |
source_words=None, | |
target_words=None, | |
window_size=11, | |
decay_factor=0.1, | |
tmd_window_size=11, | |
tmd_stride=8 | |
): | |
# preprocess | |
F = frame_num | |
W, H = size | |
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, | |
size[1] // self.vae_stride[1], size[0] // self.vae_stride[2]) | |
seq_len = math.ceil((target_shape[2] * target_shape[3]) / | |
(self.patch_size[1] * self.patch_size[2]) * | |
target_shape[1] / self.sp_size) * self.sp_size | |
if n_prompt == "": | |
n_prompt = self.sample_neg_prompt | |
seed = seed if seed >= 0 else random.randint(0, sys.maxsize) | |
# seed_g = torch.Generator(device=self.device) | |
torch.manual_seed(seed) | |
# 加载源视频潜在表示和参考图像 | |
x_src = self.load_video_frames(source_video_path) | |
C_latent, F_latent, H_latent, W_latent = x_src.shape | |
# Validate TMD parameters | |
if tmd_window_size > F_latent: | |
logging.warning(f"tmd_window_size ({tmd_window_size}) > latent frames ({F_latent}). Using full sequence as one window.") | |
tmd_window_size = F_latent | |
tmd_stride = F_latent | |
elif tmd_stride <= 0: | |
logging.warning(f"Invalid tmd_stride ({tmd_stride}). Setting stride to window size / 2.") | |
tmd_stride = max(1, tmd_window_size // 2) | |
# 计算提示词相对位置 | |
source_words_idx = None | |
target_words_idx = None | |
if source_words: | |
tk1 = self.text_encoder.tokenizer.tokenizer.tokenize(source_prompt, add_special_tokens=True) | |
tk2 = self.text_encoder.tokenizer.tokenizer.tokenize(source_words, add_special_tokens=True) | |
source_words_idx = self.find_subtokens_range(tk1[:-1], tk2[:-1]) | |
if target_words: | |
tk1 = self.text_encoder.tokenizer.tokenizer.tokenize(target_prompt, add_special_tokens=True) | |
tk2 = self.text_encoder.tokenizer.tokenizer.tokenize(target_words, add_special_tokens=True) | |
target_words_idx = self.find_subtokens_range(tk1[:-1], tk2[:-1]) | |
# 编码源和目标文本提示 | |
if not self.t5_cpu: | |
self.text_encoder.model.to(self.device) | |
context_src = self.text_encoder([source_prompt], self.device) | |
context_tar = self.text_encoder([target_prompt], self.device) | |
context_null = self.text_encoder([n_prompt], self.device) | |
if offload_model: | |
self.text_encoder.model.cpu() | |
else: | |
context_tar = self.text_encoder([target_prompt], torch.device('cpu')) | |
context_src = self.text_encoder([source_prompt], torch.device('cpu')) | |
context_null = self.text_encoder([n_prompt], torch.device('cpu')) | |
context_src = [t.to(self.device) for t in context_src] | |
context_tar = [t.to(self.device) for t in context_tar] | |
context_null = [t.to(self.device) for t in context_null] | |
arg_src_c = {'context': context_src, 'seq_len': seq_len, 'words_indices': source_words_idx, 'block_id': 18, 'type': 'src'} | |
arg_tar_c = {'context': context_tar, 'seq_len': seq_len, 'words_indices': target_words_idx, 'block_id': 18, 'type': 'tar'} | |
arg_unc = {'context': context_null, 'seq_len': seq_len} | |
# 初始化编辑路径 | |
zt_edit = x_src.clone() # [16, 21, 60, 104]: C x Frames x H x W | |
conflict_mask = torch.ones_like(x_src) | |
# 设置采样调度器 | |
if sample_solver == 'unipc': | |
sample_scheduler = FlowUniPCMultistepScheduler( | |
num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False, solver_order=2) | |
sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) | |
timesteps = sample_scheduler.timesteps | |
elif sample_solver == 'dpm++': | |
sample_scheduler = FlowDPMSolverMultistepScheduler( | |
num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False, solver_order=1) | |
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) | |
timesteps, _ = retrieve_timesteps(sample_scheduler, device=self.device, sigmas=sampling_sigmas) | |
else: | |
raise NotImplementedError("Unsupported solver.") | |
self.model.to(self.device) | |
# 编辑过程 | |
with amp.autocast(dtype=self.param_dtype), torch.no_grad(): | |
for index, t in enumerate(tqdm(timesteps)): | |
t_next = timesteps[timesteps.tolist().index(t) + 1] if t > timesteps[-1] else 0 | |
arg_src_c["timestep"] = t | |
arg_tar_c["timestep"] = t | |
arg_unc["timestep"] = t | |
timestep = torch.tensor([t], device=self.device) | |
relative_index = nmax_step - (sampling_steps - index) | |
v_list = [] | |
if sampling_steps - (index + 1) >= nmax_step: | |
continue | |
if sampling_steps - (index + 1) >= nmin_step: | |
t_i = t / 1000.0 | |
t_im1 = t_next / 1000.0 | |
v_delta_sum = torch.zeros_like(zt_edit) | |
v_worse = torch.zeros_like(zt_edit) | |
v_mask = torch.zeros_like(zt_edit) | |
for time in range(n_avg): | |
# --- Temporal MultiDiffusion Logic --- | |
V_delta_accumulator = torch.zeros_like(zt_edit) | |
window_counts = torch.zeros_like(zt_edit) # Use float for division later | |
window_mask_sum = torch.zeros_like(zt_edit) | |
fwd_noise = torch.randn_like(zt_edit[:, 0:tmd_window_size, :, :], device=self.device) | |
# for f_start in range(0, F_latent, tmd_stride): | |
if tmd_window_size >= F_latent: | |
window_starts = [0] | |
else: | |
window_starts = list(range(0, F_latent - tmd_window_size, tmd_stride)) | |
last_possible_start = F_latent - tmd_window_size | |
if not window_starts or window_starts[-1] < last_possible_start: | |
if last_possible_start >= 0: | |
window_starts.append(last_possible_start) | |
for f_start in window_starts: | |
f_end = F_latent if tmd_window_size >= F_latent else f_start + tmd_window_size | |
# --- Calculate V_delta within the window --- | |
# Extract window slices | |
zt_edit_w = zt_edit[:, f_start:f_end, :, :] | |
x_src_w = x_src[:, f_start:f_end, :, :] | |
# 计算 zt_src | |
zt_src = (1 - t_i) * x_src_w + t_i * fwd_noise | |
zt_tar = zt_edit_w + zt_src - x_src_w | |
# 计算源和目标噪声预测 | |
noise_pred_src, src_attn_map = self.model([zt_src], t=timestep, **arg_src_c) | |
noise_pred_tar, tar_attn_map = self.model([zt_tar], t=timestep, **arg_tar_c) | |
noise_pred_src, noise_pred_tar = noise_pred_src[0], noise_pred_tar[0] | |
# uncond | |
noise_pred_uncond_src, _ = self.model([zt_src], t=timestep, **arg_unc) | |
noise_pred_uncond_tar, _ = self.model([zt_tar], t=timestep, **arg_unc) | |
noise_pred_uncond_src, noise_pred_uncond_tar = noise_pred_uncond_src[0], noise_pred_uncond_tar[0] | |
# 计算引导后的噪声预测 | |
noise_pred_src_guided = noise_pred_uncond_src + guide_scale * (noise_pred_src - noise_pred_uncond_src) | |
noise_pred_tar_guided = noise_pred_uncond_tar + tar_guide_scale * (noise_pred_tar - noise_pred_uncond_tar) | |
sum_attn_mask = torch.zeros_like(zt_edit_w) | |
conflict_map = self.generate_conflict_map(noise_pred_src_guided, noise_pred_tar_guided) | |
raw_mask = self.compute_dynamic_source_mask(conflict_map, 0.5) | |
clamped_mask = torch.clamp(raw_mask, min=0.0, max=1.0) | |
conflict_mask = 5.0 - 4.0 * clamped_mask | |
if src_attn_map is not None: | |
src_attn_mask = self.create_binary_mask(src_attn_map, | |
n=window_size, | |
pooling_mode='avg', | |
threshold=torch.mean(src_attn_map).item(), | |
threshold_method='fixed') | |
sum_attn_mask += src_attn_mask | |
if tar_attn_map is not None: | |
tar_attn_mask = self.create_binary_mask(tar_attn_map, | |
n=window_size, | |
pooling_mode='avg', | |
threshold=torch.mean(tar_attn_map).item(), | |
threshold_method='fixed') | |
sum_attn_mask += tar_attn_mask | |
sum_attn_mask = torch.clamp(sum_attn_mask, min=0.0, max=1.0) | |
V_delta = noise_pred_tar_guided - noise_pred_src_guided | |
# Accumulate results | |
V_delta_accumulator[:, f_start:f_end, :, :] += V_delta | |
window_counts[:, f_start:f_end, :, :] += 1.0 | |
window_mask_sum[:, f_start:f_end, :, :] += sum_attn_mask | |
V_delta_final = V_delta_accumulator / torch.clamp(window_counts, min=1.0) # Avoid division by zero | |
v_delta_sum += V_delta_final | |
v_list.append(V_delta_final) | |
v_mask += window_mask_sum | |
if time < worse_avg: | |
v_worse += V_delta_final | |
v_list = torch.stack(v_list, dim=0) | |
V_delta_better = v_list.mean(dim=0) | |
v_trend = [] | |
for worse_set in itertools.combinations(v_list, worse_avg): | |
v_worse = torch.zeros_like(V_delta_better) | |
for worse in worse_set: | |
v_worse += worse | |
v_worse = v_worse / worse_avg | |
v_trend.append(V_delta_better - v_worse) | |
v_trend = torch.stack(v_trend, dim=0) | |
v_trend = v_trend.mean(dim=0) | |
v_mask = torch.clamp(v_mask, min=0.0, max=1.0) | |
v_mask = self.soften_mask_edges(v_mask, decay_factor=decay_factor) | |
V_delta_final = V_delta_better + (omega - 1) * v_trend | |
V_delta_final = V_delta_final * v_mask | |
zt_edit = zt_edit + (t_im1 - t_i) * V_delta_final | |
else: | |
# 使用tar进行采样 | |
noise_pred_uncond, _ = self.model([zt_edit], t=timestep, context=context_null, seq_len=seq_len) | |
noise_pred_cond, _ = self.model([zt_edit], t=timestep, context=context_tar, seq_len=seq_len) | |
noise_pred_cond, noise_pred_uncond = noise_pred_cond[0], noise_pred_uncond[0] | |
noise_pred = noise_pred_uncond + 6 * ( | |
noise_pred_cond - noise_pred_uncond) | |
temp_x0 = sample_scheduler.step( | |
noise_pred.unsqueeze(0), | |
t, | |
zt_edit.unsqueeze(0), | |
return_dict=False)[0] | |
zt_edit = temp_x0.squeeze(0) | |
# 解码编辑结果 | |
if offload_model: | |
self.model.cpu() | |
if self.rank == 0: | |
videos = self.vae.decode([zt_edit]) | |
return videos[0] | |
return None | |