rahul7star commited on
Commit
c84c1eb
·
verified ·
1 Parent(s): b440049

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -470
app.py CHANGED
@@ -10,340 +10,40 @@ import numpy as np
10
  from PIL import Image
11
  import random
12
 
13
- # Base MODEL_ID (using original Wan model that's compatible with diffusers)
14
- #MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
15
- MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-720P"
16
 
17
- # FusionX enhancement LoRAs (based on FusionX composition)
18
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
19
- LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
20
 
21
-
22
- # Additional enhancement LoRAs for FusionX-like quality
23
- # ACCVIDEO_LORA_REPO = "alibaba-pai/Wan2.1-Fun-Reward-LoRAs"
24
- # MPS_LORA_FILENAME = "Wan2.1-MPS-Reward-LoRA.safetensors"
25
-
26
- # Load enhanced model components
27
- print("🚀 Loading FusionX Enhanced Wan2.1 I2V Model...")
28
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
29
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
30
  pipe = WanImageToVideoPipeline.from_pretrained(
31
  MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
32
  )
33
-
34
- # FusionX optimized scheduler settings
35
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
36
  pipe.to("cuda")
37
 
38
- # Load FusionX enhancement LoRAs
39
- from huggingface_hub import hf_hub_download
40
- import torch
41
-
42
- from huggingface_hub import hf_hub_download
43
- import torch
44
-
45
- # Load FusionX enhancement LoRAs
46
- lora_adapters = []
47
- lora_weights = []
48
-
49
- # Print all named parameters (safely) from any pipeline
50
- def print_named_params(module, module_name=""):
51
- print(f"\n🔍 Parameters in {module_name or 'pipeline'}:")
52
- for name, param in module.named_parameters():
53
- print(f"{name}: {param.shape}")
54
-
55
- # Try printing known submodules in the pipeline
56
- def print_all_pipeline_keys(pipe):
57
- print("🧠 Exploring pipeline structure:")
58
- for attr in dir(pipe):
59
- if not attr.startswith("_"):
60
- try:
61
- obj = getattr(pipe, attr)
62
- if isinstance(obj, torch.nn.Module):
63
- print_named_params(obj, attr)
64
- except Exception as e:
65
- print(f"⚠️ Could not inspect {attr}: {e}")
66
-
67
- # Print LoRA file contents
68
- def print_lora_checkpoint_keys(path):
69
- try:
70
- lora_checkpoint = torch.load(path, map_location="cpu")
71
- print("\n📦 LoRA Checkpoint Keys:")
72
- for k, v in lora_checkpoint.items():
73
- print(f"{k}: {v.shape}")
74
- except Exception as e:
75
- print(f"❌ Failed to load LoRA for inspection: {e}")
76
-
77
- # Step 1: Explore the pipeline model structure
78
- print_all_pipeline_keys(pipe)
79
-
80
- # Step 2: Download LoRA file
81
- lora_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
82
-
83
- # Step 3: Print LoRA checkpoint keys
84
- print_lora_checkpoint_keys(lora_path)
85
-
86
- # Step 4: Load and apply the LoRA
87
- try:
88
- pipe.load_lora_weights(lora_path, adapter_name="main")
89
- pipe.set_adapters(["main"], adapter_weights=[1.0])
90
- pipe.fuse_lora()
91
- print("✅ LoRA applied successfully.")
92
- except Exception as e:
93
- print(f"❌ Failed to load LoRA: {e}")
94
-
95
-
96
-
97
-
98
-
99
 
100
  MOD_VALUE = 32
101
- DEFAULT_H_SLIDER_VALUE = 576 # FusionX optimized default
102
- DEFAULT_W_SLIDER_VALUE = 1024 # FusionX optimized default
103
- NEW_FORMULA_MAX_AREA = 576.0 * 1024.0 # Updated for FusionX
104
 
105
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 1080
106
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 1920
107
  MAX_SEED = np.iinfo(np.int32).max
108
 
109
  FIXED_FPS = 24
110
  MIN_FRAMES_MODEL = 8
