CY commited on
Commit
7d35d1e
·
1 Parent(s): a073f0a

Added jam space

Browse files
Files changed (5) hide show
  1. TangoFlux.py +0 -58
  2. app.py +64 -151
  3. gt0.json +1 -0
  4. model.py +162 -493
  5. requirements.txt +34 -10
TangoFlux.py DELETED
@@ -1,58 +0,0 @@
1
- from diffusers import AutoencoderOobleck
2
- import torch
3
- from transformers import T5EncoderModel,T5TokenizerFast
4
- from diffusers import FluxTransformer2DModel
5
- from torch import nn
6
- from typing import List
7
- from diffusers import FlowMatchEulerDiscreteScheduler
8
- from diffusers.training_utils import compute_density_for_timestep_sampling
9
- import copy
10
- import torch.nn.functional as F
11
- import numpy as np
12
- from model import TangoFlux
13
- from huggingface_hub import snapshot_download
14
- from tqdm import tqdm
15
- from typing import Optional,Union,List
16
- from datasets import load_dataset, Audio
17
- from math import pi
18
- import json
19
- import inspect
20
- import yaml
21
- from safetensors.torch import load_file
22
- import os
23
- print(os.environ['HOME'])
24
-
25
- class TangoFluxInference:
26
-
27
- def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
28
-
29
-
30
- self.vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae',token=os.environ['HF_TOKEN'])
31
-
32
- paths = snapshot_download(repo_id=name,token=os.environ['HF_TOKEN'])
33
- weights = load_file("{}/model_1.safetensors".format(paths))
34
-
35
- with open('{}/config.json'.format(paths),'r') as f:
36
- config = json.load(f)
37
- self.model = TangoFlux(config)
38
- self.model.load_state_dict(weights,strict=False)
39
- # _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
40
- self.vae.to(device)
41
- self.model.to(device)
42
-
43
- def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
44
-
45
- with torch.no_grad():
46
- latents = self.model.inference_flow(prompt,
47
- duration=duration,
48
- num_inference_steps=steps,
49
- guidance_scale=guidance_scale)
50
-
51
-
52
-
53
- wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
54
- return wave
55
-
56
-
57
-
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,155 +1,68 @@
1
- import spaces
2
  import gradio as gr
3
- import json
4
- import torch
5
- import wavio
6
- from tqdm import tqdm
7
- from huggingface_hub import snapshot_download
8
- from pydub import AudioSegment
9
- from gradio import Markdown
10
- import uuid
11
- import torch
12
- from diffusers import DiffusionPipeline,AudioPipelineOutput
13
- from transformers import CLIPTextModel, T5EncoderModel, AutoModel, T5Tokenizer, T5TokenizerFast
14
- from typing import Union
15
- from diffusers.utils.torch_utils import randn_tensor
16
- from tqdm import tqdm
17
- from TangoFlux import TangoFluxInference
18
- import torchaudio
19
-
20
-
21
-
22
- tangoflux = TangoFluxInference(name="techneto/tangoflux-music-base-24k-120k-steps")
23
-
24
-
25
-
26
- @spaces.GPU(duration=15)
27
- def gradio_generate(prompt, steps, guidance,duration=10):
28
-
29
- output = tangoflux.generate(prompt,steps=steps,guidance_scale=guidance,duration=duration)
30
- #output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
31
 
32
-
33
- #wavio.write(output_filename, output_wave, rate=44100, sampwidth=2)
34
- filename = 'temp.wav'
35
- #print(f"Saving audio to file: {unique_filename}")
36
-
37
- # Save to file
38
- torchaudio.save(filename, output, 44100)
39
-
40
-
41
- # Return the path to the generated audio file
42
- return filename
43
-
44
- #if (output_format == "mp3"):
45
- # AudioSegment.from_wav("temp.wav").export("temp.mp3", format = "mp3")
46
- # output_filename = "temp.mp3"
47
-
48
- #return output_filename
49
-
50
- description_text = """
51
- Generate high quality and faithful audio in just a few seconds using <b>TangoFlux</b> by providing a text prompt. <b>TangoFlux</b> was trained from scratch and underwent alignment to follow human instructions using a new method called <b>Claped-Ranked Preference Optimization (CRPO)</b>.
52
- <div style="display: flex; gap: 10px; align-items: center;">
53
- <a href="https://arxiv.org/abs/2412.21037">
54
- <img src="https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf" alt="arXiv">
55
- </a>
56
- <a href="https://huggingface.co/declare-lab/TangoFlux">
57
- <img src="https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
58
- </a>
59
- <a href="https://tangoflux.github.io/">
60
- <img src="https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat" alt="Static Badge">
61
- </a>
62
- <a href="https://huggingface.co/spaces/declare-lab/TangoFlux">
63
- <img src="https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
64
- </a>
65
- <a href="https://huggingface.co/datasets/declare-lab/CRPO">
66
- <img src="https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
67
- </a>
68
- <a href="https://github.com/declare-lab/TangoFlux">
69
- <img src="https://img.shields.io/badge/Github-brown?logo=github&link=https%3A%2F%2Fgithub.com%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
70
- </a>
71
- </div>
72
- """
73
- # Gradio input and output components
74
- input_text = gr.Textbox(lines=2, label="Prompt")
75
- #output_format = gr.Radio(label = "Output format", info = "The file you can dowload", choices = "wav"], value = "wav")
76
- output_audio = gr.Audio(label="Generated Audio", type="filepath")
77
- denoising_steps = gr.Slider(minimum=10, maximum=100, value=25, step=5, label="Steps", interactive=True)
78
- guidance_scale = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True)
79
- duration_scale = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True)
80
-
81
 
82
  # Gradio interface
