File size: 24,200 Bytes
78360e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
import os
import time
import argparse
import json
import torch
import traceback
import gc
import random

# These imports rely on your existing code structure
# They must match the location of your WAN code, etc.
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.modules.attention import get_attention_modes
from wan.utils.utils import cache_video
from mmgp import offload, safetensors2, profile_type

try:
    import triton
except ImportError:
    pass

DATA_DIR = "ckpts"

# --------------------------------------------------
# HELPER FUNCTIONS
# --------------------------------------------------

def sanitize_file_name(file_name):
    """Clean up file name from special chars."""
    return (
        file_name.replace("/", "")
        .replace("\\", "")
        .replace(":", "")
        .replace("|", "")
        .replace("?", "")
        .replace("<", "")
        .replace(">", "")
        .replace('"', "")
    )

def extract_preset(lset_name, lora_dir, loras):
    """

    Load a .lset JSON that lists the LoRA files to apply, plus multipliers

    and possibly a suggested prompt prefix.

    """
    lset_name = sanitize_file_name(lset_name)
    if not lset_name.endswith(".lset"):
        lset_name_filename = os.path.join(lora_dir, lset_name + ".lset")
    else:
        lset_name_filename = os.path.join(lora_dir, lset_name)

    if not os.path.isfile(lset_name_filename):
        raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}")

    with open(lset_name_filename, "r", encoding="utf-8") as reader:
        text = reader.read()
    lset = json.loads(text)

    loras_choices_files = lset["loras"]
    loras_choices = []
    missing_loras = []
    for lora_file in loras_choices_files:
        # Build absolute path and see if it is in loras
        full_lora_path = os.path.join(lora_dir, lora_file)
        if full_lora_path in loras:
            idx = loras.index(full_lora_path)
            loras_choices.append(str(idx))
        else:
            missing_loras.append(lora_file)

    if len(missing_loras) > 0:
        missing_list = ", ".join(missing_loras)
        raise ValueError(f"Missing LoRA files for preset: {missing_list}")

    loras_mult_choices = lset["loras_mult"]
    prompt_prefix = lset.get("prompt", "")
    full_prompt = lset.get("full_prompt", False)
    return loras_choices, loras_mult_choices, prompt_prefix, full_prompt

def get_attention_mode(args_attention, installed_modes):
    """

    Decide which attention mode to use: either the user choice or auto fallback.

    """
    if args_attention == "auto":
        for candidate in ["sage2", "sage", "sdpa"]:
            if candidate in installed_modes:
                return candidate
        return "sdpa"  # last fallback
    elif args_attention in installed_modes:
        return args_attention
    else:
        raise ValueError(
            f"Requested attention mode '{args_attention}' not installed. "
            f"Installed modes: {installed_modes}"
        )

def load_i2v_model(model_filename, text_encoder_filename, is_720p):
    """

    Load the i2v model with a specific size config and text encoder.

    """
    if is_720p:
        print("Loading 14B-720p i2v model ...")
        cfg = WAN_CONFIGS['i2v-14B']
        wan_model = wan.WanI2V(
            config=cfg,
            checkpoint_dir=DATA_DIR,
            model_filename=model_filename,
            text_encoder_filename=text_encoder_filename
        )
    else:
        print("Loading 14B-480p i2v model ...")
        cfg = WAN_CONFIGS['i2v-14B']
        wan_model = wan.WanI2V(
            config=cfg,
            checkpoint_dir=DATA_DIR,
            model_filename=model_filename,
            text_encoder_filename=text_encoder_filename
        )
    # Pipe structure
    pipe = {
        "transformer": wan_model.model,
        "text_encoder": wan_model.text_encoder.model,
        "text_encoder_2": wan_model.clip.model,
        "vae": wan_model.vae.model
    }
    return wan_model, pipe

def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps):
    """

    Load loras from a directory, optionally apply a preset.

    """
    from pathlib import Path
    import glob

    if not lora_dir or not Path(lora_dir).is_dir():
        print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.")
        return [], [], [], "", "", False

    # Gather LoRA files
    loras = sorted(
        glob.glob(os.path.join(lora_dir, "*.sft"))
        + glob.glob(os.path.join(lora_dir, "*.safetensors"))
    )
    loras_names = [Path(x).stem for x in loras]

    # Offload them with no activation
    offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False)

    # If user gave a preset, apply it
    default_loras_choices = []
    default_loras_multis_str = ""
    default_prompt_prefix = ""
    preset_applied_full_prompt = False
    if lora_preset:
        loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras)
        default_loras_choices = loras_choices
        # If user stored loras_mult as a list or string in JSON, unify that to str
        if isinstance(loras_mult, list):
            # Just store them in a single line
            default_loras_multis_str = " ".join([str(x) for x in loras_mult])
        else:
            default_loras_multis_str = str(loras_mult)
        default_prompt_prefix = prefix
        preset_applied_full_prompt = full_prompt

    return (
        loras,
        loras_names,
        default_loras_choices,
        default_loras_multis_str,
        default_prompt_prefix,
        preset_applied_full_prompt
    )

