DiffuCoder / app.py
mrfakename's picture
Update app.py
2a7fe05 verified
raw
history blame
13 kB
import os
import sys
import torch
import torchaudio
import tempfile
import json
import gradio as gr
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
import spaces
# Add the SongBloom module to the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
os.environ['DISABLE_FLASH_ATTN'] = "1"
from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler
class SongBloomApp:
def __init__(self):
self.model = None
self.is_loading = False
def hf_download(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache"):
"""Download model files from Hugging Face"""
cfg_path = hf_hub_download(
repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir)
ckpt_path = hf_hub_download(
repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir)
vae_cfg_path = hf_hub_download(
repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir)
vae_ckpt_path = hf_hub_download(
repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir)
g2p_path = hf_hub_download(
repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir)
return cfg_path
def load_config(self, cfg_file, parent_dir="./"):
"""Load model configuration"""
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x))
OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None else OmegaConf.create()
return file_cfg
def load_model(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", dtype="float32"):
"""Load the SongBloom model"""
if self.is_loading:
return "Model is already loading, please wait..."
if self.model is not None:
return "Model is already loaded!"
try:
self.is_loading = True
local_dir = "./cache"
# Download model files
cfg_path = self.hf_download(repo_id, model_name, local_dir)
cfg = self.load_config(f"{local_dir}/{model_name}.yaml", parent_dir=local_dir)
# Load model
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
self.model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype_torch)
self.model.set_generation_params(**cfg.inference)
self.is_loading = False
return "Model loaded successfully!"
except Exception as e:
self.is_loading = False
return f"Error loading model: {str(e)}"
@spaces.GPU
def generate_song(self, lyrics, prompt_audio, n_samples=1, dtype="float32", progress=gr.Progress()):
"""Generate song from lyrics and audio prompt"""
if self.model is None:
return [], "Please load the model first!"
if not lyrics.strip():
return [], "Please provide lyrics!"
if prompt_audio is None:
return [], "Please upload a prompt audio file!"
try:
progress(0.1, desc="Processing audio prompt...")
# Load and process the prompt audio
prompt_wav, sr = torchaudio.load(prompt_audio)
if sr != self.model.sample_rate:
prompt_wav = torchaudio.functional.resample(prompt_wav, sr, self.model.sample_rate)
# Convert to mono and limit to 10 seconds
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype_torch)
prompt_wav = prompt_wav[..., :10*self.model.sample_rate]
progress(0.3, desc="Generating song...")
output_files = []
# Generate samples
for i in range(n_samples):
progress(0.3 + (i / n_samples) * 0.6, desc=f"Generating sample {i+1}/{n_samples}...")
wav = self.model.generate(lyrics, prompt_wav)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.flac', delete=False) as tmp_file:
torchaudio.save(tmp_file.name, wav[0].cpu().float(), self.model.sample_rate)
output_files.append(tmp_file.name)
progress(1.0, desc="Complete!")
return output_files, f"Successfully generated {n_samples} song(s)!"
except Exception as e:
return [], f"Error generating song: {str(e)}"
def format_lyrics_example(self, example_type):
"""Provide example lyrics in the correct format"""
if example_type == "Chinese":
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] 风轻轻吹过古道.岁月在墙上刻下记号.梦中你笑得多甜.醒来却只剩下寂寥.繁花似锦的春天.少了你的色彩也失了妖娆 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] 月儿弯弯照九州.你是否也在仰望同一片天空.灯火阑珊处.我寻觅你的影踪.回忆如波光粼粼.荡漾在心湖的每个角落 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
else: # English
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] City lights flicker through the car window. Dreams pass fast where the lost ones go. Neon signs echo stories untold. I chase shadows while the night grows cold , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] Footsteps loud in the tunnel of time. Regret and hope in a crooked rhyme. You held my hand when I slipped through the dark. Lit a match and you became my spark , [bridge] We were nothing and everything too. Lost in a moment. found in the view. Of all we broke and still survived. Somehow the flame stayed alive , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
# Initialize the app
app = SongBloomApp()
# Create Gradio interface
def create_interface():
with gr.Blocks(title="SongBloom: AI Song Generation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎵 SongBloom: AI Song Generation
Generate full-length songs from lyrics and audio prompts using the SongBloom model.
**How to use:**
1. First, load the model (this may take a few minutes)
2. Enter your lyrics in the specified format
3. Upload a 10-second audio prompt (WAV format, 48kHz recommended)
4. Click "Generate Song" to create your music
""")
with gr.Row():
with gr.Column(scale=1):
# Model Loading Section
gr.Markdown("## 🤖 Model Setup")
model_status = gr.Textbox(
label="Model Status",
value="Model not loaded",
interactive=False
)
with gr.Row():
repo_id = gr.Textbox(
label="Repository ID",
value="CypressYang/SongBloom",
interactive=True
)
model_name = gr.Textbox(
label="Model Name",
value="songbloom_full_150s",
interactive=True
)
dtype_choice = gr.Dropdown(
choices=["float32", "bfloat16"],
value="float32",
label="Precision (use bfloat16 for lower VRAM)",
interactive=True
)
load_btn = gr.Button("Load Model", variant="primary")
# Lyrics Input Section
gr.Markdown("## 📝 Lyrics Input")
# Example selector
example_type = gr.Dropdown(
choices=["Chinese", "English"],
value="Chinese",
label="Load Example Lyrics",
interactive=True
)
lyrics_input = gr.Textbox(
label="Lyrics",
placeholder="Enter your lyrics in the specified format...",
lines=8,
max_lines=15
)
load_example_btn = gr.Button("Load Example", variant="secondary")
# Audio Upload Section
gr.Markdown("## 🎧 Audio Prompt")
audio_input = gr.Audio(
label="Upload Audio Prompt (10-second WAV file recommended)",
type="filepath"
)
# Generation Settings
gr.Markdown("## ⚙️ Generation Settings")
n_samples = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Number of samples to generate"
)
generate_btn = gr.Button("🎵 Generate Song", variant="primary", size="lg")
with gr.Column(scale=1):
# Output Section
gr.Markdown("## 🎶 Generated Songs")
generation_status = gr.Textbox(
label="Generation Status",
value="Ready to generate",
interactive=False
)
output_audio = gr.Gallery(
label="Generated Audio Files",
show_label=True,
elem_id="gallery",
columns=1,
rows=3,
object_fit="contain",
height="auto",
type="filepath"
)
# Format Instructions
gr.Markdown("""
## 📋 Lyric Format Instructions
**Structure Tags:**
- `[intro]`, `[verse]`, `[chorus]`, `[bridge]`, `[inst]`, `[outro]`
- Repeat tags for duration (e.g., `[intro] [intro] [intro]` for ~3 seconds)
**Text Rules:**
- Use `.` to separate sentences
- Use `,` to separate sections
- Example: `[verse] First line. Second line , [chorus] Chorus text`
**Audio Prompt:**
- 10-second audio file
- WAV format preferred
- 48kHz sample rate recommended
- Defines the musical style/genre
""")
# Event handlers
load_btn.click(
fn=app.load_model,
inputs=[repo_id, model_name, dtype_choice],
outputs=[model_status]
)
load_example_btn.click(
fn=app.format_lyrics_example,
inputs=[example_type],
outputs=[lyrics_input]
)
generate_btn.click(
fn=app.generate_song,
inputs=[lyrics_input, audio_input, n_samples, dtype_choice],
outputs=[output_audio, generation_status]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)