DiffuCoder / app.py
mrfakename's picture
init
a0e2cb7
raw
history blame
13 kB
import os
import sys
import torch
import torchaudio
import tempfile
import json
import gradio as gr
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
# Add the SongBloom module to the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
os.environ['DISABLE_FLASH_ATTN'] = "1"
from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler
class SongBloomApp:
def __init__(self):
self.model = None
self.is_loading = False
def hf_download(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache"):
"""Download model files from Hugging Face"""
cfg_path = hf_hub_download(
repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir)
ckpt_path = hf_hub_download(
repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir)
vae_cfg_path = hf_hub_download(
repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir)
vae_ckpt_path = hf_hub_download(
repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir)
g2p_path = hf_hub_download(
repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir)
return cfg_path
def load_config(self, cfg_file, parent_dir="./"):
"""Load model configuration"""
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x))
OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None else OmegaConf.create()
return file_cfg
def load_model(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", dtype="float32"):
"""Load the SongBloom model"""
if self.is_loading:
return "Model is already loading, please wait..."
if self.model is not None:
return "Model is already loaded!"
try:
self.is_loading = True
local_dir = "./cache"
# Download model files
cfg_path = self.hf_download(repo_id, model_name, local_dir)
cfg = self.load_config(f"{local_dir}/{model_name}.yaml", parent_dir=local_dir)
# Load model
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
self.model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype_torch)
self.model.set_generation_params(**cfg.inference)
self.is_loading = False
return "Model loaded successfully!"
except Exception as e:
self.is_loading = False
return f"Error loading model: {str(e)}"
def generate_song(self, lyrics, prompt_audio, n_samples=1, dtype="float32", progress=gr.Progress()):
"""Generate song from lyrics and audio prompt"""
if self.model is None:
return [], "Please load the model first!"
if not lyrics.strip():
return [], "Please provide lyrics!"
if prompt_audio is None:
return [], "Please upload a prompt audio file!"
try:
progress(0.1, desc="Processing audio prompt...")
# Load and process the prompt audio
prompt_wav, sr = torchaudio.load(prompt_audio)
if sr != self.model.sample_rate:
prompt_wav = torchaudio.functional.resample(prompt_wav, sr, self.model.sample_rate)
# Convert to mono and limit to 10 seconds
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype_torch)
prompt_wav = prompt_wav[..., :10*self.model.sample_rate]
progress(0.3, desc="Generating song...")
output_files = []
# Generate samples
for i in range(n_samples):
progress(0.3 + (i / n_samples) * 0.6, desc=f"Generating sample {i+1}/{n_samples}...")
wav = self.model.generate(lyrics, prompt_wav)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.flac', delete=False) as tmp_file:
torchaudio.save(tmp_file.name, wav[0].cpu().float(), self.model.sample_rate)
output_files.append(tmp_file.name)
progress(1.0, desc="Complete!")
return output_files, f"Successfully generated {n_samples} song(s)!"
except Exception as e:
return [], f"Error generating song: {str(e)}"
def format_lyrics_example(self, example_type):
"""Provide example lyrics in the correct format"""
if example_type == "Chinese":
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] 风轻轻吹过古道.岁月在墙上刻下记号.梦中你笑得多甜.醒来却只剩下寂寥.繁花似锦的春天.少了你的色彩也失了妖娆 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] 月儿弯弯照九州.你是否也在仰望同一片天空.灯火阑珊处.我寻觅你的影踪.回忆如波光粼粼.荡漾在心湖的每个角落 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
else: # English
return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] City lights flicker through the car window. Dreams pass fast where the lost ones go. Neon signs echo stories untold. I chase shadows while the night grows cold , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] Footsteps loud in the tunnel of time. Regret and hope in a crooked rhyme. You held my hand when I slipped through the dark. Lit a match and you became my spark , [bridge] We were nothing and everything too. Lost in a moment. found in the view. Of all we broke and still survived. Somehow the flame stayed alive , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
# Initialize the app
app = SongBloomApp()
# Create Gradio interface
def create_interface():
with gr.Blocks(title="SongBloom: AI Song Generation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎵 SongBloom: AI Song Generation
Generate full-length songs from lyrics and audio prompts using the SongBloom model.
**How to use:**
1. First, load the model (this may take a few minutes)
2. Enter your lyrics in the specified format
3. Upload a 10-second audio prompt (WAV format, 48kHz recommended)
4. Click "Generate Song" to create your music
""")
with gr.Row():
with gr.Column(scale=1):
# Model Loading Section
gr.Markdown("## 🤖 Model Setup")
model_status = gr.Textbox(
label="Model Status",
value="Model not loaded",
interactive=False
)
with gr.Row():
repo_id = gr.Textbox(
label="Repository ID",
value="CypressYang/SongBloom",
interactive=True
)
model_name = gr.Textbox(
label="Model Name",
value="songbloom_full_150s",
interactive=True
)
dtype_choice = gr.Dropdown(
choices=["float32", "bfloat16"],
value="float32",
label="Precision (use bfloat16 for lower VRAM)",
interactive=True
)
load_btn = gr.Button("Load Model", variant="primary")
# Lyrics Input Section
gr.Markdown("## 📝 Lyrics Input")
# Example selector
example_type = gr.Dropdown(
choices=["Chinese", "English"],
value="Chinese",
label="Load Example Lyrics",
interactive=True
)
lyrics_input = gr.Textbox(
label="Lyrics",
placeholder="Enter your lyrics in the specified format...",
lines=8,
max_lines=15
)
load_example_btn = gr.Button("Load Example", variant="secondary")
# Audio Upload Section
gr.Markdown("## 🎧 Audio Prompt")
audio_input = gr.Audio(
label="Upload Audio Prompt (10-second WAV file recommended)",
type="filepath"
)
# Generation Settings
gr.Markdown("## ⚙️ Generation Settings")
n_samples = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Number of samples to generate"
)
generate_btn = gr.Button("🎵 Generate Song", variant="primary", size="lg")
with gr.Column(scale=1):
# Output Section
gr.Markdown("## 🎶 Generated Songs")
generation_status = gr.Textbox(
label="Generation Status",
value="Ready to generate",
interactive=False
)
output_audio = gr.Gallery(
label="Generated Audio Files",
show_label=True,
elem_id="gallery",
columns=1,
rows=3,
object_fit="contain",
height="auto",
type="filepath"
)
# Format Instructions
gr.Markdown("""
## 📋 Lyric Format Instructions
**Structure Tags:**
- `[intro]`, `[verse]`, `[chorus]`, `[bridge]`, `[inst]`, `[outro]`
- Repeat tags for duration (e.g., `[intro] [intro] [intro]` for ~3 seconds)
**Text Rules:**
- Use `.` to separate sentences
- Use `,` to separate sections
- Example: `[verse] First line. Second line , [chorus] Chorus text`
**Audio Prompt:**
- 10-second audio file
- WAV format preferred
- 48kHz sample rate recommended
- Defines the musical style/genre
""")
# Event handlers
load_btn.click(
fn=app.load_model,
inputs=[repo_id, model_name, dtype_choice],
outputs=[model_status]
)
load_example_btn.click(
fn=app.format_lyrics_example,
inputs=[example_type],
outputs=[lyrics_input]
)
generate_btn.click(
fn=app.generate_song,
inputs=[lyrics_input, audio_input, n_samples, dtype_choice],
outputs=[output_audio, generation_status]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)