Spaces:
Runtime error
Runtime error
import argparse | |
from typing import Optional | |
from PIL import Image | |
import torch | |
import torchvision.transforms.functional as TF | |
from tqdm import tqdm | |
from accelerate import Accelerator, init_empty_weights | |
# --- FIX STARTS HERE --- | |
import torch._dynamo | |
torch._dynamo.config.suppress_errors = True | |
# --- FIX ENDS HERE --- | |
from dataset.image_video_dataset import ARCHITECTURE_WAN, ARCHITECTURE_WAN_FULL | |
from hv_generate_video import resize_image_to_bucket | |
from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
from utils import model_utils | |
from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen | |
from wan.configs import WAN_CONFIGS | |
from wan.modules.clip import CLIPModel | |
from wan.modules.model import WanModel, detect_wan_sd_dtype, load_wan_model | |
from wan.modules.t5 import T5EncoderModel | |
from wan.modules.vae import WanVAE | |
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
class WanNetworkTrainer(NetworkTrainer): | |
def __init__(self): | |
super().__init__() | |
# region model specific | |
def architecture(self) -> str: | |
return ARCHITECTURE_WAN | |
def architecture_full_name(self) -> str: | |
return ARCHITECTURE_WAN_FULL | |
def handle_model_specific_args(self, args): | |
self.config = WAN_CONFIGS[args.task] | |
self._i2v_training = "i2v" in args.task | |
self.dit_dtype = detect_wan_sd_dtype(args.dit) | |
if self.dit_dtype == torch.float16: | |
assert args.mixed_precision in ["fp16", "no"], "DiT weights are in fp16, mixed precision must be fp16 or no" | |
elif self.dit_dtype == torch.bfloat16: | |
assert args.mixed_precision in ["bf16", "no"], "DiT weights are in bf16, mixed precision must be bf16 or no" | |
if args.fp8_scaled and self.dit_dtype.itemsize == 1: | |
raise ValueError( | |
"DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください" | |
) | |
args.dit_dtype = model_utils.dtype_to_str(self.dit_dtype) | |
def i2v_training(self) -> bool: | |
return self._i2v_training | |
def process_sample_prompts( | |
self, | |
args: argparse.Namespace, | |
accelerator: Accelerator, | |
sample_prompts: str, | |
): | |
config = self.config | |
device = accelerator.device | |
t5_path, clip_path, fp8_t5 = args.t5, args.clip, args.fp8_t5 | |
logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}") | |
prompts = load_prompts(sample_prompts) | |
def encode_for_text_encoder(text_encoder): | |
sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask) | |
# with accelerator.autocast(), torch.no_grad(): # this causes NaN if dit_dtype is fp16 | |
t5_dtype = config.t5_dtype | |
with torch.amp.autocast(device_type=device.type, dtype=t5_dtype), torch.no_grad(): | |
for prompt_dict in prompts: | |
if "negative_prompt" not in prompt_dict: | |
prompt_dict["negative_prompt"] = self.config["sample_neg_prompt"] | |
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", None)]: | |
if p is None: | |
continue | |
if p not in sample_prompts_te_outputs: | |
logger.info(f"cache Text Encoder outputs for prompt: {p}") | |
prompt_outputs = text_encoder([p], device) | |
sample_prompts_te_outputs[p] = prompt_outputs | |
return sample_prompts_te_outputs | |
# Load Text Encoder 1 and encode | |
logger.info(f"loading T5: {t5_path}") | |
t5 = T5EncoderModel(text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=t5_path, fp8=fp8_t5) | |
logger.info("encoding with Text Encoder 1") | |
te_outputs_1 = encode_for_text_encoder(t5) | |
del t5 | |
# load CLIP and encode image (for I2V training) | |
sample_prompts_image_embs = {} | |
for prompt_dict in prompts: | |
if prompt_dict.get("image_path", None) is not None: | |
sample_prompts_image_embs[prompt_dict["image_path"]] = None | |
if len(sample_prompts_image_embs) > 0: | |
logger.info(f"loading CLIP: {clip_path}") | |
assert clip_path is not None, "CLIP path is required for I2V training / I2V学習にはCLIPのパスが必要です" | |
clip = CLIPModel(dtype=config.clip_dtype, device=device, weight_path=clip_path) | |
clip.model.to(device) | |
logger.info(f"Encoding image to CLIP context") | |
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): | |
for image_path in sample_prompts_image_embs: | |
logger.info(f"Encoding image: {image_path}") | |
img = Image.open(image_path).convert("RGB") | |
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) # -1 to 1 | |
clip_context = clip.visual([img[:, None, :, :]]) | |
sample_prompts_image_embs[image_path] = clip_context | |
del clip | |
clean_memory_on_device(device) | |
# prepare sample parameters | |
sample_parameters = [] | |
for prompt_dict in prompts: | |
prompt_dict_copy = prompt_dict.copy() | |
p = prompt_dict.get("prompt", "") | |
prompt_dict_copy["t5_embeds"] = te_outputs_1[p][0] | |
p = prompt_dict.get("negative_prompt", None) | |
if p is not None: | |
prompt_dict_copy["negative_t5_embeds"] = te_outputs_1[p][0] | |
p = prompt_dict.get("image_path", None) | |
if p is not None: | |
prompt_dict_copy["clip_embeds"] = sample_prompts_image_embs[p] | |
sample_parameters.append(prompt_dict_copy) | |
clean_memory_on_device(accelerator.device) | |
return sample_parameters | |
def do_inference( | |
self, | |
accelerator, | |
args, | |
sample_parameter, | |
vae, | |
dit_dtype, | |
transformer, | |
discrete_flow_shift, | |
sample_steps, | |
width, | |
height, | |
frame_count, | |
generator, | |
do_classifier_free_guidance, | |
guidance_scale, | |
cfg_scale, | |
image_path=None, | |
): | |
"""architecture dependent inference""" | |
model: WanModel = transformer | |
device = accelerator.device | |
if cfg_scale is None: | |
cfg_scale = 5.0 | |
do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0 | |
# Calculate latent video length based on VAE version | |
latent_video_length = (frame_count - 1) // self.config["vae_stride"][0] + 1 | |
# Get embeddings | |
context = sample_parameter["t5_embeds"].to(device=device) | |
if do_classifier_free_guidance: | |
context_null = sample_parameter["negative_t5_embeds"].to(device=device) | |
else: | |
context_null = None | |
num_channels_latents = 16 # model.in_dim | |
vae_scale_factor = self.config["vae_stride"][1] | |
# Initialize latents | |
lat_h = height // vae_scale_factor | |
lat_w = width // vae_scale_factor | |
shape_or_frame = (1, num_channels_latents, 1, lat_h, lat_w) | |
latents = [] | |
for _ in range(latent_video_length): | |
latents.append(torch.randn(shape_or_frame, generator=generator, device=device, dtype=dit_dtype)) | |
latents = torch.cat(latents, dim=2) | |
if self.i2v_training: | |
# Move VAE to the appropriate device for sampling: consider to cache image latents in CPU in advance | |
vae.to(device) | |
vae.eval() | |
image = Image.open(image_path) | |
image = resize_image_to_bucket(image, (width, height)) # returns a numpy array | |
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).float() # C, 1, H, W | |
image = image / 127.5 - 1 # -1 to 1 | |
# Create mask for the required number of frames | |
msk = torch.ones(1, frame_count, lat_h, lat_w, device=device) | |
msk[:, 1:] = 0 | |
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) | |
msk = msk.transpose(1, 2) # B, C, T, H, W | |
with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): | |
# Zero padding for the required number of frames only | |
padding_frames = frame_count - 1 # The first frame is the input image | |
image = torch.concat([image, torch.zeros(3, padding_frames, height, width)], dim=1).to(device=device) | |
y = vae.encode([image])[0] | |
y = y[:, :latent_video_length] # may be not needed | |
y = y.unsqueeze(0) # add batch dim | |
image_latents = torch.concat([msk, y], dim=1) | |
vae.to("cpu") | |
clean_memory_on_device(device) | |
else: | |
image_latents = None | |
# use the default value for num_train_timesteps (1000) | |
scheduler = FlowUniPCMultistepScheduler(shift=1, use_dynamic_shifting=False) | |
scheduler.set_timesteps(sample_steps, device=device, shift=discrete_flow_shift) | |
timesteps = scheduler.timesteps | |
# Generate noise for the required number of frames only | |
noise = torch.randn(16, latent_video_length, lat_h, lat_w, dtype=torch.float32, generator=generator, device=device).to( | |
"cpu" | |
) | |
# prepare the model input | |
max_seq_len = latent_video_length * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) | |
arg_c = {"context": [context], "seq_len": max_seq_len} | |
arg_null = {"context": [context_null], "seq_len": max_seq_len} | |
if self.i2v_training: | |
# I2V training | |
arg_c["clip_fea"] = sample_parameter["clip_embeds"].to(device=device, dtype=dit_dtype) | |
arg_c["y"] = image_latents | |
arg_null["clip_fea"] = arg_c["clip_fea"] | |
arg_null["y"] = image_latents | |
# Wrap the inner loop with tqdm to track progress over timesteps | |
prompt_idx = sample_parameter.get("enum", 0) | |
latent = noise | |
with torch.no_grad(): | |
for i, t in enumerate(tqdm(timesteps, desc=f"Sampling timesteps for prompt {prompt_idx+1}")): | |
latent_model_input = [latent.to(device=device)] | |
timestep = t.unsqueeze(0) | |
with accelerator.autocast(): | |
noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to("cpu") | |
if do_classifier_free_guidance: | |
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to("cpu") | |
else: | |
noise_pred_uncond = None | |
if do_classifier_free_guidance: | |
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond) | |
else: | |
noise_pred = noise_pred_cond | |
temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator)[0] | |
latent = temp_x0.squeeze(0) | |
# Move VAE to the appropriate device for sampling | |
vae.to(device) | |
vae.eval() | |
# Decode latents to video | |
logger.info(f"Decoding video from latents: {latent.shape}") | |
latent = latent.unsqueeze(0) # add batch dim | |
latent = latent.to(device=device) | |
with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): | |
video = vae.decode(latent)[0] # vae returns list | |
video = video.unsqueeze(0) # add batch dim | |
del latent | |
logger.info(f"Decoding complete") | |
video = video.to(torch.float32).cpu() | |
video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1 | |
vae.to("cpu") | |
clean_memory_on_device(device) | |
return video | |
def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str): | |
vae_path = args.vae | |
logger.info(f"Loading VAE model from {vae_path}") | |
cache_device = torch.device("cpu") if args.vae_cache_cpu else None | |
vae = WanVAE(vae_path=vae_path, device="cpu", dtype=vae_dtype, cache_device=cache_device) | |
return vae | |
def load_transformer( | |
self, | |
accelerator: Accelerator, | |
args: argparse.Namespace, | |
dit_path: str, | |
attn_mode: str, | |
split_attn: bool, | |
loading_device: str, | |
dit_weight_dtype: Optional[torch.dtype], | |
): | |
model = load_wan_model( | |
self.config, | |
self.i2v_training, | |
accelerator.device, | |
dit_path, | |
attn_mode, | |
split_attn, | |
loading_device, | |
dit_weight_dtype, | |
args.fp8_scaled, | |
) | |
return model | |
def scale_shift_latents(self, latents): | |
return latents | |
def call_dit( | |
self, | |
args: argparse.Namespace, | |
accelerator: Accelerator, | |
transformer, | |
latents: torch.Tensor, | |
batch: dict[str, torch.Tensor], | |
noise: torch.Tensor, | |
noisy_model_input: torch.Tensor, | |
timesteps: torch.Tensor, | |
network_dtype: torch.dtype, | |
): | |
model: WanModel = transformer | |
# I2V training | |
if self.i2v_training: | |
image_latents = batch["latents_image"] | |
clip_fea = batch["clip"] | |
image_latents = image_latents.to(device=accelerator.device, dtype=network_dtype) | |
clip_fea = clip_fea.to(device=accelerator.device, dtype=network_dtype) | |
else: | |
image_latents = None | |
clip_fea = None | |
context = [t.to(device=accelerator.device, dtype=network_dtype) for t in batch["t5"]] | |
# ensure the hidden state will require grad | |
if args.gradient_checkpointing: | |
noisy_model_input.requires_grad_(True) | |
for t in context: | |
t.requires_grad_(True) | |
if image_latents is not None: | |
image_latents.requires_grad_(True) | |
if clip_fea is not None: | |
clip_fea.requires_grad_(True) | |
# call DiT | |
lat_f, lat_h, lat_w = latents.shape[2:5] | |
seq_len = lat_f * lat_h * lat_w // (self.config.patch_size[0] * self.config.patch_size[1] * self.config.patch_size[2]) | |
latents = latents.to(device=accelerator.device, dtype=network_dtype) | |
noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype) | |
with accelerator.autocast(): | |
model_pred = model(noisy_model_input, t=timesteps, context=context, clip_fea=clip_fea, seq_len=seq_len, y=image_latents) | |
model_pred = torch.stack(model_pred, dim=0) # list to tensor | |
# flow matching loss | |
target = noise - latents | |
return model_pred, target | |
# endregion model specific | |
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | |
"""Wan2.1 specific parser setup""" | |
parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") | |
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") | |
parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") | |
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") | |
parser.add_argument( | |
"--clip", | |
type=str, | |
default=None, | |
help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required", | |
) | |
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") | |
return parser | |
if __name__ == "__main__": | |
parser = setup_parser_common() | |
parser = wan_setup_parser(parser) | |
args = parser.parse_args() | |
args = read_config_from_file(args, parser) | |
args.dit_dtype = None # automatically detected | |
if args.vae_dtype is None: | |
args.vae_dtype = "bfloat16" # make bfloat16 as default for VAE | |
trainer = WanNetworkTrainer() | |
trainer.train(args) |