Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import argparse | |
| import logging | |
| import os | |
| import sys | |
| import warnings | |
| from datetime import datetime | |
| warnings.filterwarnings('ignore') | |
| import random | |
| import torch | |
| import torch.distributed as dist | |
| from PIL import Image | |
| import wan | |
| from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS | |
| from wan.distributed.util import init_distributed_group | |
| from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander | |
| from wan.utils.utils import merge_video_audio, save_video, str2bool | |
| EXAMPLE_PROMPT = { | |
| "t2v-A14B": { | |
| "prompt": | |
| "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", | |
| }, | |
| "i2v-A14B": { | |
| "prompt": | |
| "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", | |
| "image": | |
| "examples/i2v_input.JPG", | |
| }, | |
| "ti2v-5B": { | |
| "prompt": | |
| "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", | |
| }, | |
| "animate-14B": { | |
| "prompt": "视频中的人在做动作", | |
| "video": "", | |
| "pose": "", | |
| "mask": "", | |
| }, | |
| "s2v-14B": { | |
| "prompt": | |
| "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", | |
| "image": | |
| "examples/i2v_input.JPG", | |
| "audio": | |
| "examples/talk.wav", | |
| "tts_prompt_audio": | |
| "examples/zero_shot_prompt.wav", | |
| "tts_prompt_text": | |
| "希望你以后能够做的比我还好呦。", | |
| "tts_text": | |
| "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。" | |
| }, | |
| } | |
| def _validate_args(args): | |
| # Basic check | |
| assert args.ckpt_dir is not None, "Please specify the checkpoint directory." | |
| assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" | |
| assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" | |
| if args.prompt is None: | |
| args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] | |
| if args.image is None and "image" in EXAMPLE_PROMPT[args.task]: | |
| args.image = EXAMPLE_PROMPT[args.task]["image"] | |
| if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]: | |
| args.audio = EXAMPLE_PROMPT[args.task]["audio"] | |
| if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]: | |
| args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"] | |
| args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"] | |
| args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"] | |
| if args.task == "i2v-A14B": | |
| assert args.image is not None, "Please specify the image path for i2v." | |
| cfg = WAN_CONFIGS[args.task] | |
| if args.sample_steps is None: | |
| args.sample_steps = cfg.sample_steps | |
| if args.sample_shift is None: | |
| args.sample_shift = cfg.sample_shift | |
| if args.sample_guide_scale is None: | |
| args.sample_guide_scale = cfg.sample_guide_scale | |
| if args.frame_num is None: | |
| args.frame_num = cfg.frame_num | |
| args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( | |
| 0, sys.maxsize) | |
| # Size check | |
| if not 's2v' in args.task: | |
| assert args.size in SUPPORTED_SIZES[ | |
| args. | |
| task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" | |
| class _Args: | |
| pass | |
| def _parse_args(): | |
| args = _Args() | |
| # core generation options | |
| args.task = "animate-14B" | |
| # args.size = "1280*720" | |
| args.size = "720*1280" | |
| args.frame_num = None | |
| args.ckpt_dir = "./Wan2.2-Animate-14B/" | |
| args.offload_model = False | |
| args.ulysses_size = 1 | |
| args.t5_fsdp = False | |
| args.t5_cpu = False | |
| args.dit_fsdp = False | |
| args.prompt = None | |
| args.use_prompt_extend = False | |
| args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"] | |
| args.prompt_extend_model = None | |
| args.prompt_extend_target_lang = "zh" # ["zh", "en"] | |
| args.base_seed = 1234 | |
| args.image = None | |
| args.sample_solver = "unipc" # ['unipc', 'dpm++'] | |
| args.sample_steps = None | |
| args.sample_shift = None | |
| args.sample_guide_scale = None | |
| args.convert_model_dtype = True | |
| # animate | |
| args.refert_num = 1 | |
| # s2v-only | |
| args.num_clip = None | |
| args.audio = None | |
| args.enable_tts = False | |
| args.tts_prompt_audio = None | |
| args.tts_prompt_text = None | |
| args.tts_text = None | |
| args.pose_video = None | |
| args.start_from_ref = False | |
| args.infer_frames = 80 | |
| _validate_args(args) | |
| return args | |
| def _init_logging(rank): | |
| # logging | |
| if rank == 0: | |
| # set format | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="[%(asctime)s] %(levelname)s: %(message)s", | |
| handlers=[logging.StreamHandler(stream=sys.stdout)]) | |
| else: | |
| logging.basicConfig(level=logging.ERROR) | |
| def load_model(use_relighting_lora = False): | |
| cfg = WAN_CONFIGS["animate-14B"] | |
| return wan.WanAnimate( | |
| config=cfg, | |
| checkpoint_dir="./Wan2.2-Animate-14B/", | |
| device_id=0, | |
| rank=0, | |
| t5_fsdp=False, | |
| dit_fsdp=False, | |
| use_sp=False, | |
| t5_cpu=False, | |
| convert_model_dtype=False, | |
| use_relighting_lora=use_relighting_lora | |
| ) | |
| def generate(wan_animate, preprocess_dir, save_file, replace_flag = False): | |
| args = _parse_args() | |
| rank = int(os.getenv("RANK", 0)) | |
| world_size = int(os.getenv("WORLD_SIZE", 1)) | |
| local_rank = int(os.getenv("LOCAL_RANK", 0)) | |
| device = local_rank | |
| _init_logging(rank) | |
| cfg = WAN_CONFIGS[args.task] | |
| logging.info(f"Input prompt: {args.prompt}") | |
| img = None | |
| if args.image is not None: | |
| img = Image.open(args.image).convert("RGB") | |
| logging.info(f"Input image: {args.image}") | |
| print(f'rank:{rank}') | |
| logging.info(f"Generating video ...") | |
| video = wan_animate.generate( | |
| src_root_path=preprocess_dir, | |
| replace_flag=replace_flag, | |
| refert_num = args.refert_num, | |
| clip_len=args.frame_num, | |
| shift=args.sample_shift, | |
| sample_solver=args.sample_solver, | |
| sampling_steps=args.sample_steps, | |
| guide_scale=args.sample_guide_scale, | |
| seed=args.base_seed, | |
| offload_model=args.offload_model) | |
| if rank == 0: | |
| save_video( | |
| tensor=video[None], | |
| save_file=save_file, | |
| fps=cfg.sample_fps, | |
| nrow=1, | |
| normalize=True, | |
| value_range=(-1, 1)) | |
| # if "s2v" in args.task: | |
| # if args.enable_tts is False: | |
| # merge_video_audio(video_path=args.save_file, audio_path=args.audio) | |
| # else: | |
| # merge_video_audio(video_path=args.save_file, audio_path="tts.wav") | |
| del video | |
| torch.cuda.synchronize() | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| logging.info("Finished.") | |