111
- MAX_FRAMES_MODEL = 121 # FusionX supports up to 121 frames
112
-
113
- # Enhanced prompts for FusionX-style output
114
- default_prompt_i2v = "Cinematic motion, smooth animation, detailed textures, dynamic lighting, professional cinematography"
115
- default_negative_prompt = "Static image, no motion, blurred details, overexposed, underexposed, low quality, worst quality, JPEG artifacts, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, watermark, text, signature, three legs, many people in the background, walking backwards"
116
-
117
- # Enhanced CSS for FusionX theme
118
- custom_css = """
119
- /* Enhanced FusionX theme with cinematic styling */
120
- .gradio-container {
121
- font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important;
122
- background: linear-gradient(135deg, #1a1a2e 0%, #16213e 25%, #0f3460 50%, #533a7d 75%, #6a4c93 100%) !important;
123
- background-size: 400% 400% !important;
124
- animation: cinematicShift 20s ease infinite !important;
125
- }
126
-
127
- @keyframes cinematicShift {
128
- 0% { background-position: 0% 50%; }
129
- 25% { background-position: 100% 50%; }
130
- 50% { background-position: 100% 100%; }
131
- 75% { background-position: 0% 100%; }
132
- 100% { background-position: 0% 50%; }
133
- }
134
-
135
- /* Main container with cinematic glass effect */
136
- .main-container {
137
- backdrop-filter: blur(15px);
138
- background: rgba(255, 255, 255, 0.08) !important;
139
- border-radius: 25px !important;
140
- padding: 35px !important;
141
- box-shadow: 0 12px 40px 0 rgba(31, 38, 135, 0.4) !important;
142
- border: 1px solid rgba(255, 255, 255, 0.15) !important;
143
- position: relative;
144
- overflow: hidden;
145
- }
146
-
147
- .main-container::before {
148
- content: '';
149
- position: absolute;
150
- top: 0;
151
- left: 0;
152
- right: 0;
153
- bottom: 0;
154
- background: linear-gradient(45deg, rgba(255,255,255,0.1) 0%, transparent 50%, rgba(255,255,255,0.05) 100%);
155
- pointer-events: none;
156
- }
157
-
158
- /* Enhanced header with FusionX branding */
159
- h1 {
160
- background: linear-gradient(45deg, #ffffff, #f0f8ff, #e6e6fa) !important;
161
- -webkit-background-clip: text !important;
162
- -webkit-text-fill-color: transparent !important;
163
- background-clip: text !important;
164
- font-weight: 900 !important;
165
- font-size: 2.8rem !important;
166
- text-align: center !important;
167
- margin-bottom: 2.5rem !important;
168
- text-shadow: 2px 2px 8px rgba(0,0,0,0.3) !important;
169
- position: relative;
170
- }
171
-
172
- h1::after {
173
- content: '🎬 FusionX Enhanced';
174
- display: block;
175
- font-size: 1rem;
176
- color: #6a4c93;
177
- margin-top: 0.5rem;
178
- font-weight: 500;
179
- }
180
 
181
- /* Enhanced component containers */
182
- .input-container, .output-container {
183
- background: rgba(255, 255, 255, 0.06) !important;
184
- border-radius: 20px !important;
185
- padding: 25px !important;
186
- margin: 15px 0 !important;
187
- backdrop-filter: blur(10px) !important;
188
- border: 1px solid rgba(255, 255, 255, 0.12) !important;
189
- box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.1) !important;
190
- }
191
 
192
- /* Cinematic input styling */
193
- input, textarea, .gr-box {
194
- background: rgba(255, 255, 255, 0.95) !important;
195
- border: 1px solid rgba(106, 76, 147, 0.3) !important;
196
- border-radius: 12px !important;
197
- color: #1a1a2e !important;
198
- transition: all 0.4s ease !important;
199
- box-shadow: 0 2px 8px rgba(106, 76, 147, 0.1) !important;
200
- }
201
-
202
- input:focus, textarea:focus {
203
- background: rgba(255, 255, 255, 1) !important;
204
- border-color: #6a4c93 !important;
205
- box-shadow: 0 0 0 3px rgba(106, 76, 147, 0.15) !important;
206
- transform: translateY(-1px) !important;
207
- }
208
-
209
- /* Enhanced FusionX button */
210
- .generate-btn {
211
- background: linear-gradient(135deg, #6a4c93 0%, #533a7d 50%, #0f3460 100%) !important;
212
- color: white !important;
213
- font-weight: 700 !important;
214
- font-size: 1.2rem !important;
215
- padding: 15px 40px !important;
216
- border-radius: 60px !important;
217
- border: none !important;
218
- cursor: pointer !important;
219
- transition: all 0.4s ease !important;
220
- box-shadow: 0 6px 20px rgba(106, 76, 147, 0.4) !important;
221
- position: relative;
222
- overflow: hidden;
223
- }
224
-
225
- .generate-btn::before {
226
- content: '';
227
- position: absolute;
228
- top: 0;
229
- left: -100%;
230
- width: 100%;
231
- height: 100%;
232
- background: linear-gradient(90deg, transparent, rgba(255,255,255,0.3), transparent);
233
- transition: left 0.5s ease;
234
- }
235
-
236
- .generate-btn:hover::before {
237
- left: 100%;
238
- }
239
-
240
- .generate-btn:hover {
241
- transform: translateY(-3px) scale(1.02) !important;
242
- box-shadow: 0 8px 25px rgba(106, 76, 147, 0.6) !important;
243
- }
244
-
245
- /* Enhanced slider styling */
246
- input[type="range"] {
247
- background: transparent !important;
248
- }
249
-
250
- input[type="range"]::-webkit-slider-track {
251
- background: linear-gradient(90deg, rgba(106, 76, 147, 0.3), rgba(83, 58, 125, 0.5)) !important;
252
- border-radius: 8px !important;
253
- height: 8px !important;
254
- }
255
-
256
- input[type="range"]::-webkit-slider-thumb {
257
- background: linear-gradient(135deg, #6a4c93, #533a7d) !important;
258
- border: 3px solid white !important;
259
- border-radius: 50% !important;
260
- cursor: pointer !important;
261
- width: 22px !important;
262
- height: 22px !important;
263
- -webkit-appearance: none !important;
264
- box-shadow: 0 2px 8px rgba(106, 76, 147, 0.3) !important;
265
- }
266
-
267
- /* Enhanced accordion */
268
- .gr-accordion {
269
- background: rgba(255, 255, 255, 0.04) !important;
270
- border-radius: 15px !important;
271
- border: 1px solid rgba(255, 255, 255, 0.08) !important;
272
- margin: 20px 0 !important;
273
- backdrop-filter: blur(5px) !important;
274
- }
275
-
276
- /* Enhanced labels */
277
- label {
278
- color: #ffffff !important;
279
- font-weight: 600 !important;
280
- font-size: 1rem !important;
281
- margin-bottom: 8px !important;
282
- text-shadow: 1px 1px 2px rgba(0,0,0,0.5) !important;
283
- }
284
-
285
- /* Enhanced image upload */
286
- .image-upload {
287
- border: 3px dashed rgba(106, 76, 147, 0.4) !important;
288
- border-radius: 20px !important;
289
- background: rgba(255, 255, 255, 0.03) !important;
290
- transition: all 0.4s ease !important;
291
- position: relative;
292
- }
293
-
294
- .image-upload:hover {
295
- border-color: rgba(106, 76, 147, 0.7) !important;
296
- background: rgba(255, 255, 255, 0.08) !important;
297
- transform: scale(1.01) !important;
298
- }
299
-
300
- /* Enhanced video output */
301
- video {
302
- border-radius: 20px !important;
303
- box-shadow: 0 8px 30px rgba(0, 0, 0, 0.4) !important;
304
- border: 2px solid rgba(106, 76, 147, 0.3) !important;
305
- }
306
-
307
- /* Enhanced examples section */
308
- .gr-examples {
309
- background: rgba(255, 255, 255, 0.04) !important;
310
- border-radius: 20px !important;
311
- padding: 25px !important;
312
- margin-top: 25px !important;
313
- border: 1px solid rgba(255, 255, 255, 0.1) !important;
314
- }
315
-
316
- /* Enhanced checkbox */
317
- input[type="checkbox"] {
318
- accent-color: #6a4c93 !important;
319
- transform: scale(1.2) !important;
320
- }
321
-
322
- /* Responsive enhancements */
323
- @media (max-width: 768px) {
324
- h1 { font-size: 2.2rem !important; }
325
- .main-container { padding: 25px !important; }
326
- .generate-btn { padding: 12px 30px !important; font-size: 1.1rem !important; }
327
- }
328
-
329
- /* Badge container styling */
330
- .badge-container {
331
- display: flex;
332
- justify-content: center;
333
- gap: 15px;
334
- margin: 20px 0;
335
- flex-wrap: wrap;
336
- }
337
-
338
- .badge-container img {
339
- border-radius: 8px;
340
- transition: transform 0.3s ease;
341
- }
342
-
343
- .badge-container img:hover {
344
- transform: scale(1.05);
345
- }
346
- """
347
 
