File size: 17,572 Bytes
1dc20d9
4c10907
7d35d1e
a7334d4
 
4515b1f
 
6394bbc
bc3ffb2
1e7fc7e
 
deaa9a6
1e7fc7e
bc3ffb2
 
 
1e7fc7e
bc3ffb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e7fc7e
 
 
 
 
 
 
 
 
bc3ffb2
 
 
 
1e7fc7e
bc3ffb2
 
 
 
 
1dc20d9
e26fb40
4515b1f
 
21e5441
 
4515b1f
21e5441
4515b1f
e26fb40
 
 
072d1f9
e26fb40
 
21e5441
e26fb40
7d35d1e
 
 
 
 
eabc43b
6394bbc
deaa9a6
6394bbc
eabc43b
bc3ffb2
7d35d1e
d2cd103
eabc43b
bc3ffb2
 
 
 
 
 
 
a7334d4
 
df6c3f5
a7334d4
 
 
 
7d35d1e
a7334d4
 
 
 
 
65e9daa
a7334d4
 
 
 
 
 
7d35d1e
a7334d4
 
 
 
 
df6c3f5
a7334d4
 
 
 
 
 
 
 
 
 
bc3ffb2
a7334d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc3ffb2
 
a7334d4
 
 
 
 
bc3ffb2
a7334d4
 
 
bc3ffb2
 
 
 
 
 
 
 
 
 
a7334d4
 
bc3ffb2
 
 
a7334d4
bc3ffb2
a7334d4
8e872fa
 
bc3ffb2
8e872fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7334d4
 
30f9d01
a1ddd2f
 
 
 
bc3ffb2
 
 
 
a1ddd2f
30f9d01
7d35d1e
 
1e7fc7e
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
8e872fa
 
 
bc3ffb2
 
 
 
a7334d4
 
 
 
 
 
1e7fc7e
 
8e872fa
 
 
 
 
 
 
 
 
a7334d4
8e872fa
 
 
 
 
 
 
 
 
a7334d4
7d35d1e
 
 
bc3ffb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7334d4
 
 
8e872fa
a1ddd2f
a7334d4
8e872fa
7d35d1e
bc3ffb2
 
 
7d35d1e
a1ddd2f
bc3ffb2
3db0011
bc3ffb2
 
 
7d35d1e
 
 
 
 
 
bc3ffb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
 
bc3ffb2
7d35d1e
 
 
a7334d4
8e872fa
a7334d4
bc3ffb2
8e872fa
bc3ffb2
 
8e872fa
bc3ffb2
 
8e872fa
 
 
bc3ffb2
8e872fa
bc3ffb2
 
8e872fa
a7334d4
 
bc3ffb2
 
 
a7334d4
8e872fa
 
 
 
bc3ffb2
8e872fa
7d35d1e
a7334d4
 
 
 
 
 
 
7d35d1e
8e872fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import spaces
import gradio as gr
import os
import json
import tempfile
import requests
import subprocess
from pathlib import Path
import torchaudio
import torch
import pyloudnorm as pyln

from model import Jamify, normalize_audio
from utils import json_to_text, text_to_json, convert_text_time_to_beats, convert_text_beats_to_time, convert_text_beats_to_time_with_regrouping, text_to_words, beats_to_text_with_regrouping, round_to_quarter_beats

def crop_audio_to_30_seconds(audio_path):
    """Crop audio to first 30 seconds, normalize, and return path to temporary cropped file"""
    if not audio_path or not os.path.exists(audio_path):
        return None
    
    try:
        # Load audio
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Calculate 30 seconds in samples
        target_samples = sample_rate * 30
        
        # Crop to first 30 seconds (or full audio if shorter)
        if waveform.shape[1] > target_samples:
            cropped_waveform = waveform[:, :target_samples]
        else:
            cropped_waveform = waveform
        
        # Resample to 44100 Hz if needed (to match prediction pipeline)
        if sample_rate != 44100:
            resampler = torchaudio.transforms.Resample(sample_rate, 44100)
            cropped_waveform = resampler(cropped_waveform)
            sample_rate = 44100
        
        # Apply the same normalization as the prediction pipeline
        normalized_waveform = normalize_audio(cropped_waveform)
        
        # Save to temporary file
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
            temp_path = temp_file.name
        
        torchaudio.save(temp_path, normalized_waveform, sample_rate)
        return temp_path
        
    except Exception as e:
        print(f"Error processing audio: {e}")
        return None

