File size: 3,508 Bytes
1dc20d9
4c10907
7d35d1e
6394bbc
85aaa31
7d35d1e
85aaa31
7d35d1e
 
 
 
 
eabc43b
6394bbc
85aaa31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6394bbc
eabc43b
7d35d1e
 
eabc43b
 
7d35d1e
 
 
85aaa31
 
 
7d35d1e
4474af1
85aaa31
7d35d1e
 
85aaa31
7d35d1e
df6c3f5
7d35d1e
30f9d01
 
85aaa31
 
 
 
 
 
 
 
 
 
7d35d1e
 
 
 
 
 
 
85aaa31
 
 
 
 
 
7d35d1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85aaa31
7d35d1e
 
 
85aaa31
7d35d1e
85aaa31
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import spaces
import gradio as gr
import os
from pathlib import Path
import ast
from model import Jamify
import json
# Initialize the Jamify model once
print("Initializing Jamify model...")
jamify_model = Jamify()
print("Jamify model ready.")




def parse(file):
    parsed_data = []
    with open(file, 'r') as f:
        content = f.read()
        for line in f:
            line = line.strip()
            if not line:
                continue
            start,end,word = line.split(' ')
            data_point = {
                    "word": word,
                    "start": float(start),
                    "end": float(end)
                }
            parsed_data.append(data_point)
        
    with open('temp.json','w') as f:
        json.dump(parsed_data,f)

    return content

@spaces.GPU(duration=100)
def generate_song(reference_audio, lyrics_file, style_prompt, duration):
    # We need to save the uploaded files to temporary paths to pass to the model
    reference_audio = reference_audio if reference_audio else None
    

    # The model expects paths, so we write the prompt to a temp file if needed
    # (This part of the model could be improved to accept the string directly)
    #print(type(lyrics_file))
    #parse(lyrics_file)
    #parse(lyrics_file)
    output_path = jamify_model.predict(
        reference_audio_path=reference_audio,
        lyrics_json_path='temp.json',
        style_prompt=style_prompt,
        duration_sec=duration

    )
    
    return output_path

# Gradio interface
def process_text_file(file_obj):
    """
    Reads the content of an uploaded text file.
    """
    if file_obj is not None:
        with open(file_obj.name, 'r') as f:
            content = f.read()
        return content
    return "No file uploaded."

with gr.Blocks() as demo:
    gr.Markdown("# Jamify: Music Generation from Lyrics and Style")
    gr.Markdown("Provide your lyrics, a style reference (either an audio file or a text prompt), and a desired duration to generate a song.")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Inputs")
            #lyrics_file = gr.JSON(label="Lyrics File (.json)")
            #file = gr.File(label="Upload JSON File", file_types=[".json"])
            file_input = gr.File(file_types=["text"], file_count="single", label="Upload your text file here")
            output_text = gr.Textbox(label="File Content", lines=10)
            
            file_input.upload(parse, inputs=file_input, outputs=output_text)
            duration_slider = gr.Slider(minimum=5, maximum=180, value=30, step=1, label="Duration (seconds)")
            
            with gr.Tab("Style from Audio"):
                reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath")
            with gr.Tab("Style from Text"):
                style_prompt = gr.Textbox(label="Style Prompt", lines=3, placeholder="e.g., A high-energy electronic dance track with a strong bassline and euphoric synths.")
            
            generate_button = gr.Button("Generate Song", variant="primary")
            
        with gr.Column():
            gr.Markdown("### Output")
            output_audio = gr.Audio(label="Generated Song")

    generate_button.click(
        fn=generate_song,
        inputs=[reference_audio, output_text, style_prompt, duration_slider],
        outputs=output_audio,
        api_name="generate_song"
    )
 

demo.queue().launch()
#demo.launch(server_name="0.0.0.0", server_port=7860)