83
- gr_interface = gr.Interface(
84
- fn=gradio_generate,
85
- inputs=[input_text, denoising_steps, guidance_scale,duration_scale],
86
- outputs=output_audio,
87
- title="TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization",
88
- description=description_text,
89
- allow_flagging=False,
90
- examples=[
91
- ["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion"],
92
- ["A soccer ball hits a goalpost with a metallic clang, followed by cheers, clapping, and the distant hum of a commentator’s voice"],
93
- ["The deep growl of an alligator ripples through the swamp as reeds sway with a soft rustle and a turtle splashes into the murky water"],
94
- ["A basketball bounces rhythmically on a court, shoes squeak against the floor, and a referee’s whistle cuts through the air"],
95
- ["A train conductor blows a sharp whistle, metal wheels screech on the rails, and passengers murmur while settling into their seats"],
96
- ["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background"],
97
- ["Alarms blare with rising urgency as fragments clatter against a metallic hull, interrupted by a faint hiss of escaping air"],
98
- ["Tiny pops and hisses of chemical reactions intermingle with the rhythmic pumping of a centrifuge and the soft whirr of air filtration"],
99
- ["A train conductor blows a sharp whistle, metal wheels screech on the rails, and passengers murmur while settling into their seats"],
100
- ["Simulate a forest ambiance with birds chirping and wind rustling through the leaves"],
101
- ["Quiet whispered conversation gradually fading into distant jet engine roar diminishing into silence"],
102
- ["Clear sound of bicycle tires crunching on loose gravel and dirt, followed by deep male laughter echoing"],
103
- ["Multiple ducks quacking loudly with splashing water and piercing wild animal shriek in background"],
104
- ["Create a serene soundscape of a quiet beach at sunset"],
105
- ["A pile of coins spills onto a wooden table with a metallic clatter, followed by the hushed murmur of a tavern crowd and the creak of a swinging door"],
106
- ["Generate an energetic and bustling city street scene with distant traffic and close conversations"],
107
- ["Powerful ocean waves crashing and receding on sandy beach with distant seagulls"],
108
- ["Gentle female voice cooing and baby responding with happy gurgles and giggles"],
109
- ["Clear male voice speaking, sharp popping sound, followed by genuine group laughter"],
110
- ["Stream of water hitting empty ceramic cup, pitch rising as cup fills up"],
111
- ["Massive crowd erupting in thunderous applause and excited cheering"],
112
- ["Deep rolling thunder with bright lightning strikes crackling through sky"],
113
- ["Aggressive dog barking and distressed cat meowing as racing car roars past at high speed"],
114
- ["Peaceful stream bubbling and birds singing, interrupted by sudden explosive gunshot"],
115
- ["Man speaking outdoors, goat bleating loudly, metal gate scraping closed, ducks quacking frantically, wind howling into microphone"],
116
- ["Series of loud aggressive dog barks echoing"],
117
- ["Multiple distinct cat meows at different pitches"],
118
- ["Rhythmic wooden table tapping overlaid with steady water pouring sound"],
119
- ["Sustained crowd applause with camera clicks and amplified male announcer voice"],
120
- ["Two sharp gunshots followed by panicked birds taking flight with rapid wing flaps"],
121
- ["Melodic human whistling harmonizing with natural birdsong"],
122
- ["Deep rhythmic snoring with clear breathing patterns"],
123
- ["Multiple racing engines revving and accelerating with sharp whistle piercing through"],
124
- ["Massive stadium crowd cheering as thunder crashes and lightning strikes"],
125
- ["Heavy helicopter blades chopping through air with engine and wind noise"],
126
- ["Dog barking excitedly and man shouting as race car engine roars past"],
127
- ["Quiet speech and then and airplane flying away"],
128
- ["A bicycle peddling on dirt and gravel followed by a man speaking then laughing"],
129
- ["Ducks quack and water splashes with some animal screeching in the background"],
130
- ["Describe the sound of the ocean"],
131
- ["A woman and a baby are having a conversation"],
132
- ["A man speaks followed by a popping noise and laughter"],
133
- ["A cup is filled from a faucet"],
134
- ["An audience cheering and clapping"],
135
- ["Rolling thunder with lightning strikes"],
136
- ["A dog barking and a cat mewing and a racing car passes by"],
137
- ["Gentle water stream, birds chirping and sudden gun shot"],
138
- ["A dog barking"],
139
- ["A cat meowing"],
140
- ["Wooden table tapping sound while water pouring"],
141
- ["Applause from a crowd with distant clicking and a man speaking over a loudspeaker"],
142
- ["two gunshots followed by birds flying away while chirping"],
143
- ["Whistling with birds chirping"],
144
- ["A person snoring"],
145
- ["Motor vehicles are driving with loud engines and a person whistles"],
146
- ["People cheering in a stadium while thunder and lightning strikes"],
147
- ["A helicopter is in flight"],
148
- ["A dog barking and a man talking and a racing car passes by"],
149
- ],
150
- cache_examples="lazy", # Turn on to cache.
151
- )
152
-
153
-
154
-
155
- gr_interface.queue(15).launch()
 
 
1
  import gradio as gr
2
+ import os
3
+ from model import Jamify
4
+
5
+ # Initialize the Jamify model once
6
+ print("Initializing Jamify model...")
7
+ jamify_model = Jamify()
8
+ print("Jamify model ready.")
9
+
10
+ def generate_song(reference_audio, lyrics_file, style_prompt, duration):
11
+ # We need to save the uploaded files to temporary paths to pass to the model
12
+ ref_audio_path = reference_audio.name if reference_audio else None
13
+ lyrics_path = lyrics_file.name
14
+
15
+ # The model expects paths, so we write the prompt to a temp file if needed
16
+ # (This part of the model could be improved to accept the string directly)
17
+
18
+ output_path = jamify_model.predict(
19
+ reference_audio_path=ref_audio_path,
20
+ lyrics_json_path=lyrics_path,
21
+ style_prompt=style_prompt,
22
+ duration_sec=duration
23
+ )
 
 
 
 
 
 
24
 
