import os
os.environ['HF_HOME'] = os.path.abspath(
    os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))
)

import gradio as gr
import torch
import traceback
import einops
import safetensors.torch as sf
import numpy as np
import math
import spaces
from PIL import Image

# Diffusers models
from diffusers import AutoencoderKLHunyuanVideo

# Transformers models
from transformers import (
    LlamaModel, CLIPTextModel,
    LlamaTokenizerFast, CLIPTokenizer,
    AutoImageProcessor, CLIPImageProcessor, CLIPVisionModel
)

# Local helper modules
from diffusers_helper.hunyuan import (
    encode_prompt_conds, vae_decode,
    vae_encode, vae_decode_fake
)

from diffusers_helper.utils import (
    save_bcthw_as_mp4, crop_or_pad_yield_mask,
    soft_append_bcthw, resize_and_center_crop,
    state_dict_weighted_merge, state_dict_offset_merge,
    generate_timestamp
)

from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket

# Thread utilities
from diffusers_helper.thread_utils import AsyncStream, async_run

# Gradio progress bar utils
from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html

# Set device to CPU
device = torch.device("cpu")

# Load models
text_encoder = LlamaModel.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder='text_encoder',
    torch_dtype=torch.float16
).to(device)

text_encoder_2 = CLIPTextModel.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder='text_encoder_2',
    torch_dtype=torch.float16
).to(device)

tokenizer = LlamaTokenizerFast.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder='tokenizer'
)

tokenizer_2 = CLIPTokenizer.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder='tokenizer_2'
)

vae = AutoencoderKLHunyuanVideo.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder='vae',
    torch_dtype=torch.float16
).to(device)

# Use AutoImageProcessor instead of SiglipImageProcessor
feature_extractor = CLIPImageProcessor.from_pretrained(
    "lllyasviel/flux_redux_bfl",
    subfolder='feature_extractor'
)

image_encoder = CLIPVisionModel.from_pretrained(
    "lllyasviel/flux_redux_bfl",
    subfolder='image_encoder',
    torch_dtype=torch.float16,
    ignore_mismatched_sizes=True
).to(device)  # Make sure device is defined earlier as "cpu"

transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
    'lllyasviel/FramePack_F1_I2V_HY_20250503',
    torch_dtype=torch.bfloat16
).to(device)

# Evaluation mode
vae.eval()
text_encoder.eval()
text_encoder_2.eval()
image_encoder.eval()
transformer.eval()

# Move to correct dtype
transformer.to(dtype=torch.bfloat16)
vae.to(dtype=torch.float16)
image_encoder.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
text_encoder_2.to(dtype=torch.float16)

# No gradient
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
transformer.requires_grad_(False)

stream = AsyncStream()
outputs_folder = './outputs/'
os.makedirs(outputs_folder, exist_ok=True)

examples = [
    ["img_examples/1.png", "The girl dances gracefully, with clear movements, full of charm."],
    ["img_examples/2.jpg", "The man dances flamboyantly, swinging his hips and striking bold poses with dramatic flair."],
    ["img_examples/3.png", "The woman dances elegantly among the blossoms, spinning slowly with flowing sleeves and graceful hand movements."]
]

def generate_examples(input_image, prompt):
    t2v=False
    n_prompt=""
    seed=31337
    total_second_length=60
    latent_window_size=9
    steps=25
    cfg=1.0
    gs=10.0
    rs=0.0
    gpu_memory_preservation=6  # unused
    use_teacache=True
    mp4_crf=16
    global stream
    if t2v:
        default_height, default_width = 640, 640
        input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
        print("No input image provided. Using a blank white image.")
    yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
    stream = AsyncStream()
    async_run(
        worker, input_image, prompt, n_prompt, seed,
        total_second_length, latent_window_size, steps,
        cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
    )
    output_filename = None
    while True:
        flag, data = stream.output_queue.next()
        if flag == 'file':
            output_filename = data
            yield (
                output_filename,
                gr.update(),
                gr.update(),
                gr.update(),
                gr.update(interactive=False),
                gr.update(interactive=True)
            )
        if flag == 'progress':
            preview, desc, html = data
            yield (
                gr.update(),
                gr.update(visible=True, value=preview),
                desc,
                html,
                gr.update(interactive=False),
                gr.update(interactive=True)
            )
        if flag == 'end':
            yield (
                output_filename,
                gr.update(visible=False),
                gr.update(),
                '',
                gr.update(interactive=True),
                gr.update(interactive=False)
            )
            break