348
  def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
349
  min_slider_h, max_slider_h,
@@ -385,19 +85,18 @@ def get_duration(input_image, prompt, height, width,
385
  guidance_scale, steps,
386
  seed, randomize_seed,
387
  progress):
388
- # FusionX optimized duration calculation
389
- if steps > 8 and duration_seconds > 3:
390
- return 100
391
- elif steps > 8 or duration_seconds > 3:
392
- return 80
393
  else:
394
- return 65
395
 
396
  @spaces.GPU(duration=get_duration)
397
  def generate_video(input_image, prompt, height, width,
398
- negative_prompt=default_negative_prompt, duration_seconds=3,
399
- guidance_scale=1, steps=8, # FusionX optimized default
400
- seed=42, randomize_seed=False,
401
  progress=gr.Progress(track_tqdm=True)):
402
 
403
  if input_image is None:
@@ -412,19 +111,11 @@ def generate_video(input_image, prompt, height, width,
412
 
413
  resized_image = input_image.resize((target_w, target_h))
414
 
415
- # Enhanced prompt for FusionX-style output
416
- enhanced_prompt = f"{prompt}, cinematic quality, smooth motion, detailed animation, dynamic lighting"
417
-
418
  with torch.inference_mode():
419
  output_frames_list = pipe(
420
- image=resized_image,
421
- prompt=enhanced_prompt,
422
- negative_prompt=negative_prompt,
423
- height=target_h,
424
- width=target_w,
425
- num_frames=num_frames,
426
- guidance_scale=float(guidance_scale),
427
- num_inference_steps=int(steps),
428
  generator=torch.Generator(device="cuda").manual_seed(current_seed)
429
  ).frames[0]
430
 
@@ -433,143 +124,55 @@ def generate_video(input_image, prompt, height, width,
433
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
434
  return video_path, current_seed
435
 
436
- # --- Gradio UI ---
437
- css="""
438
- #title{text-align: center}
439
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
440
- #title img{width: 100px; margin-right: 0.5em}
441
- #col-container { margin: 0 auto; max-width: 1000px; } /* Increased max-width for gallery */
442
- #gallery .grid-wrap{height: 20vh !important; max-height: 250px !important;}
443
- .custom_lora_card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 10px; margin-top: 10px; background-color: #f9f9f9; }
444
- .card_internal { display: flex; align-items: center; }
445
- .card_internal img { margin-right: 1em; border-radius: 4px; }
446
- .card_internal div h4 { margin-bottom: 0.2em; }
447
- .card_internal div small { font-size: 0.9em; color: #555; }
448
- #lora_list_link { font-size: 90%; background: var(--block-background-fill); padding: 0.5em 1em; border-radius: 8px; display:inline-block; margin-top:10px;}
449
- """
450
-
451
- with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo:
452
- with gr.Column(elem_classes=["main-container"]):
453
- gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA + Audio")
454
-
455
-
456
- with gr.Row():
457
- with gr.Column(elem_classes=["input-container"]):
458
- input_image_component = gr.Image(
459
- type="pil",
460
- label="🖼️ Input Image (auto-resized to target H/W)",
461
- elem_classes=["image-upload"]
462
- )
463
- prompt_input = gr.Textbox(
464
- label="✏️ Enhanced Prompt (FusionX-style enhancements applied)",
465
- value=default_prompt_i2v,
466
- lines=3
467
- )
468
- duration_seconds_input = gr.Slider(
469
- minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1),
470
- maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1),
471
- step=0.1,
472
- value=2,
473
- label="⏱️ Duration (seconds)",
474
- info=f"FusionX Enhanced supports {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps. Recommended: 2-5 seconds"
475
- )
476
-
477
- with gr.Accordion("⚙️ Advanced FusionX Settings", open=False):
478
- negative_prompt_input = gr.Textbox(
479
- label="❌ Negative Prompt (FusionX Enhanced)",
480
- value=default_negative_prompt,
481
- lines=4
482
- )
483
- seed_input = gr.Slider(
484
- label="🎲 Seed",
485
- minimum=0,
486
- maximum=MAX_SEED,
487
- step=1,
488
- value=42,
489
- interactive=True
490
- )
491
- randomize_seed_checkbox = gr.Checkbox(
492
- label="🔀 Randomize seed",
493
- value=True,
494
- interactive=True
495
- )
496
- with gr.Row():
497
- height_input = gr.Slider(
498
- minimum=SLIDER_MIN_H,
499
- maximum=SLIDER_MAX_H,
500
- step=MOD_VALUE,
501
- value=DEFAULT_H_SLIDER_VALUE,
502
- label=f"📏 Output Height (FusionX optimized: {MOD_VALUE} multiples)"
503
- )
504
- width_input = gr.Slider(
505
- minimum=SLIDER_MIN_W,
506
- maximum=SLIDER_MAX_W,
507
- step=MOD_VALUE,
508
- value=DEFAULT_W_SLIDER_VALUE,
509
- label=f"📐 Output Width (FusionX optimized: {MOD_VALUE} multiples)"
510
- )
511
- steps_slider = gr.Slider(
512
- minimum=1,
513
- maximum=20,
514
- step=1,
515
- value=8, # FusionX optimized
516
- label="🚀 Inference Steps (FusionX Enhanced: 8-10 recommended)",
517
- info="FusionX Enhanced delivers excellent results in just 8-10 steps!"
518
- )
519
- guidance_scale_input = gr.Slider(
520
- minimum=0.0,
521
- maximum=20.0,
522
- step=0.5,
523
- value=1.0,
524
- label="🎯 Guidance Scale (FusionX optimized)",
525
- visible=False
526
- )
527
-
528
- generate_button = gr.Button(
529
- "🎬 Generate FusionX Enhanced Video",
530
- variant="primary",
531
- elem_classes=["generate-btn"]
532
- )
533
-
534
- with gr.Column(elem_classes=["output-container"]):
535
- video_output = gr.Video(
536
- label="🎥 FusionX Enhanced Generated Video",
537
- autoplay=True,
538
- interactive=False
539
- )
540
-
541
- input_image_component.upload(
542
- fn=handle_image_upload_for_dims_wan,
543
- inputs=[input_image_component, height_input, width_input],
544
- outputs=[height_input, width_input]
545
- )
546
-
547
- input_image_component.clear(
548
- fn=handle_image_upload_for_dims_wan,
549
- inputs=[input_image_component, height_input, width_input],
550
- outputs=[height_input, width_input]
551
- )
552
-
553
- ui_inputs = [
554
- input_image_component, prompt_input, height_input, width_input,
555
- negative_prompt_input, duration_seconds_input,
556
- guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
557
- ]
558
- generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
559
-
560
  with gr.Column():