25
+ return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Gradio interface
28
+ with gr.Blocks() as demo:
29
+ gr.Markdown("# Jamify: Music Generation from Lyrics and Style")
30
+ gr.Markdown("Provide your lyrics, a style reference (either an audio file or a text prompt), and a desired duration to generate a song.")
31
+
32
+ with gr.Row():
33
+ with gr.Column():
34
+ gr.Markdown("### Inputs")
35
+ lyrics_file = gr.File(label="Lyrics File (.json)", type="filepath")
36
+ duration_slider = gr.Slider(minimum=5, maximum=180, value=30, step=1, label="Duration (seconds)")
37
+
38
+ with gr.Tab("Style from Audio"):
39
+ reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath")
40
+ with gr.Tab("Style from Text"):
41
+ style_prompt = gr.Textbox(label="Style Prompt", lines=3, placeholder="e.g., A high-energy electronic dance track with a strong bassline and euphoric synths.")
42
+
43
+ generate_button = gr.Button("Generate Song", variant="primary")
44
+
45
+ with gr.Column():
46
+ gr.Markdown("### Output")
47
+ output_audio = gr.Audio(label="Generated Song")
48
+
49
+ generate_button.click(
50
+ fn=generate_song,
51
+ inputs=[reference_audio, lyrics_file, style_prompt, duration_slider],
52
+ outputs=output_audio,
53
+ api_name="generate_song"
54
+ )
55
+
56
+ gr.Markdown("### Example Usage")
57
+ gr.Examples(
58
+ examples=[
59
+ [None, "jamify/inputs/Jade Bird - Avalanche.json", "A sad, slow, acoustic country song", 30],
60
+ ["jamify/inputs/Rizzle Kicks, Rachel Chinouriri - Follow Excitement!.mp3", "jamify/inputs/Rizzle Kicks, Rachel Chinouriri - Follow Excitement!.json", "", 45],
61
+ ],
62
+ inputs=[reference_audio, lyrics_file, style_prompt, duration_slider],
63
+ outputs=output_audio,
64
+ fn=generate_song,
65
+ cache_examples=True
66
+ )
67
+
68
+ demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gt0.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"word": "Every", "start_offset": 259, "end_offset": 267, "start": 20.72, "end": 21.36, "phoneme": "\u025bv\u025di|_"}, {"word": "night", "start_offset": 267, "end_offset": 275, "start": 21.36, "end": 22.0, "phoneme": "na\u026at|_"}, {"word": "in", "start_offset": 279, "end_offset": 283, "start": 22.32, "end": 22.64, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 283, "end_offset": 287, "start": 22.64, "end": 22.96, "phoneme": "ma\u026a|_"}, {"word": "dreams,", "start_offset": 287, "end_offset": 301, "start": 22.96, "end": 24.080000000000002, "phoneme": "dri\u02d0mz,"}, {"word": "I", "start_offset": 309, "end_offset": 313, "start": 24.72, "end": 25.04, "phoneme": "a\u026a|_"}, {"word": "see", "start_offset": 317, "end_offset": 321, "start": 25.36, "end": 25.68, "phoneme": "si\u02d0|_"}, {"word": "you,", "start_offset": 321, "end_offset": 325, "start": 25.68, "end": 26.0, "phoneme": "ju\u02d0,"}, {"word": "I", "start_offset": 340, "end_offset": 344, "start": 27.2, "end": 27.52, "phoneme": "a\u026a|_"}, {"word": "feel", "start_offset": 348, "end_offset": 352, "start": 27.84, "end": 28.16, "phoneme": "fi\u02d0l|_"}, {"word": "you.", "start_offset": 358, "end_offset": 362, "start": 28.64, "end": 28.96, "phoneme": "ju\u02d0."}, {"word": "That", "start_offset": 377, "end_offset": 381, "start": 30.16, "end": 30.48, "phoneme": "\u00f0\u00e6t|_"}, {"word": "is", "start_offset": 385, "end_offset": 389, "start": 30.8, "end": 31.12, "phoneme": "\u026az"}, {"word": "how", "start_offset": 393, "end_offset": 397, "start": 31.44, "end": 31.76, "phoneme": "ha\u028a|_"}, {"word": "I", "start_offset": 401, "end_offset": 405, "start": 32.08, "end": 32.4, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 405, "end_offset": 409, "start": 32.4, "end": 32.72, "phoneme": "no\u028a|_"}, {"word": "you", "start_offset": 413, "end_offset": 417, "start": 33.04, "end": 33.36, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 428, "end_offset": 431, "start": 34.24, "end": 34.480000000000004, "phoneme": "go\u028a|_"}, {"word": "far", "start_offset": 495, "end_offset": 503, "start": 39.6, "end": 40.24, "phoneme": "f\u0251\u02d0r"}, {"word": "across", "start_offset": 507, "end_offset": 517, "start": 40.56, "end": 41.36, "phoneme": "\u0259kr\u0254s|_"}, {"word": "the", "start_offset": 519, "end_offset": 523, "start": 41.52, "end": 41.84, "phoneme": "\u00f0\u0259|_"}, {"word": "distance", "start_offset": 527, "end_offset": 538, "start": 42.160000000000004, "end": 43.04, "phoneme": "d\u026ast\u0259ns|_"}, {"word": "and", "start_offset": 552, "end_offset": 556, "start": 44.160000000000004, "end": 44.480000000000004, "phoneme": "\u0259nd"}, {"word": "spaces", "start_offset": 556, "end_offset": 572, "start": 44.480000000000004, "end": 45.76, "phoneme": "spe\u026as\u0259z"}, {"word": "between", "start_offset": 583, "end_offset": 587, "start": 46.64, "end": 46.96, "phoneme": "b\u026atwi\u02d0n|_"}, {"word": "us.", "start_offset": 602, "end_offset": 606, "start": 48.160000000000004, "end": 48.480000000000004, "phoneme": "\u028cs."}, {"word": "You", "start_offset": 621, "end_offset": 625, "start": 49.68, "end": 50.0, "phoneme": "ju\u02d0|_"}, {"word": "have", "start_offset": 629, "end_offset": 633, "start": 50.32, "end": 50.64, "phoneme": "h\u00e6v"}, {"word": "come", "start_offset": 633, "end_offset": 637, "start": 50.64, "end": 50.96, "phoneme": "k\u028cm|_"}, {"word": "to", "start_offset": 641, "end_offset": 645, "start": 51.28, "end": 51.6, "phoneme": "tu\u02d0|_"}, {"word": "show", "start_offset": 649, "end_offset": 653, "start": 51.92, "end": 52.24, "phoneme": "\u0283o\u028a|_"}, {"word": "you", "start_offset": 655, "end_offset": 659, "start": 52.4, "end": 52.72, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 673, "end_offset": 676, "start": 53.84, "end": 54.08, "phoneme": "go\u028a|_"}, {"word": "near,", "start_offset": 738, "end_offset": 745, "start": 59.04, "end": 59.6, "phoneme": "n\u026ar,"}, {"word": "far,", "start_offset": 768, "end_offset": 776, "start": 61.44, "end": 62.08, "phoneme": "f\u0251\u02d0r,"}, {"word": "wherever", "start_offset": 794, "end_offset": 806, "start": 63.52, "end": 64.48, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 822, "end_offset": 826, "start": 65.76, "end": 66.08, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 826, "end_offset": 830, "start": 66.08, "end": 66.4, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 849, "end_offset": 852, "start": 67.92, "end": 68.16, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 856, "end_offset": 868, "start": 68.48, "end": 69.44, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 875, "end_offset": 878, "start": 70.0, "end": 70.24, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 886, "end_offset": 890, "start": 70.88, "end": 71.2, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 890, "end_offset": 898, "start": 71.2, "end": 71.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 898, "end_offset": 901, "start": 71.84, "end": 72.08, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 916, "end_offset": 920, "start": 73.28, "end": 73.60000000000001, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 982, "end_offset": 985, "start": 78.56, "end": 78.8, "phoneme": "\u0251\u02d0n|_"}, {"word": "small.", "start_offset": 1009, "end_offset": 1017, "start": 80.72, "end": 81.36, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 1037, "end_offset": 1041, "start": 82.96000000000001, "end": 83.28, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 1045, "end_offset": 1049, "start": 83.60000000000001, "end": 83.92, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 1065, "end_offset": 1069, "start": 85.2, "end": 85.52, "phoneme": "\u00f0\u0259|_"}, {"word": "door,", "start_offset": 1069, "end_offset": 1076, "start": 85.52, "end": 86.08, "phoneme": "d\u0254r,"}, {"word": "and", "start_offset": 1090, "end_offset": 1094, "start": 87.2, "end": 87.52, "phoneme": "\u0259nd"}, {"word": "you'll", "start_offset": 1094, "end_offset": 1100, "start": 87.52, "end": 88.0, "phoneme": "j\u028c\u028al|_"}, {"word": "hear", "start_offset": 1103, "end_offset": 1108, "start": 88.24, "end": 88.64, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 1119, "end_offset": 1122, "start": 89.52, "end": 89.76, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1126, "end_offset": 1130, "start": 90.08, "end": 90.4, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 1130, "end_offset": 1138, "start": 90.4, "end": 91.04, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 1141, "end_offset": 1145, "start": 91.28, "end": 91.60000000000001, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 1157, "end_offset": 1161, "start": 92.56, "end": 92.88, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 1165, "end_offset": 1173, "start": 93.2, "end": 93.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 1173, "end_offset": 1177, "start": 93.84, "end": 94.16, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 1185, "end_offset": 1189, "start": 94.8, "end": 95.12, "phoneme": "go\u028a|_"}, {"word": "and", "start_offset": 1211, "end_offset": 1215, "start": 96.88, "end": 97.2, "phoneme": "\u0259nd"}, {"word": "dawn.", "start_offset": 1223, "end_offset": 1233, "start": 97.84, "end": 98.64, "phoneme": "d\u0254n."}, {"word": "Love", "start_offset": 1345, "end_offset": 1353, "start": 107.60000000000001, "end": 108.24000000000001, "phoneme": "l\u028cv"}, {"word": "can", "start_offset": 1356, "end_offset": 1360, "start": 108.48, "end": 108.8, "phoneme": "k\u00e6n|_"}, {"word": "touch", "start_offset": 1360, "end_offset": 1366, "start": 108.8, "end": 109.28, "phoneme": "t\u028ct\u0283|_"}, {"word": "us", "start_offset": 1369, "end_offset": 1373, "start": 109.52, "end": 109.84, "phoneme": "\u028cs|_"}, {"word": "one", "start_offset": 1376, "end_offset": 1380, "start": 110.08, "end": 110.4, "phoneme": "w\u028cn|_"}, {"word": "time", "start_offset": 1384, "end_offset": 1388, "start": 110.72, "end": 111.04, "phoneme": "ta\u026am|_"}, {"word": "and", "start_offset": 1399, "end_offset": 1402, "start": 111.92, "end": 112.16, "phoneme": "\u0259nd"}, {"word": "last", "start_offset": 1406, "end_offset": 1410, "start": 112.48, "end": 112.8, "phoneme": "l\u00e6st|_"}, {"word": "for", "start_offset": 1416, "end_offset": 1420, "start": 113.28, "end": 113.60000000000001, "phoneme": "f\u0254r"}, {"word": "a", "start_offset": 1431, "end_offset": 1435, "start": 114.48, "end": 114.8, "phoneme": "\u0259|_"}, {"word": "lifetime", "start_offset": 1435, "end_offset": 1458, "start": 114.8, "end": 116.64, "phoneme": "la\u026afta\u026am|_"}, {"word": "and", "start_offset": 1471, "end_offset": 1475, "start": 117.68, "end": 118.0, "phoneme": "\u0259nd"}, {"word": "never", "start_offset": 1479, "end_offset": 1483, "start": 118.32000000000001, "end": 118.64, "phoneme": "n\u025bv\u025d"}, {"word": "let", "start_offset": 1487, "end_offset": 1491, "start": 118.96000000000001, "end": 119.28, "phoneme": "l\u025bt|_"}, {"word": "go", "start_offset": 1495, "end_offset": 1499, "start": 119.60000000000001, "end": 119.92, "phoneme": "go\u028a|_"}, {"word": "till", "start_offset": 1503, "end_offset": 1511, "start": 120.24000000000001, "end": 120.88, "phoneme": "t\u026al|_"}, {"word": "we're", "start_offset": 1521, "end_offset": 1528, "start": 121.68, "end": 122.24000000000001, "phoneme": "w\u025d\u02d0|_"}, {"word": "gone.", "start_offset": 1528, "end_offset": 1536, "start": 122.24000000000001, "end": 122.88, "phoneme": "g\u0254n."}, {"word": "Love", "start_offset": 1587, "end_offset": 1596, "start": 126.96000000000001, "end": 127.68, "phoneme": "l\u028cv"}, {"word": "was", "start_offset": 1599, "end_offset": 1603, "start": 127.92, "end": 128.24, "phoneme": "w\u0251\u02d0z"}, {"word": "when", "start_offset": 1607, "end_offset": 1611, "start": 128.56, "end": 128.88, "phoneme": "w\u025bn|_"}, {"word": "I", "start_offset": 1611, "end_offset": 1615, "start": 128.88, "end": 129.2, "phoneme": "a\u026a|_"}, {"word": "loved", "start_offset": 1615, "end_offset": 1626, "start": 129.2, "end": 130.08, "phoneme": "l\u028cvd"}, {"word": "you", "start_offset": 1626, "end_offset": 1630, "start": 130.08, "end": 130.4, "phoneme": "ju\u02d0|_"}, {"word": "one", "start_offset": 1641, "end_offset": 1644, "start": 131.28, "end": 131.52, "phoneme": "w\u028cn|_"}, {"word": "true", "start_offset": 1648, "end_offset": 1656, "start": 131.84, "end": 132.48, "phoneme": "tru\u02d0|_"}, {"word": "time.", "start_offset": 1656, "end_offset": 1660, "start": 132.48, "end": 132.8, "phoneme": "ta\u026am."}, {"word": "I", "start_offset": 1672, "end_offset": 1675, "start": 133.76, "end": 134.0, "phoneme": "a\u026a|_"}, {"word": "hold", "start_offset": 1679, "end_offset": 1687, "start": 134.32, "end": 134.96, "phoneme": "ho\u028ald"}, {"word": "to", "start_offset": 1691, "end_offset": 1693, "start": 135.28, "end": 135.44, "phoneme": "tu\u02d0|_"}, {"word": "in", "start_offset": 1712, "end_offset": 1716, "start": 136.96, "end": 137.28, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1720, "end_offset": 1724, "start": 137.6, "end": 137.92000000000002, "phoneme": "ma\u026a|_"}, {"word": "life", "start_offset": 1724, "end_offset": 1728, "start": 137.92000000000002, "end": 138.24, "phoneme": "la\u026af|_"}, {"word": "will", "start_offset": 1731, "end_offset": 1733, "start": 138.48, "end": 138.64000000000001, "phoneme": "w\u026al|_"}, {"word": "always", "start_offset": 1743, "end_offset": 1747, "start": 139.44, "end": 139.76, "phoneme": "\u0254lwe\u026az"}, {"word": "go", "start_offset": 1763, "end_offset": 1767, "start": 141.04, "end": 141.36, "phoneme": "go\u028a|_"}, {"word": "near", "start_offset": 1830, "end_offset": 1836, "start": 146.4, "end": 146.88, "phoneme": "n\u026ar"}, {"word": "far", "start_offset": 1859, "end_offset": 1867, "start": 148.72, "end": 149.36, "phoneme": "f\u0251\u02d0r"}, {"word": "wherever", "start_offset": 1884, "end_offset": 1896, "start": 150.72, "end": 151.68, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 1914, "end_offset": 1918, "start": 153.12, "end": 153.44, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 1918, "end_offset": 1922, "start": 153.44, "end": 153.76, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 1940, "end_offset": 1943, "start": 155.20000000000002, "end": 155.44, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 1947, "end_offset": 1959, "start": 155.76, "end": 156.72, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 1966, "end_offset": 1970, "start": 157.28, "end": 157.6, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 1974, "end_offset": 1977, "start": 157.92000000000002, "end": 158.16, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 1981, "end_offset": 1986, "start": 158.48, "end": 158.88, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 1990, "end_offset": 1993, "start": 159.20000000000002, "end": 159.44, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 2008, "end_offset": 2011, "start": 160.64000000000001, "end": 160.88, "phoneme": "go\u028a|_"}, {"word": "small.", "start_offset": 2099, "end_offset": 2111, "start": 167.92000000000002, "end": 168.88, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 2127, "end_offset": 2131, "start": 170.16, "end": 170.48, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 2136, "end_offset": 2140, "start": 170.88, "end": 171.20000000000002, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 2156, "end_offset": 2160, "start": 172.48, "end": 172.8, "phoneme": "\u00f0\u0259|_"}, {"word": "door", "start_offset": 2160, "end_offset": 2167, "start": 172.8, "end": 173.36, "phoneme": "d\u0254r"}, {"word": "and", "start_offset": 2181, "end_offset": 2185, "start": 174.48, "end": 174.8, "phoneme": "\u0259nd"}, {"word": "you", "start_offset": 2185, "end_offset": 2187, "start": 174.8, "end": 174.96, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2195, "end_offset": 2203, "start": 175.6, "end": 176.24, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 2209, "end_offset": 2213, "start": 176.72, "end": 177.04, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2217, "end_offset": 2221, "start": 177.36, "end": 177.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2221, "end_offset": 2230, "start": 177.68, "end": 178.4, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2232, "end_offset": 2236, "start": 178.56, "end": 178.88, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2248, "end_offset": 2251, "start": 179.84, "end": 180.08, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2255, "end_offset": 2263, "start": 180.4, "end": 181.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2263, "end_offset": 2266, "start": 181.04, "end": 181.28, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2278, "end_offset": 2282, "start": 182.24, "end": 182.56, "phoneme": "go\u028a|_"}, {"word": "on.", "start_offset": 2286, "end_offset": 2289, "start": 182.88, "end": 183.12, "phoneme": "\u0251\u02d0n."}, {"word": "You", "start_offset": 2557, "end_offset": 2559, "start": 204.56, "end": 204.72, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2587, "end_offset": 2594, "start": 206.96, "end": 207.52, "phoneme": "hi\u02d0r"}, {"word": "there's", "start_offset": 2610, "end_offset": 2620, "start": 208.8, "end": 209.6, "phoneme": "\u00f0\u025brz"}, {"word": "nothing", "start_offset": 2620, "end_offset": 2632, "start": 209.6, "end": 210.56, "phoneme": "n\u028c\u03b8\u026a\u014b|_"}, {"word": "I", "start_offset": 2640, "end_offset": 2644, "start": 211.20000000000002, "end": 211.52, "phoneme": "a\u026a|_"}, {"word": "fear,", "start_offset": 2644, "end_offset": 2651, "start": 211.52, "end": 212.08, "phoneme": "f\u026ar,"}, {"word": "and", "start_offset": 2666, "end_offset": 2669, "start": 213.28, "end": 213.52, "phoneme": "\u0259nd"}, {"word": "I", "start_offset": 2673, "end_offset": 2677, "start": 213.84, "end": 214.16, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 2677, "end_offset": 2681, "start": 214.16, "end": 214.48000000000002, "phoneme": "no\u028a|_"}, {"word": "that", "start_offset": 2693, "end_offset": 2697, "start": 215.44, "end": 215.76, "phoneme": "\u00f0\u00e6t|_"}, {"word": "my", "start_offset": 2701, "end_offset": 2705, "start": 216.08, "end": 216.4, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2705, "end_offset": 2713, "start": 216.4, "end": 217.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2717, "end_offset": 2721, "start": 217.36, "end": 217.68, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2733, "end_offset": 2736, "start": 218.64000000000001, "end": 218.88, "phoneme": "go\u028a|_"}, {"word": "forever", "start_offset": 2852, "end_offset": 2863, "start": 228.16, "end": 229.04, "phoneme": "f\u025d\u025bv\u025d"}, {"word": "this", "start_offset": 2881, "end_offset": 2883, "start": 230.48000000000002, "end": 230.64000000000001, "phoneme": "\u00f0\u026as|_"}, {"word": "way.", "start_offset": 2888, "end_offset": 2892, "start": 231.04, "end": 231.36, "phoneme": "we\u026a."}, {"word": "You", "start_offset": 2908, "end_offset": 2911, "start": 232.64000000000001, "end": 232.88, "phoneme": "ju\u02d0|_"}, {"word": "are", "start_offset": 2914, "end_offset": 2918, "start": 233.12, "end": 233.44, "phoneme": "\u0251\u02d0r"}, {"word": "safe", "start_offset": 2928, "end_offset": 2935, "start": 234.24, "end": 234.8, "phoneme": "se\u026af|_"}, {"word": "in", "start_offset": 2938, "end_offset": 2942, "start": 235.04, "end": 235.36, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2942, "end_offset": 2946, "start": 235.36, "end": 235.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2950, "end_offset": 2957, "start": 236.0, "end": 236.56, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2959, "end_offset": 2963, "start": 236.72, "end": 237.04, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2975, "end_offset": 2978, "start": 238.0, "end": 238.24, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2982, "end_offset": 2990, "start": 238.56, "end": 239.20000000000002, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2990, "end_offset": 2994, "start": 239.20000000000002, "end": 239.52, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 3002, "end_offset": 3005, "start": 240.16, "end": 240.4, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 3009, "end_offset": 3012, "start": 240.72, "end": 240.96, "phoneme": "\u0251\u02d0n|_"}, {"word": "there.", "start_offset": 3028, "end_offset": 3032, "start": 242.24, "end": 242.56, "phoneme": "\u00f0\u025br."}]
model.py CHANGED
@@ -1,511 +1,180 @@
1
- from transformers import T5EncoderModel,T5TokenizerFast
2
- import torch
3
- from diffusers import FluxTransformer2DModel
4
- from torch import nn
5
 
