import os
import time
import argparse
import json
import torch
import traceback
import gc
import random

# These imports rely on your existing code structure
# They must match the location of your WAN code, etc.
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.modules.attention import get_attention_modes
from wan.utils.utils import cache_video
from mmgp import offload, safetensors2, profile_type

try:
    import triton
except ImportError:
    pass

DATA_DIR = "ckpts"

# --------------------------------------------------
# HELPER FUNCTIONS
# --------------------------------------------------

def sanitize_file_name(file_name):
    """Clean up file name from special chars."""
    return (
        file_name.replace("/", "")
        .replace("\\", "")
        .replace(":", "")
        .replace("|", "")
        .replace("?", "")
        .replace("<", "")
        .replace(">", "")
        .replace('"', "")
    )

def extract_preset(lset_name, lora_dir, loras):
    """
    Load a .lset JSON that lists the LoRA files to apply, plus multipliers
    and possibly a suggested prompt prefix.
    """
    lset_name = sanitize_file_name(lset_name)
    if not lset_name.endswith(".lset"):
        lset_name_filename = os.path.join(lora_dir, lset_name + ".lset")
    else:
        lset_name_filename = os.path.join(lora_dir, lset_name)

    if not os.path.isfile(lset_name_filename):
        raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}")

    with open(lset_name_filename, "r", encoding="utf-8") as reader:
        text = reader.read()
    lset = json.loads(text)

    loras_choices_files = lset["loras"]
    loras_choices = []
    missing_loras = []
    for lora_file in loras_choices_files:
        # Build absolute path and see if it is in loras
        full_lora_path = os.path.join(lora_dir, lora_file)
        if full_lora_path in loras:
            idx = loras.index(full_lora_path)
            loras_choices.append(str(idx))
        else:
            missing_loras.append(lora_file)

    if len(missing_loras) > 0:
        missing_list = ", ".join(missing_loras)
        raise ValueError(f"Missing LoRA files for preset: {missing_list}")

    loras_mult_choices = lset["loras_mult"]
    prompt_prefix = lset.get("prompt", "")
    full_prompt = lset.get("full_prompt", False)
    return loras_choices, loras_mult_choices, prompt_prefix, full_prompt

def get_attention_mode(args_attention, installed_modes):
    """
    Decide which attention mode to use: either the user choice or auto fallback.
    """
    if args_attention == "auto":
        for candidate in ["sage2", "sage", "sdpa"]:
            if candidate in installed_modes:
                return candidate
        return "sdpa"  # last fallback
    elif args_attention in installed_modes:
        return args_attention
    else:
        raise ValueError(
            f"Requested attention mode '{args_attention}' not installed. "
            f"Installed modes: {installed_modes}"
        )

def load_i2v_model(model_filename, text_encoder_filename, is_720p):
    """
    Load the i2v model with a specific size config and text encoder.
    """
    if is_720p:
        print("Loading 14B-720p i2v model ...")
        cfg = WAN_CONFIGS['i2v-14B']
        wan_model = wan.WanI2V(
            config=cfg,
            checkpoint_dir=DATA_DIR,
            model_filename=model_filename,
            text_encoder_filename=text_encoder_filename
        )
    else:
        print("Loading 14B-480p i2v model ...")
        cfg = WAN_CONFIGS['i2v-14B']
        wan_model = wan.WanI2V(
            config=cfg,
            checkpoint_dir=DATA_DIR,
            model_filename=model_filename,
            text_encoder_filename=text_encoder_filename
        )
    # Pipe structure
    pipe = {
        "transformer": wan_model.model,
        "text_encoder": wan_model.text_encoder.model,
        "text_encoder_2": wan_model.clip.model,
        "vae": wan_model.vae.model
    }
    return wan_model, pipe

def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps):
    """
    Load loras from a directory, optionally apply a preset.
    """
    from pathlib import Path
    import glob

    if not lora_dir or not Path(lora_dir).is_dir():
        print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.")
        return [], [], [], "", "", False

    # Gather LoRA files
    loras = sorted(
        glob.glob(os.path.join(lora_dir, "*.sft"))
        + glob.glob(os.path.join(lora_dir, "*.safetensors"))
    )
    loras_names = [Path(x).stem for x in loras]

    # Offload them with no activation
    offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False)

    # If user gave a preset, apply it
    default_loras_choices = []
    default_loras_multis_str = ""
    default_prompt_prefix = ""
    preset_applied_full_prompt = False
    if lora_preset:
        loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras)
        default_loras_choices = loras_choices
        # If user stored loras_mult as a list or string in JSON, unify that to str
        if isinstance(loras_mult, list):
            # Just store them in a single line
            default_loras_multis_str = " ".join([str(x) for x in loras_mult])
        else:
            default_loras_multis_str = str(loras_mult)
        default_prompt_prefix = prefix
        preset_applied_full_prompt = full_prompt

    return (
        loras,
        loras_names,
        default_loras_choices,
        default_loras_multis_str,
        default_prompt_prefix,
        preset_applied_full_prompt
    )