@torch.no_grad()
def worker(
    input_image, prompt, n_prompt, seed,
    total_second_length, latent_window_size, steps,
    cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
):
    total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
    total_latent_sections = int(max(round(total_latent_sections), 1))
    job_id = generate_timestamp()
    stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
    try:
        llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
        if cfg == 1:
            llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
        else:
            llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
        llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
        llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
        H, W, C = input_image.shape
        height, width = find_nearest_bucket(H, W, resolution=640)
        input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
        Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
        input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
        input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
        start_latent = vae_encode(input_image_pt, vae).to(device)
        image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
        image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
        llama_vec = llama_vec.to(transformer.dtype).to(device)
        llama_vec_n = llama_vec_n.to(transformer.dtype).to(device)
        clip_l_pooler = clip_l_pooler.to(transformer.dtype).to(device)
        clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype).to(device)
        image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype).to(device)
        rnd = torch.Generator("cpu").manual_seed(seed)
        history_latents = torch.zeros(
            size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
            dtype=torch.float32
        ).to(device)
        history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
        total_generated_latent_frames = 1
        for section_index in range(total_latent_sections):
            if stream.input_queue.top() == 'end':
                stream.output_queue.push(('end', None))
                return
            if use_teacache:
                transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
            else:
                transformer.initialize_teacache(enable_teacache=False)
            def callback(d):
                preview = d['denoised']
                preview = vae_decode_fake(preview)
                preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
                preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
                if stream.input_queue.top() == 'end':
                    stream.output_queue.push(('end', None))
                    raise KeyboardInterrupt('User ends the task.')
                current_step = d['i'] + 1
                percentage = int(100.0 * current_step / steps)
                hint = f'Sampling {current_step}/{steps}'
                desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}'
                stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
                return
            indices = torch.arange(
                0, sum([1, 16, 2, 1, latent_window_size])
            ).unsqueeze(0)
            (
                clean_latent_indices_start,
                clean_latent_4x_indices,
                clean_latent_2x_indices,
                clean_latent_1x_indices,
                latent_indices
            ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
            clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
            clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
                :, :, -sum([16, 2, 1]):, :, :
            ].split([16, 2, 1], dim=2)
            clean_latents = torch.cat(
                [start_latent.to(history_latents), clean_latents_1x],
                dim=2
            )
            generated_latents = sample_hunyuan(
                transformer=transformer,
                sampler='unipc',
                width=width,
                height=height,
                frames=latent_window_size * 4 - 3,
                real_guidance_scale=cfg,
                distilled_guidance_scale=gs,
                guidance_rescale=rs,
                num_inference_steps=steps,
                generator=rnd,
                prompt_embeds=llama_vec,
                prompt_embeds_mask=llama_attention_mask,
                prompt_poolers=clip_l_pooler,
                negative_prompt_embeds=llama_vec_n,
                negative_prompt_embeds_mask=llama_attention_mask_n,
                negative_prompt_poolers=clip_l_pooler_n,
                device=device,
                dtype=torch.bfloat16,
                image_embeddings=image_encoder_last_hidden_state,
                latent_indices=latent_indices,
                clean_latents=clean_latents,
                clean_latent_indices=clean_latent_indices,
                clean_latents_2x=clean_latents_2x,
                clean_latent_2x_indices=clean_latent_2x_indices,
                clean_latents_4x=clean_latents_4x,
                clean_latent_4x_indices=clean_latent_4x_indices,
                callback=callback,
            )
            total_generated_latent_frames += int(generated_latents.shape[2])
            history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
            real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
            if history_pixels is None:
                history_pixels = vae_decode(real_history_latents, vae).cpu()
            else:
                section_latent_frames = latent_window_size * 2
                overlapped_frames = latent_window_size * 4 - 3
                current_pixels = vae_decode(
                    real_history_latents[:, :, -section_latent_frames:], vae
                ).cpu()
                history_pixels = soft_append_bcthw(
                    history_pixels, current_pixels, overlapped_frames
                )
            output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
            save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
            stream.output_queue.push(('file', output_filename))
    except Exception as e:
        traceback.print_exc()
    stream.output_queue.push(('end', None))
    return