def download_resources():
    """Download examples data from GitHub repository if not already present"""
    examples_dir = Path("examples")
    if examples_dir.exists():
        subprocess.run(["rm", "-rf", str(examples_dir)])
    repo_url = "https://github.com/xhhhhang/jam-examples.git"
    subprocess.run(["git", "clone", repo_url, str(examples_dir)], check=True)

    public_dir = Path("public")
    if public_dir.exists():
        subprocess.run(["rm", "-rf", str(public_dir)])
    repo_url = "https://github.com/xhhhhang/jam-public-resources.git"
    subprocess.run(["git", "clone", repo_url, str(public_dir)], check=True)

print('Downloading examples data...')
download_resources()
# Initialize the Jamify model once
print("Initializing Jamify model...")
jamify_model = Jamify()
print("Jamify model ready.")



gr.set_static_paths(paths=[Path.cwd().absolute()])

@spaces.GPU(duration=100)
def generate_song(reference_audio, lyrics_text, duration, mode="time", bpm=120, style_prompt=None):
    # We need to save the uploaded files to temporary paths to pass to the model
    reference_audio = reference_audio not in ("", None) and reference_audio or None
    
    # Convert beats to time format if in beats mode
    if mode == "beats" and lyrics_text:
        try:
            lyrics_text = convert_text_beats_to_time(lyrics_text, bpm)
        except Exception as e:
            print(f"Error converting beats to time: {e}")
    
    # Convert text format to JSON and save to temporary file
    lyrics_json = text_to_json(lyrics_text)
    
    # Create temporary file for lyrics JSON
    with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
        json.dump(lyrics_json, f, indent=2)
        lyrics_file = f.name

    try:
        output_path = jamify_model.predict(
            reference_audio_path=reference_audio,
            lyrics_json_path=lyrics_file,
            style_prompt=style_prompt,
            duration=duration
        )
        return output_path
    finally:
        # Clean up temporary file
        if os.path.exists(lyrics_file):
            os.unlink(lyrics_file)

# Load and cache examples
def load_examples():
    """Load examples from the examples directory and pre-compute text formats"""
    examples = []
    examples_file = "examples/input.json"
    
    if os.path.exists(examples_file):
        print("Loading and caching examples...")
        with open(examples_file, 'r') as f:
            examples_data = json.load(f)
            
        for example in examples_data:
            example_id = example.get('id', '')
            audio_path = example.get('audio_path', '')
            lrc_path = example.get('lrc_path', '')
            duration = example.get('duration', 120)
            bpm = example.get('bpm', 120.0)  # Read BPM from input.json, default to 120
            
            # Load lyrics and convert to text format (pre-computed/cached)
            lyrics_text = ""
            if os.path.exists(lrc_path):
                try:
                    with open(lrc_path, 'r') as f:
                        lyrics_json = json.load(f)
                    lyrics_text = json_to_text(lyrics_json)
                    print(f"Cached example {example_id}: {len(lyrics_text)} chars")
                except Exception as e:
                    print(f"Error loading lyrics from {lrc_path}: {e}")
            
            examples.append({
                'id': example_id,
                'audio_path': audio_path if os.path.exists(audio_path) else None,
                'lyrics_text': lyrics_text,
                'duration': duration,
                'bpm': bpm
            })
    
    print(f"Loaded {len(examples)} cached examples")
    return examples

def load_example(example_idx, examples, mode="time"):
    """Load a specific example and return its data"""
    if 0 <= example_idx < len(examples):
        example = examples[example_idx]
        lyrics_text = example['lyrics_text']
        bpm = example.get('bpm', 120.0)
        
        # Convert to beats format if in beats mode
        if mode == "beats" and lyrics_text:
            try:
                lyrics_text = beats_to_text_with_regrouping(lyrics_text, bpm, round_to_quarters=True)
            except Exception as e:
                print(f"Error converting to beats format: {e}")
        
        return (
            example['audio_path'],
            lyrics_text,
            example['duration'],
            bpm
        )
    return None, "", 120, 120.0

def clear_form():
    """Clear all form inputs to allow user to create their own song"""
    return None, "", 120, 120.0  # audio, lyrics, duration, bpm

def update_button_styles(selected_idx, total_examples):
    """Update button styles to highlight the selected example"""
    updates = []
    for i in range(total_examples):
        if i == selected_idx:
            updates.append(gr.update(variant="primary"))
        else:
            updates.append(gr.update(variant="secondary"))
    
    # Update "Make Your Own" button
    if selected_idx == -1:
        make_your_own_update = gr.update(variant="primary")
    else:
        make_your_own_update = gr.update(variant="secondary")
    
    return updates + [make_your_own_update]