6
- from typing import List
7
- from diffusers import FlowMatchEulerDiscreteScheduler
8
- from diffusers.training_utils import compute_density_for_timestep_sampling
9
- import copy
10
- import torch.nn.functional as F
11
  import numpy as np
12
- from tqdm import tqdm
13
-
14
- from typing import Optional,Union,List
15
- from datasets import load_dataset, Audio
16
- from math import pi
17
- import inspect
18
- import yaml
19
-
20
-
21
-
22
- class StableAudioPositionalEmbedding(nn.Module):
23
- """Used for continuous time
24
-
25
- Adapted from stable audio open.
26
-
27
- """
28
-
29
- def __init__(self, dim: int):
30
- super().__init__()
31
- assert (dim % 2) == 0
32
- half_dim = dim // 2
33
- self.weights = nn.Parameter(torch.randn(half_dim))
34
-
35
- def forward(self, times: torch.Tensor) -> torch.Tensor:
36
- times = times[..., None]
37
- freqs = times * self.weights[None] * 2 * pi
38
- fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
39
- fouriered = torch.cat((times, fouriered), dim=-1)
40
- return fouriered
41
-
42
- class DurationEmbedder(nn.Module):
43
- """
44
- A simple linear projection model to map numbers to a latent space.
45
-
46
- Code is adapted from
47
- https://github.com/Stability-AI/stable-audio-tools
48
-
49
- Args:
50
- number_embedding_dim (`int`):
51
- Dimensionality of the number embeddings.
52
- min_value (`int`):
53
- The minimum value of the seconds number conditioning modules.
54
- max_value (`int`):
55
- The maximum value of the seconds number conditioning modules
56
- internal_dim (`int`):
57
- Dimensionality of the intermediate number hidden states.
58
- """
59
-
60
- def __init__(
61
- self,
62
- number_embedding_dim,
63
- min_value,
64
- max_value,
65
- internal_dim: Optional[int] = 256,
66
- ):
67
- super().__init__()
68
- self.time_positional_embedding = nn.Sequential(
69
- StableAudioPositionalEmbedding(internal_dim),
70
- nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
71
- )
72
-
73
- self.number_embedding_dim = number_embedding_dim
74
- self.min_value = min_value
75
- self.max_value = max_value
76
- self.dtype = torch.float32
77
-
78
- def forward(
79
- self,
80
- floats: torch.Tensor,
81
- ):
82
- floats = floats.clamp(self.min_value, self.max_value)
83
-
84
- normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
85
-
86
- # Cast floats to same type as embedder
87
- embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
88
- normalized_floats = normalized_floats.to(embedder_dtype)
89
-
90
- embedding = self.time_positional_embedding(normalized_floats)
91
- float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
92
-
93
- return float_embeds
94
-
95
-
96
- def retrieve_timesteps(
97
- scheduler,
98
- num_inference_steps: Optional[int] = None,
99
- device: Optional[Union[str, torch.device]] = None,
100
- timesteps: Optional[List[int]] = None,
101
- sigmas: Optional[List[float]] = None,
102
- **kwargs,
103
- ):
104
-
105
- if timesteps is not None and sigmas is not None:
106
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
107
- if timesteps is not None:
108
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
109
- if not accepts_timesteps:
110
- raise ValueError(
111
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
112
- f" timestep schedules. Please check whether you are using the correct scheduler."
113
- )
114
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
115
- timesteps = scheduler.timesteps
116
- num_inference_steps = len(timesteps)
117
- elif sigmas is not None:
118
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
- if not accept_sigmas:
120
- raise ValueError(
121
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
- f" sigmas schedules. Please check whether you are using the correct scheduler."
123
- )
124
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
125
- timesteps = scheduler.timesteps
126
- num_inference_steps = len(timesteps)
127
- else:
128
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
129
- timesteps = scheduler.timesteps
130
- return timesteps, num_inference_steps
131
-
132
-
133
-
134
-
135
-
136
-
137
-
138
- class TangoFlux(nn.Module):
139
-
140
-
141
- def __init__(self,config,initialize_reference_model=False):
142
-
143
- super().__init__()
144
-
145
-
146
-
147
- self.num_layers = config.get('num_layers', 6)
148
- self.num_single_layers = config.get('num_single_layers', 18)
149
- self.in_channels = config.get('in_channels', 64)
150
- self.attention_head_dim = config.get('attention_head_dim', 128)
151
- self.joint_attention_dim = config.get('joint_attention_dim', 1024)
152
- self.num_attention_heads = config.get('num_attention_heads', 8)
153
- self.audio_seq_len = config.get('audio_seq_len', 645)
154
- self.max_duration = config.get('max_duration', 30)
155
- self.uncondition = config.get('uncondition', False)
156
- self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large")
157
-
158
- self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
159
- self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
160
- self.max_text_seq_len = 64
161
- self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
162
- self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name)
163
- self.text_embedding_dim = self.text_encoder.config.d_model
164
-
165
-
166
- self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU())
167
- self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration)
168
-
169
- self.transformer = FluxTransformer2DModel(
170
- in_channels=self.in_channels,
171
- num_layers=self.num_layers,
172
- num_single_layers=self.num_single_layers,
173
- attention_head_dim=self.attention_head_dim,
174
- num_attention_heads=self.num_attention_heads,
175
- joint_attention_dim=self.joint_attention_dim,
176
- pooled_projection_dim=self.text_embedding_dim,
177
- guidance_embeds=False)
178
-
179
- self.beta_dpo = 2000 ## this is used for dpo training
180
-
181
-
182
-
183
-
184
-
185
-
186
- def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32):
187
- device = self.text_encoder.device
188
- sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
189
-
190
-
191
- schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
192
- timesteps = timesteps.to(device)
193
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
194
-
195
- sigma = sigmas[step_indices].flatten()
196
- while len(sigma.shape) < n_dim:
197
- sigma = sigma.unsqueeze(-1)
198
- return sigma
199
-
200
-
201
-
202
- def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
203
- device = self.text_encoder.device
204
- batch = self.tokenizer(
205
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
206
- )
207
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
208
-
209
- with torch.no_grad():
210
- prompt_embeds = self.text_encoder(
211
- input_ids=input_ids, attention_mask=attention_mask
212
- )[0]
213
-
214
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
215
- attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
216
-
217
- # get unconditional embeddings for classifier free guidance
218
- uncond_tokens = [""]
219
-
220
- max_length = prompt_embeds.shape[1]
221
- uncond_batch = self.tokenizer(
222
- uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt",
223
- )
224
- uncond_input_ids = uncond_batch.input_ids.to(device)
225
- uncond_attention_mask = uncond_batch.attention_mask.to(device)
226
-
227
- with torch.no_grad():
228
- negative_prompt_embeds = self.text_encoder(
229
- input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
230
- )[0]
231
-
232
- negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
233
- uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
234
-
235
- # For classifier free guidance, we need to do two forward passes.
236
- # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
237
-
238
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
239
- prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
240
- boolean_prompt_mask = (prompt_mask == 1).to(device)
241
-
242
- return prompt_embeds, boolean_prompt_mask
243
-
244
- @torch.no_grad()
245
- def encode_text(self, prompt):
246
- device = self.text_encoder.device
247
- batch = self.tokenizer(
248
- prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt")
249
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
250
-
251
-
252
-
253
- encoder_hidden_states = self.text_encoder(
254
- input_ids=input_ids, attention_mask=attention_mask)[0]
255
-
256
- boolean_encoder_mask = (attention_mask == 1).to(device)
257
-
258
- return encoder_hidden_states, boolean_encoder_mask
259
-
260
-
261
- def encode_duration(self,duration):
262
- return self.duration_emebdder(duration)
263
-
264
-
265
-
266
- @torch.no_grad()
267
- def inference_flow(self, prompt,
268
- num_inference_steps=50,
269
- timesteps=None,
270
- guidance_scale=3,
271
- duration=10,
272
- disable_progress=False,
273
- num_samples_per_prompt=1):
274
-
275
- '''Only tested for single inference. Haven't test for batch inference'''
276
-
277
- bsz = num_samples_per_prompt
278
- device = self.transformer.device
279
- scheduler = self.noise_scheduler
280
-
281
- if not isinstance(prompt,list):
282
- prompt = [prompt]
283
- if not isinstance(duration,torch.Tensor):
284
- duration = torch.tensor([duration],device=device)
285
- classifier_free_guidance = guidance_scale > 1.0
286
- duration_hidden_states = self.encode_duration(duration)
287
- if classifier_free_guidance:
288
- bsz = 2 * num_samples_per_prompt
289
-
290
- encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt)
291
- duration_hidden_states = duration_hidden_states.repeat(bsz,1,1)
292
-
293
-
294
  else:
 
 