def get_duration(
    input_image, prompt, t2v, n_prompt,
    seed, total_second_length, latent_window_size,
    steps, cfg, gs, rs, gpu_memory_preservation,
    use_teacache, mp4_crf, quality_radio=None, aspect_ratio=None
):
    # Accept extra arguments for compatibility with process()
    return total_second_length * 60

@spaces.GPU(duration=get_duration)
def process(
    input_image, prompt, t2v=False, n_prompt="", seed=31337,
    total_second_length=60, latent_window_size=9, steps=25,
    cfg=1.0, gs=10.0, rs=0.0, gpu_memory_preservation=6,
    use_teacache=True, mp4_crf=16, quality_radio="640x360", aspect_ratio="1:1"
):
    global stream
    
    # Map quality options to actual resolutions
    quality_map = {
        "360p": (640, 360),
        "480p": (854, 480),
        "540p": (960, 540),
        "720p": (1280, 720),
        "640x360": (640, 360),  # fallback
    }
    
    # Map aspect ratio strings to width/height ratios
    aspect_map = {
        "1:1": (1, 1),
        "3:4": (3, 4),
        "4:3": (4, 3),
        "16:9": (16, 9),
        "9:16": (9, 16),
    }

    # Get target resolution based on selected quality
    target_width, target_height = quality_map.get(quality_radio, (640, 360))
    
    if t2v:
        ar_w, ar_h = aspect_map.get(aspect_ratio, (1, 1))
        # Recalculate based on aspect ratio
        if ar_w >= ar_h:
            target_height = int(round(target_width * ar_h / ar_w))
        else:
            target_width = int(round(target_height * ar_w / ar_h))
        input_image = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255
        print(f"Using blank white image for text-to-video mode, {target_width}x{target_height} ({aspect_ratio})")
    else:
        # Resize and crop input image to match selected resolution
        H, W, C = input_image.shape
        height, width = find_nearest_bucket(H, W, resolution=target_width)
        input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
        Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))

    yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
    stream = AsyncStream()
    async_run(
        worker, input_image, prompt, n_prompt, seed,
        total_second_length, latent_window_size, steps,
        cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
    )
    output_filename = None
    while True:
        flag, data = stream.output_queue.next()
        if flag == 'file':
            output_filename = data
            yield (
                output_filename,
                gr.update(),
                gr.update(),
                gr.update(),
                gr.update(interactive=False),
                gr.update(interactive=True)
            )
        elif flag == 'progress':
            preview, desc, html = data
            yield (
                gr.update(),
                gr.update(visible=True, value=preview),
                desc,
                html,
                gr.update(interactive=False),
                gr.update(interactive=True)
            )
        elif flag == 'end':
            yield (
                output_filename,
                gr.update(visible=False),
                gr.update(),
                '',
                gr.update(interactive=True),
                gr.update(interactive=False)
            )
            break

def end_process():
    stream.input_queue.push('end')

quick_prompts = [
    'The girl dances gracefully, with clear movements, full of charm.',
    'A character doing some simple body movements.'
]
quick_prompts = [[x] for x in quick_prompts]


