File size: 9,081 Bytes
74803f2 c6f9117 3ce6ce1 c6f9117 3ce6ce1 46b2b6a 3ce6ce1 46b2b6a 3ce6ce1 46b2b6a 3ce6ce1 74803f2 3ce6ce1 79f7542 74803f2 3ce6ce1 74803f2 3ce6ce1 71b7d34 3ce6ce1 71b7d34 7a28c31 71b7d34 74803f2 71b7d34 74803f2 71b7d34 7b5f169 3ce6ce1 7a28c31 3ce6ce1 7b5f169 74803f2 7b5f169 3ce6ce1 74803f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
import types
import random
import spaces
import os
import torch
import numpy as np
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download
import gradio as gr
import tempfile
from huggingface_hub import hf_hub_download
from src.pipeline_wan_nag import NAGWanPipeline
from src.transformer_wan_nag import NagWanTransformer3DModel
# Dummy constants (replace with actual model values)
MOD_VALUE = 32
DEFAULT_DURATION_SECONDS = 4
DEFAULT_STEPS = 4
DEFAULT_SEED = 2025
DEFAULT_H_SLIDER_VALUE = 480
DEFAULT_W_SLIDER_VALUE = 832
NEW_FORMULA_MAX_AREA = 480.0 * 832.0
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
pipe = NAGWanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
pipe.to("cuda")
# Patch transformer methods
pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward
# --- Predefined LoRAs ---
AVAILABLE_LORAS = [
{
"label": "CausVid LoRA",
"repo_id": "Kijai/WanVideo_comfy",
"filename": "Wan21_CausVid_14B_T2V_lora_rank32.safetensors",
"adapter_name": "causvid_lora",
"default_weight": 0.95,
"scale_blocks": ["blocks.0"],
},
{
"label": "Detail Enhancer V1",
"repo_id": "vrgamedevgirl84/Wan14BT2VFusioniX",
"filename": "OtherLoRa's/DetailEnhancerV1.safetensors",
"adapter_name": "mps_lora",
"default_weight": 0.7
}
]
def load_loras_from_ui(selected_labels, weights, custom_repo, custom_file, custom_weight):
lora_adapters = []
lora_weights = []
selected_configs = []
for i, label in enumerate(selected_labels):
lora = next((l for l in AVAILABLE_LORAS if l["label"] == label), None)
if lora:
config = lora.copy()
config["weight"] = weights[i]
selected_configs.append(config)
# if custom_repo and custom_file:
# adapter_name = os.path.splitext(os.path.basename(custom_file))[0]
# selected_configs.append({
# "repo_id": custom_repo,
# "filename": custom_file,
# "adapter_name": adapter_name,
# "weight": float(custom_weight),
# })
for config in selected_configs:
snapshot_path = snapshot_download(
repo_id=config["repo_id"],
allow_patterns=[config["filename"]],
repo_type="model"
)
lora_path = os.path.join(snapshot_path, config["filename"])
pipe.load_lora_weights(lora_path, adapter_name=config["adapter_name"])
if config.get("scale_blocks"):
for name, param in pipe.transformer.named_parameters():
if "lora_B" in name and any(b in name for b in config["scale_blocks"]):
param.data *= 0.25
lora_adapters.append(config["adapter_name"])
lora_weights.append(config["weight"])
if lora_adapters:
pipe.set_adapters(lora_adapters, adapter_weights=lora_weights)
pipe.fuse_lora()
print(f"✅ Fused LoRAs: {lora_adapters}")
# def get_duration(
# prompt,
# nag_negative_prompt, nag_scale,
# height, width, duration_seconds,
# steps,
# seed, randomize_seed,
# compare,
# ):
# duration = int(duration_seconds) * int(steps) * 2.25 + 5
# if compare:
# duration *= 2
# return duration
@spaces.GPU(duration=200)
def generate_video(prompt, nag_negative_prompt, nag_scale,
height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
steps=DEFAULT_STEPS, seed=DEFAULT_SEED, randomize_seed=False, compare=True):
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
with torch.inference_mode():
nag_output_frames_list = pipe(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt,
nag_scale=nag_scale,
nag_tau=3.5,
nag_alpha=0.5,
height=target_h, width=target_w, num_frames=num_frames,
guidance_scale=0.,
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
nag_video_path = tmpfile.name
export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)
if compare:
baseline_output_frames_list = pipe(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt,
height=target_h, width=target_w, num_frames=num_frames,
guidance_scale=0.,
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
baseline_video_path = tmpfile.name
export_to_video(baseline_output_frames_list, baseline_video_path, fps=FIXED_FPS)
else:
baseline_video_path = None
return nag_video_path, baseline_video_path, current_seed
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# Wan2.1-T2V-14B + NAG + LoRA Control")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
nag_negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NAG_NEGATIVE_PROMPT)
nag_scale = gr.Slider(1., 20., value=11., step=0.25, label="NAG Scale")
compare = gr.Checkbox(label="Compare with baseline", value=True)
with gr.Accordion("Advanced", open=False):
steps_slider = gr.Slider(1, 8, value=DEFAULT_STEPS, label="Inference Steps")
duration_seconds_input = gr.Slider(1, 5, value=DEFAULT_DURATION_SECONDS, label="Duration (seconds)")
seed_input = gr.Slider(0, MAX_SEED, step=1, value=DEFAULT_SEED, label="Seed")
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
height_input = gr.Slider(SLIDER_MIN_H, SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label="Height")
width_input = gr.Slider(SLIDER_MIN_W, SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label="Width")
with gr.Accordion("LoRA Settings", open=False):
lora_selector = gr.CheckboxGroup([l["label"] for l in AVAILABLE_LORAS], label="Select Predefined LoRAs")
lora_sliders = [
gr.Slider(0.0, 1.5, value=l["default_weight"], label=f"{l['label']} Weight")
for l in AVAILABLE_LORAS
]
# custom_repo = gr.Textbox(label="Custom Repo ID (optional)", placeholder="e.g. my-user/my-repo")
# custom_file = gr.Textbox(label="Custom Filename (optional)", placeholder="e.g. my_model.safetensors")
# custom_weight = gr.Slider(0.0, 1.5, value=1.0, label="Custom Weight")
generate_button = gr.Button("Generate Video")
with gr.Column():
nag_video_output = gr.Video(label="Video with NAG")
baseline_video_output = gr.Video(label="Baseline Video")
def generate_wrapper(*args):
selected_labels = args[-5]
if not isinstance(selected_labels, list):
selected_labels = [] # Ensure it's iterable even if empty or NaN
lora_weights = args[-4:-4 + len(AVAILABLE_LORAS)]
if selected_labels:
load_loras_from_ui(selected_labels, lora_weights)
return generate_video(*args[:-5])
inputs = [
prompt,
nag_negative_prompt, nag_scale,
height_input, width_input, duration_seconds_input,
steps_slider, seed_input, randomize_seed_checkbox, compare,
lora_selector # ✅ CheckboxGroup - must be BEFORE sliders
] + lora_sliders
generate_button.click(
fn=generate_wrapper,
inputs=inputs,
outputs=[nag_video_output, baseline_video_output, seed_input],
)
if __name__ == "__main__":
demo.queue().launch()
|