mrfakename commited on
Commit
9cc2d55
·
verified ·
1 Parent(s): 14f64c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -279
app.py CHANGED
@@ -1,291 +1,90 @@
1
- import os
2
- import sys
3
- import torch
4
- import torchaudio
5
- import tempfile
6
- import json
7
- import gradio as gr
8
- from omegaconf import OmegaConf
9
- from huggingface_hub import hf_hub_download
10
  import spaces
 
 
 
11
 
12
- # Add the SongBloom module to the path
13
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
-
15
- os.environ['DISABLE_FLASH_ATTN'] = "1"
16
- from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler
17
-
18
-
19
- class SongBloomApp:
20
- def __init__(self):
21
- self.model = None
22
- self.is_loading = False
23
-
24
- def hf_download(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache"):
25
- """Download model files from Hugging Face"""
26
- cfg_path = hf_hub_download(
27
- repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir)
28
- ckpt_path = hf_hub_download(
29
- repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir)
30
-
31
- vae_cfg_path = hf_hub_download(
32
- repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir)
33
- vae_ckpt_path = hf_hub_download(
34
- repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir)
35
-
36
- g2p_path = hf_hub_download(
37
- repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir)
38
-
39
- return cfg_path
40
 
41
- def load_config(self, cfg_file, parent_dir="./"):
42
- """Load model configuration"""
43
- OmegaConf.register_new_resolver("eval", lambda x: eval(x))
44
- OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
45
- OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
46
- OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x))
47
- OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
48
-
49
- file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None else OmegaConf.create()
50
- return file_cfg
51
 
52
- def load_model(self, repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", dtype="float32"):
53
- """Load the SongBloom model"""
54
- if self.is_loading:
55
- return "Model is already loading, please wait..."
56
-
57
- if self.model is not None:
58
- return "Model is already loaded!"
59
-
60
- try:
61
- self.is_loading = True
62
- local_dir = "./cache"
63
-
64
- # Download model files
65
- cfg_path = self.hf_download(repo_id, model_name, local_dir)
66
- cfg = self.load_config(f"{local_dir}/{model_name}.yaml", parent_dir=local_dir)
67
-
68
- # Load model
69
- dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
70
- self.model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype_torch)
71
- self.model.set_generation_params(**cfg.inference)
72
-
73
- self.is_loading = False
74
- return "Model loaded successfully!"
75
-
76
- except Exception as e:
77
- self.is_loading = False
78
- return f"Error loading model: {str(e)}"
79
 
80
- @spaces.GPU
81
- def generate_song(self, lyrics, prompt_audio, n_samples=1, dtype="float32", progress=gr.Progress()):
82
- """Generate song from lyrics and audio prompt"""
83
- if self.model is None:
84
- return [], "Please load the model first!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- if not lyrics.strip():
87
- return [], "Please provide lyrics!"
 
 
 
88
 
89
- if prompt_audio is None:
90
- return [], "Please upload a prompt audio file!"
 
 
 
 
91
 
92
- try:
93
- progress(0.1, desc="Processing audio prompt...")
94
-
95
- # Load and process the prompt audio
96
- prompt_wav, sr = torchaudio.load(prompt_audio)
97
- if sr != self.model.sample_rate:
98
- prompt_wav = torchaudio.functional.resample(prompt_wav, sr, self.model.sample_rate)
99
-
100
- # Convert to mono and limit to 10 seconds
101
- dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
102
- prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype_torch)
103
- prompt_wav = prompt_wav[..., :10*self.model.sample_rate]
104
-
105
- progress(0.3, desc="Generating song...")
106
-
107
- output_files = []
108
-
109
- # Generate samples
110
- for i in range(n_samples):
111
- progress(0.3 + (i / n_samples) * 0.6, desc=f"Generating sample {i+1}/{n_samples}...")
112
-
113
- wav = self.model.generate(lyrics, prompt_wav)
114
-
115
- # Save to temporary file
116
- with tempfile.NamedTemporaryFile(suffix='.flac', delete=False) as tmp_file:
117
- torchaudio.save(tmp_file.name, wav[0].cpu().float(), self.model.sample_rate)
118
- output_files.append(tmp_file.name)
119
-
120
- progress(1.0, desc="Complete!")
121
- return output_files, f"Successfully generated {n_samples} song(s)!"
122
-
123
- except Exception as e:
124
- return [], f"Error generating song: {str(e)}"
125
-
126
- def format_lyrics_example(self, example_type):
127
- """Provide example lyrics in the correct format"""
128
- if example_type == "Chinese":
129
- return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] 风轻轻吹过古道.岁月在墙上刻下记号.梦中你笑得多甜.醒来却只剩下寂寥.繁花似锦的春天.少了你的色彩也失了妖娆 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] 月儿弯弯照九州.你是否也在仰望同一片天空.灯火阑珊处.我寻觅你的影踪.回忆如波光粼粼.荡漾在心湖的每个角落 , [chorus] 想见你.在晨曦中.在月光下.每个瞬间都渴望.没有你.星辰也黯淡.花香也无味.只剩下思念的煎熬.想见你.穿越千山万水.只为那一瞥.你的容颜 , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
130
- else: # English
131
- return "[intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] [intro] , [verse] City lights flicker through the car window. Dreams pass fast where the lost ones go. Neon signs echo stories untold. I chase shadows while the night grows cold , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] [inst] , [verse] Footsteps loud in the tunnel of time. Regret and hope in a crooked rhyme. You held my hand when I slipped through the dark. Lit a match and you became my spark , [bridge] We were nothing and everything too. Lost in a moment. found in the view. Of all we broke and still survived. Somehow the flame stayed alive , [chorus] Run with me down the empty street. Where silence and heartbeat always meet. Every breath. a whispered vow. We are forever. here and now , [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro] [outro]"
132
-
133
-
134
- # Initialize the app
135
- app = SongBloomApp()
136
 