561
- gr.Examples(
562
- examples=[
563
- ["peng.png", "a penguin gracefully dancing in the pristine snow, cinematic motion with detailed feathers", 576, 576],
564
- ["frog.jpg", "the frog jumps energetically with smooth, lifelike motion and detailed texture", 576, 576],
565
- ],
566
- inputs=[input_image_component, prompt_input, height_input, width_input],
567
- outputs=[video_output, seed_input],
568
- fn=generate_video,
569
- cache_examples="lazy",
570
- label="🌟 FusionX Enhanced Example Gallery"
571
- )
572
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  if __name__ == "__main__":
575
  demo.queue().launch()
 
10
  from PIL import Image
11
  import random
12
 
13
+ MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
 
 
14
 
 
15
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
16
+ LORA_FILENAME = "FusionX_LoRa/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
17
 
 
 
 
 
 
 
 
18
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
19
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
20
  pipe = WanImageToVideoPipeline.from_pretrained(
21
  MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
22
  )
 
 
23
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
24
  pipe.to("cuda")
25
 
26
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
27
+ pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
28
+ pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
29
+ pipe.fuse_lora()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  MOD_VALUE = 32
32
+ DEFAULT_H_SLIDER_VALUE = 640
33
+ DEFAULT_W_SLIDER_VALUE = 1024
34
+ NEW_FORMULA_MAX_AREA = 640.0 * 1024.0
35
 
