File size: 13,030 Bytes
a0e2cb7
 
 
 
 
 
 
 
 
2a7fe05
a0e2cb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a7fe05
a0e2cb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
import os
import sys
import torch
import torchaudio
import tempfile
import json
import gradio as gr
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
import spaces

# Add the SongBloom module to the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

os.environ['DISABLE_FLASH_ATTN'] = "1"
from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler


class SongBloomApp:
    def __init__(self):
        self.model = None
        self.is_loading = False
        
    def hf_download(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache"):
        """Download model files from Hugging Face"""
        cfg_path = hf_hub_download(
            repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir)
        ckpt_path = hf_hub_download(
            repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir)
        
        vae_cfg_path = hf_hub_download(
            repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir)
        vae_ckpt_path = hf_hub_download(
            repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir)
        
        g2p_path = hf_hub_download(
            repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir)
        
        return cfg_path

    def load_config(self, cfg_file, parent_dir="./"):
        """Load model configuration"""
        OmegaConf.register_new_resolver("eval", lambda x: eval(x))
        OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
        OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
        OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x))
        OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
        
        file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None else OmegaConf.create()
        return file_cfg

    def load_model(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", dtype="float32"):
        """Load the SongBloom model"""
        if self.is_loading:
            return "Model is already loading, please wait..."
        
        if self.model is not None:
            return "Model is already loaded!"
        
        try:
            self.is_loading = True
            local_dir = "./cache"
            
            # Download model files
            cfg_path = self.hf_download(repo_id, model_name, local_dir)
            cfg = self.load_config(f"{local_dir}/{model_name}.yaml", parent_dir=local_dir)
            
            # Load model
            dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
            self.model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype_torch)
            self.model.set_generation_params(**cfg.inference)
            
            self.is_loading = False
            return "Model loaded successfully!"
            
        except Exception as e:
            self.is_loading = False
            return f"Error loading model: {str(e)}"

    @spaces.GPU
    def generate_song(self, lyrics, prompt_audio, n_samples=1, dtype="float32", progress=gr.Progress()):
        """Generate song from lyrics and audio prompt"""
        if self.model is None:
            return [], "Please load the model first!"
        
        if not lyrics.strip():
            return [], "Please provide lyrics!"
        
        if prompt_audio is None:
            return [], "Please upload a prompt audio file!"
        
        try:
            progress(0.1, desc="Processing audio prompt...")
            
            # Load and process the prompt audio
            prompt_wav, sr = torchaudio.load(prompt_audio)
            if sr != self.model.sample_rate:
                prompt_wav = torchaudio.functional.resample(prompt_wav, sr, self.model.sample_rate)
            
            # Convert to mono and limit to 10 seconds
            dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
            prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype_torch)
            prompt_wav = prompt_wav[..., :10*self.model.sample_rate]
            
            progress(0.3, desc="Generating song...")
            
            output_files = []
            
            # Generate samples
            for i in range(n_samples):
                progress(0.3 + (i / n_samples) * 0.6, desc=f"Generating sample {i+1}/{n_samples}...")
                
                wav = self.model.generate(lyrics, prompt_wav)
                
                # Save to temporary file
                with tempfile.NamedTemporaryFile(suffix='.flac', delete=False) as tmp_file:
                    torchaudio.save(tmp_file.name, wav[0].cpu().float(), self.model.sample_rate)
                    output_files.append(tmp_file.name)
            
            progress(1.0, desc="Complete!")
            return output_files, f"Successfully generated {n_samples} song(s)!"
            
        except Exception as e:
            return [], f"Error generating song: {str(e)}"

    def format_lyrics_example(self, example_type):
        """Provide example lyrics in the correct format"""
        if example_type == "Chinese":
            return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] 风轻轻吹过古道.岁月在墙上刻下记号.梦中你笑得多甜.醒来却只剩下寂寥.繁花似锦的春天.少了你的色彩也失了妖娆 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] 月儿弯弯照九州.你是否也在仰望同一片天空.灯火阑珊处.我寻觅你的影踪.回忆如波光粼粼.荡漾在心湖的每个角落 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
        else:  # English
            return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] City lights flicker through the car window. Dreams pass fast where the lost ones go. Neon signs echo stories untold. I chase shadows while the night grows cold , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] Footsteps loud in the tunnel of time. Regret and hope in a crooked rhyme. You held my hand when I slipped through the dark. Lit a match and you became my spark , [bridge] We were nothing and everything too. Lost in a moment. found in the view. Of all we broke and still survived. Somehow the flame stayed alive , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"


