Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import torch | |
import torchaudio | |
import tempfile | |
import json | |
import gradio as gr | |
from omegaconf import OmegaConf | |
from huggingface_hub import hf_hub_download | |
# Add the SongBloom module to the path | |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
os.environ['DISABLE_FLASH_ATTN'] = "1" | |
from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler | |
class SongBloomApp: | |
def __init__(self): | |
self.model = None | |
self.is_loading = False | |
def hf_download(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache"): | |
"""Download model files from Hugging Face""" | |
cfg_path = hf_hub_download( | |
repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir) | |
ckpt_path = hf_hub_download( | |
repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir) | |
vae_cfg_path = hf_hub_download( | |
repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir) | |
vae_ckpt_path = hf_hub_download( | |
repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir) | |
g2p_path = hf_hub_download( | |
repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir) | |
return cfg_path | |
def load_config(self, cfg_file, parent_dir="./"): | |
"""Load model configuration""" | |
OmegaConf.register_new_resolver("eval", lambda x: eval(x)) | |
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) | |
OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0]) | |
OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x)) | |
OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir)) | |
file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None else OmegaConf.create() | |
return file_cfg | |
def load_model(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", dtype="float32"): | |
"""Load the SongBloom model""" | |
if self.is_loading: | |
return "Model is already loading, please wait..." | |
if self.model is not None: | |
return "Model is already loaded!" | |
try: | |
self.is_loading = True | |
local_dir = "./cache" | |
# Download model files | |
cfg_path = self.hf_download(repo_id, model_name, local_dir) | |
cfg = self.load_config(f"{local_dir}/{model_name}.yaml", parent_dir=local_dir) | |
# Load model | |
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16 | |
self.model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype_torch) | |
self.model.set_generation_params(**cfg.inference) | |
self.is_loading = False | |
return "Model loaded successfully!" | |
except Exception as e: | |
self.is_loading = False | |
return f"Error loading model: {str(e)}" | |
def generate_song(self, lyrics, prompt_audio, n_samples=1, dtype="float32", progress=gr.Progress()): | |
"""Generate song from lyrics and audio prompt""" | |
if self.model is None: | |
return [], "Please load the model first!" | |
if not lyrics.strip(): | |
return [], "Please provide lyrics!" | |
if prompt_audio is None: | |
return [], "Please upload a prompt audio file!" | |
try: | |
progress(0.1, desc="Processing audio prompt...") | |
# Load and process the prompt audio | |
prompt_wav, sr = torchaudio.load(prompt_audio) | |
if sr != self.model.sample_rate: | |
prompt_wav = torchaudio.functional.resample(prompt_wav, sr, self.model.sample_rate) | |
# Convert to mono and limit to 10 seconds | |
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16 | |
prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype_torch) | |
prompt_wav = prompt_wav[..., :10*self.model.sample_rate] | |
progress(0.3, desc="Generating song...") | |
output_files = [] | |
# Generate samples | |
for i in range(n_samples): | |
progress(0.3 + (i / n_samples) * 0.6, desc=f"Generating sample {i+1}/{n_samples}...") | |
wav = self.model.generate(lyrics, prompt_wav) | |
# Save to temporary file | |
with tempfile.NamedTemporaryFile(suffix='.flac', delete=False) as tmp_file: | |
torchaudio.save(tmp_file.name, wav[0].cpu().float(), self.model.sample_rate) | |
output_files.append(tmp_file.name) | |
progress(1.0, desc="Complete!") | |
return output_files, f"Successfully generated {n_samples} song(s)!" | |
except Exception as e: | |
return [], f"Error generating song: {str(e)}" | |
def format_lyrics_example(self, example_type): | |
"""Provide example lyrics in the correct format""" | |
if example_type == "Chinese": | |
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] 风轻轻吹过古道.岁月在墙上刻下记号.梦中你笑得多甜.醒来却只剩下寂寥.繁花似锦的春天.少了你的色彩也失了妖娆 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] 月儿弯弯照九州.你是否也在仰望同一片天空.灯火阑珊处.我寻觅你的影踪.回忆如波光粼粼.荡漾在心湖的每个角落 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]" | |
else: # English | |
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] City lights flicker through the car window. Dreams pass fast where the lost ones go. Neon signs echo stories untold. I chase shadows while the night grows cold , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] Footsteps loud in the tunnel of time. Regret and hope in a crooked rhyme. You held my hand when I slipped through the dark. Lit a match and you became my spark , [bridge] We were nothing and everything too. Lost in a moment. found in the view. Of all we broke and still survived. Somehow the flame stayed alive , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]" | |
# Initialize the app | |
app = SongBloomApp() | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="SongBloom: AI Song Generation", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🎵 SongBloom: AI Song Generation | |
Generate full-length songs from lyrics and audio prompts using the SongBloom model. | |
**How to use:** | |
1. First, load the model (this may take a few minutes) | |
2. Enter your lyrics in the specified format | |
3. Upload a 10-second audio prompt (WAV format, 48kHz recommended) | |
4. Click "Generate Song" to create your music | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Model Loading Section | |
gr.Markdown("## 🤖 Model Setup") | |
model_status = gr.Textbox( | |
label="Model Status", | |
value="Model not loaded", | |
interactive=False | |
) | |
with gr.Row(): | |
repo_id = gr.Textbox( | |
label="Repository ID", | |
value="CypressYang/SongBloom", | |
interactive=True | |
) | |
model_name = gr.Textbox( | |
label="Model Name", | |
value="songbloom_full_150s", | |
interactive=True | |
) | |
dtype_choice = gr.Dropdown( | |
choices=["float32", "bfloat16"], | |
value="float32", | |
label="Precision (use bfloat16 for lower VRAM)", | |
interactive=True | |
) | |
load_btn = gr.Button("Load Model", variant="primary") | |
# Lyrics Input Section | |
gr.Markdown("## 📝 Lyrics Input") | |
# Example selector | |
example_type = gr.Dropdown( | |
choices=["Chinese", "English"], | |
value="Chinese", | |
label="Load Example Lyrics", | |
interactive=True | |
) | |
lyrics_input = gr.Textbox( | |
label="Lyrics", | |
placeholder="Enter your lyrics in the specified format...", | |
lines=8, | |
max_lines=15 | |
) | |
load_example_btn = gr.Button("Load Example", variant="secondary") | |
# Audio Upload Section | |
gr.Markdown("## 🎧 Audio Prompt") | |
audio_input = gr.Audio( | |
label="Upload Audio Prompt (10-second WAV file recommended)", | |
type="filepath" | |
) | |
# Generation Settings | |
gr.Markdown("## ⚙️ Generation Settings") | |
n_samples = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=2, | |
step=1, | |
label="Number of samples to generate" | |
) | |
generate_btn = gr.Button("🎵 Generate Song", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
# Output Section | |
gr.Markdown("## 🎶 Generated Songs") | |
generation_status = gr.Textbox( | |
label="Generation Status", | |
value="Ready to generate", | |
interactive=False | |
) | |
output_audio = gr.Gallery( | |
label="Generated Audio Files", | |
show_label=True, | |
elem_id="gallery", | |
columns=1, | |
rows=3, | |
object_fit="contain", | |
height="auto", | |
type="filepath" | |
) | |
# Format Instructions | |
gr.Markdown(""" | |
## 📋 Lyric Format Instructions | |
**Structure Tags:** | |
- `[intro]`, `[verse]`, `[chorus]`, `[bridge]`, `[inst]`, `[outro]` | |
- Repeat tags for duration (e.g., `[intro] [intro] [intro]` for ~3 seconds) | |
**Text Rules:** | |
- Use `.` to separate sentences | |
- Use `,` to separate sections | |
- Example: `[verse] First line. Second line , [chorus] Chorus text` | |
**Audio Prompt:** | |
- 10-second audio file | |
- WAV format preferred | |
- 48kHz sample rate recommended | |
- Defines the musical style/genre | |
""") | |
# Event handlers | |
load_btn.click( | |
fn=app.load_model, | |
inputs=[repo_id, model_name, dtype_choice], | |
outputs=[model_status] | |
) | |
load_example_btn.click( | |
fn=app.format_lyrics_example, | |
inputs=[example_type], | |
outputs=[lyrics_input] | |
) | |
generate_btn.click( | |
fn=app.generate_song, | |
inputs=[lyrics_input, audio_input, n_samples, dtype_choice], | |
outputs=[output_audio, generation_status] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) |