def parse_loras_and_activate(

    transformer,

    loras,

    loras_choices,

    loras_mult_str,

    num_inference_steps

):
    """

    Activate the chosen LoRAs with multipliers over the pipeline's transformer.

    Supports stepwise expansions (like "0.5,0.8" for partial steps).

    """
    if not loras or not loras_choices:
        # no LoRAs selected
        return

    # Handle multipliers
    def is_float_or_comma_list(x):
        """

        Example: "0.5", or "0.8,1.0", etc. is valid.

        """
        if not x:
            return False
        for chunk in x.split(","):
            try:
                float(chunk.strip())
            except ValueError:
                return False
        return True

    # Convert multiline or spaced lines to a single list
    lines = [
        line.strip()
        for line in loras_mult_str.replace("\r", "\n").split("\n")
        if line.strip() and not line.strip().startswith("#")
    ]
    # Now combine them by space
    joined_line = " ".join(lines)  # "1.0 2.0,3.0"
    if not joined_line.strip():
        multipliers = []
    else:
        multipliers = joined_line.split(" ")

    # Expand each item
    final_multipliers = []
    for mult in multipliers:
        mult = mult.strip()
        if not mult:
            continue
        if is_float_or_comma_list(mult):
            # Could be "0.7" or "0.5,0.6"
            if "," in mult:
                # expand over steps
                chunk_vals = [float(x.strip()) for x in mult.split(",")]
                expanded = expand_list_over_steps(chunk_vals, num_inference_steps)
                final_multipliers.append(expanded)
            else:
                final_multipliers.append(float(mult))
        else:
            raise ValueError(f"Invalid LoRA multiplier: '{mult}'")

    # If fewer multipliers than chosen LoRAs => pad with 1.0
    needed = len(loras_choices) - len(final_multipliers)
    if needed > 0:
        final_multipliers += [1.0]*needed

    # Actually activate them
    offload.activate_loras(transformer, loras_choices, final_multipliers)

def expand_list_over_steps(short_list, num_steps):
    """

    If user gave (0.5, 0.8) for example, expand them over `num_steps`.

    The expansion is simply linear slice across steps.

    """
    result = []
    inc = len(short_list) / float(num_steps)
    idxf = 0.0
    for _ in range(num_steps):
        value = short_list[int(idxf)]
        result.append(value)
        idxf += inc
    return result

def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR):
    """

    Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'.

    If not, downloads them from a Hugging Face Hub repo.

    Adjust the 'repo_id' and needed files as appropriate.

    """
    import os
    from pathlib import Path

    try:
        from huggingface_hub import hf_hub_download, snapshot_download
    except ImportError as e:
        raise ImportError(
            "huggingface_hub is required for automatic model download. "
            "Please install it via `pip install huggingface_hub`."
        ) from e

    # Identify just the filename portion for each path
    def basename(path_str):
        return os.path.basename(path_str)

    repo_id = "DeepBeepMeep/Wan2.1"
    target_root = local_folder

    # You can customize this list as needed for i2v usage.
    # At minimum you need:
    #   1) The requested i2v transformer file
    #   2) The requested text encoder file
    #   3) VAE file
    #   4) The open-clip xlm-roberta-large weights
    #
    # If your i2v config references additional files, add them here.
    needed_files = [
        "Wan2.1_VAE.pth",
        "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
        basename(text_encoder_filename),
        basename(transformer_filename_i2v),
    ]

    # The original script also downloads an entire "xlm-roberta-large" folder
    # via snapshot_download. If you require that for your pipeline,
    # you can add it here, for example:
    subfolder_name = "xlm-roberta-large"
    if not Path(os.path.join(target_root, subfolder_name)).exists():
        snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root)

    for filename in needed_files:
        local_path = os.path.join(target_root, filename)
        if not os.path.isfile(local_path):
            print(f"File '{filename}' not found locally. Downloading from {repo_id} ...")
            hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                local_dir=target_root
            )
        else:
            # Already present
            pass

    print("All required i2v files are present.")


# --------------------------------------------------
# ARGUMENT PARSER
# --------------------------------------------------

