import gradio as gr
import re
import subprocess
import time
from tqdm import tqdm
from huggingface_hub import snapshot_download
import torch

# Force the device to CPU
device = torch.device("cpu")

# Download model
snapshot_download(
    repo_id="Wan-AI/Wan2.1-T2V-1.3B",
    local_dir="./Wan2.1-T2V-1.3B"
)
print("Model downloaded successfully.")

def infer(prompt, progress=gr.Progress(track_tqdm=True)):
    # Configuration:
    total_process_steps = 11          # Total INFO messages expected
    irrelevant_steps = 4              # First 4 INFO messages are ignored  
    relevant_steps = total_process_steps - irrelevant_steps  # 7 overall steps

    # Create overall progress bar (Level 1)
    overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
                       ncols=120, dynamic_ncols=False, leave=True)
    processed_steps = 0

    # Regex for video generation progress (Level 3)
    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    video_progress_bar = None

    # Variables for sub-step progress bar (Level 2)
    sub_bar = None
    sub_ticks = 0
    sub_tick_total = 1500
    video_phase = False

    # Command to run the video generation
    command = [
        "python", "-u", "-m", "generate",  # using -u for unbuffered output
        "--task", "t2v-1.3B",
        "--size", "480*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--t5_cpu",
        "--offload_model", "True",  # Change from True (bool) to "True" (str)
        "--save_file", "generated_video.mp4"
    ]
    print("Starting video generation process...")

    process = subprocess.Popen(command,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.STDOUT,
                               text=True,
                               bufsize=1)
    
    # Print logs
   
    stdout = process.stdout
    stderr = process.stderr
    print(stdout)
    while True:
        line = stdout.readline()
        if not line:
            break
        stripped_line = line.strip()
        if not stripped_line:
            continue

        # Check for video generation progress (Level 3)
        progress_match = progress_pattern.search(stripped_line)
        
        if progress_match:
            if sub_bar is not None:
                if sub_ticks < sub_tick_total:
                    sub_bar.update(sub_tick_total - sub_ticks)
                sub_bar.close()
                overall_bar.update(1)
                overall_bar.refresh()
                sub_bar = None
                sub_ticks = 0
            video_phase = True
            current = int(progress_match.group(2))
            total = int(progress_match.group(3))
            if video_progress_bar is None:
                video_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
                                          ncols=120, dynamic_ncols=True, leave=True)
            video_progress_bar.update(current - video_progress_bar.n)
            video_progress_bar.refresh()
            if video_progress_bar.n >= video_progress_bar.total:
                video_phase = False
                overall_bar.update(1)
                overall_bar.refresh()
                video_progress_bar.close()
                video_progress_bar = None
            continue

        # Process INFO messages (Level 2 sub-step)
        if "INFO:" in stripped_line:
            parts = stripped_line.split("INFO:", 1)
            msg = parts[1].strip() if len(parts) > 1 else ""
            print(f"[INFO]: {msg}")  # Log the message

            # For the first 4 INFO messages, simply count them.
            if processed_steps < irrelevant_steps:
                processed_steps += 1
                continue
            else:
                # A new relevant INFO message has arrived.
                if sub_bar is not None:
                    if sub_ticks < sub_tick_total:
                        sub_bar.update(sub_tick_total - sub_ticks)
                    sub_bar.close()
                    overall_bar.update(1)
                    overall_bar.refresh()
                    sub_bar = None
                    sub_ticks = 0
                # Start a new sub-step bar for the current INFO message.
                sub_bar = tqdm(total=sub_tick_total, desc=msg, position=2,
                               ncols=120, dynamic_ncols=False, leave=True)
                sub_ticks = 0
            continue
        else:
            print(stripped_line)

    # Drain any remaining output
    for line in process.stdout:
        print(line.strip())

    process.wait()

    # Finalize progress bars
    if video_progress_bar is not None:
        video_progress_bar.close()
    if sub_bar is not None:
        sub_bar.close()
    overall_bar.close()

    # Add log for successful video generation
    if process.returncode == 0:
        print("Video generation completed successfully.")
        return "generated_video.mp4"
    else:
        print("Error executing command.")
        raise Exception("Error executing command")

# Gradio UI to trigger inference
with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1 1.3B")
        gr.Markdown("Enjoy this simple working UI, duplicate the space to skip the queue :)")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn=infer,
        inputs=[prompt],
        outputs=[video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)