def make_custom_css():
    base_progress_css = make_progress_bar_css()
    extra_css = """
    body {
        background: #1a1b1e !important;
        font-family: "Noto Sans", sans-serif;
        color: #e0e0e0;
    }
    #title-container {
        text-align: center;
        padding: 20px 0;
        margin-bottom: 30px;
    }
    #title-container h1 {
        color: #4b9ffa;
        font-size: 2.5rem;
        margin: 0;
        font-weight: 800;
    }
    #title-container p {
        color: #e0e0e0;
    }
    .three-column-container {
        display: flex;
        gap: 20px;
        min-height: 800px;
        max-width: 1600px;
        margin: 0 auto;
    }
    .settings-panel {
        flex: 0 0 150px;
        background: #2a2b2e;
        padding: 12px;
        border-radius: 15px;
        border: 1px solid #3a3b3e;
    }
    .settings-panel .gr-slider {
        width: calc(100% - 10px) !important;
    }
    .settings-panel label {
        color: #e0e0e0 !important;
    }
    .settings-panel label span:first-child {
        font-size: 0.9rem !important;
    }
    .main-panel {
        flex: 1;
        background: #2a2b2e;
        padding: 20px;
        border-radius: 15px;
        border: 1px solid #3a3b3e;
        display: flex;
        flex-direction: column;
        gap: 20px;
    }
    .output-panel {
        flex: 1;
        background: #2a2b2e;
        padding: 20px;
        border-radius: 15px;
        border: 1px solid #3a3b3e;
        display: flex;
        flex-direction: column;
        align-items: center;  /* Center output content */
        gap: 20px;
    }
    .output-panel > div {
        width: 100%;
        max-width: 640px;  /* Limit width for better centering */
    }
    .settings-panel h3 {
        color: #4b9ffa;
        margin-bottom: 15px;
        font-size: 1.1rem;
        border-bottom: 2px solid #4b9ffa;
        padding-bottom: 8px;
    }
    .prompt-container {
        min-height: 200px;
    }
    .quick-prompts {
        margin-top: 10px;
        padding: 10px;
        background: #1a1b1e;
        border-radius: 10px;
    }
    .button-container {
        display: flex;
        gap: 10px;
        margin: 15px 0;
        justify-content: center;
        width: 100%;
    }
    /* Override Gradio's default light theme */
    .gr-box {
        background: #2a2b2e !important;
        border-color: #3a3b3e !important;
    }
    .gr-input, .gr-textbox {
        background: #1a1b1e !important;
        border-color: #3a3b3e !important;
        color: #e0e0e0 !important;
    }
    .gr-form {
        background: transparent !important;
        border: none !important;
    }
    .gr-label {
        color: #e0e0e0 !important;
    }
    .gr-button {
        background: #4b9ffa !important;
        color: white !important;
    }
    .gr-button.secondary-btn {
        background: #ff4d4d !important;
    }
    """
    return base_progress_css + extra_css

css = make_custom_css()

