Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import os | |
import json | |
import tempfile | |
import requests | |
import subprocess | |
from pathlib import Path | |
from model import Jamify | |
from utils import json_to_text, text_to_json | |
def download_examples_data(): | |
"""Download examples data from GitHub repository if not already present""" | |
examples_dir = Path("examples") | |
if examples_dir.exists(): | |
subprocess.run(["rm", "-rf", str(examples_dir)]) | |
repo_url = "https://github.com/xhhhhang/jam-examples.git" | |
subprocess.run(["git", "clone", repo_url, str(examples_dir)], check=True) | |
print('Downloading examples data...') | |
download_examples_data() | |
# Initialize the Jamify model once | |
print("Initializing Jamify model...") | |
jamify_model = Jamify() | |
print("Jamify model ready.") | |
gr.set_static_paths(paths=[Path.cwd().absolute()]) | |
def generate_song(reference_audio, lyrics_text, style_prompt, duration): | |
# We need to save the uploaded files to temporary paths to pass to the model | |
reference_audio = reference_audio not in ("", None) and reference_audio or None | |
# Convert text format to JSON and save to temporary file | |
lyrics_json = text_to_json(lyrics_text) | |
# Create temporary file for lyrics JSON | |
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: | |
json.dump(lyrics_json, f, indent=2) | |
lyrics_file = f.name | |
try: | |
output_path = jamify_model.predict( | |
reference_audio_path=reference_audio, | |
lyrics_json_path=lyrics_file, | |
style_prompt=style_prompt, | |
duration=duration | |
) | |
return output_path | |
finally: | |
# Clean up temporary file | |
if os.path.exists(lyrics_file): | |
os.unlink(lyrics_file) | |
# Load and cache examples | |
def load_examples(): | |
"""Load examples from the examples directory and pre-compute text formats""" | |
examples = [] | |
examples_file = "examples/input.json" | |
if os.path.exists(examples_file): | |
print("Loading and caching examples...") | |
with open(examples_file, 'r') as f: | |
examples_data = json.load(f) | |
for example in examples_data: | |
example_id = example.get('id', '') | |
audio_path = example.get('audio_path', '') | |
lrc_path = example.get('lrc_path', '') | |
duration = example.get('duration', 120) | |
# Load lyrics and convert to text format (pre-computed/cached) | |
lyrics_text = "" | |
if os.path.exists(lrc_path): | |
try: | |
with open(lrc_path, 'r') as f: | |
lyrics_json = json.load(f) | |
lyrics_text = json_to_text(lyrics_json) | |
print(f"Cached example {example_id}: {len(lyrics_text)} chars") | |
except Exception as e: | |
print(f"Error loading lyrics from {lrc_path}: {e}") | |
examples.append({ | |
'id': example_id, | |
'audio_path': audio_path if os.path.exists(audio_path) else None, | |
'lyrics_text': lyrics_text, | |
'duration': duration | |
}) | |
print(f"Loaded {len(examples)} cached examples") | |
return examples | |
def load_example(example_idx, examples): | |
"""Load a specific example and return its data""" | |
if 0 <= example_idx < len(examples): | |
example = examples[example_idx] | |
return ( | |
example['audio_path'], | |
example['lyrics_text'], | |
example['duration'] | |
) | |
return None, "", 120 | |
def clear_form(): | |
"""Clear all form inputs to allow user to create their own song""" | |
return None, "", 120 # audio, lyrics, duration | |
def update_button_styles(selected_idx, total_examples): | |
"""Update button styles to highlight the selected example""" | |
updates = [] | |
for i in range(total_examples): | |
if i == selected_idx: | |
updates.append(gr.update(variant="primary")) | |
else: | |
updates.append(gr.update(variant="secondary")) | |
# Update "Make Your Own" button | |
if selected_idx == -1: | |
make_your_own_update = gr.update(variant="primary") | |
else: | |
make_your_own_update = gr.update(variant="secondary") | |
return updates + [make_your_own_update] | |
# Load examples at startup | |
examples = load_examples() | |
# Get default values from first example | |
default_audio = examples[0]['audio_path'] if examples else None | |
default_lyrics = examples[0]['lyrics_text'] if examples else "" | |
default_duration = examples[0]['duration'] if examples else 120 | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Jamify: Music Generation from Lyrics and Style") | |
gr.Markdown("Provide your lyrics, a style reference (either an audio file or a text prompt), and a desired duration to generate a song.") | |
# State to track selected example (-1 means "Make Your Own" is selected, 0 is first example) | |
selected_example = gr.State(0 if examples else -1) | |
# Sample buttons section | |
if examples: | |
gr.Markdown("### Sample Examples") | |
with gr.Row(): | |
example_buttons = [] | |
for i, example in enumerate(examples): | |
# Use consistent button width and truncate long IDs if needed | |
button_text = example['id'][:12] + "..." if len(example['id']) > 15 else example['id'] | |
# First button starts as primary (selected), others as secondary | |
initial_variant = "primary" if i == 0 else "secondary" | |
button = gr.Button( | |
button_text, | |
variant=initial_variant, | |
size="sm", | |
scale=1, # Equal width for all buttons | |
min_width=80 # Minimum consistent width | |
) | |
example_buttons.append(button) | |
# Add "Make Your Own" button with same sizing (starts as secondary since first example is selected) | |
make_your_own_button = gr.Button( | |
"🎵 Make Your Own", | |
variant="secondary", | |
size="sm", | |
scale=1, | |
min_width=80 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Inputs") | |
lyrics_text = gr.Textbox( | |
label="Lyrics", | |
lines=10, | |
placeholder="Enter lyrics with timestamps: word[start_time:end_time] word[start_time:end_time]...\n\nExample: Hello[0.0:1.2] world[1.5:2.8] this[3.0:3.8] is[4.2:4.6] my[5.0:5.8] song[6.2:7.0]\n\nFormat: Each word followed by [start_seconds:end_seconds] in brackets\nTimestamps should be in seconds with up to 2 decimal places", | |
value=default_lyrics | |
) | |
duration_slider = gr.Slider(minimum=120, maximum=230, value=default_duration, step=1, label="Duration (seconds)") | |
with gr.Tab("Style from Audio"): | |
reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath", value=default_audio) | |
with gr.Tab("Style from Text"): | |
style_prompt = gr.Textbox(label="Style Prompt", lines=3, placeholder="e.g., A high-energy electronic dance track with a strong bassline and euphoric synths.") | |
generate_button = gr.Button("Generate Song", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Output") | |
output_audio = gr.Audio(label="Generated Song") | |
generate_button.click( | |
fn=generate_song, | |
inputs=[reference_audio, lyrics_text, style_prompt, duration_slider], | |
outputs=output_audio, | |
api_name="generate_song" | |
) | |
# Connect example buttons to load data and update selection | |
if examples: | |
def load_example_and_update_selection(idx): | |
"""Load example data and update button selection state""" | |
audio, lyrics, duration = load_example(idx, examples) | |
button_updates = update_button_styles(idx, len(examples)) | |
return [audio, lyrics, duration, idx] + button_updates | |
def clear_form_and_update_selection(): | |
"""Clear form and update button selection state""" | |
audio, lyrics, duration = clear_form() | |
button_updates = update_button_styles(-1, len(examples)) | |
return [audio, lyrics, duration, -1] + button_updates | |
for i, button in enumerate(example_buttons): | |
button.click( | |
fn=lambda idx=i: load_example_and_update_selection(idx), | |
outputs=[reference_audio, lyrics_text, duration_slider, selected_example] + example_buttons + [make_your_own_button] | |
) | |
# Connect "Make Your Own" button to clear form and update selection | |
make_your_own_button.click( | |
fn=clear_form_and_update_selection, | |
outputs=[reference_audio, lyrics_text, duration_slider, selected_example] + example_buttons + [make_your_own_button] | |
) | |
# Create necessary temporary directories for Gradio | |
print("Creating temporary directories...") | |
try: | |
os.makedirs("/tmp/gradio", exist_ok=True) | |
print("Temporary directories created successfully.") | |
except Exception as e: | |
print(f"Warning: Could not create temporary directories: {e}") | |
demo.queue().launch(share=True) |