295
 
296
- encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt,num_samples_per_prompt=num_samples_per_prompt)
297
-
298
- mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
299
- masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
300
-
301
- pooled = torch.nanmean(masked_data, dim=1)
302
- pooled_projection = self.fc(pooled)
303
-
304
- encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
305
-
306
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
307
- timesteps, num_inference_steps = retrieve_timesteps(
308
- scheduler,
309
- num_inference_steps,
310
- device,
311
- timesteps,
312
- sigmas
313
- )
314
-
315
- latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64)
316
- weight_dtype = latents.dtype
317
-
318
- progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
319
-
320
- txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
321
- audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
322
 
 
 
 
323
 
324
- timesteps = timesteps.to(device)
325
- latents = latents.to(device)
326
- encoder_hidden_states = encoder_hidden_states.to(device)
327
 
 
 
 
 
328
 
329
- for i, t in enumerate(timesteps):
330
-
331
- latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
332
-
333
 
 
 
 
334
 
335
- noise_pred = self.transformer(
336
- hidden_states=latents_input,
337
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
338
- timestep=torch.tensor([t/1000],device=device),
339
- guidance = None,
340
- pooled_projections=pooled_projection,
341
- encoder_hidden_states=encoder_hidden_states,
342
- txt_ids=txt_ids,
343
- img_ids=audio_ids,
344
- return_dict=False,
345
- )[0]
346
 