# Initialize the app
app = SongBloomApp()

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="SongBloom: AI Song Generation", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # 🎵 SongBloom: AI Song Generation
        
        Generate full-length songs from lyrics and audio prompts using the SongBloom model.
        
        **How to use:**
        1. First, load the model (this may take a few minutes)
        2. Enter your lyrics in the specified format
        3. Upload a 10-second audio prompt (WAV format, 48kHz recommended)
        4. Click "Generate Song" to create your music
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                # Model Loading Section
                gr.Markdown("## 🤖 Model Setup")
                model_status = gr.Textbox(
                    label="Model Status", 
                    value="Model not loaded",
                    interactive=False
                )
                
                with gr.Row():
                    repo_id = gr.Textbox(
                        label="Repository ID", 
                        value="CypressYang/SongBloom",
                        interactive=True
                    )
                    model_name = gr.Textbox(
                        label="Model Name", 
                        value="songbloom_full_150s",
                        interactive=True
                    )
                
                dtype_choice = gr.Dropdown(
                    choices=["float32", "bfloat16"],
                    value="float32",
                    label="Precision (use bfloat16 for lower VRAM)",
                    interactive=True
                )
                
                load_btn = gr.Button("Load Model", variant="primary")
                
                # Lyrics Input Section
                gr.Markdown("## 📝 Lyrics Input")
                
                # Example selector
                example_type = gr.Dropdown(
                    choices=["Chinese", "English"],
                    value="Chinese",
                    label="Load Example Lyrics",
                    interactive=True
                )
                
                lyrics_input = gr.Textbox(
                    label="Lyrics",
                    placeholder="Enter your lyrics in the specified format...",
                    lines=8,
                    max_lines=15
                )
                
                load_example_btn = gr.Button("Load Example", variant="secondary")
                
                # Audio Upload Section
                gr.Markdown("## 🎧 Audio Prompt")
                audio_input = gr.Audio(
                    label="Upload Audio Prompt (10-second WAV file recommended)",
                    type="filepath"
                )
                
                # Generation Settings
                gr.Markdown("## ⚙️ Generation Settings")
                n_samples = gr.Slider(
                    minimum=1,
                    maximum=5,
                    value=2,
                    step=1,
                    label="Number of samples to generate"
                )
                
                generate_btn = gr.Button("🎵 Generate Song", variant="primary", size="lg")
                
            with gr.Column(scale=1):
                # Output Section
                gr.Markdown("## 🎶 Generated Songs")
                
                generation_status = gr.Textbox(
                    label="Generation Status",
                    value="Ready to generate",
                    interactive=False
                )
                
                output_audio = gr.Gallery(
                    label="Generated Audio Files",
                    show_label=True,
                    elem_id="gallery",
                    columns=1,
                    rows=3,
                    object_fit="contain",
                    height="auto",
                    type="filepath"
                )
                
                # Format Instructions
                gr.Markdown("""
                ## 📋 Lyric Format Instructions
                
                **Structure Tags:**
                - `[intro]`, `[verse]`, `[chorus]`, `[bridge]`, `[inst]`, `[outro]`
                - Repeat tags for duration (e.g., `[intro] [intro] [intro]` for ~3 seconds)
                
                **Text Rules:**
                - Use `.` to separate sentences
                - Use `,` to separate sections
                - Example: `[verse] First line. Second line , [chorus] Chorus text`
                
                **Audio Prompt:**
                - 10-second audio file
                - WAV format preferred
                - 48kHz sample rate recommended
                - Defines the musical style/genre
                """)
        
        # Event handlers
        load_btn.click(
            fn=app.load_model,
            inputs=[repo_id, model_name, dtype_choice],
            outputs=[model_status]
        )
        
        load_example_btn.click(
            fn=app.format_lyrics_example,
            inputs=[example_type],
            outputs=[lyrics_input]
        )
        
        generate_btn.click(
            fn=app.generate_song,
            inputs=[lyrics_input, audio_input, n_samples, dtype_choice],
            outputs=[output_audio, generation_status]
        )
    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True
    )