import spaces from huggingface_hub import snapshot_download, hf_hub_download import os import subprocess import importlib, site from PIL import Image import uuid import shutil import time import cv2 import json import gradio as gr import sys import gc BASE = os.path.dirname(os.path.abspath(__file__)) PREPROCESS_DIR = os.path.join(BASE, "wan", "modules", "animate", "preprocess") sys.path.append(PREPROCESS_DIR) # Re-discover all .pth/.egg-link files for sitedir in site.getsitepackages(): site.addsitedir(sitedir) # Clear caches so importlib will pick up new modules importlib.invalidate_caches() def sh(cmd): subprocess.check_call(cmd, shell=True) try: sh("pip install flash-attn --no-build-isolation") print("Attempting to download and build sam2...") print("download sam") sam_dir = snapshot_download(repo_id="alexnasa/sam2") @spaces.GPU(duration=500) def install_sam(): os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .") print("install sam") install_sam() # tell Python to re-scan site-packages now that the egg-link exists import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() print("sam2 installed successfully.") except Exception as e: raise gr.Error("sam2 installation failed") import torch from generate import generate, load_model from preprocess_data import run as run_preprocess from preprocess_data import load_preprocess_models print(f"Torch version: {torch.__version__}") os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results" snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B") wan_animate = load_model(True) rc_mapping = { "Video → Ref Image" : False, "Video ← Ref Image" : True } def preprocess_video(input_video_path, duration, session_id=None): if session_id is None: session_id = uuid.uuid4().hex output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(output_dir, exist_ok=True) process_video_path = os.path.join(output_dir, 'input_video.mp4') convert_video_to_30fps_and_clip(input_video_path, process_video_path, duration_s=duration) return process_video_path def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None): """ Extracts the audio track from a video file and saves it as a WAV file. Args: video_path (str): Path to the input video file. output_wav_path (str): Path to save the extracted WAV file. sample_rate (int, optional): Output sample rate (e.g., 16000). If None, keep the original. """ cmd = [ 'ffmpeg', '-i', video_path, # Input video '-vn', # Disable video '-acodec', 'pcm_s16le', # 16-bit PCM (WAV format) '-ac', '1', # Mono channel (use '2' for stereo) '-y', # Overwrite output '-loglevel', 'error' # Cleaner output ] # Only add the sample rate option if explicitly specified if sample_rate is not None: cmd.extend(['-ar', str(sample_rate)]) cmd.append(output_wav_path) try: subprocess.run(cmd, check=True, capture_output=True, text=True) return True except subprocess.CalledProcessError as e: return False def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path): """ Combines a silent MP4 video with a WAV audio file into a single MP4 with sound. Args: video_path (str): Path to the silent video file. audio_path (str): Path to the WAV audio file. output_video_path (str): Path to save the output MP4 with audio. """ cmd = [ 'ffmpeg', '-i', video_path, # Input video '-i', audio_path, # Input audio '-c:v', 'copy', # Copy video without re-encoding '-c:a', 'aac', # Encode audio as AAC (MP4-compatible) '-shortest', # Stop when the shortest stream ends '-y', # Overwrite output '-loglevel', 'error', output_video_path ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) except subprocess.CalledProcessError as e: raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}") def convert_video_to_30fps_and_clip(input_video_path, output_video_path, duration_s=2, target_fps=30): # Decide target box depending on orientation *as ffmpeg sees it*. # We'll just compute both and let expressions pick the right one. # If you truly want different targets by orientation, keep your is_portrait() and set these constants accordingly. # Build a crop expression that: # - never exceeds the input size # - keeps values even (required by yuv420p) # - stays centered crop_w_expr = "floor(min(in_w\,1280)/2)*2" crop_h_expr = "floor(min(in_h\,720)/2)*2" crop_x_expr = f"floor((in_w - {crop_w_expr})/2/2)*2" crop_y_expr = f"floor((in_h - {crop_h_expr})/2/2)*2" vf = ( f"crop={crop_w_expr}:{crop_h_expr}:{crop_x_expr}:{crop_y_expr}," f"fps={target_fps}" ) cmd = [ "ffmpeg", "-nostdin", "-hide_banner", "-y", "-i", input_video_path, "-t", str(duration_s), # Do crop and fps in one -vf so they see the same frame geometry "-vf", vf, # Make sure the output has even dims and a standard pix_fmt "-pix_fmt", "yuv420p", "-c:v", "libx264", "-c:a", "aac", output_video_path, ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) except subprocess.CalledProcessError as e: raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}") def is_portrait(video_file): # Get video information cap = cv2.VideoCapture(video_file) if not cap.isOpened(): error_msg = "Cannot open video file" gr.Warning(error_msg) orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() return orig_width < orig_height def calculate_time_required(max_duration_s, rc_bool): frames_count = 30 * max_duration_s chunks = frames_count // 77 + 1 if rc_bool: pose2d_tracking_duration_s = 75 iteration_per_step_s = 12 else: pose2d_tracking_duration_s = 65 iteration_per_step_s = 12 time_required = pose2d_tracking_duration_s + iteration_per_step_s * 5 * chunks print(f'for frames_count:{frames_count} doing {chunks} chunks the time_required is {time_required}') return time_required def update_time_required(max_duration_s, rc_str): rc_bool = rc_mapping[rc_str] duration_s = calculate_time_required(max_duration_s, rc_bool) duration_m = duration_s / 60 return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)") def get_duration(input_video, max_duration_s, edited_frame, rc_bool, session_id, progress): return calculate_time_required(max_duration_s, rc_bool) @spaces.GPU(duration=120) def _animate(input_video, max_duration_s, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),): if session_id is None: session_id = uuid.uuid4().hex output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(output_dir, exist_ok=True) preprocess_dir = os.path.join(output_dir, "preprocess_dir") os.makedirs(preprocess_dir, exist_ok=True) output_video_path = os.path.join(output_dir, 'result.mp4') # --- Measure preprocess time --- start_preprocess = time.time() if is_portrait(input_video): w = 480 h = 832 else: w = 832 h = 480 tag_string = "retarget_flag" if rc_bool: tag_string = "replace_flag" # sh("python ./wan/modules/animate/preprocess/preprocess_data.py " # "--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint " # f"--video_path {input_video} " # f"--refer_path {edited_frame} " # f"--save_path {preprocess_dir} " # f"--resolution_area {w} {h} --{tag_string} " # ) preprocess_model = load_preprocess_models() run_preprocess(preprocess_model, input_video, edited_frame, preprocess_dir, w, h, tag_string) preprocess_time = time.time() - start_preprocess print(f"Preprocess took {preprocess_time:.2f} seconds") # --- Measure generate time --- start_generate = time.time() generate(wan_animate, preprocess_dir, output_video_path, rc_bool) generate_time = time.time() - start_generate print(f"Generate took {generate_time:.2f} seconds") # --- Optional total time --- total_time = preprocess_time + generate_time print(f"Total time: {total_time:.2f} seconds") gc.collect() torch.cuda.empty_cache() return output_video_path def animate_scene(input_video, max_duration_s, edited_frame, rc_str, session_id = None, progress=gr.Progress(track_tqdm=True),): if not input_video: raise gr.Error("Please provide an video") if not edited_frame: raise gr.Error("Please provide an image") if session_id is None: session_id = uuid.uuid4().hex input_video = preprocess_video(input_video, max_duration_s, session_id) rc_bool = rc_mapping[rc_str] output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(output_dir, exist_ok=True) input_audio_path = os.path.join(output_dir, 'input_audio.wav') audio_extracted = extract_audio_from_video_ffmpeg(input_video, input_audio_path) edited_frame_png = os.path.join(output_dir, 'edited_frame.png') edited_frame_img = Image.open(edited_frame) edited_frame_img.save(edited_frame_png) output_video_path = _animate(input_video, max_duration_s, edited_frame_png, rc_bool, session_id, progress) final_video_path = os.path.join(output_dir, 'final_result.mp4') preprocess_dir = os.path.join(output_dir, "preprocess_dir") pose_video = os.path.join(preprocess_dir, 'src_pose.mp4') if rc_bool: mask_video = os.path.join(preprocess_dir, 'src_mask.mp4') bg_video = os.path.join(preprocess_dir, 'src_bg.mp4') face_video = os.path.join(preprocess_dir, 'src_face.mp4') else: mask_video = os.path.join(preprocess_dir, 'src_pose.mp4') bg_video = os.path.join(preprocess_dir, 'src_pose.mp4') face_video = os.path.join(preprocess_dir, 'src_pose.mp4') if audio_extracted: combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path) else: final_video_path = output_video_path print(f"task for {session_id} finalised") return final_video_path, pose_video, bg_video, mask_video, face_video css = """ #col-container { margin: 0 auto; max-width: 1600px; } #step-column { padding: 20px; border-radius: 8px; box-shadow: var(--card-shadow); margin: 10px; } #col-showcase { margin: 0 auto; max-width: 1100px; } .button-gradient { background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%; border: none; padding: 14px 28px; font-size: 16px; font-weight: bold; color: white; border-radius: 10px; cursor: pointer; transition: 0.3s ease-in-out; animation: 2s linear 0s infinite normal none running gradientAnimation; box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px; } .toggle-container { display: inline-flex; background-color: #ffd6ff; /* light pink background */ border-radius: 9999px; padding: 4px; position: relative; width: fit-content; font-family: sans-serif; } .toggle-container input[type="radio"] { display: none; } .toggle-container label { position: relative; z-index: 2; flex: 1; text-align: center; font-weight: 700; color: #4b2ab5; /* dark purple text for unselected */ padding: 6px 22px; border-radius: 9999px; cursor: pointer; transition: color 0.25s ease; } /* Moving highlight */ .toggle-highlight { position: absolute; top: 4px; left: 4px; width: calc(50% - 4px); height: calc(100% - 8px); background-color: #4b2ab5; /* dark purple background */ border-radius: 9999px; transition: transform 0.25s ease; z-index: 1; } /* When "True" is checked */ #true:checked ~ label[for="true"] { color: #ffd6ff; /* light pink text */ } /* When "False" is checked */ #false:checked ~ label[for="false"] { color: #ffd6ff; /* light pink text */ } /* Move highlight to right side when False is checked */ #false:checked ~ .toggle-highlight { transform: translateX(100%); } """ def start_session(request: gr.Request): return request.session_hash def cleanup(request: gr.Request): sid = request.session_hash if sid: print(f"{sid} left") d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid) shutil.rmtree(d1, ignore_errors=True) with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo: session_state = gr.State() demo.load(start_session, outputs=[session_state]) with gr.Column(elem_id="col-container"): with gr.Row(): gr.HTML( """