36
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
37
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
38
  MAX_SEED = np.iinfo(np.int32).max
39
 
40
  FIXED_FPS = 24
41
  MIN_FRAMES_MODEL = 8
42
+ MAX_FRAMES_MODEL = 81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
45
+ default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
49
  min_slider_h, max_slider_h,
 
85
  guidance_scale, steps,
86
  seed, randomize_seed,
87
  progress):
88
+ if steps > 4 and duration_seconds > 2:
89
+ return 90
90
+ elif steps > 4 or duration_seconds > 2:
91
+ return 75
 
92
  else:
93
+ return 60
94
 
95
  @spaces.GPU(duration=get_duration)
96
  def generate_video(input_image, prompt, height, width,
97
+ negative_prompt=default_negative_prompt, duration_seconds = 2,
98
+ guidance_scale = 1, steps = 4,
99
+ seed = 42, randomize_seed = False,
100
  progress=gr.Progress(track_tqdm=True)):
101
 
102
  if input_image is None:
 
111
 
112
  resized_image = input_image.resize((target_w, target_h))
113
 
 
 
 
114
  with torch.inference_mode():
115
  output_frames_list = pipe(
116
+ image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
117
+ height=target_h, width=target_w, num_frames=num_frames,
118
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
 
 
 
 
 
119
  generator=torch.Generator(device="cuda").manual_seed(current_seed)
120
  ).frames[0]
