File size: 3,520 Bytes
9df2e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from tempfile import NamedTemporaryFile
from typing import Tuple
from zipfile import ZipFile

import gradio as gr
from accelerate import Accelerator
from huggingface_hub import hf_hub_download

from odcnn import ODCNN
from youtube import youtube

accelerator = Accelerator()
device = accelerator.device

DON_MODEL = hf_hub_download(
    repo_id="JacobLinCool/odcnn-320k-100", filename="don_model.pth"
)
KA_MODEL = hf_hub_download(
    repo_id="JacobLinCool/odcnn-320k-100", filename="ka_model.pth"
)


models = {"odcnn-320k-100": ODCNN(DON_MODEL, KA_MODEL, device)}


def run(file: str, model: str, delta: float, trim: bool) -> Tuple[str, str, str]:
    preview, tja = models[model].run(file, delta, trim)

    with NamedTemporaryFile(
        "w", suffix=".tja", delete=True
    ) as tjafile, NamedTemporaryFile("w", suffix=".zip", delete=False) as zfile:
        tjafile.write(tja)

        with ZipFile(zfile.name, "w") as z:
            z.write(file, os.path.basename(file))
            z.write(tjafile.name, f"{os.path.basename(file)}-{model}.tja")

    return preview, tja, zfile.name


def from_youtube(
    url: str, model: str, delta: float, trim: bool
) -> Tuple[str, str, str, str]:
    audio = youtube(url)
    return audio, *run(audio, model, delta, trim)


with gr.Blocks() as app:
    with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
        README = f.read()
        # remove yaml front matter
        blocks = README.split("---")
        if len(blocks) > 1:
            README = "---".join(blocks[2:])

    gr.Markdown(README)

    with gr.Row():
        with gr.Column():
            gr.Markdown("## Upload an audio file")
            audio = gr.Audio(label="Upload an audio file", type="filepath")
        with gr.Column():
            gr.Markdown(
                "## or use a YouTube URL\n\nTry something on [The First Take](https://www.youtube.com/@The_FirstTake)?"
            )
            yt = gr.Textbox(
                label="YouTube URL", placeholder="https://www.youtube.com/watch?v=..."
            )
            yt_btn = gr.Button("Use this YouTube URL")

    with gr.Row():
        model = gr.Radio(
            label="Select a model",
            choices=[s for s in models.keys()],
            value="odcnn-320k-100",
        )
        btn = gr.Button("Infer", variant="primary")

    with gr.Row():
        with gr.Column():
            synthesized = gr.Audio(
                label="Synthesized Audio",
                format="mp3",
                type="filepath",
                interactive=False,
            )
        with gr.Column():
            tja = gr.Text(label="TJA", interactive=False)

    with gr.Row():
        zip = gr.File(label="Download ZIP", type="filepath")

    with gr.Accordion("Advanced Options", open=False):
        delta = gr.Slider(
            label="Delta",
            value=0.02,
            minimum=0.01,
            maximum=0.5,
            step=0.01,
            info="Threshold for note detection (Ura)",
        )
        trim = gr.Checkbox(
            label="Trim silence",
            value=True,
            info="Trim silence from the start and end of the audio",
        )

    btn.click(
        fn=run,
        inputs=[audio, model, delta, trim],
        outputs=[synthesized, tja, zip],
        api_name="run",
    )

    yt_btn.click(
        fn=from_youtube,
        inputs=[yt, model, delta, trim],
        outputs=[audio, synthesized, tja, zip],
    )

app.queue().launch(server_name="0.0.0.0")