File size: 14,376 Bytes
65e9daa
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
 
7d35d1e
 
 
65e9daa
 
 
 
7d35d1e
 
65e9daa
 
 
 
 
 
 
 
7d35d1e
 
65e9daa
 
 
 
 
 
7d35d1e
65e9daa
7d35d1e
 
65e9daa
 
 
 
 
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
65e9daa
7d35d1e
65e9daa
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
 
 
7d35d1e
65e9daa
 
7d35d1e
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
65e9daa
 
7d35d1e
 
 
65e9daa
 
7d35d1e
 
 
65e9daa
 
 
 
 
7d35d1e
65e9daa
7d35d1e
65e9daa
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
 
 
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c10907
 
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
 
 
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
 
 
 
 
 
4c10907
65e9daa
4c10907
 
65e9daa
 
 
4c10907
65e9daa
 
 
 
 
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
8e872fa
65e9daa
 
 
 
 
 
 
 
 
 
 
 
8e872fa
65e9daa
 
 
8e872fa
 
 
 
 
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e872fa
65e9daa
 
 
4c10907
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
65e9daa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
#!/usr/bin/env python3
"""
Generate audio using JAM model
Reads from filtered test set and generates audio using CFM+DiT model.
"""

import os
import glob
import time
import json
import random
import sys
from huggingface_hub import snapshot_download
import torch
import torchaudio
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import accelerate
import pyloudnorm as pyln
from safetensors.torch import load_file
from muq import MuQMuLan
import numpy as np
from accelerate import Accelerator

from jam.dataset import enhance_webdataset_config, DiffusionWebDataset
from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE

# DiffRhythm imports for CFM+DiT model
from jam.model import CFM, DiT

def get_negative_style_prompt(device, file_path):
    vocal_stlye = np.load(file_path)

    vocal_stlye = torch.from_numpy(vocal_stlye).to(device)  # [1, 512]
    vocal_stlye = vocal_stlye.half()

    return vocal_stlye

def normalize_audio(audio, normalize_lufs=True):
    audio = audio - audio.mean(-1, keepdim=True)
    audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
    if normalize_lufs:
        meter = pyln.Meter(rate=44100)
        target_lufs = -14.0
        loudness = meter.integrated_loudness(audio.transpose(0, 1).numpy())
        normalised = pyln.normalize.loudness(audio.transpose(0, 1).numpy(), loudness, target_lufs)
        normalised = torch.from_numpy(normalised).transpose(0, 1)
    else:
        normalised = audio
    return normalised