347
- if classifier_free_guidance:
348
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
349
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
350
 
351
-
352
- latents = scheduler.step(noise_pred, t, latents).prev_sample
353
-
354
-
355
- return latents
356
-
357
- def forward(self,
358
- latents,
359
- prompt,
360
- duration=torch.tensor([10]),
361
- sft=True
362
- ):
363
-
364
 
365
- device = latents.device
366
- audio_seq_length = self.audio_seq_len
367
- bsz = latents.shape[0]
368
-
369
-
370
-
371
- encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
372
- duration_hidden_states = self.encode_duration(duration)
373
-
374
 
375
- mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
376
- masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
377
- pooled = torch.nanmean(masked_data, dim=1)
378
- pooled_projection = self.fc(pooled)
379
 
380
- ## Add duration hidden states to encoder hidden states
381
- encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
382
-
383
- txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
384
- audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
385
 
386
- if sft:
387
-
388
- if self.uncondition:
389
- mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
390
- if len(mask_indices) > 0:
391
- encoder_hidden_states[mask_indices] = 0
392
-
393
-
394
- noise = torch.randn_like(latents)
395
-
396
-
397
- u = compute_density_for_timestep_sampling(
398
- weighting_scheme='logit_normal',
399
- batch_size=bsz,
400
- logit_mean=0,
401
- logit_std=1,
402
- mode_scale=None,
403
- )
404
-
405
-
406
- indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
407
- timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
408
- sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
409
-
410
- noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
411
-
412
-
413
-
414
- model_pred = self.transformer(
415
- hidden_states=noisy_model_input,
416
- encoder_hidden_states=encoder_hidden_states,
417
- pooled_projections=pooled_projection,
418
- img_ids=audio_ids,
419
- txt_ids=txt_ids,
420
- guidance=None,
421
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
422
- timestep=timesteps/1000,
423
- return_dict=False)[0]
424
-
425
-
426
-
427
- target = noise - latents
428
- loss = torch.mean(
429
- ( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
430
- 1,
431
- )
432
- loss = loss.mean()
433
- raw_model_loss, raw_ref_loss,implicit_acc,epsilon_diff = 0,0,0,0 ## default this to 0 if doing sft
434
-
435
- else:
436
- encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
437
- pooled_projection = pooled_projection.repeat(2,1)
438
- noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected
439
- u = compute_density_for_timestep_sampling(
440
- weighting_scheme='logit_normal',
441
- batch_size=bsz//2,
442
- logit_mean=0,
443
- logit_std=1,
444
- mode_scale=None,
445
- )
446
-
447
-
448
- indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
449
- timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
450
- timesteps = timesteps.repeat(2)
451
- sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
452
-
453
- noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
454
-
455
- model_pred = self.transformer(
456
- hidden_states=noisy_model_input,
457
- encoder_hidden_states=encoder_hidden_states,
458
- pooled_projections=pooled_projection,
459
- img_ids=audio_ids,
460
- txt_ids=txt_ids,
461
- guidance=None,
462
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
463
- timestep=timesteps/1000,
464
- return_dict=False)[0]
465
- target = noise - latents
466
-
467
- model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
468
- model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
469
- model_losses_w, model_losses_l = model_losses.chunk(2)
470
- model_diff = model_losses_w - model_losses_l
471
- raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
472
-
473
-
474
- with torch.no_grad():
475
- ref_preds = self.ref_transformer(
476
- hidden_states=noisy_model_input,
477
- encoder_hidden_states=encoder_hidden_states,
478
- pooled_projections=pooled_projection,
479
- img_ids=audio_ids,
480
- txt_ids=txt_ids,
481
- guidance=None,
482
- timestep=timesteps/1000,
483
- return_dict=False)[0]
484
-
485
-
486
- ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
487
- ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
488
-
489
- ref_losses_w, ref_losses_l = ref_loss.chunk(2)
490
- ref_diff = ref_losses_w - ref_losses_l
491
- raw_ref_loss = ref_loss.mean()
492
-
493
-
494
-
495
-
496
-
497
- epsilon_diff = torch.max(torch.zeros_like(model_losses_w),
498
- ref_losses_w-model_losses_w).mean()
499
-
500
-
501
-
502
- scale_term = -0.5 * self.beta_dpo
503
- inside_term = scale_term * (model_diff - ref_diff)
504
- implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0)
505
- loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
506
 
 
 
 
 
 
 
 
 