Wan2.2-Animate-14B

[Model]

HF Space By:

""" ) with gr.Row(): with gr.Column(elem_id="step-column"): gr.HTML("""
1. Upload a Video
""") input_video = gr.Video(label="Input Video", height=512) max_duration_slider = gr.Slider(2, 8, 2, step=2, label="Max Duration", visible=False) gr.Examples( examples=[ [ "./examples/test_example.mp4", ], ], inputs=[input_video], cache_examples=False, ) with gr.Column(elem_id="step-column"): gr.HTML("""
2. Upload a Ref Image
""") edited_frame = gr.Image(label="Ref Image", type="filepath", height=512) replace_character_string = gr.Radio( ["Video → Ref Image", "Video ← Ref Image"], value="Video ← Ref Image", show_label=False ) gr.Examples( examples=[ [ "./examples/ali.png", ], [ "./examples/amber.png", ], [ "./examples/ella.png", ], [ "./examples/sydney.png", ], ], inputs=[edited_frame], cache_examples=False, ) with gr.Column(elem_id="step-column"): gr.HTML("""
3. Wan Animate it!
""") output_video = gr.Video(label="Edited Video", height=512) time_required = gr.Text(value="⌚ Zero GPU Required: ~195.0s (3.2 mins)", show_label=False, visible=False) action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient") with gr.Accordion("Preprocessed Data", open=False, visible=True): with gr.Row(): pose_video = gr.Video(label="Pose Video") bg_video = gr.Video(label="Background Video") face_video = gr.Video(label="Face Video") mask_video = gr.Video(label="Mask Video") with gr.Row(): with gr.Column(elem_id="col-showcase"): gr.Examples( examples=[ [ "./examples/okay.mp4", 2, "./examples/amber.png", "Video ← Ref Image" ], [ "./examples/superman.mp4", 2, "./examples/superman.png", "Video ← Ref Image" ], [ "./examples/test_example.mp4", 2, "./examples/ella.png", "Video ← Ref Image" ], [ "./examples/paul.mp4", 2, "./examples/man.png", "Video → Ref Image" ], [ "./examples/desi.mp4", 2, "./examples/desi.png", "Video ← Ref Image" ], ], inputs=[input_video, max_duration_slider, edited_frame, replace_character_string], outputs=[output_video, pose_video, bg_video, mask_video, face_video], fn=animate_scene, cache_examples=True, ) action_button.click(fn=animate_scene, inputs=[input_video, max_duration_slider, edited_frame, replace_character_string, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video]) max_duration_slider.change(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required]) replace_character_string.change(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required]) if __name__ == "__main__": demo.queue() demo.unload(cleanup) demo.launch(ssr_mode=False, share=True)