def parse_loras_and_activate(
    transformer,
    loras,
    loras_choices,
    loras_mult_str,
    num_inference_steps
):
    """
    Activate the chosen LoRAs with multipliers over the pipeline's transformer.
    Supports stepwise expansions (like "0.5,0.8" for partial steps).
    """
    if not loras or not loras_choices:
        # no LoRAs selected
        return

    # Handle multipliers
    def is_float_or_comma_list(x):
        """
        Example: "0.5", or "0.8,1.0", etc. is valid.
        """
        if not x:
            return False
        for chunk in x.split(","):
            try:
                float(chunk.strip())
            except ValueError:
                return False
        return True

    # Convert multiline or spaced lines to a single list
    lines = [
        line.strip()
        for line in loras_mult_str.replace("\r", "\n").split("\n")
        if line.strip() and not line.strip().startswith("#")
    ]
    # Now combine them by space
    joined_line = " ".join(lines)  # "1.0 2.0,3.0"
    if not joined_line.strip():
        multipliers = []
    else:
        multipliers = joined_line.split(" ")

    # Expand each item
    final_multipliers = []
    for mult in multipliers:
        mult = mult.strip()
        if not mult:
            continue
        if is_float_or_comma_list(mult):
            # Could be "0.7" or "0.5,0.6"
            if "," in mult:
                # expand over steps
                chunk_vals = [float(x.strip()) for x in mult.split(",")]
                expanded = expand_list_over_steps(chunk_vals, num_inference_steps)
                final_multipliers.append(expanded)
            else:
                final_multipliers.append(float(mult))
        else:
            raise ValueError(f"Invalid LoRA multiplier: '{mult}'")

    # If fewer multipliers than chosen LoRAs => pad with 1.0
    needed = len(loras_choices) - len(final_multipliers)
    if needed > 0:
        final_multipliers += [1.0]*needed

    # Actually activate them
    offload.activate_loras(transformer, loras_choices, final_multipliers)

def expand_list_over_steps(short_list, num_steps):
    """
    If user gave (0.5, 0.8) for example, expand them over `num_steps`.
    The expansion is simply linear slice across steps.
    """
    result = []
    inc = len(short_list) / float(num_steps)
    idxf = 0.0
    for _ in range(num_steps):
        value = short_list[int(idxf)]
        result.append(value)
        idxf += inc
    return result

def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR):
    """
    Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'.
    If not, downloads them from a Hugging Face Hub repo.
    Adjust the 'repo_id' and needed files as appropriate.
    """
    import os
    from pathlib import Path

    try:
        from huggingface_hub import hf_hub_download, snapshot_download
    except ImportError as e:
        raise ImportError(
            "huggingface_hub is required for automatic model download. "
            "Please install it via `pip install huggingface_hub`."
        ) from e

    # Identify just the filename portion for each path
    def basename(path_str):
        return os.path.basename(path_str)

    repo_id = "DeepBeepMeep/Wan2.1"
    target_root = local_folder

    # You can customize this list as needed for i2v usage.
    # At minimum you need:
    #   1) The requested i2v transformer file
    #   2) The requested text encoder file
    #   3) VAE file
    #   4) The open-clip xlm-roberta-large weights
    #
    # If your i2v config references additional files, add them here.
    needed_files = [
        "Wan2.1_VAE.pth",
        "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
        basename(text_encoder_filename),
        basename(transformer_filename_i2v),
    ]

    # The original script also downloads an entire "xlm-roberta-large" folder
    # via snapshot_download. If you require that for your pipeline,
    # you can add it here, for example:
    subfolder_name = "xlm-roberta-large"
    if not Path(os.path.join(target_root, subfolder_name)).exists():
        snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root)

    for filename in needed_files:
        local_path = os.path.join(target_root, filename)
        if not os.path.isfile(local_path):
            print(f"File '{filename}' not found locally. Downloading from {repo_id} ...")
            hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                local_dir=target_root
            )
        else:
            # Already present
            pass

    print("All required i2v files are present.")


# --------------------------------------------------
# ARGUMENT PARSER
# --------------------------------------------------

