import torch import torchaudio from omegaconf import OmegaConf from huggingface_hub import snapshot_download import numpy as np import json import os from safetensors.torch import load_file # Imports from the jamify library from jam.model.cfm import CFM from jam.model.dit import DiT from jam.model.vae import StableAudioOpenVAE from jam.dataset import DiffusionWebDataset, enhance_webdataset_config from muq import MuQMuLan # Helper functions adapted from jamify/src/jam/infer.py def get_negative_style_prompt(device, file_path): vocal_style = np.load(file_path) vocal_style = torch.from_numpy(vocal_style).to(device) return vocal_style.half() def normalize_audio(audio): audio = audio - audio.mean(-1, keepdim=True) audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) return audio class Jamify: def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): self.device = torch.device(device) # --- FIX: Point to the local jamify repository for config and public files --- #jamify_repo_path = "/Users/cy/Desktop/JAM/jamify" print("Downloading main model checkpoint...") model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") # Use local config and data files config_path = os.path.join(model_repo_path, "jam_infer.yaml") self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy") tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt") silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt") print("Loading configuration...") self.config = OmegaConf.load(config_path) self.config.data.train_dataset.silence_latent_path = silence_latent_path # --- FIX: Override the relative paths in the config with absolute paths --- self.config.data.train_dataset.tokenizer_path = tokenizer_path self.config.evaluation.dataset.tokenizer_path = tokenizer_path self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path print("Loading VAE model...") self.vae = StableAudioOpenVAE().to(self.device).eval() print("Loading CFM model...") self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path) print("Loading MuQ style model...") self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval() print("Setting up dataset processor...") dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) enhance_webdataset_config(dataset_cfg) dataset_cfg.multiple_styles = False self.dataset_processor = DiffusionWebDataset(**dataset_cfg) print("Jamify model loaded successfully.") def _load_cfm_model(self, model_config, checkpoint_path): dit_config = model_config["dit"].copy() if "text_num_embeds" not in dit_config: dit_config["text_num_embeds"] = 256 model = CFM( transformer=DiT(**dit_config), **model_config["cfm"] ).to(self.device) state_dict = load_file(checkpoint_path) model.load_state_dict(state_dict, strict=False) return model.eval() def _generate_style_embedding_from_audio(self, audio_path): waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != 24000: resampler = torchaudio.transforms.Resample(sample_rate, 24000) waveform = resampler(waveform) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform.squeeze(0).to(self.device) with torch.inference_mode(): style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30]) return style_embedding[0] def _generate_style_embedding_from_prompt(self, prompt): with torch.inference_mode(): style_embedding = self.muq_model(texts=[prompt]).squeeze(0) return style_embedding def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50): print("Starting prediction...") if reference_audio_path: print(f"Generating style from audio: {reference_audio_path}") style_embedding = self._generate_style_embedding_from_audio(reference_audio_path) elif style_prompt: print(f"Generating style from prompt: '{style_prompt}'") style_embedding = self._generate_style_embedding_from_prompt(style_prompt) else: print("No style provided, using zero embedding.") style_embedding = torch.zeros(512, device=self.device) print(f"Loading lyrics from: {lyrics_json_path}") with open(lyrics_json_path, 'r') as f: lrc_data = json.load(f) if 'word' not in lrc_data: lrc_data = {'word': lrc_data} frame_rate = 21.5 num_frames = int(duration_sec * frame_rate) fake_latent = torch.randn(128, num_frames) sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data) print("Processing sample...") processed_sample = self.dataset_processor.process_sample_safely(sample_tuple) if processed_sample is None: raise ValueError("Failed to process the provided lyrics and style.") batch = self.dataset_processor.custom_collate_fn([processed_sample]) for key, value in batch.items(): if isinstance(value, torch.Tensor): batch[key] = value.to(self.device) print("Generating audio latent...") with torch.inference_mode(): batch_size = 1 text = batch["lrc"] style_prompt_tensor = batch["prompt"] start_time = batch["start_time"] duration_abs = batch["duration_abs"] duration_rel = batch["duration_rel"] cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device) pred_frames = [(0, self.cfm_model.max_frames)] negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path) negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) sample_kwargs = self.config.evaluation.sample_kwargs sample_kwargs.steps = steps latents, _ = self.cfm_model.sample( cond=cond, text=text, style_prompt=style_prompt_tensor, duration_abs=duration_abs, duration_rel=duration_rel, negative_style_prompt=negative_style_prompt, start_time=start_time, latent_pred_segments=pred_frames, **sample_kwargs) latent = latents[0][0] print("Decoding latent to audio...") latent_for_vae = latent.transpose(0, 1).unsqueeze(0) pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() pred_audio = normalize_audio(pred_audio) sample_rate = 44100 trim_samples = int(duration_sec * sample_rate) if pred_audio.shape[1] > trim_samples: pred_audio = pred_audio[:, :trim_samples] output_path = "generated_song.mp3" print(f"Saving audio to {output_path}") torchaudio.save(output_path, pred_audio, sample_rate, format="mp3") return output_path