class FilteredTestSetDataset(Dataset):
    """Custom dataset for loading from filtered test set JSON"""
    def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False):
        with open(test_set_path, 'r') as f:
            self.test_samples = json.load(f)
        
        if num_samples is not None:
            self.test_samples = self.test_samples[:num_samples]
            
        self.diffusion_dataset = diffusion_dataset
        self.muq_model = muq_model
        self.random_crop_style = random_crop_style
        self.num_style_secs = num_style_secs
        self.use_prompt_style = use_prompt_style
        if self.use_prompt_style:
            print("Using prompt style instead of audio style.")

    def __len__(self):
        return len(self.test_samples)
    
    def __getitem__(self, idx):
        test_sample = self.test_samples[idx]
        sample_id = test_sample["id"]
        
        # Load LRC data
        lrc_path = test_sample["lrc_path"]
        with open(lrc_path, 'r') as f:
            lrc_data = json.load(f)
        if 'word' not in lrc_data:
            data = {'word': lrc_data}
            lrc_data = data
        
        # Generate style embedding from original audio on-the-fly
        audio_path = test_sample["audio_path"]
        if self.use_prompt_style:
            prompt_path = test_sample["prompt_path"]
            prompt = open(prompt_path, 'r').read()
            if len(prompt) > 300:
                print(f"Sample {sample_id} has prompt length {len(prompt)}")
                prompt = prompt[:300]
            print(prompt)
            style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
        else:
            style_embedding = self.generate_style_embedding(audio_path)
        
        duration = test_sample["duration"]
        
        # Create fake latent with correct length
        # Assuming frame_rate from config (typically 21.5 fps for 44.1kHz)
        frame_rate = 21.5
        num_frames = int(duration * frame_rate)
        fake_latent = torch.randn(128, num_frames)  # 128 is latent dim
        
        # Create sample tuple matching DiffusionWebDataset format
        fake_sample = (
            sample_id,
            fake_latent,     # latent with correct duration
            style_embedding, # style from actual audio
            lrc_data        # actual LRC data
        )
        
        # Process through DiffusionWebDataset's process_sample_safely
        processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample)
        
        # Add metadata
        if processed_sample is not None:
            processed_sample['test_metadata'] = {
                'sample_id': sample_id,
                'audio_path': audio_path,
                'lrc_path': lrc_path,
                'duration': duration,
                'num_frames': num_frames
            }
        
        return processed_sample
    
    def generate_style_embedding(self, audio_path):
        """Generate style embedding using MuQ model on the whole music"""
        # Load audio
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Resample to 24kHz if needed (MuQ expects 24kHz)
        if sample_rate != 24000:
            resampler = torchaudio.transforms.Resample(sample_rate, 24000)
            waveform = resampler(waveform)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono
        waveform = waveform.squeeze(0)  # Now shape is (time,)
        
        # Move to same device as model
        waveform = waveform.to(self.muq_model.device)
        
        # Generate embedding using MuQ model
        with torch.inference_mode():
            # MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim)
            if self.random_crop_style:
                # Randomly crop 30 seconds from the waveform
                total_samples = waveform.shape[0]
                target_samples = 24000 * self.num_style_secs  # 30 seconds at 24kHz
                
                start_idx = random.randint(0, total_samples - target_samples)
                style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples])
            else:
                style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs])
        
        # Keep shape as (embedding_dim,) not scalar
        return style_embedding[0]


def custom_collate_fn_with_metadata(batch, base_collate_fn):
    """Custom collate function that preserves test_metadata"""
    # Filter out None samples
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    
    # Extract test_metadata before collating
    test_metadata = [item.pop('test_metadata') for item in batch]
    
    # Use base collate function for the rest
    collated = base_collate_fn(batch)
    
    # Add test_metadata back
    if collated is not None:
        collated['test_metadata'] = test_metadata
    
    return collated


def load_model(model_config, checkpoint_path, device):
    """
    Load JAM CFM model from checkpoint (follows infer.py pattern)
    """
    # Build CFM model from config
    dit_config = model_config["dit"].copy()
    # Add text_num_embeds if not specified - should be at least 64 for phoneme tokens
    if "text_num_embeds" not in dit_config:
        dit_config["text_num_embeds"] = 256  # Default value from DiT
    
    cfm = CFM(
        transformer=DiT(**dit_config),
        **model_config["cfm"]
    )
    cfm = cfm.to(device)
    
    # Load checkpoint - use the path from config
    checkpoint = load_file(checkpoint_path)
    cfm.load_state_dict(checkpoint, strict=False)
    
    return cfm.eval()


def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'):
    """
    Generate latent from batch data (follows infer.py pattern)
    """
    with torch.inference_mode():
        batch_size = len(batch["lrc"])
        text = batch["lrc"].to(device)
        style_prompt = batch["prompt"].to(device)
        start_time = batch["start_time"].to(device)
        duration_abs = batch["duration_abs"].to(device)
        duration_rel = batch["duration_rel"].to(device)
        
        # Create zero conditioning latent
        # Handle case where model might be wrapped by accelerator
        max_frames = model.max_frames
        cond = torch.zeros(batch_size, max_frames, 64).to(text.device)
        pred_frames = [(0, max_frames)]

        default_sample_kwargs = {
            "cfg_strength": 4,
            "steps": 50,
            "batch_infer_num": 1
        }
        sample_kwargs = {**default_sample_kwargs, **sample_kwargs}
        
        if negative_style_prompt_path is None:
            negative_style_prompt_path = 'public_checkpoints/vocal.npy'
            negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
        elif negative_style_prompt_path == 'zeros':
            negative_style_prompt = torch.zeros(1, 512).to(text.device)
        else:
            negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)

        negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)

        latents, _ = model.sample(
            cond=cond,
            text=text,
            style_prompt=negative_style_prompt if ignore_style else style_prompt,
            duration_abs=duration_abs,
            duration_rel=duration_rel,
            negative_style_prompt=negative_style_prompt,
            start_time=start_time,
            latent_pred_segments=pred_frames,
            **sample_kwargs
        )
        
        return latents


