JAM / model.py
CY
Added jam space
7d35d1e
raw
history blame
7.69 kB
import torch
import torchaudio
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
import numpy as np
import json
import os
from safetensors.torch import load_file
# Imports from the jamify library
from jam.model.cfm import CFM
from jam.model.dit import DiT
from jam.model.vae import StableAudioOpenVAE
from jam.dataset import DiffusionWebDataset, enhance_webdataset_config
from muq import MuQMuLan
# Helper functions adapted from jamify/src/jam/infer.py
def get_negative_style_prompt(device, file_path):
vocal_style = np.load(file_path)
vocal_style = torch.from_numpy(vocal_style).to(device)
return vocal_style.half()
def normalize_audio(audio):
audio = audio - audio.mean(-1, keepdim=True)
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
return audio
class Jamify:
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
self.device = torch.device(device)
# --- FIX: Point to the local jamify repository for config and public files ---
#jamify_repo_path = "/Users/cy/Desktop/JAM/jamify"
print("Downloading main model checkpoint...")
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
# Use local config and data files
config_path = os.path.join(model_repo_path, "jam_infer.yaml")
self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy")
tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt")
silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt")
print("Loading configuration...")
self.config = OmegaConf.load(config_path)
self.config.data.train_dataset.silence_latent_path = silence_latent_path
# --- FIX: Override the relative paths in the config with absolute paths ---
self.config.data.train_dataset.tokenizer_path = tokenizer_path
self.config.evaluation.dataset.tokenizer_path = tokenizer_path
self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path
print("Loading VAE model...")
self.vae = StableAudioOpenVAE().to(self.device).eval()
print("Loading CFM model...")
self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path)
print("Loading MuQ style model...")
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
print("Setting up dataset processor...")
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
enhance_webdataset_config(dataset_cfg)
dataset_cfg.multiple_styles = False
self.dataset_processor = DiffusionWebDataset(**dataset_cfg)
print("Jamify model loaded successfully.")
def _load_cfm_model(self, model_config, checkpoint_path):
dit_config = model_config["dit"].copy()
if "text_num_embeds" not in dit_config:
dit_config["text_num_embeds"] = 256
model = CFM(
transformer=DiT(**dit_config),
**model_config["cfm"]
).to(self.device)
state_dict = load_file(checkpoint_path)
model.load_state_dict(state_dict, strict=False)
return model.eval()
def _generate_style_embedding_from_audio(self, audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 24000:
resampler = torchaudio.transforms.Resample(sample_rate, 24000)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.squeeze(0).to(self.device)
with torch.inference_mode():
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30])
return style_embedding[0]
def _generate_style_embedding_from_prompt(self, prompt):
with torch.inference_mode():
style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
return style_embedding
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50):
print("Starting prediction...")
if reference_audio_path:
print(f"Generating style from audio: {reference_audio_path}")
style_embedding = self._generate_style_embedding_from_audio(reference_audio_path)
elif style_prompt:
print(f"Generating style from prompt: '{style_prompt}'")
style_embedding = self._generate_style_embedding_from_prompt(style_prompt)
else:
print("No style provided, using zero embedding.")
style_embedding = torch.zeros(512, device=self.device)
print(f"Loading lyrics from: {lyrics_json_path}")
with open(lyrics_json_path, 'r') as f:
lrc_data = json.load(f)
if 'word' not in lrc_data:
lrc_data = {'word': lrc_data}
frame_rate = 21.5
num_frames = int(duration_sec * frame_rate)
fake_latent = torch.randn(128, num_frames)
sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data)
print("Processing sample...")
processed_sample = self.dataset_processor.process_sample_safely(sample_tuple)
if processed_sample is None:
raise ValueError("Failed to process the provided lyrics and style.")
batch = self.dataset_processor.custom_collate_fn([processed_sample])
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.to(self.device)
print("Generating audio latent...")
with torch.inference_mode():
batch_size = 1
text = batch["lrc"]
style_prompt_tensor = batch["prompt"]
start_time = batch["start_time"]
duration_abs = batch["duration_abs"]
duration_rel = batch["duration_rel"]
cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device)
pred_frames = [(0, self.cfm_model.max_frames)]
negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path)
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
sample_kwargs = self.config.evaluation.sample_kwargs
sample_kwargs.steps = steps
latents, _ = self.cfm_model.sample(
cond=cond, text=text, style_prompt=style_prompt_tensor,
duration_abs=duration_abs, duration_rel=duration_rel,
negative_style_prompt=negative_style_prompt, start_time=start_time,
latent_pred_segments=pred_frames, **sample_kwargs)
latent = latents[0][0]
print("Decoding latent to audio...")
latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
pred_audio = normalize_audio(pred_audio)
sample_rate = 44100
trim_samples = int(duration_sec * sample_rate)
if pred_audio.shape[1] > trim_samples:
pred_audio = pred_audio[:, :trim_samples]
output_path = "generated_song.mp3"
print(f"Saving audio to {output_path}")
torchaudio.save(output_path, pred_audio, sample_rate, format="mp3")
return output_path