Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,376 Bytes
65e9daa 4c10907 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 7d35d1e 65e9daa 4c10907 65e9daa 4c10907 65e9daa 4c10907 65e9daa 4c10907 65e9daa 4c10907 65e9daa 4c10907 65e9daa 4c10907 65e9daa 4c10907 65e9daa 8e872fa 65e9daa 8e872fa 65e9daa 8e872fa 65e9daa 8e872fa 65e9daa 4c10907 65e9daa 7d35d1e 65e9daa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
#!/usr/bin/env python3
"""
Generate audio using JAM model
Reads from filtered test set and generates audio using CFM+DiT model.
"""
import os
import glob
import time
import json
import random
import sys
from huggingface_hub import snapshot_download
import torch
import torchaudio
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import accelerate
import pyloudnorm as pyln
from safetensors.torch import load_file
from muq import MuQMuLan
import numpy as np
from accelerate import Accelerator
from jam.dataset import enhance_webdataset_config, DiffusionWebDataset
from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE
# DiffRhythm imports for CFM+DiT model
from jam.model import CFM, DiT
def get_negative_style_prompt(device, file_path):
vocal_stlye = np.load(file_path)
vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
vocal_stlye = vocal_stlye.half()
return vocal_stlye
def normalize_audio(audio, normalize_lufs=True):
audio = audio - audio.mean(-1, keepdim=True)
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
if normalize_lufs:
meter = pyln.Meter(rate=44100)
target_lufs = -14.0
loudness = meter.integrated_loudness(audio.transpose(0, 1).numpy())
normalised = pyln.normalize.loudness(audio.transpose(0, 1).numpy(), loudness, target_lufs)
normalised = torch.from_numpy(normalised).transpose(0, 1)
else:
normalised = audio
return normalised
class FilteredTestSetDataset(Dataset):
"""Custom dataset for loading from filtered test set JSON"""
def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False):
with open(test_set_path, 'r') as f:
self.test_samples = json.load(f)
if num_samples is not None:
self.test_samples = self.test_samples[:num_samples]
self.diffusion_dataset = diffusion_dataset
self.muq_model = muq_model
self.random_crop_style = random_crop_style
self.num_style_secs = num_style_secs
self.use_prompt_style = use_prompt_style
if self.use_prompt_style:
print("Using prompt style instead of audio style.")
def __len__(self):
return len(self.test_samples)
def __getitem__(self, idx):
test_sample = self.test_samples[idx]
sample_id = test_sample["id"]
# Load LRC data
lrc_path = test_sample["lrc_path"]
with open(lrc_path, 'r') as f:
lrc_data = json.load(f)
if 'word' not in lrc_data:
data = {'word': lrc_data}
lrc_data = data
# Generate style embedding from original audio on-the-fly
audio_path = test_sample["audio_path"]
if self.use_prompt_style:
prompt_path = test_sample["prompt_path"]
prompt = open(prompt_path, 'r').read()
if len(prompt) > 300:
print(f"Sample {sample_id} has prompt length {len(prompt)}")
prompt = prompt[:300]
print(prompt)
style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
else:
style_embedding = self.generate_style_embedding(audio_path)
duration = test_sample["duration"]
# Create fake latent with correct length
# Assuming frame_rate from config (typically 21.5 fps for 44.1kHz)
frame_rate = 21.5
num_frames = int(duration * frame_rate)
fake_latent = torch.randn(128, num_frames) # 128 is latent dim
# Create sample tuple matching DiffusionWebDataset format
fake_sample = (
sample_id,
fake_latent, # latent with correct duration
style_embedding, # style from actual audio
lrc_data # actual LRC data
)
# Process through DiffusionWebDataset's process_sample_safely
processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample)
# Add metadata
if processed_sample is not None:
processed_sample['test_metadata'] = {
'sample_id': sample_id,
'audio_path': audio_path,
'lrc_path': lrc_path,
'duration': duration,
'num_frames': num_frames
}
return processed_sample
def generate_style_embedding(self, audio_path):
"""Generate style embedding using MuQ model on the whole music"""
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 24kHz if needed (MuQ expects 24kHz)
if sample_rate != 24000:
resampler = torchaudio.transforms.Resample(sample_rate, 24000)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono
waveform = waveform.squeeze(0) # Now shape is (time,)
# Move to same device as model
waveform = waveform.to(self.muq_model.device)
# Generate embedding using MuQ model
with torch.inference_mode():
# MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim)
if self.random_crop_style:
# Randomly crop 30 seconds from the waveform
total_samples = waveform.shape[0]
target_samples = 24000 * self.num_style_secs # 30 seconds at 24kHz
start_idx = random.randint(0, total_samples - target_samples)
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples])
else:
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs])
# Keep shape as (embedding_dim,) not scalar
return style_embedding[0]
def custom_collate_fn_with_metadata(batch, base_collate_fn):
"""Custom collate function that preserves test_metadata"""
# Filter out None samples
batch = [item for item in batch if item is not None]
if not batch:
return None
# Extract test_metadata before collating
test_metadata = [item.pop('test_metadata') for item in batch]
# Use base collate function for the rest
collated = base_collate_fn(batch)
# Add test_metadata back
if collated is not None:
collated['test_metadata'] = test_metadata
return collated
def load_model(model_config, checkpoint_path, device):
"""
Load JAM CFM model from checkpoint (follows infer.py pattern)
"""
# Build CFM model from config
dit_config = model_config["dit"].copy()
# Add text_num_embeds if not specified - should be at least 64 for phoneme tokens
if "text_num_embeds" not in dit_config:
dit_config["text_num_embeds"] = 256 # Default value from DiT
cfm = CFM(
transformer=DiT(**dit_config),
**model_config["cfm"]
)
cfm = cfm.to(device)
# Load checkpoint - use the path from config
checkpoint = load_file(checkpoint_path)
cfm.load_state_dict(checkpoint, strict=False)
return cfm.eval()
def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'):
"""
Generate latent from batch data (follows infer.py pattern)
"""
with torch.inference_mode():
batch_size = len(batch["lrc"])
text = batch["lrc"].to(device)
style_prompt = batch["prompt"].to(device)
start_time = batch["start_time"].to(device)
duration_abs = batch["duration_abs"].to(device)
duration_rel = batch["duration_rel"].to(device)
# Create zero conditioning latent
# Handle case where model might be wrapped by accelerator
max_frames = model.max_frames
cond = torch.zeros(batch_size, max_frames, 64).to(text.device)
pred_frames = [(0, max_frames)]
default_sample_kwargs = {
"cfg_strength": 4,
"steps": 50,
"batch_infer_num": 1
}
sample_kwargs = {**default_sample_kwargs, **sample_kwargs}
if negative_style_prompt_path is None:
negative_style_prompt_path = 'public_checkpoints/vocal.npy'
negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
elif negative_style_prompt_path == 'zeros':
negative_style_prompt = torch.zeros(1, 512).to(text.device)
else:
negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
latents, _ = model.sample(
cond=cond,
text=text,
style_prompt=negative_style_prompt if ignore_style else style_prompt,
duration_abs=duration_abs,
duration_rel=duration_rel,
negative_style_prompt=negative_style_prompt,
start_time=start_time,
latent_pred_segments=pred_frames,
**sample_kwargs
)
return latents
class Jamify:
def __init__(self):
os.makedirs('outputs', exist_ok=True)
device = 'cuda'
config_path = 'jam_infer.yaml'
self.config = OmegaConf.load(config_path)
OmegaConf.resolve(self.config)
# Override output directory for evaluation
print("Downloading main model checkpoint...")
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
# Load VAE based on configuration
vae_type = self.config.evaluation.get('vae_type', 'stable_audio')
if vae_type == 'diffrhythm':
vae = DiffRhythmVAE(device=device).to(device)
else:
vae = StableAudioOpenVAE().to(device)
self.vae = vae
self.vae_type = vae_type
self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, device)
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(device).eval()
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
enhance_webdataset_config(dataset_cfg)
# Override multiple_styles to False since we're generating single style embeddings
dataset_cfg.multiple_styles = False
self.base_dataset = DiffusionWebDataset(**dataset_cfg)
def cleanup_old_files(self, sample_id):
# Clean up old generated files (keep only last 5 files)
old_mp3_files = sorted(glob.glob("outputs/*.mp3"))
if len(old_mp3_files) >= 10:
for old_file in old_mp3_files[:-9]: # Keep last 4, delete older ones
try:
os.remove(old_file)
print(f"Cleaned up old file: {old_file}")
except OSError:
pass
os.unlink(f"outputs/{sample_id}.json")
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration):
sample_id = str(int(time.time() * 1000000)) # microsecond timestamp for uniqueness
test_set = [{
"id": sample_id,
"audio_path": reference_audio_path,
"lrc_path": lyrics_json_path,
"duration": duration,
"prompt_path": style_prompt
}]
json.dump(test_set, open(f"outputs/{sample_id}.json", "w"))
# Create filtered test set dataset
test_dataset = FilteredTestSetDataset(
test_set_path=f"outputs/{sample_id}.json",
diffusion_dataset=self.base_dataset,
muq_model=self.muq_model,
num_samples=1,
random_crop_style=self.config.evaluation.random_crop_style,
num_style_secs=self.config.evaluation.num_style_secs,
use_prompt_style=self.config.evaluation.use_prompt_style
)
# Create dataloader with custom collate function
dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn)
)
batch = next(iter(dataloader))
sample_kwargs = self.config.evaluation.sample_kwargs
latent = generate_latent(self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style)[0][0]
test_metadata = batch['test_metadata'][0]
sample_id = test_metadata['sample_id']
original_duration = test_metadata['duration']
# Decode audio
latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
# Use chunked decoding if configured (only for DiffRhythm VAE)
use_chunked = self.config.evaluation.get('use_chunked_decoding', True)
if self.vae_type == 'diffrhythm' and use_chunked:
pred_audio = self.vae.decode(
latent_for_vae,
chunked=True,
overlap=self.config.evaluation.get('chunked_overlap', 32),
chunk_size=self.config.evaluation.get('chunked_size', 128)
).sample.squeeze(0).detach().cpu()
else:
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
pred_audio = normalize_audio(pred_audio)
sample_rate = 44100
trim_samples = int(original_duration * sample_rate)
if pred_audio.shape[1] > trim_samples:
pred_audio_trimmed = pred_audio[:, :trim_samples]
else:
pred_audio_trimmed = pred_audio
output_path = f'outputs/{sample_id}.mp3'
torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3")
self.cleanup_old_files(sample_id)
return output_path
|