def parse_args():
    parser = argparse.ArgumentParser(
        description="Image-to-Video inference using WAN 2.1 i2v"
    )
    # Model + Tools
    parser.add_argument(
        "--quantize-transformer",
        action="store_true",
        help="Use on-the-fly transformer quantization"
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        help="Enable PyTorch 2.0 compile for the transformer"
    )
    parser.add_argument(
        "--attention",
        type=str,
        default="auto",
        help="Which attention to use: auto, sdpa, sage, sage2, flash"
    )
    parser.add_argument(
        "--profile",
        type=int,
        default=4,
        help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM"
    )
    parser.add_argument(
        "--preload",
        type=int,
        default=0,
        help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)"
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="Verbosity level [0..5]"
    )

    # i2v Model
    parser.add_argument(
        "--transformer-file",
        type=str,
        default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors",
        help="Which i2v model to load"
    )
    parser.add_argument(
        "--text-encoder-file",
        type=str,
        default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors",
        help="Which text encoder to use"
    )

    # LoRA
    parser.add_argument(
        "--lora-dir",
        type=str,
        default="",
        help="Path to a directory containing i2v LoRAs"
    )
    parser.add_argument(
        "--lora-preset",
        type=str,
        default="",
        help="A .lset preset name in the lora_dir to auto-apply"
    )

    # Generation Options
    parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation")
    parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
    parser.add_argument("--resolution", type=str, default="832x480", help="WxH")
    parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.")
    parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.")
    parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale")
    parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.")
    parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos")
    parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
    parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
    parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
    parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
    parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
    parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")

    # LoRA usage
    parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
    parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.")

    # Input
    parser.add_argument(
        "--input-image",
        type=str,
        default=None,
        required=True,
        help="Path to an input image (or multiple)."
    )
    parser.add_argument(
        "--output-file",
        type=str,
        default="output.mp4",
        help="Where to save the resulting video."
    )

    return parser.parse_args()

# --------------------------------------------------
# MAIN
# --------------------------------------------------