# Load examples at startup
examples = load_examples()

# Get default values from first example
default_audio = examples[0]['audio_path'] if examples else None
default_lyrics = examples[0]['lyrics_text'] if examples else ""
default_duration = examples[0]['duration'] if examples else 120
default_bpm = examples[0]['bpm'] if examples else 120.0

# Create cropped version of default audio for display
default_audio_display = crop_audio_to_30_seconds(default_audio) if default_audio else None

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Jamify: Music Generation from Lyrics and Style")
    gr.Markdown("Provide your lyrics, an audio style reference, and a desired duration to generate a song.")
    
    # Helpful reminder for users
    gr.Markdown("""
    💡 **Demo Tip**: Don't start from scratch! Use the sample examples below as templates:
    - Click any sample to load its lyrics and audio style
    - **Edit the lyrics**: Change words, modify timing, or adjust the structure
    - **Experiment with timing**: Try different word durations or beats
    - **Mix and match**: Use lyrics from one example with audio style from another
    
    This approach is much easier than creating everything from zero!
    """)
    
    # State to track selected example (-1 means "Make Your Own" is selected, 0 is first example)
    selected_example = gr.State(0 if examples else -1)
    
    # States for mode and BPM
    input_mode = gr.State("time")
    current_bpm = gr.State(default_bpm)
    
    # Sample buttons section
    if examples:
        gr.Markdown("### Sample Examples")
        with gr.Row():
            example_buttons = []
            for i, example in enumerate(examples):
                # Use consistent button width with 10 character limit
                button_text = example['id'][:10] if len(example['id']) <= 10 else example['id'][:9] + "…"
                # First button starts as primary (selected), others as secondary
                initial_variant = "primary" if i == 0 else "secondary"
                button = gr.Button(
                    button_text, 
                    variant=initial_variant, 
                    size="sm",
                    scale=1,  # Equal width for all buttons
                    min_width=80  # Minimum consistent width
                )
                example_buttons.append(button)
            
            # Add "Make Your Own" button with same sizing (starts as secondary since first example is selected)
            make_your_own_button = gr.Button(
                "🎵 Make Your Own", 
                variant="secondary", 
                size="sm",
                scale=1,
                min_width=80
            )
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Inputs")
            
            # Mode switcher
            mode_radio = gr.Radio(
                choices=["Time Mode", "Beats Mode"],
                value="Time Mode",
                label="Input Format",
                info="Choose how to specify timing: seconds or musical beats"
            )
            
            # BPM input (initially hidden)
            bpm_input = gr.Number(
                label="BPM (Beats Per Minute)",
                value=default_bpm,
                minimum=60,
                maximum=200,
                step=1,
                visible=False,
                info="Tempo for converting beats to time"
            )
            
            lyrics_text = gr.Textbox(
                label="Lyrics", 
                lines=10, 
                placeholder="Enter lyrics with timestamps: word[start_time:end_time] word[start_time:end_time]...\n\nExample: Hello[0.0:1.2] world[1.5:2.8] this[3.0:3.8] is[4.2:4.6] my[5.0:5.8] song[6.2:7.0]\n\nFormat: Each word followed by [start_seconds:end_seconds] in brackets\nTimestamps should be in seconds with up to 2 decimal places",
                value=default_lyrics
            )
            duration_slider = gr.Slider(minimum=120, maximum=230, value=default_duration, step=1, label="Duration (seconds)")
            
        with gr.Column():
            gr.Markdown("### Style & Generation")
            
            with gr.Tab("Style from Audio"):
                reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath", value=default_audio)
                reference_audio_display = gr.Audio(
                    label="Reference Audio (Only first 30 seconds will be used for generation)", 
                    value=default_audio_display,
                    visible=default_audio_display is not None
                )
            
            generate_button = gr.Button("Generate Song", variant="primary")
            
            gr.Markdown("### Output")
            output_audio = gr.Audio(label="Generated Song")

    # Mode switching functions
    def switch_mode(mode_choice, current_lyrics, current_bpm_val):
        """Handle switching between time and beats mode"""
        mode = "beats" if mode_choice == "Beats Mode" else "time"
        
        # Update BPM input visibility
        bpm_visible = (mode == "beats")
        
        # Update lyrics placeholder and convert existing text
        if mode == "time":
            placeholder = "Enter lyrics with timestamps: word[start_time:end_time] word[start_time:end_time]...\n\nExample: Hello[0.0:1.2] world[1.5:2.8] this[3.0:3.8] is[4.2:4.6] my[5.0:5.8] song[6.2:7.0]\n\nFormat: Each word followed by [start_seconds:end_seconds] in brackets\nTimestamps should be in seconds with up to 2 decimal places"
            label = "Lyrics"
            
            # Convert from beats to time if there's content
            converted_lyrics = current_lyrics
            if current_lyrics.strip():
                try:
                    converted_lyrics = convert_text_beats_to_time_with_regrouping(current_lyrics, current_bpm_val)
                except Exception as e:
                    print(f"Error converting beats to time: {e}")
        else:
            placeholder = "Enter lyrics with beat timestamps: word[start_beat:end_beat] word[start_beat:end_beat]...\n\nExample: Hello[0:1] world[1.5:2.75] this[3:3.75] is[4.25:4.5] my[5:5.75] song[6.25:7]\n\nFormat: Each word followed by [start_beat:end_beat] in brackets\nBeats are in quarter notes (1 beat = quarter note, 0.25 = sixteenth note)"
            label = "Lyrics (Beats Format)"
            
            # Convert from time to beats if there's content
            converted_lyrics = current_lyrics
            if current_lyrics.strip():
                try:
                    converted_lyrics = beats_to_text_with_regrouping(current_lyrics, current_bpm_val, round_to_quarters=True)
                except Exception as e:
                    print(f"Error converting time to beats: {e}")
        
        return (
            gr.update(visible=bpm_visible),  # bpm_input visibility
            gr.update(placeholder=placeholder, label=label, value=converted_lyrics),  # lyrics_text
            mode  # input_mode state
        )
    
    def update_bpm_state(bpm_val):
        """Update the BPM state"""
        return bpm_val
    
    def update_reference_audio_display(audio_file):
        """Process and display the cropped reference audio"""
        if audio_file is None:
            return gr.update(visible=False, value=None)
        
        cropped_path = crop_audio_to_30_seconds(audio_file)
        if cropped_path:
            return gr.update(visible=True, value=cropped_path)
        else:
            return gr.update(visible=False, value=None)
    
    # Connect mode switching
    mode_radio.change(
        fn=switch_mode,
        inputs=[mode_radio, lyrics_text, current_bpm],
        outputs=[bpm_input, lyrics_text, input_mode]
    )
    
    # Connect BPM changes
    bpm_input.change(
        fn=update_bpm_state,
        inputs=[bpm_input],
        outputs=[current_bpm]
    )
    
    # Connect reference audio file changes to display
    reference_audio.change(
        fn=update_reference_audio_display,
        inputs=[reference_audio],
        outputs=[reference_audio_display]
    )

    generate_button.click(
        fn=generate_song,
        inputs=[reference_audio, lyrics_text, duration_slider, input_mode, current_bpm],
        outputs=output_audio,
        api_name="generate_song"
    )
    
    # Connect example buttons to load data and update selection
    if examples:
        def load_example_and_update_selection(idx, current_mode):
            """Load example data and update button selection state"""
            mode = "beats" if current_mode == "Beats Mode" else "time"
            audio, lyrics, duration, bpm = load_example(idx, examples, mode)
            button_updates = update_button_styles(idx, len(examples))
            audio_display_update = update_reference_audio_display(audio)
            return [audio, lyrics, duration, bpm, idx, audio_display_update] + button_updates
        
        def clear_form_and_update_selection():
            """Clear form and update button selection state"""
            audio, lyrics, duration, bpm = clear_form()
            button_updates = update_button_styles(-1, len(examples))
            audio_display_update = update_reference_audio_display(audio)
            return [audio, lyrics, duration, bpm, -1, audio_display_update] + button_updates
        
        for i, button in enumerate(example_buttons):
            button.click(
                fn=lambda current_mode, idx=i: load_example_and_update_selection(idx, current_mode),
                inputs=[mode_radio],
                outputs=[reference_audio, lyrics_text, duration_slider, current_bpm, selected_example, reference_audio_display] + example_buttons + [make_your_own_button]
            )
        
        # Connect "Make Your Own" button to clear form and update selection
        make_your_own_button.click(
            fn=clear_form_and_update_selection,
            outputs=[reference_audio, lyrics_text, duration_slider, current_bpm, selected_example, reference_audio_display] + example_buttons + [make_your_own_button]
        )

# Create necessary temporary directories for Gradio
print("Creating temporary directories...")
try:
    os.makedirs("/tmp/gradio", exist_ok=True)
    print("Temporary directories created successfully.")
except Exception as e:
    print(f"Warning: Could not create temporary directories: {e}")

demo.queue().launch(share=True)