137
  # Create Gradio interface
138
- def create_interface():
139
- with gr.Blocks(title="SongBloom: AI Song Generation", theme=gr.themes.Soft()) as demo:
140
- gr.Markdown("""
141
- # 🎵 SongBloom: AI Song Generation
142
-
143
- Generate full-length songs from lyrics and audio prompts using the SongBloom model.
144
-
145
- **How to use:**
146
- 1. First, load the model (this may take a few minutes)
147
- 2. Enter your lyrics in the specified format
148
- 3. Upload a 10-second audio prompt (WAV format, 48kHz recommended)
149
- 4. Click "Generate Song" to create your music
150
- """)
151
-
152
- with gr.Row():
153
- with gr.Column(scale=1):
154
- # Model Loading Section
155
- gr.Markdown("## 🤖 Model Setup")
156
- model_status = gr.Textbox(
157
- label="Model Status",
158
- value="Model not loaded",
159
- interactive=False
160
- )
161
-
162
- with gr.Row():
163
- repo_id = gr.Textbox(
164
- label="Repository ID",
165
- value="CypressYang/SongBloom",
166
- interactive=True
167
- )
168
- model_name = gr.Textbox(
169
- label="Model Name",
170
- value="songbloom_full_150s",
171
- interactive=True
172
- )
173
-
174
- dtype_choice = gr.Dropdown(
175
- choices=["float32", "bfloat16"],
176
- value="float32",
177
- label="Precision (use bfloat16 for lower VRAM)",
178
- interactive=True
179
- )
180
-
181
- load_btn = gr.Button("Load Model", variant="primary")
182
-
183
- # Lyrics Input Section
184
- gr.Markdown("## 📝 Lyrics Input")
185
-
186
- # Example selector
187
- example_type = gr.Dropdown(
188
- choices=["Chinese", "English"],
189
- value="Chinese",
190
- label="Load Example Lyrics",
191
- interactive=True
192
- )
193
-
194
- lyrics_input = gr.Textbox(
195
- label="Lyrics",
196
- placeholder="Enter your lyrics in the specified format...",
197
- lines=8,
198
- max_lines=15
199
- )
200
-
201
- load_example_btn = gr.Button("Load Example", variant="secondary")
202
-
203
- # Audio Upload Section
204
- gr.Markdown("## 🎧 Audio Prompt")
205
- audio_input = gr.Audio(
206
- label="Upload Audio Prompt (10-second WAV file recommended)",
207
- type="filepath"
208
- )
209
-
210
- # Generation Settings
211
- gr.Markdown("## ⚙️ Generation Settings")
212
- n_samples = gr.Slider(
213
- minimum=1,
214
- maximum=5,
215
- value=2,
216
- step=1,
217
- label="Number of samples to generate"
218
- )
219
-
220
- generate_btn = gr.Button("🎵 Generate Song", variant="primary", size="lg")
221
-
222
- with gr.Column(scale=1):
223
- # Output Section
224
- gr.Markdown("## 🎶 Generated Songs")
225
-
226
- generation_status = gr.Textbox(
227
- label="Generation Status",
228
- value="Ready to generate",
229
- interactive=False
230
- )
231
-
232
- output_audio = gr.Gallery(
233
- label="Generated Audio Files",
234
- show_label=True,
235
- elem_id="gallery",
236
- columns=1,
237
- rows=3,
238
- object_fit="contain",
239
- height="auto",
240
- type="filepath"
241
- )
242
-
243
- # Format Instructions
244
- gr.Markdown("""
245
- ## 📋 Lyric Format Instructions
246
-
247
- **Structure Tags:**
248
- - `[intro]`, `[verse]`, `[chorus]`, `[bridge]`, `[inst]`, `[outro]`
249
- - Repeat tags for duration (e.g., `[intro] [intro] [intro]` for ~3 seconds)
250
-
251
- **Text Rules:**
252
- - Use `.` to separate sentences
253
- - Use `,` to separate sections
254
- - Example: `[verse] First line. Second line , [chorus] Chorus text`
255
-
256
- **Audio Prompt:**
257
- - 10-second audio file
258
- - WAV format preferred
259
- - 48kHz sample rate recommended
260
- - Defines the musical style/genre
261
- """)
262
-
263
- # Event handlers
264
- load_btn.click(
265
- fn=app.load_model,
266
- inputs=[repo_id, model_name, dtype_choice],
267
- outputs=[model_status]
268
- )
269
-
270
- load_example_btn.click(
271
- fn=app.format_lyrics_example,
272
- inputs=[example_type],
273
- outputs=[lyrics_input]
274
- )
275
-
276
- generate_btn.click(
277
- fn=app.generate_song,
278
- inputs=[lyrics_input, audio_input, n_samples, dtype_choice],
279
- outputs=[output_audio, generation_status]
280
- )
281
-
282
- return demo
283
 
 
284
  if __name__ == "__main__":