def parse_args():
    parser = argparse.ArgumentParser(
        description="Image-to-Video inference using WAN 2.1 i2v"
    )
    # Model + Tools
    parser.add_argument(
        "--quantize-transformer",
        action="store_true",
        help="Use on-the-fly transformer quantization"
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        help="Enable PyTorch 2.0 compile for the transformer"
    )
    parser.add_argument(
        "--attention",
        type=str,
        default="auto",
        help="Which attention to use: auto, sdpa, sage, sage2, flash"
    )
    parser.add_argument(
        "--profile",
        type=int,
        default=4,
        help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM"
    )
    parser.add_argument(
        "--preload",
        type=int,
        default=0,
        help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)"
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="Verbosity level [0..5]"
    )

    # i2v Model
    parser.add_argument(
        "--transformer-file",
        type=str,
        default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors",
        help="Which i2v model to load"
    )
    parser.add_argument(
        "--text-encoder-file",
        type=str,
        default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors",
        help="Which text encoder to use"
    )

    # LoRA
    parser.add_argument(
        "--lora-dir",
        type=str,
        default="",
        help="Path to a directory containing i2v LoRAs"
    )
    parser.add_argument(
        "--lora-preset",
        type=str,
        default="",
        help="A .lset preset name in the lora_dir to auto-apply"
    )

    # Generation Options
    parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation")
    parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
    parser.add_argument("--resolution", type=str, default="832x480", help="WxH")
    parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.")
    parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.")
    parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale")
    parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.")
    parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos")
    parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
    parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
    parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
    parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
    parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
    parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")

    # LoRA usage
    parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
    parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.")

    # Input
    parser.add_argument(
        "--input-image",
        type=str,
        default=None,
        required=True,
        help="Path to an input image (or multiple)."
    )
    parser.add_argument(
        "--output-file",
        type=str,
        default="output.mp4",
        help="Where to save the resulting video."
    )

    return parser.parse_args()

# --------------------------------------------------
# MAIN
# --------------------------------------------------