121
 
 
124
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
125
  return video_path, current_seed
126
 
127
+ with gr.Blocks() as demo:
128
+ gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA")
129
+ gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors) and is compatible with 🧨 diffusers")
130
+ with gr.Row():
131
+ with gr.Column():
132
+ input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
133
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
134
+ duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
135
+
136
+ with gr.Accordion("Advanced Settings", open=False):
137
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
138
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
139
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
140
+ with gr.Row():
141
+ height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
142
+ width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
143
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
144
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
145
+
146
+ generate_button = gr.Button("Generate Video", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  with gr.Column():
148
+ video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
149
+
150
+ input_image_component.upload(
151
+ fn=handle_image_upload_for_dims_wan,
152
+ inputs=[input_image_component, height_input, width_input],
153
+ outputs=[height_input, width_input]
154
+ )
155
+
156
+ input_image_component.clear(
157
+ fn=handle_image_upload_for_dims_wan,
158
+ inputs=[input_image_component, height_input, width_input],
159
+ outputs=[height_input, width_input]
160
+ )
161
+
162
+ ui_inputs = [
163
+ input_image_component, prompt_input, height_input, width_input,
164
+ negative_prompt_input, duration_seconds_input,
165
+ guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
166
+ ]
167
+ generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
168
+
169
+ gr.Examples(
170
+ examples=[
171
+ ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
172
+ ["forg.jpg", "the frog jumps around", 448, 832],
173
+ ],
174
+ inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
175
+ )
176
 
177
  if __name__ == "__main__":
178
  demo.queue().launch()