block = gr.Blocks(css=css).queue()
with block:
    with gr.Group(elem_id="title-container"):
        gr.Markdown("<h1>FramePack</h1>")
        gr.Markdown(
            """Generate amazing animations from a single image using AI. 
            Just upload an image, write a prompt, and watch the magic happen!"""
        )

    with gr.Row(elem_classes="three-column-container"):
        # Left Column - Settings
        with gr.Column(elem_classes="settings-panel"):
            gr.Markdown("### Generation Settings")
            
            with gr.Group():
                total_second_length = gr.Slider(
                    label="Duration (Seconds)",
                    minimum=1,
                    maximum=10,
                    value=2,
                    step=1,
                    info='Length of generated video'
                )
                steps = gr.Slider(
                    label="Quality Steps",
                    minimum=1,
                    maximum=100,
                    value=15,
                    step=1,
                    info='25-30 recommended'
                )
                gs = gr.Slider(
                    label="Animation Strength",
                    minimum=1.0,
                    maximum=32.0,
                    value=10.0,
                    step=0.1,
                    info='8-12 recommended'
                )
                quality_radio = gr.Radio(
                    label="Video Quality (Resolution)",
                    choices=["360p", "480p", "540p", "720p"],
                    value="640x360",
                    info="Choose output video resolution"
                )
                # Aspect ratio dropdown, hidden by default
                aspect_ratio = gr.Dropdown(
                    label="Aspect Ratio",
                    choices=["1:1", "3:4", "4:3", "16:9", "9:16"],
                    value="1:1",
                    visible=False,
                    info="Only applies to Text to Video mode"
                )
            
            gr.Markdown("### Advanced")
            with gr.Group():
                t2v = gr.Checkbox(
                    label='Text to Video Mode',
                    value=False,
                    info='Generate without input image'
                )
                use_teacache = gr.Checkbox(
                    label='Fast Mode',
                    value=True,
                    info='Faster but may affect details'
                )
                gpu_memory_preservation = gr.Slider(
                    label="VRAM Usage",
                    minimum=6,
                    maximum=128,
                    value=6,
                    step=1
                )
                seed = gr.Number(
                    label="Seed",
                    value=31337,
                    precision=0
                )

            # Hidden settings
            n_prompt = gr.Textbox(visible=False, value="")
            latent_window_size = gr.Slider(visible=False, value=9)
            cfg = gr.Slider(visible=False, value=1.0)
            rs = gr.Slider(visible=False, value=0.0)
            mp4_crf = gr.Number(visible=False, value=16)  # <-- Add this hidden component

        # Middle Column - Main Content
        with gr.Column(elem_classes="main-panel"):
            input_image = gr.Image(
                label="Upload Your Image",
                type="numpy",
                height=320
            )
            
            # Moved buttons here
            with gr.Group(elem_classes="button-container"):
                start_button = gr.Button(
                    value="▶️ Generate Animation",
                    elem_classes=["primary-btn"]
                )
                stop_button = gr.Button(
                    value="⏹️ Stop",
                    elem_classes=["secondary-btn"],
                    interactive=False
                )
            
            with gr.Group(elem_classes="prompt-container"):
                prompt = gr.Textbox(
                    label="Describe the animation you want",
                    placeholder="E.g., The character dances gracefully with flowing movements...",
                    lines=4
                )
                
                with gr.Group(elem_classes="quick-prompts"):
                    gr.Markdown("### 💡 Quick Prompts")
                    example_quick_prompts = gr.Dataset(
                        samples=quick_prompts,
                        label='Click to use',
                        samples_per_page=3,
                        components=[prompt]
                    )
            
        # Right Column - Output
        with gr.Column(elem_classes="output-panel"):
            preview_image = gr.Image(
                label="Generation Preview",
                height=200,
                visible=False
            )
            result_video = gr.Video(
                label="Generated Animation",
                autoplay=True,
                show_share_button=True,
                height=400,
                loop=True
            )
            with gr.Group(elem_classes="progress-container"):
                progress_desc = gr.Markdown(
                    elem_classes='no-generating-animation'
                )
                progress_bar = gr.HTML(
                    elem_classes='no-generating-animation'
                )

    # Setup callbacks
    ips = [
        input_image, prompt, t2v, n_prompt, seed,
        total_second_length, latent_window_size,
        steps, cfg, gs, rs, gpu_memory_preservation,
        use_teacache, mp4_crf,  # Use the hidden component here
        quality_radio, aspect_ratio
    ]
    
    start_button.click(
        fn=process,
        inputs=ips,
        outputs=[
            result_video, preview_image,
            progress_desc, progress_bar,
            start_button, stop_button
        ]
    )
    
    stop_button.click(fn=end_process)
    
    example_quick_prompts.click(
        fn=lambda x: x[0],
        inputs=[example_quick_prompts],
        outputs=prompt,
        show_progress=False,
        queue=False
    )

    # Show/hide aspect ratio dropdown based on t2v checkbox
    def show_aspect_ratio(t2v_checked):
        return gr.update(visible=bool(t2v_checked))
    t2v.change(
        fn=show_aspect_ratio,
        inputs=[t2v],
        outputs=[aspect_ratio],
        queue=False
    )

block.launch(share=True)