507
 
508
- return loss, raw_model_loss, raw_ref_loss, implicit_acc,epsilon_diff
509
-
510
-
511
-
 
 
 
 
 
1
 
2
+ import torch
3
+ import torchaudio
4
+ from omegaconf import OmegaConf
5
+ from huggingface_hub import snapshot_download
 
6
  import numpy as np
7
+ import json
8
+ import os
9
+ from safetensors.torch import load_file
10
+
11
+ # Imports from the jamify library
12
+ from jam.model.cfm import CFM
13
+ from jam.model.dit import DiT
14
+ from jam.model.vae import StableAudioOpenVAE
15
+ from jam.dataset import DiffusionWebDataset, enhance_webdataset_config
16
+ from muq import MuQMuLan
17
+
18
+ # Helper functions adapted from jamify/src/jam/infer.py
19
+ def get_negative_style_prompt(device, file_path):
20
+ vocal_style = np.load(file_path)
21
+ vocal_style = torch.from_numpy(vocal_style).to(device)
22
+ return vocal_style.half()
23
+
24
+ def normalize_audio(audio):
25
+ audio = audio - audio.mean(-1, keepdim=True)
26
+ audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
27
+ return audio
28
+
29
+ class Jamify:
30
+ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
31
+ self.device = torch.device(device)
32
+
33
+ # --- FIX: Point to the local jamify repository for config and public files ---
34
+ #jamify_repo_path = "/Users/cy/Desktop/JAM/jamify"
35
+
36
+ print("Downloading main model checkpoint...")
37
+ model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
38
+ self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
39
+
40
+ # Use local config and data files
41
+ config_path = os.path.join(model_repo_path, "jam_infer.yaml")
42
+ self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy")
43
+ tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt")
44
+ silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt")
45
+ print("Loading configuration...")
46
+ self.config = OmegaConf.load(config_path)
47
+ self.config.data.train_dataset.silence_latent_path = silence_latent_path
48
+
49
+ # --- FIX: Override the relative paths in the config with absolute paths ---
50
+ self.config.data.train_dataset.tokenizer_path = tokenizer_path
51
+ self.config.evaluation.dataset.tokenizer_path = tokenizer_path
52
+ self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path
53
+
54
+ print("Loading VAE model...")
55
+ self.vae = StableAudioOpenVAE().to(self.device).eval()
56
+
57
+ print("Loading CFM model...")
58
+ self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path)
59
+
60
+ print("Loading MuQ style model...")
61
+ self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
62
+
63
+ print("Setting up dataset processor...")
64
+ dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
65
+ enhance_webdataset_config(dataset_cfg)
66
+ dataset_cfg.multiple_styles = False
67
+ self.dataset_processor = DiffusionWebDataset(**dataset_cfg)
68
+
69
+ print("Jamify model loaded successfully.")
70
+
71
+ def _load_cfm_model(self, model_config, checkpoint_path):
72
+ dit_config = model_config["dit"].copy()
73
+ if "text_num_embeds" not in dit_config:
74
+ dit_config["text_num_embeds"] = 256
75
+
76
+ model = CFM(
77
+ transformer=DiT(**dit_config),
78
+ **model_config["cfm"]
79
+ ).to(self.device)
80
+
81
+ state_dict = load_file(checkpoint_path)
82
+ model.load_state_dict(state_dict, strict=False)
83
+ return model.eval()
84
+
85
+ def _generate_style_embedding_from_audio(self, audio_path):
86
+ waveform, sample_rate = torchaudio.load(audio_path)
87
+ if sample_rate != 24000:
88
+ resampler = torchaudio.transforms.Resample(sample_rate, 24000)
89
+ waveform = resampler(waveform)
90
+ if waveform.shape[0] > 1:
91
+ waveform = waveform.mean(dim=0, keepdim=True)
92
+
93
+ waveform = waveform.squeeze(0).to(self.device)
94
+
95
+ with torch.inference_mode():
96
+ style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30])
97
+ return style_embedding[0]
98
+
99
+ def _generate_style_embedding_from_prompt(self, prompt):
100
+ with torch.inference_mode():
101
+ style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
102
+ return style_embedding
103
+
104
+ def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50):
105
+ print("Starting prediction...")
106
+
107
+ if reference_audio_path:
108
+ print(f"Generating style from audio: {reference_audio_path}")
109
+ style_embedding = self._generate_style_embedding_from_audio(reference_audio_path)
110
+ elif style_prompt:
111
+ print(f"Generating style from prompt: '{style_prompt}'")
112
+ style_embedding = self._generate_style_embedding_from_prompt(style_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  else:
114
+ print("No style provided, using zero embedding.")
115
+ style_embedding = torch.zeros(512, device=self.device)
116
 
117
+ print(f"Loading lyrics from: {lyrics_json_path}")
118
+ with open(lyrics_json_path, 'r') as f:
119
+ lrc_data = json.load(f)
120
+ if 'word' not in lrc_data:
121
+ lrc_data = {'word': lrc_data}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ frame_rate = 21.5
124
+ num_frames = int(duration_sec * frame_rate)
125
+ fake_latent = torch.randn(128, num_frames)
126
 
127
+ sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data)
 
 
128
 
