Wan2.2-Animate-ZEROGPU / generate.py
alex
further optimisation
76cd760
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.distributed.util import init_distributed_group
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import merge_video_audio, save_video, str2bool
EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"i2v-A14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
"ti2v-5B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"animate-14B": {
"prompt": "视频中的人在做动作",
"video": "",
"pose": "",
"mask": "",
},
"s2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
"audio":
"examples/talk.wav",
"tts_prompt_audio":
"examples/zero_shot_prompt.wav",
"tts_prompt_text":
"希望你以后能够做的比我还好呦。",
"tts_text":
"收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
args.image = EXAMPLE_PROMPT[args.task]["image"]
if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
args.audio = EXAMPLE_PROMPT[args.task]["audio"]
if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
if args.task == "i2v-A14B":
assert args.image is not None, "Please specify the image path for i2v."
cfg = WAN_CONFIGS[args.task]
if args.sample_steps is None:
args.sample_steps = cfg.sample_steps
if args.sample_shift is None:
args.sample_shift = cfg.sample_shift
if args.sample_guide_scale is None:
args.sample_guide_scale = cfg.sample_guide_scale
if args.frame_num is None:
args.frame_num = cfg.frame_num
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
if not 's2v' in args.task:
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
class _Args:
pass
def _parse_args():
args = _Args()
# core generation options
args.task = "animate-14B"
# args.size = "1280*720"
args.size = "720*1280"
args.frame_num = None
args.ckpt_dir = "./Wan2.2-Animate-14B/"
args.offload_model = False
args.ulysses_size = 1
args.t5_fsdp = False
args.t5_cpu = False
args.dit_fsdp = False
args.prompt = None
args.use_prompt_extend = False
args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
args.prompt_extend_model = None
args.prompt_extend_target_lang = "zh" # ["zh", "en"]
args.base_seed = 1234
args.image = None
args.sample_solver = "unipc" # ['unipc', 'dpm++']
args.sample_steps = None
args.sample_shift = None
args.sample_guide_scale = None
args.convert_model_dtype = True
# animate
args.refert_num = 1
# s2v-only
args.num_clip = None
args.audio = None
args.enable_tts = False
args.tts_prompt_audio = None
args.tts_prompt_text = None
args.tts_text = None
args.pose_video = None
args.start_from_ref = False
args.infer_frames = 80
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def load_model(use_relighting_lora = False):
cfg = WAN_CONFIGS["animate-14B"]
return wan.WanAnimate(
config=cfg,
checkpoint_dir="./Wan2.2-Animate-14B/",
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
convert_model_dtype=False,
use_relighting_lora=use_relighting_lora
)
def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
args = _parse_args()
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
cfg = WAN_CONFIGS[args.task]
logging.info(f"Input prompt: {args.prompt}")
img = None
if args.image is not None:
img = Image.open(args.image).convert("RGB")
logging.info(f"Input image: {args.image}")
print(f'rank:{rank}')
logging.info(f"Generating video ...")
video = wan_animate.generate(
src_root_path=preprocess_dir,
replace_flag=replace_flag,
refert_num = args.refert_num,
clip_len=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
save_video(
tensor=video[None],
save_file=save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
# if "s2v" in args.task:
# if args.enable_tts is False:
# merge_video_audio(video_path=args.save_file, audio_path=args.audio)
# else:
# merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
del video
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
logging.info("Finished.")