Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchaudio | |
from omegaconf import OmegaConf | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
import json | |
import os | |
from safetensors.torch import load_file | |
# Imports from the jamify library | |
from jam.model.cfm import CFM | |
from jam.model.dit import DiT | |
from jam.model.vae import StableAudioOpenVAE | |
from jam.dataset import DiffusionWebDataset, enhance_webdataset_config | |
from muq import MuQMuLan | |
# Helper functions adapted from jamify/src/jam/infer.py | |
def get_negative_style_prompt(device, file_path): | |
vocal_style = np.load(file_path) | |
vocal_style = torch.from_numpy(vocal_style).to(device) | |
return vocal_style.half() | |
def normalize_audio(audio): | |
audio = audio - audio.mean(-1, keepdim=True) | |
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) | |
return audio | |
class Jamify: | |
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): | |
self.device = torch.device(device) | |
# --- FIX: Point to the local jamify repository for config and public files --- | |
#jamify_repo_path = "/Users/cy/Desktop/JAM/jamify" | |
print("Downloading main model checkpoint...") | |
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") | |
self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") | |
# Use local config and data files | |
config_path = os.path.join(model_repo_path, "jam_infer.yaml") | |
self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy") | |
tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt") | |
silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt") | |
print("Loading configuration...") | |
self.config = OmegaConf.load(config_path) | |
self.config.data.train_dataset.silence_latent_path = silence_latent_path | |
# --- FIX: Override the relative paths in the config with absolute paths --- | |
self.config.data.train_dataset.tokenizer_path = tokenizer_path | |
self.config.evaluation.dataset.tokenizer_path = tokenizer_path | |
self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path | |
print("Loading VAE model...") | |
self.vae = StableAudioOpenVAE().to(self.device).eval() | |
print("Loading CFM model...") | |
self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path) | |
print("Loading MuQ style model...") | |
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval() | |
print("Setting up dataset processor...") | |
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) | |
enhance_webdataset_config(dataset_cfg) | |
dataset_cfg.multiple_styles = False | |
self.dataset_processor = DiffusionWebDataset(**dataset_cfg) | |
print("Jamify model loaded successfully.") | |
def _load_cfm_model(self, model_config, checkpoint_path): | |
dit_config = model_config["dit"].copy() | |
if "text_num_embeds" not in dit_config: | |
dit_config["text_num_embeds"] = 256 | |
model = CFM( | |
transformer=DiT(**dit_config), | |
**model_config["cfm"] | |
).to(self.device) | |
state_dict = load_file(checkpoint_path) | |
model.load_state_dict(state_dict, strict=False) | |
return model.eval() | |
def _generate_style_embedding_from_audio(self, audio_path): | |
waveform, sample_rate = torchaudio.load(audio_path) | |
if sample_rate != 24000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 24000) | |
waveform = resampler(waveform) | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
waveform = waveform.squeeze(0).to(self.device) | |
with torch.inference_mode(): | |
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30]) | |
return style_embedding[0] | |
def _generate_style_embedding_from_prompt(self, prompt): | |
with torch.inference_mode(): | |
style_embedding = self.muq_model(texts=[prompt]).squeeze(0) | |
return style_embedding | |
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50): | |
print("Starting prediction...") | |
if reference_audio_path: | |
print(f"Generating style from audio: {reference_audio_path}") | |
style_embedding = self._generate_style_embedding_from_audio(reference_audio_path) | |
elif style_prompt: | |
print(f"Generating style from prompt: '{style_prompt}'") | |
style_embedding = self._generate_style_embedding_from_prompt(style_prompt) | |
else: | |
print("No style provided, using zero embedding.") | |
style_embedding = torch.zeros(512, device=self.device) | |
print(f"Loading lyrics from: {lyrics_json_path}") | |
with open(lyrics_json_path, 'r') as f: | |
lrc_data = json.load(f) | |
if 'word' not in lrc_data: | |
lrc_data = {'word': lrc_data} | |
frame_rate = 21.5 | |
num_frames = int(duration_sec * frame_rate) | |
fake_latent = torch.randn(128, num_frames) | |
sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data) | |
print("Processing sample...") | |
processed_sample = self.dataset_processor.process_sample_safely(sample_tuple) | |
if processed_sample is None: | |
raise ValueError("Failed to process the provided lyrics and style.") | |
batch = self.dataset_processor.custom_collate_fn([processed_sample]) | |
for key, value in batch.items(): | |
if isinstance(value, torch.Tensor): | |
batch[key] = value.to(self.device) | |
print("Generating audio latent...") | |
with torch.inference_mode(): | |
batch_size = 1 | |
text = batch["lrc"] | |
style_prompt_tensor = batch["prompt"] | |
start_time = batch["start_time"] | |
duration_abs = batch["duration_abs"] | |
duration_rel = batch["duration_rel"] | |
cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device) | |
pred_frames = [(0, self.cfm_model.max_frames)] | |
negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path) | |
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) | |
sample_kwargs = self.config.evaluation.sample_kwargs | |
sample_kwargs.steps = steps | |
latents, _ = self.cfm_model.sample( | |
cond=cond, text=text, style_prompt=style_prompt_tensor, | |
duration_abs=duration_abs, duration_rel=duration_rel, | |
negative_style_prompt=negative_style_prompt, start_time=start_time, | |
latent_pred_segments=pred_frames, **sample_kwargs) | |
latent = latents[0][0] | |
print("Decoding latent to audio...") | |
latent_for_vae = latent.transpose(0, 1).unsqueeze(0) | |
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() | |
pred_audio = normalize_audio(pred_audio) | |
sample_rate = 44100 | |
trim_samples = int(duration_sec * sample_rate) | |
if pred_audio.shape[1] > trim_samples: | |
pred_audio = pred_audio[:, :trim_samples] | |
output_path = "generated_song.mp3" | |
print(f"Saving audio to {output_path}") | |
torchaudio.save(output_path, pred_audio, sample_rate, format="mp3") | |
return output_path | |