wan2-1-video-generation / scripts /generate_ode_pairs.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
from utils.distributed import launch_distributed_job
from utils.scheduler import FlowMatchScheduler
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
from utils.dataset import TextDataset
import torch.distributed as dist
from tqdm import tqdm
import argparse
import torch
import math
import os
def init_model(device):
model = WanDiffusionWrapper().to(device).to(torch.float32)
encoder = WanTextEncoder().to(device).to(torch.float32)
model.model.requires_grad_(False)
scheduler = FlowMatchScheduler(
shift=8.0, sigma_min=0.0, extra_one_step=True)
scheduler.set_timesteps(num_inference_steps=48, denoising_strength=1.0)
scheduler.sigmas = scheduler.sigmas.to(device)
sample_neg_prompt = '่‰ฒ่ฐƒ่‰ณไธฝ๏ผŒ่ฟ‡ๆ›๏ผŒ้™ๆ€๏ผŒ็ป†่Š‚ๆจก็ณŠไธๆธ…๏ผŒๅญ—ๅน•๏ผŒ้ฃŽๆ ผ๏ผŒไฝœๅ“๏ผŒ็”ปไฝœ๏ผŒ็”ป้ข๏ผŒ้™ๆญข๏ผŒๆ•ดไฝ“ๅ‘็ฐ๏ผŒๆœ€ๅทฎ่ดจ้‡๏ผŒไฝŽ่ดจ้‡๏ผŒJPEGๅŽ‹็ผฉๆฎ‹็•™๏ผŒไธ‘้™‹็š„๏ผŒๆฎ‹็ผบ็š„๏ผŒๅคšไฝ™็š„ๆ‰‹ๆŒ‡๏ผŒ็”ปๅพ—ไธๅฅฝ็š„ๆ‰‹้ƒจ๏ผŒ็”ปๅพ—ไธๅฅฝ็š„่„ธ้ƒจ๏ผŒ็•ธๅฝข็š„๏ผŒๆฏๅฎน็š„๏ผŒๅฝขๆ€็•ธๅฝข็š„่‚ขไฝ“๏ผŒๆ‰‹ๆŒ‡่žๅˆ๏ผŒ้™ๆญขไธๅŠจ็š„็”ป้ข๏ผŒๆ‚ไนฑ็š„่ƒŒๆ™ฏ๏ผŒไธ‰ๆก่…ฟ๏ผŒ่ƒŒๆ™ฏไบบๅพˆๅคš๏ผŒๅ€’็€่ตฐ'
unconditional_dict = encoder(
text_prompts=[sample_neg_prompt]
)
return model, encoder, scheduler, unconditional_dict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--output_folder", type=str)
parser.add_argument("--caption_path", type=str)
parser.add_argument("--guidance_scale", type=float, default=6.0)
args = parser.parse_args()
# launch_distributed_job()
launch_distributed_job()
device = torch.cuda.current_device()
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model, encoder, scheduler, unconditional_dict = init_model(device=device)
dataset = TextDataset(args.caption_path)
# if global_rank == 0:
os.makedirs(args.output_folder, exist_ok=True)
for index in tqdm(range(int(math.ceil(len(dataset) / dist.get_world_size()))), disable=dist.get_rank() != 0):
prompt_index = index * dist.get_world_size() + dist.get_rank()
if prompt_index >= len(dataset):
continue
prompt = dataset[prompt_index]
conditional_dict = encoder(text_prompts=prompt)
latents = torch.randn(
[1, 21, 16, 60, 104], dtype=torch.float32, device=device
)
noisy_input = []
for progress_id, t in enumerate(tqdm(scheduler.timesteps)):
timestep = t * \
torch.ones([1, 21], device=device, dtype=torch.float32)
noisy_input.append(latents)
_, x0_pred_cond = model(
latents, conditional_dict, timestep
)
_, x0_pred_uncond = model(
latents, unconditional_dict, timestep
)
x0_pred = x0_pred_uncond + args.guidance_scale * (
x0_pred_cond - x0_pred_uncond
)
flow_pred = model._convert_x0_to_flow_pred(
scheduler=scheduler,
x0_pred=x0_pred.flatten(0, 1),
xt=latents.flatten(0, 1),
timestep=timestep.flatten(0, 1)
).unflatten(0, x0_pred.shape[:2])
latents = scheduler.step(
flow_pred.flatten(0, 1),
scheduler.timesteps[progress_id] * torch.ones(
[1, 21], device=device, dtype=torch.long).flatten(0, 1),
latents.flatten(0, 1)
).unflatten(dim=0, sizes=flow_pred.shape[:2])
noisy_input.append(latents)
noisy_inputs = torch.stack(noisy_input, dim=1)
noisy_inputs = noisy_inputs[:, [0, 12, 24, 36, -1]]
stored_data = noisy_inputs
torch.save(
{prompt: stored_data.cpu().detach()},
os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
)
dist.barrier()
if __name__ == "__main__":
main()