def main():
    args = parse_args()

    # Setup environment
    offload.default_verboseLevel = args.verbose
    installed_attn_modes = get_attention_modes()

    # Decide attention
    chosen_attention = get_attention_mode(args.attention, installed_attn_modes)
    offload.shared_state["_attention"] = chosen_attention

    # Determine i2v resolution format
    if "720" in args.transformer_file:
        is_720p = True
    else:
        is_720p = False

    # Make sure we have the needed models locally
    download_models_if_needed(args.transformer_file, args.text_encoder_file)

    # Load i2v
    wan_model, pipe = load_i2v_model(
        model_filename=args.transformer_file,
        text_encoder_filename=args.text_encoder_file,
        is_720p=is_720p
    )
    wan_model._interrupt = False

    # Offload / profile
    # e.g. for your script:  offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...)
    # pass the budgets if you want, etc.
    kwargs = {}
    if args.profile == 2 or args.profile == 4:
        # preload is in MB
        if args.preload == 0:
            budgets = {"transformer": 100, "text_encoder": 100, "*": 1000}
        else:
            budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000}
        kwargs["budgets"] = budgets
    elif args.profile == 3:
        kwargs["budgets"] = {"*": "70%"}

    compile_choice = "transformer" if args.compile else ""
    # Create the offload object
    offloadobj = offload.profile(
        pipe,
        profile_no=args.profile,
        compile=compile_choice,
        quantizeTransformer=args.quantize_transformer,
        **kwargs
    )

    # If user wants to use LoRAs
    (
        loras,
        loras_names,
        default_loras_choices,
        default_loras_multis_str,
        preset_prompt_prefix,
        preset_full_prompt
    ) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps)

    # Combine user prompt with preset prompt if the preset indicates so
    if preset_prompt_prefix:
        if preset_full_prompt:
            # Full override
            user_prompt = preset_prompt_prefix
        else:
            # Just prefix
            user_prompt = preset_prompt_prefix + "\n" + args.prompt
    else:
        user_prompt = args.prompt

    # Actually parse user LoRA choices if they did not rely purely on the preset
    if args.loras_choices:
        # If user gave e.g. "0,1", we treat that as new additions
        lora_choice_list = [x.strip() for x in args.loras_choices.split(",")]
    else:
        # Use the defaults from the preset
        lora_choice_list = default_loras_choices

    # Activate them
    parse_loras_and_activate(
        pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps
    )

    # Negative prompt
    negative_prompt = args.negative_prompt or ""

    # Sanity check resolution
    if "*" in args.resolution.lower():
        print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.")
        resolution_str = args.resolution.lower().replace("*", "x")
    else:
        resolution_str = args.resolution

    try:
        width, height = [int(x) for x in resolution_str.split("x")]
    except:
        raise ValueError(f"Invalid resolution: '{resolution_str}'")

    # Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
    if args.slg_layers:
        slg_list = [int(x) for x in args.slg_layers.split(",")]
    else:
        slg_list = None

    # Additional checks (from your original code).
    if "480p" in args.transformer_file:
        # Then we cannot exceed certain area for 480p model
        if width * height > 832*480:
            raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.")
    # etc.

    # Handle random seed
    if args.seed < 0:
        args.seed = random.randint(0, 999999999)
    print(f"Using seed={args.seed}")

    # Setup tea cache if needed
    trans = wan_model.model
    trans.enable_cache = (args.teacache > 0)
    if trans.enable_cache:
        if "480p" in args.transformer_file:
            # example from your code
            trans.coefficients = [-3.02331670e+02,  2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
        elif "720p" in args.transformer_file:
            trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
        else:
            raise ValueError("Teacache not supported for this model variant")

    # Attempt generation
    print("Starting generation ...")
    start_time = time.time()

    # Read the input image
    if not os.path.isfile(args.input_image):
        raise ValueError(f"Input image does not exist: {args.input_image}")

    from PIL import Image
    input_img = Image.open(args.input_image).convert("RGB")

    # Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration

    # Define the generation call
    #  - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ...
    #    You can correct to that if needed:
    frame_count = (args.frames // 4)*4 + 1  # ensures it's 4*N+1
    # RIFLEx
    enable_riflex = args.riflex

    # If teacache => reset counters
    if trans.enable_cache:
        trans.teacache_counter = 0
        trans.teacache_multiplier = args.teacache
        trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
        trans.num_steps = args.steps
        trans.teacache_skipped_steps = 0
        trans.previous_residual_uncond = None
        trans.previous_residual_cond = None

     # VAE Tiling
    device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
    if device_mem_capacity >= 28000:  # 81 frames 720p requires about 28 GB VRAM
        use_vae_config = 1            
    elif device_mem_capacity >= 8000:
        use_vae_config = 2
    else:          
        use_vae_config = 3

    if use_vae_config == 1:
        VAE_tile_size = 0  
    elif use_vae_config == 2:
        VAE_tile_size = 256  
    else: 
        VAE_tile_size = 128  

    print('Using VAE tile size of', VAE_tile_size)

    # Actually run the i2v generation
    try:
        sample_frames = wan_model.generate(
            input_prompt = user_prompt,
            image_start = input_img,
            frame_num=frame_count,
            width=width,
            height=height,
            # max_area=MAX_AREA_CONFIGS[f"{width}*{height}"],  # or you can pass your custom
            shift=args.flow_shift,
            sampling_steps=args.steps,
            guide_scale=args.guidance_scale,
            n_prompt=negative_prompt,
            seed=args.seed,
            offload_model=False,
            callback=None,  # or define your own callback if you want
            enable_RIFLEx=enable_riflex,
            VAE_tile_size=VAE_tile_size,
            joint_pass=slg_list is None,  # set if you want a small speed improvement without SLG
            slg_layers=slg_list,
            slg_start=args.slg_start,
            slg_end=args.slg_end,
        )
    except Exception as e:
        offloadobj.unload_all()
        gc.collect()
        torch.cuda.empty_cache()

        err_str = f"Generation failed with error: {e}"
        # Attempt to detect OOM errors
        s = str(e).lower()
        if any(keyword in s for keyword in ["memory", "cuda", "alloc"]):
            raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str)
        else:
            traceback.print_exc()
            raise RuntimeError(err_str)

    # After generation
    offloadobj.unload_all()
    gc.collect()
    torch.cuda.empty_cache()

    if sample_frames is None:
        raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")

    # If teacache was used, we can see how many steps were skipped
    if trans.enable_cache:
        print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")

    # Save result
    sample_frames = sample_frames.cpu()  # shape = c, t, h, w => [3, T, H, W]
    os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)

    # Use the provided helper from your code to store the MP4
    # By default, you used cache_video(tensor=..., save_file=..., fps=16, ...)
    # or you can do your own. We'll do the same for consistency:
    cache_video(
        tensor=sample_frames[None],  # shape => [1, c, T, H, W]
        save_file=args.output_file,
        fps=16,
        nrow=1,
        normalize=True,
        value_range=(-1, 1)
    )

    end_time = time.time()
    elapsed_s = end_time - start_time
    print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.")

if __name__ == "__main__":
    main()