import os
from tempfile import NamedTemporaryFile
from typing import Tuple
from zipfile import ZipFile

import gradio as gr
from accelerate import Accelerator
from huggingface_hub import hf_hub_download

from odcnn import ODCNN
from youtube import youtube

accelerator = Accelerator()
device = accelerator.device

DON_MODEL = hf_hub_download(
    repo_id="JacobLinCool/odcnn-320k-100", filename="don_model.pth"
)
KA_MODEL = hf_hub_download(
    repo_id="JacobLinCool/odcnn-320k-100", filename="ka_model.pth"
)


models = {"odcnn-320k-100": ODCNN(DON_MODEL, KA_MODEL, device)}


def run(file: str, model: str, delta: float, trim: bool) -> Tuple[str, str, str]:
    preview, tja = models[model].run(file, delta, trim)

    with NamedTemporaryFile(
        "w", suffix=".tja", delete=True
    ) as tjafile, NamedTemporaryFile("w", suffix=".zip", delete=False) as zfile:
        tjafile.write(tja)

        with ZipFile(zfile.name, "w") as z:
            z.write(file, os.path.basename(file))
            z.write(tjafile.name, f"{os.path.basename(file)}-{model}.tja")

    return preview, tja, zfile.name


def from_youtube(
    url: str, model: str, delta: float, trim: bool
) -> Tuple[str, str, str, str]:
    audio = youtube(url)
    return audio, *run(audio, model, delta, trim)


with gr.Blocks() as app:
    with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
        README = f.read()
        # remove yaml front matter
        blocks = README.split("---")
        if len(blocks) > 1:
            README = "---".join(blocks[2:])

    gr.Markdown(README)

    with gr.Row():
        with gr.Column():
            gr.Markdown("## Upload an audio file")
            audio = gr.Audio(label="Upload an audio file", type="filepath")
        with gr.Column():
            gr.Markdown(
                "## or use a YouTube URL\n\nTry something on [The First Take](https://www.youtube.com/@The_FirstTake)?"
            )
            yt = gr.Textbox(
                label="YouTube URL", placeholder="https://www.youtube.com/watch?v=..."
            )
            yt_btn = gr.Button("Use this YouTube URL")

    with gr.Row():
        model = gr.Radio(
            label="Select a model",
            choices=[s for s in models.keys()],
            value="odcnn-320k-100",
        )
        btn = gr.Button("Infer", variant="primary")

    with gr.Row():
        with gr.Column():
            synthesized = gr.Audio(
                label="Synthesized Audio",
                format="mp3",
                type="filepath",
                interactive=False,
            )
        with gr.Column():
            tja = gr.Text(label="TJA", interactive=False)

    with gr.Row():
        zip = gr.File(label="Download ZIP", type="filepath")

    with gr.Accordion("Advanced Options", open=False):
        delta = gr.Slider(
            label="Delta",
            value=0.02,
            minimum=0.01,
            maximum=0.5,
            step=0.01,
            info="Threshold for note detection (Ura)",
        )
        trim = gr.Checkbox(
            label="Trim silence",
            value=True,
            info="Trim silence from the start and end of the audio",
        )

    btn.click(
        fn=run,
        inputs=[audio, model, delta, trim],
        outputs=[synthesized, tja, zip],
        api_name="run",
    )

    yt_btn.click(
        fn=from_youtube,
        inputs=[yt, model, delta, trim],
        outputs=[audio, synthesized, tja, zip],
    )

app.queue().launch(server_name="0.0.0.0")