class Jamify:
    def __init__(self):
        os.makedirs('outputs', exist_ok=True)
        
        device = 'cuda'
        config_path = 'jam_infer.yaml'
        self.config = OmegaConf.load(config_path)
        OmegaConf.resolve(self.config)

        # Override output directory for evaluation
        print("Downloading main model checkpoint...")
        model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
        self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
        
        # Load VAE based on configuration
        vae_type = self.config.evaluation.get('vae_type', 'stable_audio')
        if vae_type == 'diffrhythm':
            vae = DiffRhythmVAE(device=device).to(device)
        else:
            vae = StableAudioOpenVAE().to(device)
        
        self.vae = vae
        self.vae_type = vae_type
        self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, device)
        self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(device).eval()

        dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
        enhance_webdataset_config(dataset_cfg)
        # Override multiple_styles to False since we're generating single style embeddings
        dataset_cfg.multiple_styles = False
        self.base_dataset = DiffusionWebDataset(**dataset_cfg)

    def cleanup_old_files(self, sample_id):
        # Clean up old generated files (keep only last 5 files)
        old_mp3_files = sorted(glob.glob("outputs/*.mp3"))
        if len(old_mp3_files) >= 10:
            for old_file in old_mp3_files[:-9]:  # Keep last 4, delete older ones
                try:
                    os.remove(old_file)
                    print(f"Cleaned up old file: {old_file}")
                except OSError:
                    pass
        os.unlink(f"outputs/{sample_id}.json")
    
    def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration):
        sample_id = str(int(time.time() * 1000000))  # microsecond timestamp for uniqueness
        test_set = [{
            "id": sample_id,
            "audio_path": reference_audio_path,
            "lrc_path": lyrics_json_path,
            "duration": duration,
            "prompt_path": style_prompt
        }]
        json.dump(test_set, open(f"outputs/{sample_id}.json", "w"))
        
        # Create filtered test set dataset
        test_dataset = FilteredTestSetDataset(
            test_set_path=f"outputs/{sample_id}.json",
            diffusion_dataset=self.base_dataset,
            muq_model=self.muq_model,
            num_samples=1,
            random_crop_style=self.config.evaluation.random_crop_style,
            num_style_secs=self.config.evaluation.num_style_secs,
            use_prompt_style=self.config.evaluation.use_prompt_style
        )
        
        # Create dataloader with custom collate function
        dataloader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn)
        )
        
        batch = next(iter(dataloader))
        sample_kwargs = self.config.evaluation.sample_kwargs
        latent = generate_latent(self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style)[0][0]
        
        test_metadata = batch['test_metadata'][0]
        sample_id = test_metadata['sample_id']
        original_duration = test_metadata['duration']

        # Decode audio
        latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
        
        # Use chunked decoding if configured (only for DiffRhythm VAE)
        use_chunked = self.config.evaluation.get('use_chunked_decoding', True)
        if self.vae_type == 'diffrhythm' and use_chunked:
            pred_audio = self.vae.decode(
                latent_for_vae, 
                chunked=True, 
                overlap=self.config.evaluation.get('chunked_overlap', 32),
                chunk_size=self.config.evaluation.get('chunked_size', 128)
            ).sample.squeeze(0).detach().cpu()
        else:
            pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
        
        pred_audio = normalize_audio(pred_audio)
        sample_rate = 44100
        trim_samples = int(original_duration * sample_rate)
        if pred_audio.shape[1] > trim_samples:
            pred_audio_trimmed = pred_audio[:, :trim_samples]
        else:
            pred_audio_trimmed = pred_audio
            
        output_path = f'outputs/{sample_id}.mp3'
        torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3")
        self.cleanup_old_files(sample_id)
        return output_path