129
+ print("Processing sample...")
130
+ processed_sample = self.dataset_processor.process_sample_safely(sample_tuple)
131
+ if processed_sample is None:
132
+ raise ValueError("Failed to process the provided lyrics and style.")
133
 
134
+ batch = self.dataset_processor.custom_collate_fn([processed_sample])
 
 
 
135
 
136
+ for key, value in batch.items():
137
+ if isinstance(value, torch.Tensor):
138
+ batch[key] = value.to(self.device)
139
 
140
+ print("Generating audio latent...")
141
+ with torch.inference_mode():
142
+ batch_size = 1
143
+ text = batch["lrc"]
144
+ style_prompt_tensor = batch["prompt"]
145
+ start_time = batch["start_time"]
146
+ duration_abs = batch["duration_abs"]
147
+ duration_rel = batch["duration_rel"]
 
 
 
148
 
149
+ cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device)
150
+ pred_frames = [(0, self.cfm_model.max_frames)]
 
151
 
152
+ negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path)
153
+ negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ sample_kwargs = self.config.evaluation.sample_kwargs
156
+ sample_kwargs.steps = steps
157
+ latents, _ = self.cfm_model.sample(
158
+ cond=cond, text=text, style_prompt=style_prompt_tensor,
159
+ duration_abs=duration_abs, duration_rel=duration_rel,
160
+ negative_style_prompt=negative_style_prompt, start_time=start_time,
161
+ latent_pred_segments=pred_frames, **sample_kwargs)
 
 
162
 
163
+ latent = latents[0][0]
 
 
 
164
 
165
+ print("Decoding latent to audio...")
166
+ latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
167
+ pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
 
 
168
 
169
+ pred_audio = normalize_audio(pred_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ sample_rate = 44100
172
+ trim_samples = int(duration_sec * sample_rate)
173
+ if pred_audio.shape[1] > trim_samples:
174
+ pred_audio = pred_audio[:, :trim_samples]
175
+
176
+ output_path = "generated_song.mp3"
177
+ print(f"Saving audio to {output_path}")
178
+ torchaudio.save(output_path, pred_audio, sample_rate, format="mp3")
179
 
180
+ return output_path
 
 
 
requirements.txt CHANGED
@@ -1,11 +1,35 @@
1
- torch==2.4.0
2
- torchaudio===2.4.0
3
- torchlibrosa==0.1.0
4
- torchvision==0.19.0
5
- transformers==4.44.0
6
- diffusers==0.32.0
7
- accelerate==0.34.2
8
- datasets==2.21.0
 
 
 
 
 
 
9
  librosa
10
- tqdm
11
- wavio==0.0.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ diffusers
4
+ accelerate
5
+ safetensors
6
+ wandb
7
+ gpustat
8
+ soundfile
9
+ muq
10
+ pyloudnorm
11
+ mutagen
12
+ torchdiffeq
13
+ x_transformers
14
+ ema_pytorch
15
  librosa
16
+ jiwer
17
+ demucs
18
+ audiobox-aesthetics
19
+
20
+ # WebDataset
21
+ webdataset
22
+ webdatasetng
23
+ wids
24
+ omegaconf
25
+
26
+ # DeepPhonemizer
27
+ unidecode
28
+ inflect
29
+
30
+ # duration prediction
31
+ openai
32
+ pyphen
33
+ syllables
34
+ git+https://github.com/declare-lab/jamify.git
35
+ git+https://github.com/xhhhhang/DeepPhonemizer.git