Spaces:
Runtime error
Runtime error
File size: 13,030 Bytes
a0e2cb7 2a7fe05 a0e2cb7 2a7fe05 a0e2cb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
import os
import sys
import torch
import torchaudio
import tempfile
import json
import gradio as gr
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
import spaces
# Add the SongBloom module to the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
os.environ['DISABLE_FLASH_ATTN'] = "1"
from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler
class SongBloomApp:
def __init__(self):
self.model = None
self.is_loading = False
def hf_download(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache"):
"""Download model files from Hugging Face"""
cfg_path = hf_hub_download(
repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir)
ckpt_path = hf_hub_download(
repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir)
vae_cfg_path = hf_hub_download(
repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir)
vae_ckpt_path = hf_hub_download(
repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir)
g2p_path = hf_hub_download(
repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir)
return cfg_path
def load_config(self, cfg_file, parent_dir="./"):
"""Load model configuration"""
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x))
OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None else OmegaConf.create()
return file_cfg
def load_model(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", dtype="float32"):
"""Load the SongBloom model"""
if self.is_loading:
return "Model is already loading, please wait..."
if self.model is not None:
return "Model is already loaded!"
try:
self.is_loading = True
local_dir = "./cache"
# Download model files
cfg_path = self.hf_download(repo_id, model_name, local_dir)
cfg = self.load_config(f"{local_dir}/{model_name}.yaml", parent_dir=local_dir)
# Load model
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
self.model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype_torch)
self.model.set_generation_params(**cfg.inference)
self.is_loading = False
return "Model loaded successfully!"
except Exception as e:
self.is_loading = False
return f"Error loading model: {str(e)}"
@spaces.GPU
def generate_song(self, lyrics, prompt_audio, n_samples=1, dtype="float32", progress=gr.Progress()):
"""Generate song from lyrics and audio prompt"""
if self.model is None:
return [], "Please load the model first!"
if not lyrics.strip():
return [], "Please provide lyrics!"
if prompt_audio is None:
return [], "Please upload a prompt audio file!"
try:
progress(0.1, desc="Processing audio prompt...")
# Load and process the prompt audio
prompt_wav, sr = torchaudio.load(prompt_audio)
if sr != self.model.sample_rate:
prompt_wav = torchaudio.functional.resample(prompt_wav, sr, self.model.sample_rate)
# Convert to mono and limit to 10 seconds
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype_torch)
prompt_wav = prompt_wav[..., :10*self.model.sample_rate]
progress(0.3, desc="Generating song...")
output_files = []
# Generate samples
for i in range(n_samples):
progress(0.3 + (i / n_samples) * 0.6, desc=f"Generating sample {i+1}/{n_samples}...")
wav = self.model.generate(lyrics, prompt_wav)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.flac', delete=False) as tmp_file:
torchaudio.save(tmp_file.name, wav[0].cpu().float(), self.model.sample_rate)
output_files.append(tmp_file.name)
progress(1.0, desc="Complete!")
return output_files, f"Successfully generated {n_samples} song(s)!"
except Exception as e:
return [], f"Error generating song: {str(e)}"
def format_lyrics_example(self, example_type):
"""Provide example lyrics in the correct format"""
if example_type == "Chinese":
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] 风轻轻吹过古道.岁月在墙上刻下记号.梦中你笑得多甜.醒来却只剩下寂寥.繁花似锦的春天.少了你的色彩也失了妖娆 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] 月儿弯弯照九州.你是否也在仰望同一片天空.灯火阑珊处.我寻觅你的影踪.回忆如波光粼粼.荡漾在心湖的每个角落 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
else: # English
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] City lights flicker through the car window. Dreams pass fast where the lost ones go. Neon signs echo stories untold. I chase shadows while the night grows cold , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] Footsteps loud in the tunnel of time. Regret and hope in a crooked rhyme. You held my hand when I slipped through the dark. Lit a match and you became my spark , [bridge] We were nothing and everything too. Lost in a moment. found in the view. Of all we broke and still survived. Somehow the flame stayed alive , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
# Initialize the app
app = SongBloomApp()
# Create Gradio interface
def create_interface():
with gr.Blocks(title="SongBloom: AI Song Generation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎵 SongBloom: AI Song Generation
Generate full-length songs from lyrics and audio prompts using the SongBloom model.
**How to use:**
1. First, load the model (this may take a few minutes)
2. Enter your lyrics in the specified format
3. Upload a 10-second audio prompt (WAV format, 48kHz recommended)
4. Click "Generate Song" to create your music
""")
with gr.Row():
with gr.Column(scale=1):
# Model Loading Section
gr.Markdown("## 🤖 Model Setup")
model_status = gr.Textbox(
label="Model Status",
value="Model not loaded",
interactive=False
)
with gr.Row():
repo_id = gr.Textbox(
label="Repository ID",
value="CypressYang/SongBloom",
interactive=True
)
model_name = gr.Textbox(
label="Model Name",
value="songbloom_full_150s",
interactive=True
)
dtype_choice = gr.Dropdown(
choices=["float32", "bfloat16"],
value="float32",
label="Precision (use bfloat16 for lower VRAM)",
interactive=True
)
load_btn = gr.Button("Load Model", variant="primary")
# Lyrics Input Section
gr.Markdown("## 📝 Lyrics Input")
# Example selector
example_type = gr.Dropdown(
choices=["Chinese", "English"],
value="Chinese",
label="Load Example Lyrics",
interactive=True
)
lyrics_input = gr.Textbox(
label="Lyrics",
placeholder="Enter your lyrics in the specified format...",
lines=8,
max_lines=15
)
load_example_btn = gr.Button("Load Example", variant="secondary")
# Audio Upload Section
gr.Markdown("## 🎧 Audio Prompt")
audio_input = gr.Audio(
label="Upload Audio Prompt (10-second WAV file recommended)",
type="filepath"
)
# Generation Settings
gr.Markdown("## ⚙️ Generation Settings")
n_samples = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Number of samples to generate"
)
generate_btn = gr.Button("🎵 Generate Song", variant="primary", size="lg")
with gr.Column(scale=1):
# Output Section
gr.Markdown("## 🎶 Generated Songs")
generation_status = gr.Textbox(
label="Generation Status",
value="Ready to generate",
interactive=False
)
output_audio = gr.Gallery(
label="Generated Audio Files",
show_label=True,
elem_id="gallery",
columns=1,
rows=3,
object_fit="contain",
height="auto",
type="filepath"
)
# Format Instructions
gr.Markdown("""
## 📋 Lyric Format Instructions
**Structure Tags:**
- `[intro]`, `[verse]`, `[chorus]`, `[bridge]`, `[inst]`, `[outro]`
- Repeat tags for duration (e.g., `[intro] [intro] [intro]` for ~3 seconds)
**Text Rules:**
- Use `.` to separate sentences
- Use `,` to separate sections
- Example: `[verse] First line. Second line , [chorus] Chorus text`
**Audio Prompt:**
- 10-second audio file
- WAV format preferred
- 48kHz sample rate recommended
- Defines the musical style/genre
""")
# Event handlers
load_btn.click(
fn=app.load_model,
inputs=[repo_id, model_name, dtype_choice],
outputs=[model_status]
)
load_example_btn.click(
fn=app.format_lyrics_example,
inputs=[example_type],
outputs=[lyrics_input]
)
generate_btn.click(
fn=app.generate_song,
inputs=[lyrics_input, audio_input, n_samples, dtype_choice],
outputs=[output_audio, generation_status]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
) |