import os
import sys
import time

import torch

# Ensure project roots are on sys.path
current_file_path = os.path.abspath(__file__)
project_roots = [
    os.path.dirname(current_file_path),
    os.path.dirname(os.path.dirname(current_file_path)),
    os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))),
]
for project_root in project_roots:
    if project_root not in sys.path:
        sys.path.insert(0, project_root)

from cogvideox.api.api import (
    infer_forward_api,
    update_diffusion_transformer_api,
    update_edition_api
)
from cogvideox.ui.controller import flow_scheduler_dict
from cogvideox.ui.wan_fun_ui import ui, ui_eas, ui_modelscope

if __name__ == "__main__":
    # --- Configuration ---

    # Choose the UI mode: one of "eas", "modelscope", or default
    ui_mode = "eas"

    # GPU memory mode: choices are
    #   - "model_cpu_offload"
    #   - "model_cpu_offload_and_qfloat8"
    #   - "sequential_cpu_offload"
    GPU_memory_mode = "model_cpu_offload"

    # Weight dtype: use bfloat16 if supported, otherwise float16
    weight_dtype = (
        torch.bfloat16
        if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
        else torch.float16
    )

    # Path to your OmegaConf config for WAN2.1
    config_path = "config/wan2.1/wan_civitai.yaml"

    # Server binding for Gradio
    server_name = "0.0.0.0"
    server_port = 7860

    # Parameters for modelscope mode
    model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
    model_type = "Inpaint"   # or "Control"
    savedir_sample = "samples"

    # --- Initialize UI & Controller ---

    if ui_mode == "modelscope":
        demo, controller = ui_modelscope(
            model_name,
            model_type,
            savedir_sample,
            GPU_memory_mode,
            flow_scheduler_dict,
            weight_dtype,
            config_path
        )
    elif ui_mode == "eas":
        demo, controller = ui_eas(
            model_name,
            flow_scheduler_dict,
            savedir_sample,
            config_path
        )
    else:
        demo, controller = ui(
            GPU_memory_mode,
            flow_scheduler_dict,
            weight_dtype,
            config_path
        )

    # --- Launch Gradio app ---

    # share=False for local/Colab use; ssr=False disables experimental SSR to avoid 405 errors
    app, _, _ = demo.queue(status_update_rate=1).launch(
        share=False,
        server_name=server_name,
        server_port=server_port,
        prevent_thread_lock=True
    )

    # --- Mount API endpoints ---

    infer_forward_api(None, app, controller)
    update_diffusion_transformer_api(None, app, controller)
    update_edition_api(None, app, controller)

    # Keep the script alive
    while True:
        time.sleep(5)