Spaces:
Running
on
Zero
Running
on
Zero
CY
commited on
Commit
·
7d35d1e
1
Parent(s):
a073f0a
Added jam space
Browse files- TangoFlux.py +0 -58
- app.py +64 -151
- gt0.json +1 -0
- model.py +162 -493
- requirements.txt +34 -10
TangoFlux.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
from diffusers import AutoencoderOobleck
|
2 |
-
import torch
|
3 |
-
from transformers import T5EncoderModel,T5TokenizerFast
|
4 |
-
from diffusers import FluxTransformer2DModel
|
5 |
-
from torch import nn
|
6 |
-
from typing import List
|
7 |
-
from diffusers import FlowMatchEulerDiscreteScheduler
|
8 |
-
from diffusers.training_utils import compute_density_for_timestep_sampling
|
9 |
-
import copy
|
10 |
-
import torch.nn.functional as F
|
11 |
-
import numpy as np
|
12 |
-
from model import TangoFlux
|
13 |
-
from huggingface_hub import snapshot_download
|
14 |
-
from tqdm import tqdm
|
15 |
-
from typing import Optional,Union,List
|
16 |
-
from datasets import load_dataset, Audio
|
17 |
-
from math import pi
|
18 |
-
import json
|
19 |
-
import inspect
|
20 |
-
import yaml
|
21 |
-
from safetensors.torch import load_file
|
22 |
-
import os
|
23 |
-
print(os.environ['HOME'])
|
24 |
-
|
25 |
-
class TangoFluxInference:
|
26 |
-
|
27 |
-
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
|
28 |
-
|
29 |
-
|
30 |
-
self.vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae',token=os.environ['HF_TOKEN'])
|
31 |
-
|
32 |
-
paths = snapshot_download(repo_id=name,token=os.environ['HF_TOKEN'])
|
33 |
-
weights = load_file("{}/model_1.safetensors".format(paths))
|
34 |
-
|
35 |
-
with open('{}/config.json'.format(paths),'r') as f:
|
36 |
-
config = json.load(f)
|
37 |
-
self.model = TangoFlux(config)
|
38 |
-
self.model.load_state_dict(weights,strict=False)
|
39 |
-
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
|
40 |
-
self.vae.to(device)
|
41 |
-
self.model.to(device)
|
42 |
-
|
43 |
-
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
|
44 |
-
|
45 |
-
with torch.no_grad():
|
46 |
-
latents = self.model.inference_flow(prompt,
|
47 |
-
duration=duration,
|
48 |
-
num_inference_steps=steps,
|
49 |
-
guidance_scale=guidance_scale)
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
|
54 |
-
return wave
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,155 +1,68 @@
|
|
1 |
-
import spaces
|
2 |
import gradio as gr
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
@spaces.GPU(duration=15)
|
27 |
-
def gradio_generate(prompt, steps, guidance,duration=10):
|
28 |
-
|
29 |
-
output = tangoflux.generate(prompt,steps=steps,guidance_scale=guidance,duration=duration)
|
30 |
-
#output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
31 |
|
32 |
-
|
33 |
-
#wavio.write(output_filename, output_wave, rate=44100, sampwidth=2)
|
34 |
-
filename = 'temp.wav'
|
35 |
-
#print(f"Saving audio to file: {unique_filename}")
|
36 |
-
|
37 |
-
# Save to file
|
38 |
-
torchaudio.save(filename, output, 44100)
|
39 |
-
|
40 |
-
|
41 |
-
# Return the path to the generated audio file
|
42 |
-
return filename
|
43 |
-
|
44 |
-
#if (output_format == "mp3"):
|
45 |
-
# AudioSegment.from_wav("temp.wav").export("temp.mp3", format = "mp3")
|
46 |
-
# output_filename = "temp.mp3"
|
47 |
-
|
48 |
-
#return output_filename
|
49 |
-
|
50 |
-
description_text = """
|
51 |
-
Generate high quality and faithful audio in just a few seconds using <b>TangoFlux</b> by providing a text prompt. <b>TangoFlux</b> was trained from scratch and underwent alignment to follow human instructions using a new method called <b>Claped-Ranked Preference Optimization (CRPO)</b>.
|
52 |
-
<div style="display: flex; gap: 10px; align-items: center;">
|
53 |
-
<a href="https://arxiv.org/abs/2412.21037">
|
54 |
-
<img src="https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf" alt="arXiv">
|
55 |
-
</a>
|
56 |
-
<a href="https://huggingface.co/declare-lab/TangoFlux">
|
57 |
-
<img src="https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
58 |
-
</a>
|
59 |
-
<a href="https://tangoflux.github.io/">
|
60 |
-
<img src="https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat" alt="Static Badge">
|
61 |
-
</a>
|
62 |
-
<a href="https://huggingface.co/spaces/declare-lab/TangoFlux">
|
63 |
-
<img src="https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
64 |
-
</a>
|
65 |
-
<a href="https://huggingface.co/datasets/declare-lab/CRPO">
|
66 |
-
<img src="https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
67 |
-
</a>
|
68 |
-
<a href="https://github.com/declare-lab/TangoFlux">
|
69 |
-
<img src="https://img.shields.io/badge/Github-brown?logo=github&link=https%3A%2F%2Fgithub.com%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
70 |
-
</a>
|
71 |
-
</div>
|
72 |
-
"""
|
73 |
-
# Gradio input and output components
|
74 |
-
input_text = gr.Textbox(lines=2, label="Prompt")
|
75 |
-
#output_format = gr.Radio(label = "Output format", info = "The file you can dowload", choices = "wav"], value = "wav")
|
76 |
-
output_audio = gr.Audio(label="Generated Audio", type="filepath")
|
77 |
-
denoising_steps = gr.Slider(minimum=10, maximum=100, value=25, step=5, label="Steps", interactive=True)
|
78 |
-
guidance_scale = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True)
|
79 |
-
duration_scale = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True)
|
80 |
-
|
81 |
|
82 |
# Gradio interface
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
[
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
[
|
114 |
-
["
|
115 |
-
["
|
116 |
-
|
117 |
-
[
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
["Massive stadium crowd cheering as thunder crashes and lightning strikes"],
|
125 |
-
["Heavy helicopter blades chopping through air with engine and wind noise"],
|
126 |
-
["Dog barking excitedly and man shouting as race car engine roars past"],
|
127 |
-
["Quiet speech and then and airplane flying away"],
|
128 |
-
["A bicycle peddling on dirt and gravel followed by a man speaking then laughing"],
|
129 |
-
["Ducks quack and water splashes with some animal screeching in the background"],
|
130 |
-
["Describe the sound of the ocean"],
|
131 |
-
["A woman and a baby are having a conversation"],
|
132 |
-
["A man speaks followed by a popping noise and laughter"],
|
133 |
-
["A cup is filled from a faucet"],
|
134 |
-
["An audience cheering and clapping"],
|
135 |
-
["Rolling thunder with lightning strikes"],
|
136 |
-
["A dog barking and a cat mewing and a racing car passes by"],
|
137 |
-
["Gentle water stream, birds chirping and sudden gun shot"],
|
138 |
-
["A dog barking"],
|
139 |
-
["A cat meowing"],
|
140 |
-
["Wooden table tapping sound while water pouring"],
|
141 |
-
["Applause from a crowd with distant clicking and a man speaking over a loudspeaker"],
|
142 |
-
["two gunshots followed by birds flying away while chirping"],
|
143 |
-
["Whistling with birds chirping"],
|
144 |
-
["A person snoring"],
|
145 |
-
["Motor vehicles are driving with loud engines and a person whistles"],
|
146 |
-
["People cheering in a stadium while thunder and lightning strikes"],
|
147 |
-
["A helicopter is in flight"],
|
148 |
-
["A dog barking and a man talking and a racing car passes by"],
|
149 |
-
],
|
150 |
-
cache_examples="lazy", # Turn on to cache.
|
151 |
-
)
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
gr_interface.queue(15).launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
from model import Jamify
|
4 |
+
|
5 |
+
# Initialize the Jamify model once
|
6 |
+
print("Initializing Jamify model...")
|
7 |
+
jamify_model = Jamify()
|
8 |
+
print("Jamify model ready.")
|
9 |
+
|
10 |
+
def generate_song(reference_audio, lyrics_file, style_prompt, duration):
|
11 |
+
# We need to save the uploaded files to temporary paths to pass to the model
|
12 |
+
ref_audio_path = reference_audio.name if reference_audio else None
|
13 |
+
lyrics_path = lyrics_file.name
|
14 |
+
|
15 |
+
# The model expects paths, so we write the prompt to a temp file if needed
|
16 |
+
# (This part of the model could be improved to accept the string directly)
|
17 |
+
|
18 |
+
output_path = jamify_model.predict(
|
19 |
+
reference_audio_path=ref_audio_path,
|
20 |
+
lyrics_json_path=lyrics_path,
|
21 |
+
style_prompt=style_prompt,
|
22 |
+
duration_sec=duration
|
23 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Gradio interface
|
28 |
+
with gr.Blocks() as demo:
|
29 |
+
gr.Markdown("# Jamify: Music Generation from Lyrics and Style")
|
30 |
+
gr.Markdown("Provide your lyrics, a style reference (either an audio file or a text prompt), and a desired duration to generate a song.")
|
31 |
+
|
32 |
+
with gr.Row():
|
33 |
+
with gr.Column():
|
34 |
+
gr.Markdown("### Inputs")
|
35 |
+
lyrics_file = gr.File(label="Lyrics File (.json)", type="filepath")
|
36 |
+
duration_slider = gr.Slider(minimum=5, maximum=180, value=30, step=1, label="Duration (seconds)")
|
37 |
+
|
38 |
+
with gr.Tab("Style from Audio"):
|
39 |
+
reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath")
|
40 |
+
with gr.Tab("Style from Text"):
|
41 |
+
style_prompt = gr.Textbox(label="Style Prompt", lines=3, placeholder="e.g., A high-energy electronic dance track with a strong bassline and euphoric synths.")
|
42 |
+
|
43 |
+
generate_button = gr.Button("Generate Song", variant="primary")
|
44 |
+
|
45 |
+
with gr.Column():
|
46 |
+
gr.Markdown("### Output")
|
47 |
+
output_audio = gr.Audio(label="Generated Song")
|
48 |
+
|
49 |
+
generate_button.click(
|
50 |
+
fn=generate_song,
|
51 |
+
inputs=[reference_audio, lyrics_file, style_prompt, duration_slider],
|
52 |
+
outputs=output_audio,
|
53 |
+
api_name="generate_song"
|
54 |
+
)
|
55 |
+
|
56 |
+
gr.Markdown("### Example Usage")
|
57 |
+
gr.Examples(
|
58 |
+
examples=[
|
59 |
+
[None, "jamify/inputs/Jade Bird - Avalanche.json", "A sad, slow, acoustic country song", 30],
|
60 |
+
["jamify/inputs/Rizzle Kicks, Rachel Chinouriri - Follow Excitement!.mp3", "jamify/inputs/Rizzle Kicks, Rachel Chinouriri - Follow Excitement!.json", "", 45],
|
61 |
+
],
|
62 |
+
inputs=[reference_audio, lyrics_file, style_prompt, duration_slider],
|
63 |
+
outputs=output_audio,
|
64 |
+
fn=generate_song,
|
65 |
+
cache_examples=True
|
66 |
+
)
|
67 |
+
|
68 |
+
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt0.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[{"word": "Every", "start_offset": 259, "end_offset": 267, "start": 20.72, "end": 21.36, "phoneme": "\u025bv\u025di|_"}, {"word": "night", "start_offset": 267, "end_offset": 275, "start": 21.36, "end": 22.0, "phoneme": "na\u026at|_"}, {"word": "in", "start_offset": 279, "end_offset": 283, "start": 22.32, "end": 22.64, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 283, "end_offset": 287, "start": 22.64, "end": 22.96, "phoneme": "ma\u026a|_"}, {"word": "dreams,", "start_offset": 287, "end_offset": 301, "start": 22.96, "end": 24.080000000000002, "phoneme": "dri\u02d0mz,"}, {"word": "I", "start_offset": 309, "end_offset": 313, "start": 24.72, "end": 25.04, "phoneme": "a\u026a|_"}, {"word": "see", "start_offset": 317, "end_offset": 321, "start": 25.36, "end": 25.68, "phoneme": "si\u02d0|_"}, {"word": "you,", "start_offset": 321, "end_offset": 325, "start": 25.68, "end": 26.0, "phoneme": "ju\u02d0,"}, {"word": "I", "start_offset": 340, "end_offset": 344, "start": 27.2, "end": 27.52, "phoneme": "a\u026a|_"}, {"word": "feel", "start_offset": 348, "end_offset": 352, "start": 27.84, "end": 28.16, "phoneme": "fi\u02d0l|_"}, {"word": "you.", "start_offset": 358, "end_offset": 362, "start": 28.64, "end": 28.96, "phoneme": "ju\u02d0."}, {"word": "That", "start_offset": 377, "end_offset": 381, "start": 30.16, "end": 30.48, "phoneme": "\u00f0\u00e6t|_"}, {"word": "is", "start_offset": 385, "end_offset": 389, "start": 30.8, "end": 31.12, "phoneme": "\u026az"}, {"word": "how", "start_offset": 393, "end_offset": 397, "start": 31.44, "end": 31.76, "phoneme": "ha\u028a|_"}, {"word": "I", "start_offset": 401, "end_offset": 405, "start": 32.08, "end": 32.4, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 405, "end_offset": 409, "start": 32.4, "end": 32.72, "phoneme": "no\u028a|_"}, {"word": "you", "start_offset": 413, "end_offset": 417, "start": 33.04, "end": 33.36, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 428, "end_offset": 431, "start": 34.24, "end": 34.480000000000004, "phoneme": "go\u028a|_"}, {"word": "far", "start_offset": 495, "end_offset": 503, "start": 39.6, "end": 40.24, "phoneme": "f\u0251\u02d0r"}, {"word": "across", "start_offset": 507, "end_offset": 517, "start": 40.56, "end": 41.36, "phoneme": "\u0259kr\u0254s|_"}, {"word": "the", "start_offset": 519, "end_offset": 523, "start": 41.52, "end": 41.84, "phoneme": "\u00f0\u0259|_"}, {"word": "distance", "start_offset": 527, "end_offset": 538, "start": 42.160000000000004, "end": 43.04, "phoneme": "d\u026ast\u0259ns|_"}, {"word": "and", "start_offset": 552, "end_offset": 556, "start": 44.160000000000004, "end": 44.480000000000004, "phoneme": "\u0259nd"}, {"word": "spaces", "start_offset": 556, "end_offset": 572, "start": 44.480000000000004, "end": 45.76, "phoneme": "spe\u026as\u0259z"}, {"word": "between", "start_offset": 583, "end_offset": 587, "start": 46.64, "end": 46.96, "phoneme": "b\u026atwi\u02d0n|_"}, {"word": "us.", "start_offset": 602, "end_offset": 606, "start": 48.160000000000004, "end": 48.480000000000004, "phoneme": "\u028cs."}, {"word": "You", "start_offset": 621, "end_offset": 625, "start": 49.68, "end": 50.0, "phoneme": "ju\u02d0|_"}, {"word": "have", "start_offset": 629, "end_offset": 633, "start": 50.32, "end": 50.64, "phoneme": "h\u00e6v"}, {"word": "come", "start_offset": 633, "end_offset": 637, "start": 50.64, "end": 50.96, "phoneme": "k\u028cm|_"}, {"word": "to", "start_offset": 641, "end_offset": 645, "start": 51.28, "end": 51.6, "phoneme": "tu\u02d0|_"}, {"word": "show", "start_offset": 649, "end_offset": 653, "start": 51.92, "end": 52.24, "phoneme": "\u0283o\u028a|_"}, {"word": "you", "start_offset": 655, "end_offset": 659, "start": 52.4, "end": 52.72, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 673, "end_offset": 676, "start": 53.84, "end": 54.08, "phoneme": "go\u028a|_"}, {"word": "near,", "start_offset": 738, "end_offset": 745, "start": 59.04, "end": 59.6, "phoneme": "n\u026ar,"}, {"word": "far,", "start_offset": 768, "end_offset": 776, "start": 61.44, "end": 62.08, "phoneme": "f\u0251\u02d0r,"}, {"word": "wherever", "start_offset": 794, "end_offset": 806, "start": 63.52, "end": 64.48, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 822, "end_offset": 826, "start": 65.76, "end": 66.08, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 826, "end_offset": 830, "start": 66.08, "end": 66.4, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 849, "end_offset": 852, "start": 67.92, "end": 68.16, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 856, "end_offset": 868, "start": 68.48, "end": 69.44, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 875, "end_offset": 878, "start": 70.0, "end": 70.24, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 886, "end_offset": 890, "start": 70.88, "end": 71.2, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 890, "end_offset": 898, "start": 71.2, "end": 71.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 898, "end_offset": 901, "start": 71.84, "end": 72.08, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 916, "end_offset": 920, "start": 73.28, "end": 73.60000000000001, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 982, "end_offset": 985, "start": 78.56, "end": 78.8, "phoneme": "\u0251\u02d0n|_"}, {"word": "small.", "start_offset": 1009, "end_offset": 1017, "start": 80.72, "end": 81.36, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 1037, "end_offset": 1041, "start": 82.96000000000001, "end": 83.28, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 1045, "end_offset": 1049, "start": 83.60000000000001, "end": 83.92, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 1065, "end_offset": 1069, "start": 85.2, "end": 85.52, "phoneme": "\u00f0\u0259|_"}, {"word": "door,", "start_offset": 1069, "end_offset": 1076, "start": 85.52, "end": 86.08, "phoneme": "d\u0254r,"}, {"word": "and", "start_offset": 1090, "end_offset": 1094, "start": 87.2, "end": 87.52, "phoneme": "\u0259nd"}, {"word": "you'll", "start_offset": 1094, "end_offset": 1100, "start": 87.52, "end": 88.0, "phoneme": "j\u028c\u028al|_"}, {"word": "hear", "start_offset": 1103, "end_offset": 1108, "start": 88.24, "end": 88.64, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 1119, "end_offset": 1122, "start": 89.52, "end": 89.76, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1126, "end_offset": 1130, "start": 90.08, "end": 90.4, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 1130, "end_offset": 1138, "start": 90.4, "end": 91.04, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 1141, "end_offset": 1145, "start": 91.28, "end": 91.60000000000001, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 1157, "end_offset": 1161, "start": 92.56, "end": 92.88, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 1165, "end_offset": 1173, "start": 93.2, "end": 93.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 1173, "end_offset": 1177, "start": 93.84, "end": 94.16, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 1185, "end_offset": 1189, "start": 94.8, "end": 95.12, "phoneme": "go\u028a|_"}, {"word": "and", "start_offset": 1211, "end_offset": 1215, "start": 96.88, "end": 97.2, "phoneme": "\u0259nd"}, {"word": "dawn.", "start_offset": 1223, "end_offset": 1233, "start": 97.84, "end": 98.64, "phoneme": "d\u0254n."}, {"word": "Love", "start_offset": 1345, "end_offset": 1353, "start": 107.60000000000001, "end": 108.24000000000001, "phoneme": "l\u028cv"}, {"word": "can", "start_offset": 1356, "end_offset": 1360, "start": 108.48, "end": 108.8, "phoneme": "k\u00e6n|_"}, {"word": "touch", "start_offset": 1360, "end_offset": 1366, "start": 108.8, "end": 109.28, "phoneme": "t\u028ct\u0283|_"}, {"word": "us", "start_offset": 1369, "end_offset": 1373, "start": 109.52, "end": 109.84, "phoneme": "\u028cs|_"}, {"word": "one", "start_offset": 1376, "end_offset": 1380, "start": 110.08, "end": 110.4, "phoneme": "w\u028cn|_"}, {"word": "time", "start_offset": 1384, "end_offset": 1388, "start": 110.72, "end": 111.04, "phoneme": "ta\u026am|_"}, {"word": "and", "start_offset": 1399, "end_offset": 1402, "start": 111.92, "end": 112.16, "phoneme": "\u0259nd"}, {"word": "last", "start_offset": 1406, "end_offset": 1410, "start": 112.48, "end": 112.8, "phoneme": "l\u00e6st|_"}, {"word": "for", "start_offset": 1416, "end_offset": 1420, "start": 113.28, "end": 113.60000000000001, "phoneme": "f\u0254r"}, {"word": "a", "start_offset": 1431, "end_offset": 1435, "start": 114.48, "end": 114.8, "phoneme": "\u0259|_"}, {"word": "lifetime", "start_offset": 1435, "end_offset": 1458, "start": 114.8, "end": 116.64, "phoneme": "la\u026afta\u026am|_"}, {"word": "and", "start_offset": 1471, "end_offset": 1475, "start": 117.68, "end": 118.0, "phoneme": "\u0259nd"}, {"word": "never", "start_offset": 1479, "end_offset": 1483, "start": 118.32000000000001, "end": 118.64, "phoneme": "n\u025bv\u025d"}, {"word": "let", "start_offset": 1487, "end_offset": 1491, "start": 118.96000000000001, "end": 119.28, "phoneme": "l\u025bt|_"}, {"word": "go", "start_offset": 1495, "end_offset": 1499, "start": 119.60000000000001, "end": 119.92, "phoneme": "go\u028a|_"}, {"word": "till", "start_offset": 1503, "end_offset": 1511, "start": 120.24000000000001, "end": 120.88, "phoneme": "t\u026al|_"}, {"word": "we're", "start_offset": 1521, "end_offset": 1528, "start": 121.68, "end": 122.24000000000001, "phoneme": "w\u025d\u02d0|_"}, {"word": "gone.", "start_offset": 1528, "end_offset": 1536, "start": 122.24000000000001, "end": 122.88, "phoneme": "g\u0254n."}, {"word": "Love", "start_offset": 1587, "end_offset": 1596, "start": 126.96000000000001, "end": 127.68, "phoneme": "l\u028cv"}, {"word": "was", "start_offset": 1599, "end_offset": 1603, "start": 127.92, "end": 128.24, "phoneme": "w\u0251\u02d0z"}, {"word": "when", "start_offset": 1607, "end_offset": 1611, "start": 128.56, "end": 128.88, "phoneme": "w\u025bn|_"}, {"word": "I", "start_offset": 1611, "end_offset": 1615, "start": 128.88, "end": 129.2, "phoneme": "a\u026a|_"}, {"word": "loved", "start_offset": 1615, "end_offset": 1626, "start": 129.2, "end": 130.08, "phoneme": "l\u028cvd"}, {"word": "you", "start_offset": 1626, "end_offset": 1630, "start": 130.08, "end": 130.4, "phoneme": "ju\u02d0|_"}, {"word": "one", "start_offset": 1641, "end_offset": 1644, "start": 131.28, "end": 131.52, "phoneme": "w\u028cn|_"}, {"word": "true", "start_offset": 1648, "end_offset": 1656, "start": 131.84, "end": 132.48, "phoneme": "tru\u02d0|_"}, {"word": "time.", "start_offset": 1656, "end_offset": 1660, "start": 132.48, "end": 132.8, "phoneme": "ta\u026am."}, {"word": "I", "start_offset": 1672, "end_offset": 1675, "start": 133.76, "end": 134.0, "phoneme": "a\u026a|_"}, {"word": "hold", "start_offset": 1679, "end_offset": 1687, "start": 134.32, "end": 134.96, "phoneme": "ho\u028ald"}, {"word": "to", "start_offset": 1691, "end_offset": 1693, "start": 135.28, "end": 135.44, "phoneme": "tu\u02d0|_"}, {"word": "in", "start_offset": 1712, "end_offset": 1716, "start": 136.96, "end": 137.28, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1720, "end_offset": 1724, "start": 137.6, "end": 137.92000000000002, "phoneme": "ma\u026a|_"}, {"word": "life", "start_offset": 1724, "end_offset": 1728, "start": 137.92000000000002, "end": 138.24, "phoneme": "la\u026af|_"}, {"word": "will", "start_offset": 1731, "end_offset": 1733, "start": 138.48, "end": 138.64000000000001, "phoneme": "w\u026al|_"}, {"word": "always", "start_offset": 1743, "end_offset": 1747, "start": 139.44, "end": 139.76, "phoneme": "\u0254lwe\u026az"}, {"word": "go", "start_offset": 1763, "end_offset": 1767, "start": 141.04, "end": 141.36, "phoneme": "go\u028a|_"}, {"word": "near", "start_offset": 1830, "end_offset": 1836, "start": 146.4, "end": 146.88, "phoneme": "n\u026ar"}, {"word": "far", "start_offset": 1859, "end_offset": 1867, "start": 148.72, "end": 149.36, "phoneme": "f\u0251\u02d0r"}, {"word": "wherever", "start_offset": 1884, "end_offset": 1896, "start": 150.72, "end": 151.68, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 1914, "end_offset": 1918, "start": 153.12, "end": 153.44, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 1918, "end_offset": 1922, "start": 153.44, "end": 153.76, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 1940, "end_offset": 1943, "start": 155.20000000000002, "end": 155.44, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 1947, "end_offset": 1959, "start": 155.76, "end": 156.72, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 1966, "end_offset": 1970, "start": 157.28, "end": 157.6, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 1974, "end_offset": 1977, "start": 157.92000000000002, "end": 158.16, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 1981, "end_offset": 1986, "start": 158.48, "end": 158.88, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 1990, "end_offset": 1993, "start": 159.20000000000002, "end": 159.44, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 2008, "end_offset": 2011, "start": 160.64000000000001, "end": 160.88, "phoneme": "go\u028a|_"}, {"word": "small.", "start_offset": 2099, "end_offset": 2111, "start": 167.92000000000002, "end": 168.88, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 2127, "end_offset": 2131, "start": 170.16, "end": 170.48, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 2136, "end_offset": 2140, "start": 170.88, "end": 171.20000000000002, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 2156, "end_offset": 2160, "start": 172.48, "end": 172.8, "phoneme": "\u00f0\u0259|_"}, {"word": "door", "start_offset": 2160, "end_offset": 2167, "start": 172.8, "end": 173.36, "phoneme": "d\u0254r"}, {"word": "and", "start_offset": 2181, "end_offset": 2185, "start": 174.48, "end": 174.8, "phoneme": "\u0259nd"}, {"word": "you", "start_offset": 2185, "end_offset": 2187, "start": 174.8, "end": 174.96, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2195, "end_offset": 2203, "start": 175.6, "end": 176.24, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 2209, "end_offset": 2213, "start": 176.72, "end": 177.04, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2217, "end_offset": 2221, "start": 177.36, "end": 177.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2221, "end_offset": 2230, "start": 177.68, "end": 178.4, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2232, "end_offset": 2236, "start": 178.56, "end": 178.88, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2248, "end_offset": 2251, "start": 179.84, "end": 180.08, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2255, "end_offset": 2263, "start": 180.4, "end": 181.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2263, "end_offset": 2266, "start": 181.04, "end": 181.28, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2278, "end_offset": 2282, "start": 182.24, "end": 182.56, "phoneme": "go\u028a|_"}, {"word": "on.", "start_offset": 2286, "end_offset": 2289, "start": 182.88, "end": 183.12, "phoneme": "\u0251\u02d0n."}, {"word": "You", "start_offset": 2557, "end_offset": 2559, "start": 204.56, "end": 204.72, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2587, "end_offset": 2594, "start": 206.96, "end": 207.52, "phoneme": "hi\u02d0r"}, {"word": "there's", "start_offset": 2610, "end_offset": 2620, "start": 208.8, "end": 209.6, "phoneme": "\u00f0\u025brz"}, {"word": "nothing", "start_offset": 2620, "end_offset": 2632, "start": 209.6, "end": 210.56, "phoneme": "n\u028c\u03b8\u026a\u014b|_"}, {"word": "I", "start_offset": 2640, "end_offset": 2644, "start": 211.20000000000002, "end": 211.52, "phoneme": "a\u026a|_"}, {"word": "fear,", "start_offset": 2644, "end_offset": 2651, "start": 211.52, "end": 212.08, "phoneme": "f\u026ar,"}, {"word": "and", "start_offset": 2666, "end_offset": 2669, "start": 213.28, "end": 213.52, "phoneme": "\u0259nd"}, {"word": "I", "start_offset": 2673, "end_offset": 2677, "start": 213.84, "end": 214.16, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 2677, "end_offset": 2681, "start": 214.16, "end": 214.48000000000002, "phoneme": "no\u028a|_"}, {"word": "that", "start_offset": 2693, "end_offset": 2697, "start": 215.44, "end": 215.76, "phoneme": "\u00f0\u00e6t|_"}, {"word": "my", "start_offset": 2701, "end_offset": 2705, "start": 216.08, "end": 216.4, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2705, "end_offset": 2713, "start": 216.4, "end": 217.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2717, "end_offset": 2721, "start": 217.36, "end": 217.68, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2733, "end_offset": 2736, "start": 218.64000000000001, "end": 218.88, "phoneme": "go\u028a|_"}, {"word": "forever", "start_offset": 2852, "end_offset": 2863, "start": 228.16, "end": 229.04, "phoneme": "f\u025d\u025bv\u025d"}, {"word": "this", "start_offset": 2881, "end_offset": 2883, "start": 230.48000000000002, "end": 230.64000000000001, "phoneme": "\u00f0\u026as|_"}, {"word": "way.", "start_offset": 2888, "end_offset": 2892, "start": 231.04, "end": 231.36, "phoneme": "we\u026a."}, {"word": "You", "start_offset": 2908, "end_offset": 2911, "start": 232.64000000000001, "end": 232.88, "phoneme": "ju\u02d0|_"}, {"word": "are", "start_offset": 2914, "end_offset": 2918, "start": 233.12, "end": 233.44, "phoneme": "\u0251\u02d0r"}, {"word": "safe", "start_offset": 2928, "end_offset": 2935, "start": 234.24, "end": 234.8, "phoneme": "se\u026af|_"}, {"word": "in", "start_offset": 2938, "end_offset": 2942, "start": 235.04, "end": 235.36, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2942, "end_offset": 2946, "start": 235.36, "end": 235.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2950, "end_offset": 2957, "start": 236.0, "end": 236.56, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2959, "end_offset": 2963, "start": 236.72, "end": 237.04, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2975, "end_offset": 2978, "start": 238.0, "end": 238.24, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2982, "end_offset": 2990, "start": 238.56, "end": 239.20000000000002, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2990, "end_offset": 2994, "start": 239.20000000000002, "end": 239.52, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 3002, "end_offset": 3005, "start": 240.16, "end": 240.4, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 3009, "end_offset": 3012, "start": 240.72, "end": 240.96, "phoneme": "\u0251\u02d0n|_"}, {"word": "there.", "start_offset": 3028, "end_offset": 3032, "start": 242.24, "end": 242.56, "phoneme": "\u00f0\u025br."}]
|
model.py
CHANGED
@@ -1,511 +1,180 @@
|
|
1 |
-
from transformers import T5EncoderModel,T5TokenizerFast
|
2 |
-
import torch
|
3 |
-
from diffusers import FluxTransformer2DModel
|
4 |
-
from torch import nn
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
from
|
9 |
-
import
|
10 |
-
import torch.nn.functional as F
|
11 |
import numpy as np
|
12 |
-
|
13 |
-
|
14 |
-
from
|
15 |
-
|
16 |
-
from
|
17 |
-
import
|
18 |
-
import
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
def
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
)
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
119 |
-
if not accept_sigmas:
|
120 |
-
raise ValueError(
|
121 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
122 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
123 |
-
)
|
124 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
125 |
-
timesteps = scheduler.timesteps
|
126 |
-
num_inference_steps = len(timesteps)
|
127 |
-
else:
|
128 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
129 |
-
timesteps = scheduler.timesteps
|
130 |
-
return timesteps, num_inference_steps
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
class TangoFlux(nn.Module):
|
139 |
-
|
140 |
-
|
141 |
-
def __init__(self,config,initialize_reference_model=False):
|
142 |
-
|
143 |
-
super().__init__()
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
self.num_layers = config.get('num_layers', 6)
|
148 |
-
self.num_single_layers = config.get('num_single_layers', 18)
|
149 |
-
self.in_channels = config.get('in_channels', 64)
|
150 |
-
self.attention_head_dim = config.get('attention_head_dim', 128)
|
151 |
-
self.joint_attention_dim = config.get('joint_attention_dim', 1024)
|
152 |
-
self.num_attention_heads = config.get('num_attention_heads', 8)
|
153 |
-
self.audio_seq_len = config.get('audio_seq_len', 645)
|
154 |
-
self.max_duration = config.get('max_duration', 30)
|
155 |
-
self.uncondition = config.get('uncondition', False)
|
156 |
-
self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large")
|
157 |
-
|
158 |
-
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
159 |
-
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
|
160 |
-
self.max_text_seq_len = 64
|
161 |
-
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
|
162 |
-
self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name)
|
163 |
-
self.text_embedding_dim = self.text_encoder.config.d_model
|
164 |
-
|
165 |
-
|
166 |
-
self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU())
|
167 |
-
self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration)
|
168 |
-
|
169 |
-
self.transformer = FluxTransformer2DModel(
|
170 |
-
in_channels=self.in_channels,
|
171 |
-
num_layers=self.num_layers,
|
172 |
-
num_single_layers=self.num_single_layers,
|
173 |
-
attention_head_dim=self.attention_head_dim,
|
174 |
-
num_attention_heads=self.num_attention_heads,
|
175 |
-
joint_attention_dim=self.joint_attention_dim,
|
176 |
-
pooled_projection_dim=self.text_embedding_dim,
|
177 |
-
guidance_embeds=False)
|
178 |
-
|
179 |
-
self.beta_dpo = 2000 ## this is used for dpo training
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32):
|
187 |
-
device = self.text_encoder.device
|
188 |
-
sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
|
189 |
-
|
190 |
-
|
191 |
-
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
|
192 |
-
timesteps = timesteps.to(device)
|
193 |
-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
194 |
-
|
195 |
-
sigma = sigmas[step_indices].flatten()
|
196 |
-
while len(sigma.shape) < n_dim:
|
197 |
-
sigma = sigma.unsqueeze(-1)
|
198 |
-
return sigma
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
|
203 |
-
device = self.text_encoder.device
|
204 |
-
batch = self.tokenizer(
|
205 |
-
prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
|
206 |
-
)
|
207 |
-
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
|
208 |
-
|
209 |
-
with torch.no_grad():
|
210 |
-
prompt_embeds = self.text_encoder(
|
211 |
-
input_ids=input_ids, attention_mask=attention_mask
|
212 |
-
)[0]
|
213 |
-
|
214 |
-
prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
|
215 |
-
attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
|
216 |
-
|
217 |
-
# get unconditional embeddings for classifier free guidance
|
218 |
-
uncond_tokens = [""]
|
219 |
-
|
220 |
-
max_length = prompt_embeds.shape[1]
|
221 |
-
uncond_batch = self.tokenizer(
|
222 |
-
uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt",
|
223 |
-
)
|
224 |
-
uncond_input_ids = uncond_batch.input_ids.to(device)
|
225 |
-
uncond_attention_mask = uncond_batch.attention_mask.to(device)
|
226 |
-
|
227 |
-
with torch.no_grad():
|
228 |
-
negative_prompt_embeds = self.text_encoder(
|
229 |
-
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
|
230 |
-
)[0]
|
231 |
-
|
232 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
|
233 |
-
uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
|
234 |
-
|
235 |
-
# For classifier free guidance, we need to do two forward passes.
|
236 |
-
# We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
|
237 |
-
|
238 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
239 |
-
prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
|
240 |
-
boolean_prompt_mask = (prompt_mask == 1).to(device)
|
241 |
-
|
242 |
-
return prompt_embeds, boolean_prompt_mask
|
243 |
-
|
244 |
-
@torch.no_grad()
|
245 |
-
def encode_text(self, prompt):
|
246 |
-
device = self.text_encoder.device
|
247 |
-
batch = self.tokenizer(
|
248 |
-
prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt")
|
249 |
-
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
encoder_hidden_states = self.text_encoder(
|
254 |
-
input_ids=input_ids, attention_mask=attention_mask)[0]
|
255 |
-
|
256 |
-
boolean_encoder_mask = (attention_mask == 1).to(device)
|
257 |
-
|
258 |
-
return encoder_hidden_states, boolean_encoder_mask
|
259 |
-
|
260 |
-
|
261 |
-
def encode_duration(self,duration):
|
262 |
-
return self.duration_emebdder(duration)
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
@torch.no_grad()
|
267 |
-
def inference_flow(self, prompt,
|
268 |
-
num_inference_steps=50,
|
269 |
-
timesteps=None,
|
270 |
-
guidance_scale=3,
|
271 |
-
duration=10,
|
272 |
-
disable_progress=False,
|
273 |
-
num_samples_per_prompt=1):
|
274 |
-
|
275 |
-
'''Only tested for single inference. Haven't test for batch inference'''
|
276 |
-
|
277 |
-
bsz = num_samples_per_prompt
|
278 |
-
device = self.transformer.device
|
279 |
-
scheduler = self.noise_scheduler
|
280 |
-
|
281 |
-
if not isinstance(prompt,list):
|
282 |
-
prompt = [prompt]
|
283 |
-
if not isinstance(duration,torch.Tensor):
|
284 |
-
duration = torch.tensor([duration],device=device)
|
285 |
-
classifier_free_guidance = guidance_scale > 1.0
|
286 |
-
duration_hidden_states = self.encode_duration(duration)
|
287 |
-
if classifier_free_guidance:
|
288 |
-
bsz = 2 * num_samples_per_prompt
|
289 |
-
|
290 |
-
encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt)
|
291 |
-
duration_hidden_states = duration_hidden_states.repeat(bsz,1,1)
|
292 |
-
|
293 |
-
|
294 |
else:
|
|
|
|
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
pooled = torch.nanmean(masked_data, dim=1)
|
302 |
-
pooled_projection = self.fc(pooled)
|
303 |
-
|
304 |
-
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
|
305 |
-
|
306 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
307 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
308 |
-
scheduler,
|
309 |
-
num_inference_steps,
|
310 |
-
device,
|
311 |
-
timesteps,
|
312 |
-
sigmas
|
313 |
-
)
|
314 |
-
|
315 |
-
latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64)
|
316 |
-
weight_dtype = latents.dtype
|
317 |
-
|
318 |
-
progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
|
319 |
-
|
320 |
-
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
|
321 |
-
audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
|
322 |
|
|
|
|
|
|
|
323 |
|
324 |
-
|
325 |
-
latents = latents.to(device)
|
326 |
-
encoder_hidden_states = encoder_hidden_states.to(device)
|
327 |
|
|
|
|
|
|
|
|
|
328 |
|
329 |
-
|
330 |
-
|
331 |
-
latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
|
332 |
-
|
333 |
|
|
|
|
|
|
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
img_ids=audio_ids,
|
344 |
-
return_dict=False,
|
345 |
-
)[0]
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
return latents
|
356 |
-
|
357 |
-
def forward(self,
|
358 |
-
latents,
|
359 |
-
prompt,
|
360 |
-
duration=torch.tensor([10]),
|
361 |
-
sft=True
|
362 |
-
):
|
363 |
-
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
duration_hidden_states = self.encode_duration(duration)
|
373 |
-
|
374 |
|
375 |
-
|
376 |
-
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
|
377 |
-
pooled = torch.nanmean(masked_data, dim=1)
|
378 |
-
pooled_projection = self.fc(pooled)
|
379 |
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
|
384 |
-
audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
|
385 |
|
386 |
-
|
387 |
-
|
388 |
-
if self.uncondition:
|
389 |
-
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
|
390 |
-
if len(mask_indices) > 0:
|
391 |
-
encoder_hidden_states[mask_indices] = 0
|
392 |
-
|
393 |
-
|
394 |
-
noise = torch.randn_like(latents)
|
395 |
-
|
396 |
-
|
397 |
-
u = compute_density_for_timestep_sampling(
|
398 |
-
weighting_scheme='logit_normal',
|
399 |
-
batch_size=bsz,
|
400 |
-
logit_mean=0,
|
401 |
-
logit_std=1,
|
402 |
-
mode_scale=None,
|
403 |
-
)
|
404 |
-
|
405 |
-
|
406 |
-
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
407 |
-
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
|
408 |
-
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
|
409 |
-
|
410 |
-
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
model_pred = self.transformer(
|
415 |
-
hidden_states=noisy_model_input,
|
416 |
-
encoder_hidden_states=encoder_hidden_states,
|
417 |
-
pooled_projections=pooled_projection,
|
418 |
-
img_ids=audio_ids,
|
419 |
-
txt_ids=txt_ids,
|
420 |
-
guidance=None,
|
421 |
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
422 |
-
timestep=timesteps/1000,
|
423 |
-
return_dict=False)[0]
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
target = noise - latents
|
428 |
-
loss = torch.mean(
|
429 |
-
( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
430 |
-
1,
|
431 |
-
)
|
432 |
-
loss = loss.mean()
|
433 |
-
raw_model_loss, raw_ref_loss,implicit_acc,epsilon_diff = 0,0,0,0 ## default this to 0 if doing sft
|
434 |
-
|
435 |
-
else:
|
436 |
-
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
|
437 |
-
pooled_projection = pooled_projection.repeat(2,1)
|
438 |
-
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected
|
439 |
-
u = compute_density_for_timestep_sampling(
|
440 |
-
weighting_scheme='logit_normal',
|
441 |
-
batch_size=bsz//2,
|
442 |
-
logit_mean=0,
|
443 |
-
logit_std=1,
|
444 |
-
mode_scale=None,
|
445 |
-
)
|
446 |
-
|
447 |
-
|
448 |
-
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
449 |
-
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
|
450 |
-
timesteps = timesteps.repeat(2)
|
451 |
-
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
|
452 |
-
|
453 |
-
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
454 |
-
|
455 |
-
model_pred = self.transformer(
|
456 |
-
hidden_states=noisy_model_input,
|
457 |
-
encoder_hidden_states=encoder_hidden_states,
|
458 |
-
pooled_projections=pooled_projection,
|
459 |
-
img_ids=audio_ids,
|
460 |
-
txt_ids=txt_ids,
|
461 |
-
guidance=None,
|
462 |
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
463 |
-
timestep=timesteps/1000,
|
464 |
-
return_dict=False)[0]
|
465 |
-
target = noise - latents
|
466 |
-
|
467 |
-
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
468 |
-
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
|
469 |
-
model_losses_w, model_losses_l = model_losses.chunk(2)
|
470 |
-
model_diff = model_losses_w - model_losses_l
|
471 |
-
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
|
472 |
-
|
473 |
-
|
474 |
-
with torch.no_grad():
|
475 |
-
ref_preds = self.ref_transformer(
|
476 |
-
hidden_states=noisy_model_input,
|
477 |
-
encoder_hidden_states=encoder_hidden_states,
|
478 |
-
pooled_projections=pooled_projection,
|
479 |
-
img_ids=audio_ids,
|
480 |
-
txt_ids=txt_ids,
|
481 |
-
guidance=None,
|
482 |
-
timestep=timesteps/1000,
|
483 |
-
return_dict=False)[0]
|
484 |
-
|
485 |
-
|
486 |
-
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
|
487 |
-
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
488 |
-
|
489 |
-
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
490 |
-
ref_diff = ref_losses_w - ref_losses_l
|
491 |
-
raw_ref_loss = ref_loss.mean()
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
epsilon_diff = torch.max(torch.zeros_like(model_losses_w),
|
498 |
-
ref_losses_w-model_losses_w).mean()
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
scale_term = -0.5 * self.beta_dpo
|
503 |
-
inside_term = scale_term * (model_diff - ref_diff)
|
504 |
-
implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0)
|
505 |
-
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
|
506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
|
508 |
-
return
|
509 |
-
|
510 |
-
|
511 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from huggingface_hub import snapshot_download
|
|
|
6 |
import numpy as np
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from safetensors.torch import load_file
|
10 |
+
|
11 |
+
# Imports from the jamify library
|
12 |
+
from jam.model.cfm import CFM
|
13 |
+
from jam.model.dit import DiT
|
14 |
+
from jam.model.vae import StableAudioOpenVAE
|
15 |
+
from jam.dataset import DiffusionWebDataset, enhance_webdataset_config
|
16 |
+
from muq import MuQMuLan
|
17 |
+
|
18 |
+
# Helper functions adapted from jamify/src/jam/infer.py
|
19 |
+
def get_negative_style_prompt(device, file_path):
|
20 |
+
vocal_style = np.load(file_path)
|
21 |
+
vocal_style = torch.from_numpy(vocal_style).to(device)
|
22 |
+
return vocal_style.half()
|
23 |
+
|
24 |
+
def normalize_audio(audio):
|
25 |
+
audio = audio - audio.mean(-1, keepdim=True)
|
26 |
+
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
|
27 |
+
return audio
|
28 |
+
|
29 |
+
class Jamify:
|
30 |
+
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
|
31 |
+
self.device = torch.device(device)
|
32 |
+
|
33 |
+
# --- FIX: Point to the local jamify repository for config and public files ---
|
34 |
+
#jamify_repo_path = "/Users/cy/Desktop/JAM/jamify"
|
35 |
+
|
36 |
+
print("Downloading main model checkpoint...")
|
37 |
+
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
|
38 |
+
self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
|
39 |
+
|
40 |
+
# Use local config and data files
|
41 |
+
config_path = os.path.join(model_repo_path, "jam_infer.yaml")
|
42 |
+
self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy")
|
43 |
+
tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt")
|
44 |
+
silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt")
|
45 |
+
print("Loading configuration...")
|
46 |
+
self.config = OmegaConf.load(config_path)
|
47 |
+
self.config.data.train_dataset.silence_latent_path = silence_latent_path
|
48 |
+
|
49 |
+
# --- FIX: Override the relative paths in the config with absolute paths ---
|
50 |
+
self.config.data.train_dataset.tokenizer_path = tokenizer_path
|
51 |
+
self.config.evaluation.dataset.tokenizer_path = tokenizer_path
|
52 |
+
self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path
|
53 |
+
|
54 |
+
print("Loading VAE model...")
|
55 |
+
self.vae = StableAudioOpenVAE().to(self.device).eval()
|
56 |
+
|
57 |
+
print("Loading CFM model...")
|
58 |
+
self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path)
|
59 |
+
|
60 |
+
print("Loading MuQ style model...")
|
61 |
+
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
|
62 |
+
|
63 |
+
print("Setting up dataset processor...")
|
64 |
+
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
|
65 |
+
enhance_webdataset_config(dataset_cfg)
|
66 |
+
dataset_cfg.multiple_styles = False
|
67 |
+
self.dataset_processor = DiffusionWebDataset(**dataset_cfg)
|
68 |
+
|
69 |
+
print("Jamify model loaded successfully.")
|
70 |
+
|
71 |
+
def _load_cfm_model(self, model_config, checkpoint_path):
|
72 |
+
dit_config = model_config["dit"].copy()
|
73 |
+
if "text_num_embeds" not in dit_config:
|
74 |
+
dit_config["text_num_embeds"] = 256
|
75 |
+
|
76 |
+
model = CFM(
|
77 |
+
transformer=DiT(**dit_config),
|
78 |
+
**model_config["cfm"]
|
79 |
+
).to(self.device)
|
80 |
+
|
81 |
+
state_dict = load_file(checkpoint_path)
|
82 |
+
model.load_state_dict(state_dict, strict=False)
|
83 |
+
return model.eval()
|
84 |
+
|
85 |
+
def _generate_style_embedding_from_audio(self, audio_path):
|
86 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
87 |
+
if sample_rate != 24000:
|
88 |
+
resampler = torchaudio.transforms.Resample(sample_rate, 24000)
|
89 |
+
waveform = resampler(waveform)
|
90 |
+
if waveform.shape[0] > 1:
|
91 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
92 |
+
|
93 |
+
waveform = waveform.squeeze(0).to(self.device)
|
94 |
+
|
95 |
+
with torch.inference_mode():
|
96 |
+
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30])
|
97 |
+
return style_embedding[0]
|
98 |
+
|
99 |
+
def _generate_style_embedding_from_prompt(self, prompt):
|
100 |
+
with torch.inference_mode():
|
101 |
+
style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
|
102 |
+
return style_embedding
|
103 |
+
|
104 |
+
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50):
|
105 |
+
print("Starting prediction...")
|
106 |
+
|
107 |
+
if reference_audio_path:
|
108 |
+
print(f"Generating style from audio: {reference_audio_path}")
|
109 |
+
style_embedding = self._generate_style_embedding_from_audio(reference_audio_path)
|
110 |
+
elif style_prompt:
|
111 |
+
print(f"Generating style from prompt: '{style_prompt}'")
|
112 |
+
style_embedding = self._generate_style_embedding_from_prompt(style_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
else:
|
114 |
+
print("No style provided, using zero embedding.")
|
115 |
+
style_embedding = torch.zeros(512, device=self.device)
|
116 |
|
117 |
+
print(f"Loading lyrics from: {lyrics_json_path}")
|
118 |
+
with open(lyrics_json_path, 'r') as f:
|
119 |
+
lrc_data = json.load(f)
|
120 |
+
if 'word' not in lrc_data:
|
121 |
+
lrc_data = {'word': lrc_data}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
frame_rate = 21.5
|
124 |
+
num_frames = int(duration_sec * frame_rate)
|
125 |
+
fake_latent = torch.randn(128, num_frames)
|
126 |
|
127 |
+
sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data)
|
|
|
|
|
128 |
|
129 |
+
print("Processing sample...")
|
130 |
+
processed_sample = self.dataset_processor.process_sample_safely(sample_tuple)
|
131 |
+
if processed_sample is None:
|
132 |
+
raise ValueError("Failed to process the provided lyrics and style.")
|
133 |
|
134 |
+
batch = self.dataset_processor.custom_collate_fn([processed_sample])
|
|
|
|
|
|
|
135 |
|
136 |
+
for key, value in batch.items():
|
137 |
+
if isinstance(value, torch.Tensor):
|
138 |
+
batch[key] = value.to(self.device)
|
139 |
|
140 |
+
print("Generating audio latent...")
|
141 |
+
with torch.inference_mode():
|
142 |
+
batch_size = 1
|
143 |
+
text = batch["lrc"]
|
144 |
+
style_prompt_tensor = batch["prompt"]
|
145 |
+
start_time = batch["start_time"]
|
146 |
+
duration_abs = batch["duration_abs"]
|
147 |
+
duration_rel = batch["duration_rel"]
|
|
|
|
|
|
|
148 |
|
149 |
+
cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device)
|
150 |
+
pred_frames = [(0, self.cfm_model.max_frames)]
|
|
|
151 |
|
152 |
+
negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path)
|
153 |
+
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
sample_kwargs = self.config.evaluation.sample_kwargs
|
156 |
+
sample_kwargs.steps = steps
|
157 |
+
latents, _ = self.cfm_model.sample(
|
158 |
+
cond=cond, text=text, style_prompt=style_prompt_tensor,
|
159 |
+
duration_abs=duration_abs, duration_rel=duration_rel,
|
160 |
+
negative_style_prompt=negative_style_prompt, start_time=start_time,
|
161 |
+
latent_pred_segments=pred_frames, **sample_kwargs)
|
|
|
|
|
162 |
|
163 |
+
latent = latents[0][0]
|
|
|
|
|
|
|
164 |
|
165 |
+
print("Decoding latent to audio...")
|
166 |
+
latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
|
167 |
+
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
|
|
|
|
|
168 |
|
169 |
+
pred_audio = normalize_audio(pred_audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
sample_rate = 44100
|
172 |
+
trim_samples = int(duration_sec * sample_rate)
|
173 |
+
if pred_audio.shape[1] > trim_samples:
|
174 |
+
pred_audio = pred_audio[:, :trim_samples]
|
175 |
+
|
176 |
+
output_path = "generated_song.mp3"
|
177 |
+
print(f"Saving audio to {output_path}")
|
178 |
+
torchaudio.save(output_path, pred_audio, sample_rate, format="mp3")
|
179 |
|
180 |
+
return output_path
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,11 +1,35 @@
|
|
1 |
-
torch
|
2 |
-
torchaudio
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
librosa
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
diffusers
|
4 |
+
accelerate
|
5 |
+
safetensors
|
6 |
+
wandb
|
7 |
+
gpustat
|
8 |
+
soundfile
|
9 |
+
muq
|
10 |
+
pyloudnorm
|
11 |
+
mutagen
|
12 |
+
torchdiffeq
|
13 |
+
x_transformers
|
14 |
+
ema_pytorch
|
15 |
librosa
|
16 |
+
jiwer
|
17 |
+
demucs
|
18 |
+
audiobox-aesthetics
|
19 |
+
|
20 |
+
# WebDataset
|
21 |
+
webdataset
|
22 |
+
webdatasetng
|
23 |
+
wids
|
24 |
+
omegaconf
|
25 |
+
|
26 |
+
# DeepPhonemizer
|
27 |
+
unidecode
|
28 |
+
inflect
|
29 |
+
|
30 |
+
# duration prediction
|
31 |
+
openai
|
32 |
+
pyphen
|
33 |
+
syllables
|
34 |
+
git+https://github.com/declare-lab/jamify.git
|
35 |
+
git+https://github.com/xhhhhang/DeepPhonemizer.git
|