import os import sys import torch import torchaudio import tempfile import json import gradio as gr from omegaconf import OmegaConf from huggingface_hub import hf_hub_download # Add the SongBloom module to the path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) os.environ['DISABLE_FLASH_ATTN'] = "1" from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler class SongBloomApp: def __init__(self): self.model = None self.is_loading = False def hf_download(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache"): """Download model files from Hugging Face""" cfg_path = hf_hub_download( repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir) ckpt_path = hf_hub_download( repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir) vae_cfg_path = hf_hub_download( repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir) vae_ckpt_path = hf_hub_download( repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir) g2p_path = hf_hub_download( repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir) return cfg_path def load_config(self, cfg_file, parent_dir="./"): """Load model configuration""" OmegaConf.register_new_resolver("eval", lambda x: eval(x)) OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0]) OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x)) OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir)) file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None else OmegaConf.create() return file_cfg def load_model(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", dtype="float32"): """Load the SongBloom model""" if self.is_loading: return "Model is already loading, please wait..." if self.model is not None: return "Model is already loaded!" try: self.is_loading = True local_dir = "./cache" # Download model files cfg_path = self.hf_download(repo_id, model_name, local_dir) cfg = self.load_config(f"{local_dir}/{model_name}.yaml", parent_dir=local_dir) # Load model dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16 self.model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype_torch) self.model.set_generation_params(**cfg.inference) self.is_loading = False return "Model loaded successfully!" except Exception as e: self.is_loading = False return f"Error loading model: {str(e)}" def generate_song(self, lyrics, prompt_audio, n_samples=1, dtype="float32", progress=gr.Progress()): """Generate song from lyrics and audio prompt""" if self.model is None: return [], "Please load the model first!" if not lyrics.strip(): return [], "Please provide lyrics!" if prompt_audio is None: return [], "Please upload a prompt audio file!" try: progress(0.1, desc="Processing audio prompt...") # Load and process the prompt audio prompt_wav, sr = torchaudio.load(prompt_audio) if sr != self.model.sample_rate: prompt_wav = torchaudio.functional.resample(prompt_wav, sr, self.model.sample_rate) # Convert to mono and limit to 10 seconds dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16 prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype_torch) prompt_wav = prompt_wav[..., :10*self.model.sample_rate] progress(0.3, desc="Generating song...") output_files = [] # Generate samples for i in range(n_samples): progress(0.3 + (i / n_samples) * 0.6, desc=f"Generating sample {i+1}/{n_samples}...") wav = self.model.generate(lyrics, prompt_wav) # Save to temporary file with tempfile.NamedTemporaryFile(suffix='.flac', delete=False) as tmp_file: torchaudio.save(tmp_file.name, wav[0].cpu().float(), self.model.sample_rate) output_files.append(tmp_file.name) progress(1.0, desc="Complete!") return output_files, f"Successfully generated {n_samples} song(s)!" except Exception as e: return [], f"Error generating song: {str(e)}" def format_lyrics_example(self, example_type): """Provide example lyrics in the correct format""" if example_type == "Chinese": return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] 风轻轻吹过古道.岁月在墙上刻下记号.梦中你笑得多甜.醒来却只剩下寂寥.繁花似锦的春天.少了你的色彩也失了妖娆 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] 月儿弯弯照九州.你是否也在仰望同一片天空.灯火阑珊处.我寻觅你的影踪.回忆如波光粼粼.荡漾在心湖的每个角落 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]" else: # English return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] City lights flicker through the car window. Dreams pass fast where the lost ones go. Neon signs echo stories untold. I chase shadows while the night grows cold , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] Footsteps loud in the tunnel of time. Regret and hope in a crooked rhyme. You held my hand when I slipped through the dark. Lit a match and you became my spark , [bridge] We were nothing and everything too. Lost in a moment. found in the view. Of all we broke and still survived. Somehow the flame stayed alive , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]" # Initialize the app app = SongBloomApp() # Create Gradio interface def create_interface(): with gr.Blocks(title="SongBloom: AI Song Generation", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎵 SongBloom: AI Song Generation Generate full-length songs from lyrics and audio prompts using the SongBloom model. **How to use:** 1. First, load the model (this may take a few minutes) 2. Enter your lyrics in the specified format 3. Upload a 10-second audio prompt (WAV format, 48kHz recommended) 4. Click "Generate Song" to create your music """) with gr.Row(): with gr.Column(scale=1): # Model Loading Section gr.Markdown("## 🤖 Model Setup") model_status = gr.Textbox( label="Model Status", value="Model not loaded", interactive=False ) with gr.Row(): repo_id = gr.Textbox( label="Repository ID", value="CypressYang/SongBloom", interactive=True ) model_name = gr.Textbox( label="Model Name", value="songbloom_full_150s", interactive=True ) dtype_choice = gr.Dropdown( choices=["float32", "bfloat16"], value="float32", label="Precision (use bfloat16 for lower VRAM)", interactive=True ) load_btn = gr.Button("Load Model", variant="primary") # Lyrics Input Section gr.Markdown("## 📝 Lyrics Input") # Example selector example_type = gr.Dropdown( choices=["Chinese", "English"], value="Chinese", label="Load Example Lyrics", interactive=True ) lyrics_input = gr.Textbox( label="Lyrics", placeholder="Enter your lyrics in the specified format...", lines=8, max_lines=15 ) load_example_btn = gr.Button("Load Example", variant="secondary") # Audio Upload Section gr.Markdown("## 🎧 Audio Prompt") audio_input = gr.Audio( label="Upload Audio Prompt (10-second WAV file recommended)", type="filepath" ) # Generation Settings gr.Markdown("## ⚙️ Generation Settings") n_samples = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Number of samples to generate" ) generate_btn = gr.Button("🎵 Generate Song", variant="primary", size="lg") with gr.Column(scale=1): # Output Section gr.Markdown("## 🎶 Generated Songs") generation_status = gr.Textbox( label="Generation Status", value="Ready to generate", interactive=False ) output_audio = gr.Gallery( label="Generated Audio Files", show_label=True, elem_id="gallery", columns=1, rows=3, object_fit="contain", height="auto", type="filepath" ) # Format Instructions gr.Markdown(""" ## 📋 Lyric Format Instructions **Structure Tags:** - `[intro]`, `[verse]`, `[chorus]`, `[bridge]`, `[inst]`, `[outro]` - Repeat tags for duration (e.g., `[intro] [intro] [intro]` for ~3 seconds) **Text Rules:** - Use `.` to separate sentences - Use `,` to separate sections - Example: `[verse] First line. Second line , [chorus] Chorus text` **Audio Prompt:** - 10-second audio file - WAV format preferred - 48kHz sample rate recommended - Defines the musical style/genre """) # Event handlers load_btn.click( fn=app.load_model, inputs=[repo_id, model_name, dtype_choice], outputs=[model_status] ) load_example_btn.click( fn=app.format_lyrics_example, inputs=[example_type], outputs=[lyrics_input] ) generate_btn.click( fn=app.generate_song, inputs=[lyrics_input, audio_input, n_samples, dtype_choice], outputs=[output_audio, generation_status] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )