JAM / app.py
renhang
Merge branch 'working' into pr/2
68f7846
raw
history blame
6.98 kB
import spaces
import gradio as gr
import os
import json
import tempfile
from pathlib import Path
import ast
from model import Jamify
<<<<<<< HEAD
import json
=======
from utils import json_to_text, text_to_json
>>>>>>> working
# Initialize the Jamify model once
print("Initializing Jamify model...")
jamify_model = Jamify()
print("Jamify model ready.")
def parse(file):
parsed_data = []
with open(file, 'r') as f:
content = f.read()
for line in f:
line = line.strip()
if not line:
continue
start,end,word = line.split(' ')
data_point = {
"word": word,
"start": float(start),
"end": float(end)
}
parsed_data.append(data_point)
with open('temp.json','w') as f:
json.dump(parsed_data,f)
return content
@spaces.GPU(duration=100)
def generate_song(reference_audio, lyrics_text, style_prompt, duration):
# We need to save the uploaded files to temporary paths to pass to the model
reference_audio = reference_audio not in ("", None) and reference_audio or None
<<<<<<< HEAD
# The model expects paths, so we write the prompt to a temp file if needed
# (This part of the model could be improved to accept the string directly)
#print(type(lyrics_file))
#parse(lyrics_file)
#parse(lyrics_file)
output_path = jamify_model.predict(
reference_audio_path=reference_audio,
lyrics_json_path='temp.json',
style_prompt=style_prompt,
duration_sec=duration
)
=======
# Convert text format to JSON and save to temporary file
lyrics_json = text_to_json(lyrics_text)
>>>>>>> working
# Create temporary file for lyrics JSON
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(lyrics_json, f, indent=2)
lyrics_file = f.name
try:
output_path = jamify_model.predict(
reference_audio_path=reference_audio,
lyrics_json_path=lyrics_file,
style_prompt=style_prompt,
duration_sec=duration
)
return output_path
finally:
# Clean up temporary file
if os.path.exists(lyrics_file):
os.unlink(lyrics_file)
# Load and cache examples
def load_examples():
"""Load examples from the examples directory and pre-compute text formats"""
examples = []
examples_file = "examples/input.json"
if os.path.exists(examples_file):
print("Loading and caching examples...")
with open(examples_file, 'r') as f:
examples_data = json.load(f)
for example in examples_data:
example_id = example.get('id', '')
audio_path = example.get('audio_path', '')
lrc_path = example.get('lrc_path', '')
duration = example.get('duration', 120)
# Load lyrics and convert to text format (pre-computed/cached)
lyrics_text = ""
if os.path.exists(lrc_path):
try:
with open(lrc_path, 'r') as f:
lyrics_json = json.load(f)
lyrics_text = json_to_text(lyrics_json)
print(f"Cached example {example_id}: {len(lyrics_text)} chars")
except Exception as e:
print(f"Error loading lyrics from {lrc_path}: {e}")
examples.append({
'id': example_id,
'audio_path': audio_path if os.path.exists(audio_path) else None,
'lyrics_text': lyrics_text,
'duration': duration
})
print(f"Loaded {len(examples)} cached examples")
return examples
def load_example(example_idx, examples):
"""Load a specific example and return its data"""
if 0 <= example_idx < len(examples):
example = examples[example_idx]
return (
example['audio_path'],
example['lyrics_text'],
example['duration']
)
return None, "", 120
# Load examples at startup
examples = load_examples()
# Gradio interface
def process_text_file(file_obj):
"""
Reads the content of an uploaded text file.
"""
if file_obj is not None:
with open(file_obj.name, 'r') as f:
content = f.read()
return content
return "No file uploaded."
with gr.Blocks() as demo:
gr.Markdown("# Jamify: Music Generation from Lyrics and Style")
gr.Markdown("Provide your lyrics, a style reference (either an audio file or a text prompt), and a desired duration to generate a song.")
# Sample buttons section
if examples:
gr.Markdown("### Sample Examples")
with gr.Row():
example_buttons = []
for i, example in enumerate(examples):
button = gr.Button(f"Example {example['id']}", variant="secondary", size="sm")
example_buttons.append(button)
with gr.Row():
with gr.Column():
gr.Markdown("### Inputs")
lyrics_text = gr.Textbox(
label="Lyrics",
lines=10,
placeholder="Enter lyrics in format: word[start:end] word[start:end]...\nExample: It's[4.96:5.52] a[5.52:5.84] long[5.84:6.16] way[6.16:6.48]...",
value=""
)
duration_slider = gr.Slider(minimum=5, maximum=230, value=120, step=30, label="Duration (seconds)")
with gr.Tab("Style from Audio"):
reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath")
with gr.Tab("Style from Text"):
style_prompt = gr.Textbox(label="Style Prompt", lines=3, placeholder="e.g., A high-energy electronic dance track with a strong bassline and euphoric synths.")
generate_button = gr.Button("Generate Song", variant="primary")
with gr.Column():
gr.Markdown("### Output")
output_audio = gr.Audio(label="Generated Song")
generate_button.click(
fn=generate_song,
inputs=[reference_audio, lyrics_text, style_prompt, duration_slider],
outputs=output_audio,
api_name="generate_song"
)
# Connect example buttons to load data
if examples:
for i, button in enumerate(example_buttons):
button.click(
fn=lambda idx=i: load_example(idx, examples),
outputs=[reference_audio, lyrics_text, duration_slider]
)
# Create necessary temporary directories for Gradio
print("Creating temporary directories...")
try:
os.makedirs("/tmp/gradio", exist_ok=True)
print("Temporary directories created successfully.")
except Exception as e:
print(f"Warning: Could not create temporary directories: {e}")
demo.queue().launch()