285
- demo = create_interface()
286
- demo.launch(
287
- server_name="0.0.0.0",
288
- server_port=7860,
289
- share=False,
290
- show_error=True
291
- )
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
 
6
+ # Load model and tokenizer
7
+ model_path = "apple/DiffuCoder-7B-cpGRPO"
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ model = AutoModel.from_pretrained(
11
+ model_path,
12
+ torch_dtype=torch.bfloat16,
13
+ trust_remote_code=True
14
+ ).to(device).eval()
 
 
 
 
 
15
 
16
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ @spaces.GPU
19
+ def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
20
+ # Format prompt using chat template
21
+ prompt = f"""<|im_start|>system
22
+ You are a helpful coding assistant.<|im_end|>
23
+ <|im_start|>user
24
+ {query.strip()}<|im_end|>
25
+ <|im_start|>assistant
26
+ """
27
+
28
+ inputs = tokenizer(prompt, return_tensors="pt")
29
+ input_ids = inputs.input_ids.to(device)
30
+ attention_mask = inputs.attention_mask.to(device)
31
+
32
+ # Generate with token streaming
33
+ TOKEN_PER_STEP = 1
34
+ steps = max_new_tokens // TOKEN_PER_STEP
35
+
36
+ full_output = ""
37
+ for _ in range(steps):
38
+ output = model.diffusion_generate(
39
+ input_ids,
40
+ attention_mask=attention_mask,
41
+ max_new_tokens=TOKEN_PER_STEP,
42
+ output_history=True,
43
+ return_dict_in_generate=True,
44
+ steps=1,
45
+ temperature=temperature,
46
+ top_p=top_p,
47
+ alg="entropy",
48
+ alg_temp=0.,
49
+ )
50
 
51
+ # Decode new tokens
52
+ new_tokens = tokenizer.decode(
53
+ output.sequences[0, -TOKEN_PER_STEP:].tolist(),
54
+ skip_special_tokens=True
55
+ )
56
 
57
+ # Update input for next step
58
+ input_ids = output.sequences
59
+ attention_mask = torch.cat([
60
+ attention_mask,
61
+ torch.ones(1, 1, dtype=attention_mask.dtype, device=device)
62
+ ], dim=1)
63
 
64
+ # Append to full output and stream
65
+ full_output += new_tokens
66
+ yield full_output.split('<|dlm_pad|>')[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # Create Gradio interface
69
+ demo = gr.Interface(
70
+ fn=generate_code,
71
+ inputs=[
72
+ gr.Textbox(label="Code Request", lines=3,
73
+ placeholder="Describe the code you want..."),
74
+ gr.Slider(0.1, 1.0, value=0.4, label="Temperature"),
75
+ gr.Slider(0.5, 1.0, value=0.95, label="Top-p"),
76
+ gr.Slider(32, 512, value=256, step=32, label="Max Tokens")
77
+ ],
78
+ outputs=gr.Textbox(label="Generated Code", lines=10),
79
+ title="🧠 DiffuCoder Code Generator",
80
+ description="Generate code with Apple's DiffuCoder-7B model",
81
+ examples=[
82
+ ["Write a Python function to calculate factorial"],
83
+ ["Create a function to merge two sorted lists"],
84
+ ["How to reverse a string in JavaScript?"]
85
+ ]
86
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # Run the demo
89
  if __name__ == "__main__":
90
+ demo.queue().launch()