# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import logging import os import sys import warnings from datetime import datetime warnings.filterwarnings('ignore') import random import torch import torch.distributed as dist from PIL import Image import wan from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.distributed.util import init_distributed_group from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import merge_video_audio, save_video, str2bool EXAMPLE_PROMPT = { "t2v-A14B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "i2v-A14B": { "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", "image": "examples/i2v_input.JPG", }, "ti2v-5B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "animate-14B": { "prompt": "视频中的人在做动作", "video": "", "pose": "", "mask": "", }, "s2v-14B": { "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", "image": "examples/i2v_input.JPG", "audio": "examples/talk.wav", "tts_prompt_audio": "examples/zero_shot_prompt.wav", "tts_prompt_text": "希望你以后能够做的比我还好呦。", "tts_text": "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。" }, } def _validate_args(args): # Basic check assert args.ckpt_dir is not None, "Please specify the checkpoint directory." assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if args.image is None and "image" in EXAMPLE_PROMPT[args.task]: args.image = EXAMPLE_PROMPT[args.task]["image"] if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]: args.audio = EXAMPLE_PROMPT[args.task]["audio"] if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]: args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"] args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"] args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"] if args.task == "i2v-A14B": assert args.image is not None, "Please specify the image path for i2v." cfg = WAN_CONFIGS[args.task] if args.sample_steps is None: args.sample_steps = cfg.sample_steps if args.sample_shift is None: args.sample_shift = cfg.sample_shift if args.sample_guide_scale is None: args.sample_guide_scale = cfg.sample_guide_scale if args.frame_num is None: args.frame_num = cfg.frame_num args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( 0, sys.maxsize) # Size check if not 's2v' in args.task: assert args.size in SUPPORTED_SIZES[ args. task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" class _Args: pass def _parse_args(): args = _Args() # core generation options args.task = "animate-14B" # args.size = "1280*720" args.size = "720*1280" args.frame_num = None args.ckpt_dir = "./Wan2.2-Animate-14B/" args.offload_model = False args.ulysses_size = 1 args.t5_fsdp = False args.t5_cpu = False args.dit_fsdp = False args.prompt = None args.use_prompt_extend = False args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"] args.prompt_extend_model = None args.prompt_extend_target_lang = "zh" # ["zh", "en"] args.base_seed = 1234 args.image = None args.sample_solver = "unipc" # ['unipc', 'dpm++'] args.sample_steps = None args.sample_shift = None args.sample_guide_scale = None args.convert_model_dtype = True # animate args.refert_num = 1 # s2v-only args.num_clip = None args.audio = None args.enable_tts = False args.tts_prompt_audio = None args.tts_prompt_text = None args.tts_text = None args.pose_video = None args.start_from_ref = False args.infer_frames = 80 _validate_args(args) return args def _init_logging(rank): # logging if rank == 0: # set format logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)]) else: logging.basicConfig(level=logging.ERROR) def load_model(use_relighting_lora = False): cfg = WAN_CONFIGS["animate-14B"] return wan.WanAnimate( config=cfg, checkpoint_dir="./Wan2.2-Animate-14B/", device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_sp=False, t5_cpu=False, convert_model_dtype=False, use_relighting_lora=use_relighting_lora ) def generate(wan_animate, preprocess_dir, save_file, replace_flag = False): args = _parse_args() rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) local_rank = int(os.getenv("LOCAL_RANK", 0)) device = local_rank _init_logging(rank) cfg = WAN_CONFIGS[args.task] logging.info(f"Input prompt: {args.prompt}") img = None if args.image is not None: img = Image.open(args.image).convert("RGB") logging.info(f"Input image: {args.image}") print(f'rank:{rank}') logging.info(f"Generating video ...") video = wan_animate.generate( src_root_path=preprocess_dir, replace_flag=replace_flag, refert_num = args.refert_num, clip_len=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) if rank == 0: save_video( tensor=video[None], save_file=save_file, fps=cfg.sample_fps, nrow=1, normalize=True, value_range=(-1, 1)) # if "s2v" in args.task: # if args.enable_tts is False: # merge_video_audio(video_path=args.save_file, audio_path=args.audio) # else: # merge_video_audio(video_path=args.save_file, audio_path="tts.wav") del video torch.cuda.synchronize() if dist.is_initialized(): dist.barrier() dist.destroy_process_group() logging.info("Finished.")