import os
import random
import shutil
import subprocess
from typing import List

import gradio as gr
import numpy as np
import spaces
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation

from inference_ig2mv_sdxl import (
    prepare_pipeline,
    preprocess_image,
    remove_bg,
    run_pipeline,
)
from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image

# install others
subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
MAX_SEED = np.iinfo(np.int32).max
NUM_VIEWS = 6
HEIGHT = 768
WIDTH = 768

TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_DIR, exist_ok=True)


HEADER = """
# 🔮 Image to Texture with [MV-Adapter](https://github.com/huanngzh/MV-Adapter)
## State-of-the-art Open Source Texture Generation Using Multi-View Diffusion Model
<p style="font-size: 1.1em;">By <a href="https://www.tripo3d.ai/" style="color: #1E90FF; text-decoration: none; font-weight: bold;">Tripo</a></p>
"""

EXAMPLES = [
    ["examples/001.jpeg", "examples/001.glb"],
    ["examples/002.jpeg", "examples/002.glb"],
]

# MV-Adapter
pipe = prepare_pipeline(
    base_model="stabilityai/stable-diffusion-xl-base-1.0",
    vae_model="madebyollin/sdxl-vae-fp16-fix",
    unet_model=None,
    lora_model=None,
    adapter_path="huanngzh/mv-adapter",
    scheduler=None,
    num_views=NUM_VIEWS,
    device=DEVICE,
    dtype=DTYPE,
)
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(DEVICE)
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)

if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
    hf_hub_download(
        "dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints"
    )
if not os.path.exists("checkpoints/big-lama.pt"):
    subprocess.run(
        "wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
        shell=True,
        check=True,
    )


device = "cuda" if torch.cuda.is_available() else "cpu"


def start_session(req: gr.Request):
    save_dir = os.path.join(TMP_DIR, str(req.session_hash))
    os.makedirs(save_dir, exist_ok=True)
    print("start session, mkdir", save_dir)


def end_session(req: gr.Request):
    save_dir = os.path.join(TMP_DIR, str(req.session_hash))
    shutil.rmtree(save_dir)


def get_random_hex():
    random_bytes = os.urandom(8)
    random_hex = random_bytes.hex()
    return random_hex


def get_random_seed(randomize_seed, seed):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed


@spaces.GPU(duration=90)
@torch.no_grad()
def run_mvadapter(
    mesh_path,
    prompt,
    image,
    seed=42,
    guidance_scale=3.0,
    num_inference_steps=30,
    reference_conditioning_scale=1.0,
    negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
    progress=gr.Progress(track_tqdm=True),
):
    # pre-process the reference image
    image = Image.open(image).convert("RGB") if isinstance(image, str) else image
    image = remove_bg_fn(image)
    image = preprocess_image(image, HEIGHT, WIDTH)

    if isinstance(seed, str):
        try:
            seed = int(seed.strip())
        except ValueError:
            seed = 42

    images, _, _, _ = run_pipeline(
        pipe,
        mesh_path=mesh_path,
        num_views=NUM_VIEWS,
        text=prompt,
        image=image,
        height=HEIGHT,
        width=WIDTH,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        seed=seed,
        remove_bg_fn=None,
        reference_conditioning_scale=reference_conditioning_scale,
        negative_prompt=negative_prompt,
        device=DEVICE,
    )

    torch.cuda.empty_cache()

    return images, image


@spaces.GPU(duration=90)
@torch.no_grad()
def run_texturing(
    mesh_path: str,
    mv_images: List[Image.Image],
    uv_unwarp: bool,
    preprocess_mesh: bool,
    uv_size: int,
    req: gr.Request,
):
    save_dir = os.path.join(TMP_DIR, str(req.session_hash))
    mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
    mv_images = [item[0] for item in mv_images]
    make_image_grid(mv_images, rows=1).save(mv_image_path)

    from texture import ModProcessConfig, TexturePipeline

    texture_pipe = TexturePipeline(
        upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
        inpaint_ckpt_path="checkpoints/big-lama.pt",
        device=DEVICE,
    )

    textured_glb_path = texture_pipe(
        mesh_path=mesh_path,
        save_dir=save_dir,
        save_name=f"texture_mesh_{get_random_hex()}",
        uv_unwarp=uv_unwarp,
        preprocess_mesh=preprocess_mesh,
        uv_size=uv_size,
        rgb_path=mv_image_path,
        rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
        camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
    ).shaded_model_save_path

    torch.cuda.empty_cache()

    return textured_glb_path, textured_glb_path


with gr.Blocks(title="MVAdapter") as demo:
    gr.Markdown(HEADER)

    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_mesh = gr.Model3D(label="Input 3D mesh")
                image_prompt = gr.Image(label="Input Image", type="pil")

            with gr.Accordion("Generation Settings", open=False):
                prompt = gr.Textbox(
                    label="Prompt (Optional)",
                    placeholder="Enter your prompt",
                    value="high quality",
                )
                seed = gr.Slider(
                    label="Seed", minimum=0, maximum=MAX_SEED, step=0, value=0
                )
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=8,
                    maximum=50,
                    step=1,
                    value=25,
                )
                guidance_scale = gr.Slider(
                    label="CFG scale",
                    minimum=0.0,
                    maximum=20.0,
                    step=0.1,
                    value=3.0,
                )
                reference_conditioning_scale = gr.Slider(
                    label="Image conditioning scale",
                    minimum=0.0,
                    maximum=2.0,
                    step=0.1,
                    value=1.0,
                )

            with gr.Accordion("Texture Settings", open=False):
                with gr.Row():
                    uv_unwarp = gr.Checkbox(label="Unwarp UV", value=True)
                    preprocess_mesh = gr.Checkbox(label="Preprocess Mesh", value=False)
                uv_size = gr.Slider(
                    label="UV Size", minimum=1024, maximum=8192, step=512, value=4096
                )

            gen_button = gr.Button("Generate Texture", variant="primary")

            examples = gr.Examples(
                examples=EXAMPLES,
                inputs=[image_prompt, input_mesh],
                outputs=[image_prompt],
            )

        with gr.Column():
            mv_result = gr.Gallery(
                label="Multi-View Results",
                show_label=False,
                columns=[3],
                rows=[2],
                object_fit="contain",
                height="auto",
                type="pil",
            )
            textured_model_output = gr.Model3D(label="Textured GLB", interactive=False)
            download_glb = gr.DownloadButton(label="Download GLB", interactive=False)

    gen_button.click(
        get_random_seed, inputs=[randomize_seed, seed], outputs=[seed]
    ).then(
        run_mvadapter,
        inputs=[
            input_mesh,
            prompt,
            image_prompt,
            seed,
            guidance_scale,
            num_inference_steps,
            reference_conditioning_scale,
        ],
        outputs=[mv_result, image_prompt],
    ).then(
        run_texturing,
        inputs=[input_mesh, mv_result, uv_unwarp, preprocess_mesh, uv_size],
        outputs=[textured_model_output, download_glb],
    ).then(
        lambda: gr.Button(interactive=True), outputs=[download_glb]
    )

    demo.load(start_session)
    demo.unload(end_session)

demo.launch()