def main():
    args = parse_args()

    # Setup environment
    offload.default_verboseLevel = args.verbose
    installed_attn_modes = get_attention_modes()

    # Decide attention
    chosen_attention = get_attention_mode(args.attention, installed_attn_modes)
    offload.shared_state["_attention"] = chosen_attention

    # Determine i2v resolution format
    if "720" in args.transformer_file:
        is_720p = True
    else:
        is_720p = False

    # Make sure we have the needed models locally
    download_models_if_needed(args.transformer_file, args.text_encoder_file)

    # Load i2v
    wan_model, pipe = load_i2v_model(
        model_filename=args.transformer_file,
        text_encoder_filename=args.text_encoder_file,
        is_720p=is_720p
    )
    wan_model._interrupt = False

    # Offload / profile
    # e.g. for your script:  offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...)
    # pass the budgets if you want, etc.
    kwargs = {}
    if args.profile == 2 or args.profile == 4:
        # preload is in MB
        if args.preload == 0:
            budgets = {"transformer": 100, "text_encoder": 100, "*": 1000}
        else:
            budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000}
        kwargs["budgets"] = budgets
    elif args.profile == 3:
        kwargs["budgets"] = {"*": "70%"}

    compile_choice = "transformer" if args.compile else ""
    # Create the offload object
    offloadobj = offload.profile(
        pipe,
        profile_no=args.profile,
        compile=compile_choice,
        quantizeTransformer=args.quantize_transformer,
        **kwargs
    )

    # If user wants to use LoRAs
    (
        loras,
        loras_names,
        default_loras_choices,
        default_loras_multis_str,
        preset_prompt_prefix,
        preset_full_prompt
    ) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps)

    # Combine user prompt with preset prompt if the preset indicates so
    if preset_prompt_prefix:
        if preset_full_prompt:
            # Full override
            user_prompt = preset_prompt_prefix
        else:
            # Just prefix
            user_prompt = preset_prompt_prefix + "\n" + args.prompt
    else:
        user_prompt = args.prompt

    # Actually parse user LoRA choices if they did not rely purely on the preset
    if args.loras_choices:
        # If user gave e.g. "0,1", we treat that as new additions
        lora_choice_list = [x.strip() for x in args.loras_choices.split(",")]
    else:
        # Use the defaults from the preset
        lora_choice_list = default_loras_choices

    # Activate them
    parse_loras_and_activate(
        pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps
    )

    # Negative prompt
    negative_prompt = args.negative_prompt or ""

    # Sanity check resolution
    if "*" in args.resolution.lower():
        print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.")
        resolution_str = args.resolution.lower().replace("*", "x")
    else:
        resolution_str = args.resolution

    try:
        width, height = [int(x) for x in resolution_str.split("x")]
    except:
        raise ValueError(f"Invalid resolution: '{resolution_str}'")

    # Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
    if args.slg_layers:
        slg_list = [int(x) for x in args.slg_layers.split(",")]
    else:
        slg_list = None

    # Additional checks (from your original code).
    if "480p" in args.transformer_file:
        # Then we cannot exceed certain area for 480p model
        if width * height > 832*480:
            raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.")
    # etc.

    # Handle random seed
    if args.seed < 0:
        args.seed = random.randint(0, 999999999)
    print(f"Using seed={args.seed}")

    # Setup tea cache if needed
    trans = wan_model.model
    trans.enable_cache = (args.teacache > 0)
    if trans.enable_cache:
        if "480p" in args.transformer_file:
            # example from your code
            trans.coefficients = [-3.02331670e+02,  2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
        elif "720p" in args.transformer_file:
            trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
        else:
            raise ValueError("Teacache not supported for this model variant")

    # Attempt generation
    print("Starting generation ...")
    start_time = time.time()

    # Read the input image
    if not os.path.isfile(args.input_image):
        raise ValueError(f"Input image does not exist: {args.input_image}")

    from PIL import Image
    input_img = Image.open(args.input_image).convert("RGB")

    # Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration

    # Define the generation call
    #  - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ...
    #    You can correct to that if needed:
    frame_count = (args.frames // 4)*4 + 1  # ensures it's 4*N+1
    # RIFLEx
    enable_riflex = args.riflex

    # If teacache => reset counters
    if trans.enable_cache:
        trans.teacache_counter = 0
        trans.teacache_multiplier = args.teacache
        trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
        trans.num_steps = args.steps
        trans.teacache_skipped_steps = 0
        trans.previous_residual_uncond = None
        trans.previous_residual_cond = None

     # VAE Tiling
    device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
    if device_mem_capacity >= 28000:  # 81 frames 720p requires about 28 GB VRAM
        use_vae_config = 1            
    elif device_mem_capacity >= 8000:
        use_vae_config = 2
    else:          
        use_vae_config = 3

    if use_vae_config == 1:
        VAE_tile_size = 0  
    elif use_vae_config == 2:
        VAE_tile_size = 256  
    else: 
        VAE_tile_size = 128  

    print('Using VAE tile size of', VAE_tile_size)

    # Actually run the i2v generation
    try:
        sample_frames = wan_model.generate(
            input_prompt = user_prompt,
            image_start = input_img,
            frame_num=frame_count,
            width=width,
            height=height,
            # max_area=MAX_AREA_CONFIGS[f"{width}*{height}"],  # or you can pass your custom
            shift=args.flow_shift,
            sampling_steps=args.steps,
            guide_scale=args.guidance_scale,
            n_prompt=negative_prompt,
            seed=args.seed,
            offload_model=False,
            callback=None,  # or define your own callback if you want
            enable_RIFLEx=enable_riflex,
            VAE_tile_size=VAE_tile_size,
            joint_pass=slg_list is None,  # set if you want a small speed improvement without SLG
            slg_layers=slg_list,
            slg_start=args.slg_start,
            slg_end=args.slg_end,
        )
    except Exception as e:
        offloadobj.unload_all()
        gc.collect()
        torch.cuda.empty_cache()

        err_str = f"Generation failed with error: {e}"
        # Attempt to detect OOM errors
        s = str(e).lower()
        if any(keyword in s for keyword in ["memory", "cuda", "alloc"]):
            raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str)
        else:
            traceback.print_exc()
            raise RuntimeError(err_str)

    # After generation
    offloadobj.unload_all()
    gc.collect()
    torch.cuda.empty_cache()

    if sample_frames is None:
        raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")

    # If teacache was used, we can see how many steps were skipped
    if trans.enable_cache:
        print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")

    # Save result
    sample_frames = sample_frames.cpu()  # shape = c, t, h, w => [3, T, H, W]
    os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)

    # Use the provided helper from your code to store the MP4
    # By default, you used cache_video(tensor=..., save_file=..., fps=16, ...)
    # or you can do your own. We'll do the same for consistency:
    cache_video(
        tensor=sample_frames[None],  # shape => [1, c, T, H, W]
        save_file=args.output_file,
        fps=16,
        nrow=1,
        normalize=True,
        value_range=(-1, 1)
    )

    end_time = time.time()
    elapsed_s = end_time - start_time
    print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.")

if __name__ == "__main__":
    main()