Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
597cecf
1
Parent(s):
39d2f14
pt 1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +143 -248
- ctcmodel.py +73 -75
- discriminator_conformer.py +100 -63
- dmd_trainer.py +338 -151
- duration_predictor.py +38 -17
- duration_trainer.py +181 -116
- duration_trainer_with_prompt.py +164 -94
- ecapa_tdnn.py +155 -50
- f5_tts/api.py +29 -18
- f5_tts/eval/ecapa_tdnn.py +91 -19
- f5_tts/eval/eval_infer_batch.py +34 -15
- f5_tts/eval/eval_librispeech_test_clean.py +20 -8
- f5_tts/eval/eval_seedtts_testset.py +17 -7
- f5_tts/eval/eval_utmos.py +9 -3
- f5_tts/eval/utils_eval.py +43 -13
- f5_tts/infer/infer_cli.py +53 -32
- f5_tts/infer/infer_gradio.py +178 -57
- f5_tts/infer/speech_edit.py +30 -14
- f5_tts/infer/utils_infer.py +106 -33
- f5_tts/model/__init__.py +2 -5
- f5_tts/model/backbones/dit.py +63 -30
- f5_tts/model/backbones/mmdit.py +82 -65
- f5_tts/model/backbones/unett.py +45 -23
- f5_tts/model/cfm.py +61 -25
- f5_tts/model/dataset.py +100 -82
- f5_tts/model/modules.py +105 -38
- f5_tts/model/trainer.py +163 -78
- f5_tts/model/utils.py +31 -18
- f5_tts/model_new/__init__.py +0 -1
- f5_tts/model_new/backbones/dit.py +65 -26
- f5_tts/model_new/backbones/mmdit.py +42 -20
- f5_tts/model_new/backbones/unett.py +68 -27
- f5_tts/model_new/cfm.py +41 -18
- f5_tts/model_new/dataset.py +31 -9
- f5_tts/model_new/modules.py +126 -39
- f5_tts/model_new/trainer.py +142 -43
- f5_tts/model_new/utils.py +29 -13
- f5_tts/runtime/triton_trtllm/benchmark.py +106 -28
- f5_tts/runtime/triton_trtllm/client_grpc.py +42 -13
- f5_tts/runtime/triton_trtllm/client_http.py +26 -5
- f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py +125 -33
- f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +62 -17
- f5_tts/runtime/triton_trtllm/patch/__init__.py +7 -11
- f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +13 -4
- f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +98 -43
- f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +3 -2
- f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py +76 -22
- f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +12 -5
- f5_tts/runtime/triton_trtllm/scripts/fill_template.py +6 -2
- f5_tts/scripts/count_max_epoch.py +6 -2
app.py
CHANGED
@@ -1,130 +1,51 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
import torchaudio
|
4 |
-
import numpy as np
|
5 |
import tempfile
|
6 |
import time
|
7 |
from pathlib import Path
|
8 |
-
|
9 |
-
import
|
|
|
10 |
import spaces
|
|
|
|
|
|
|
|
|
11 |
from transformers import pipeline
|
12 |
|
13 |
-
# Import the inference module
|
14 |
from infer import DMOInference
|
15 |
|
16 |
-
|
17 |
-
model = None
|
18 |
-
asr_pipe = None
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
""
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
else torch.float32
|
34 |
-
)
|
35 |
-
|
36 |
-
print("Initializing ASR pipeline...")
|
37 |
-
try:
|
38 |
-
asr_pipe = pipeline(
|
39 |
-
"automatic-speech-recognition",
|
40 |
-
model="openai/whisper-large-v3-turbo",
|
41 |
-
torch_dtype=dtype,
|
42 |
-
device="cpu" # Keep ASR on CPU to save GPU memory
|
43 |
-
)
|
44 |
-
print("ASR pipeline initialized successfully")
|
45 |
-
except Exception as e:
|
46 |
-
print(f"Error initializing ASR pipeline: {e}")
|
47 |
-
asr_pipe = None
|
48 |
-
|
49 |
-
# Transcribe function
|
50 |
def transcribe(ref_audio, language=None):
|
51 |
"""Transcribe audio using the pre-loaded ASR pipeline."""
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
except Exception as e:
|
67 |
-
print(f"Transcription error: {e}")
|
68 |
-
return ""
|
69 |
-
|
70 |
-
def download_models():
|
71 |
-
"""Download models from HuggingFace Hub."""
|
72 |
-
try:
|
73 |
-
print("Downloading models from HuggingFace...")
|
74 |
-
|
75 |
-
# Download student model
|
76 |
-
student_path = hf_hub_download(
|
77 |
-
repo_id="yl4579/DMOSpeech2",
|
78 |
-
filename="model_85000.pt",
|
79 |
-
cache_dir="./models"
|
80 |
-
)
|
81 |
-
|
82 |
-
# Download duration predictor
|
83 |
-
duration_path = hf_hub_download(
|
84 |
-
repo_id="yl4579/DMOSpeech2",
|
85 |
-
filename="model_1500.pt",
|
86 |
-
cache_dir="./models"
|
87 |
-
)
|
88 |
-
|
89 |
-
print(f"Student model: {student_path}")
|
90 |
-
print(f"Duration model: {duration_path}")
|
91 |
-
|
92 |
-
return student_path, duration_path
|
93 |
-
|
94 |
-
except Exception as e:
|
95 |
-
print(f"Error downloading models: {e}")
|
96 |
-
return None, None
|
97 |
-
|
98 |
-
def initialize_model():
|
99 |
-
"""Initialize the model on startup."""
|
100 |
-
global model
|
101 |
-
|
102 |
-
try:
|
103 |
-
# Download models
|
104 |
-
student_path, duration_path = download_models()
|
105 |
-
|
106 |
-
if not student_path or not duration_path:
|
107 |
-
return False, "Failed to download models from HuggingFace"
|
108 |
-
|
109 |
-
# Initialize model
|
110 |
-
model = DMOInference(
|
111 |
-
student_checkpoint_path=student_path,
|
112 |
-
duration_predictor_path=duration_path,
|
113 |
-
device=device,
|
114 |
-
model_type="F5TTS_Base"
|
115 |
-
)
|
116 |
-
|
117 |
-
return True, f"Model loaded successfully on {device.upper()}"
|
118 |
-
|
119 |
-
except Exception as e:
|
120 |
-
return False, f"Error initializing model: {str(e)}"
|
121 |
-
|
122 |
-
# Initialize models on startup
|
123 |
-
print("Initializing models...")
|
124 |
-
model_loaded, status_message = initialize_model()
|
125 |
-
initialize_asr_pipeline() # Initialize ASR pipeline
|
126 |
-
|
127 |
-
@spaces.GPU(duration=120) # Request GPU for up to 120 seconds
|
128 |
def generate_speech(
|
129 |
prompt_audio,
|
130 |
prompt_text,
|
@@ -134,128 +55,115 @@ def generate_speech(
|
|
134 |
custom_teacher_steps,
|
135 |
custom_teacher_stopping_time,
|
136 |
custom_student_start_step,
|
137 |
-
verbose
|
138 |
):
|
139 |
-
"""Generate speech with different configurations."""
|
140 |
-
|
141 |
-
if not model_loaded or model is None:
|
142 |
-
return None, "Model not loaded! Please refresh the page.", "", ""
|
143 |
-
|
144 |
if prompt_audio is None:
|
145 |
-
|
146 |
-
|
147 |
if not target_text:
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
torchaudio.save(output_path, generated_audio, 24000)
|
208 |
-
|
209 |
-
# Format metrics
|
210 |
-
metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio"
|
211 |
-
|
212 |
-
return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}"
|
213 |
-
|
214 |
-
except Exception as e:
|
215 |
-
return None, f"Error: {str(e)}", "", ""
|
216 |
|
217 |
# Create Gradio interface
|
218 |
with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
|
219 |
-
gr.Markdown(
|
|
|
220 |
# 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
|
221 |
|
222 |
Generate natural speech in any voice with just a short reference audio!
|
223 |
-
"""
|
224 |
-
|
|
|
225 |
with gr.Row():
|
226 |
with gr.Column(scale=1):
|
227 |
# Reference audio input
|
228 |
prompt_audio = gr.Audio(
|
229 |
label="📎 Reference Audio",
|
230 |
type="filepath",
|
231 |
-
sources=["upload", "microphone"]
|
232 |
)
|
233 |
-
|
234 |
prompt_text = gr.Textbox(
|
235 |
label="📝 Reference Text (leave empty for auto-transcription)",
|
236 |
placeholder="The text spoken in the reference audio...",
|
237 |
-
lines=2
|
238 |
)
|
239 |
-
|
240 |
target_text = gr.Textbox(
|
241 |
label="✍️ Text to Generate",
|
242 |
placeholder="Enter the text you want to synthesize...",
|
243 |
-
lines=4
|
244 |
)
|
245 |
-
|
246 |
# Generation mode
|
247 |
mode = gr.Radio(
|
248 |
choices=[
|
249 |
"Student Only (4 steps)",
|
250 |
"Teacher-Guided (8 steps)",
|
251 |
"High Diversity (16 steps)",
|
252 |
-
"Custom"
|
253 |
],
|
254 |
value="Teacher-Guided (8 steps)",
|
255 |
label="🚀 Generation Mode",
|
256 |
-
info="Choose speed vs quality/diversity tradeoff"
|
257 |
)
|
258 |
-
|
259 |
# Advanced settings (collapsible)
|
260 |
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
261 |
temperature = gr.Slider(
|
@@ -264,9 +172,9 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
|
|
264 |
value=0.0,
|
265 |
step=0.1,
|
266 |
label="Duration Temperature",
|
267 |
-
info="0 = deterministic, >0 = more variation in speech rhythm"
|
268 |
)
|
269 |
-
|
270 |
with gr.Group(visible=False) as custom_settings:
|
271 |
gr.Markdown("### Custom Mode Settings")
|
272 |
custom_teacher_steps = gr.Slider(
|
@@ -275,60 +183,50 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
|
|
275 |
value=16,
|
276 |
step=1,
|
277 |
label="Teacher Steps",
|
278 |
-
info="More steps = higher quality"
|
279 |
)
|
280 |
-
|
281 |
custom_teacher_stopping_time = gr.Slider(
|
282 |
minimum=0.0,
|
283 |
maximum=1.0,
|
284 |
value=0.07,
|
285 |
step=0.01,
|
286 |
label="Teacher Stopping Time",
|
287 |
-
info="When to switch to student"
|
288 |
)
|
289 |
-
|
290 |
custom_student_start_step = gr.Slider(
|
291 |
minimum=0,
|
292 |
maximum=4,
|
293 |
value=1,
|
294 |
step=1,
|
295 |
label="Student Start Step",
|
296 |
-
info="Which student step to start from"
|
297 |
)
|
298 |
-
|
299 |
verbose = gr.Checkbox(
|
300 |
value=False,
|
301 |
label="Verbose Output",
|
302 |
-
info="Show detailed generation steps"
|
303 |
)
|
304 |
-
|
305 |
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
|
306 |
-
|
307 |
with gr.Column(scale=1):
|
308 |
# Output
|
309 |
output_audio = gr.Audio(
|
310 |
-
label="🔊 Generated Speech",
|
311 |
-
type="filepath",
|
312 |
-
autoplay=True
|
313 |
)
|
314 |
-
|
315 |
-
status = gr.Textbox(
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
label="Performance Metrics",
|
322 |
-
interactive=False
|
323 |
-
)
|
324 |
-
|
325 |
-
info = gr.Textbox(
|
326 |
-
label="Generation Info",
|
327 |
-
interactive=False
|
328 |
-
)
|
329 |
-
|
330 |
# Tips
|
331 |
-
gr.Markdown(
|
|
|
332 |
### 💡 Quick Tips:
|
333 |
|
334 |
- **Auto-transcription**: Leave reference text empty to auto-transcribe
|
@@ -341,8 +239,9 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
|
|
341 |
- Student Only: ~0.05x (20x faster than real-time)
|
342 |
- Teacher-Guided: ~0.10x (10x faster)
|
343 |
- High Diversity: ~0.20x (5x faster)
|
344 |
-
"""
|
345 |
-
|
|
|
346 |
# Event handler
|
347 |
generate_btn.click(
|
348 |
generate_speech,
|
@@ -355,21 +254,17 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
|
|
355 |
custom_teacher_steps,
|
356 |
custom_teacher_stopping_time,
|
357 |
custom_student_start_step,
|
358 |
-
verbose
|
359 |
],
|
360 |
-
outputs=[output_audio, status, metrics, info]
|
361 |
)
|
362 |
-
|
363 |
# Update visibility of custom settings based on mode
|
364 |
def update_custom_visibility(mode):
|
365 |
-
is_custom =
|
366 |
return gr.update(visible=is_custom)
|
367 |
-
|
368 |
-
mode.change(
|
369 |
-
update_custom_visibility,
|
370 |
-
inputs=[mode],
|
371 |
-
outputs=[custom_settings]
|
372 |
-
)
|
373 |
|
374 |
# Launch the app
|
375 |
if __name__ == "__main__":
|
@@ -377,5 +272,5 @@ if __name__ == "__main__":
|
|
377 |
print(f"Warning: Model failed to load - {status_message}")
|
378 |
if not asr_pipe:
|
379 |
print("Warning: ASR pipeline not available - auto-transcription disabled")
|
380 |
-
|
381 |
-
demo.launch()
|
|
|
1 |
+
## IMPORTS ##
|
2 |
+
import os
|
|
|
|
|
3 |
import tempfile
|
4 |
import time
|
5 |
from pathlib import Path
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
import spaces
|
10 |
+
import torch
|
11 |
+
import torchaudio
|
12 |
+
from cached_path import cached_path
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
from transformers import pipeline
|
15 |
|
|
|
16 |
from infer import DMOInference
|
17 |
|
18 |
+
## CUDA DEVICE ##
|
|
|
|
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
+
## LOAD MODELS ##
|
22 |
+
asr_pipe = pipeline(
|
23 |
+
"automatic-speech-recognition", model="openai/whisper-large-v3-turbo", device=device
|
24 |
+
)
|
25 |
+
model = DMOInference(
|
26 |
+
student_checkpoint_path=str(cached_path("hf://yl4579/DMOSpeech2/model_85000.pt")),
|
27 |
+
duration_predictor_path=str(cached_path("hf://yl4579/DMOSpeech2/model_1500.pt")),
|
28 |
+
device=device,
|
29 |
+
model_type="F5TTS_Base",
|
30 |
+
)
|
31 |
+
|
32 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def transcribe(ref_audio, language=None):
|
34 |
"""Transcribe audio using the pre-loaded ASR pipeline."""
|
35 |
+
return asr_pipe(
|
36 |
+
ref_audio,
|
37 |
+
chunk_length_s=30,
|
38 |
+
batch_size=128,
|
39 |
+
generate_kwargs=(
|
40 |
+
{"task": "transcribe", "language": language}
|
41 |
+
if language
|
42 |
+
else {"task": "transcribe"}
|
43 |
+
),
|
44 |
+
return_timestamps=False,
|
45 |
+
)["text"].strip()
|
46 |
+
|
47 |
+
|
48 |
+
@spaces.GPU(duration=120)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def generate_speech(
|
50 |
prompt_audio,
|
51 |
prompt_text,
|
|
|
55 |
custom_teacher_steps,
|
56 |
custom_teacher_stopping_time,
|
57 |
custom_student_start_step,
|
58 |
+
verbose,
|
59 |
):
|
|
|
|
|
|
|
|
|
|
|
60 |
if prompt_audio is None:
|
61 |
+
raise gr.Error("Please upload a reference audio!")
|
62 |
+
|
63 |
if not target_text:
|
64 |
+
raise gr.Error("Please enter text to generate!")
|
65 |
+
|
66 |
+
if not prompt_text and prompt_text != "":
|
67 |
+
prompt_text = transcribe(prompt_audio)
|
68 |
+
|
69 |
+
|
70 |
+
if mode == "Student Only (4 steps)":
|
71 |
+
teacher_steps = 0
|
72 |
+
student_start_step = 0
|
73 |
+
teacher_stopping_time = 1.0
|
74 |
+
elif mode == "Teacher-Guided (8 steps)":
|
75 |
+
teacher_steps = 16
|
76 |
+
teacher_stopping_time = 0.07
|
77 |
+
student_start_step = 1
|
78 |
+
elif mode == "High Diversity (16 steps)":
|
79 |
+
teacher_steps = 24
|
80 |
+
teacher_stopping_time = 0.3
|
81 |
+
student_start_step = 2
|
82 |
+
else: # Custom
|
83 |
+
teacher_steps = custom_teacher_steps
|
84 |
+
teacher_stopping_time = custom_teacher_stopping_time
|
85 |
+
student_start_step = custom_student_start_step
|
86 |
+
|
87 |
+
# Generate speech
|
88 |
+
generated_audio = model.generate(
|
89 |
+
gen_text=target_text,
|
90 |
+
audio_path=prompt_audio,
|
91 |
+
prompt_text=prompt_text if prompt_text else None,
|
92 |
+
teacher_steps=teacher_steps,
|
93 |
+
teacher_stopping_time=teacher_stopping_time,
|
94 |
+
student_start_step=student_start_step,
|
95 |
+
temperature=temperature,
|
96 |
+
verbose=verbose,
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
# Save audio
|
101 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
102 |
+
output_path = tmp_file.name
|
103 |
+
|
104 |
+
if isinstance(generated_audio, np.ndarray):
|
105 |
+
generated_audio = torch.from_numpy(generated_audio)
|
106 |
+
|
107 |
+
if generated_audio.dim() == 1:
|
108 |
+
generated_audio = generated_audio.unsqueeze(0)
|
109 |
+
|
110 |
+
torchaudio.save(output_path, generated_audio, 24000)
|
111 |
+
|
112 |
+
return (
|
113 |
+
output_path,
|
114 |
+
"Success!",
|
115 |
+
(
|
116 |
+
f"Mode: {mode} | Transcribed: {prompt_text[:50]}..."
|
117 |
+
if not prompt_text
|
118 |
+
else f"Mode: {mode}"
|
119 |
+
),
|
120 |
+
)
|
121 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# Create Gradio interface
|
124 |
with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
|
125 |
+
gr.Markdown(
|
126 |
+
f"""
|
127 |
# 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
|
128 |
|
129 |
Generate natural speech in any voice with just a short reference audio!
|
130 |
+
"""
|
131 |
+
)
|
132 |
+
|
133 |
with gr.Row():
|
134 |
with gr.Column(scale=1):
|
135 |
# Reference audio input
|
136 |
prompt_audio = gr.Audio(
|
137 |
label="📎 Reference Audio",
|
138 |
type="filepath",
|
139 |
+
sources=["upload", "microphone"],
|
140 |
)
|
141 |
+
|
142 |
prompt_text = gr.Textbox(
|
143 |
label="📝 Reference Text (leave empty for auto-transcription)",
|
144 |
placeholder="The text spoken in the reference audio...",
|
145 |
+
lines=2,
|
146 |
)
|
147 |
+
|
148 |
target_text = gr.Textbox(
|
149 |
label="✍️ Text to Generate",
|
150 |
placeholder="Enter the text you want to synthesize...",
|
151 |
+
lines=4,
|
152 |
)
|
153 |
+
|
154 |
# Generation mode
|
155 |
mode = gr.Radio(
|
156 |
choices=[
|
157 |
"Student Only (4 steps)",
|
158 |
"Teacher-Guided (8 steps)",
|
159 |
"High Diversity (16 steps)",
|
160 |
+
"Custom",
|
161 |
],
|
162 |
value="Teacher-Guided (8 steps)",
|
163 |
label="🚀 Generation Mode",
|
164 |
+
info="Choose speed vs quality/diversity tradeoff",
|
165 |
)
|
166 |
+
|
167 |
# Advanced settings (collapsible)
|
168 |
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
169 |
temperature = gr.Slider(
|
|
|
172 |
value=0.0,
|
173 |
step=0.1,
|
174 |
label="Duration Temperature",
|
175 |
+
info="0 = deterministic, >0 = more variation in speech rhythm",
|
176 |
)
|
177 |
+
|
178 |
with gr.Group(visible=False) as custom_settings:
|
179 |
gr.Markdown("### Custom Mode Settings")
|
180 |
custom_teacher_steps = gr.Slider(
|
|
|
183 |
value=16,
|
184 |
step=1,
|
185 |
label="Teacher Steps",
|
186 |
+
info="More steps = higher quality",
|
187 |
)
|
188 |
+
|
189 |
custom_teacher_stopping_time = gr.Slider(
|
190 |
minimum=0.0,
|
191 |
maximum=1.0,
|
192 |
value=0.07,
|
193 |
step=0.01,
|
194 |
label="Teacher Stopping Time",
|
195 |
+
info="When to switch to student",
|
196 |
)
|
197 |
+
|
198 |
custom_student_start_step = gr.Slider(
|
199 |
minimum=0,
|
200 |
maximum=4,
|
201 |
value=1,
|
202 |
step=1,
|
203 |
label="Student Start Step",
|
204 |
+
info="Which student step to start from",
|
205 |
)
|
206 |
+
|
207 |
verbose = gr.Checkbox(
|
208 |
value=False,
|
209 |
label="Verbose Output",
|
210 |
+
info="Show detailed generation steps",
|
211 |
)
|
212 |
+
|
213 |
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
|
214 |
+
|
215 |
with gr.Column(scale=1):
|
216 |
# Output
|
217 |
output_audio = gr.Audio(
|
218 |
+
label="🔊 Generated Speech", type="filepath", autoplay=True
|
|
|
|
|
219 |
)
|
220 |
+
|
221 |
+
status = gr.Textbox(label="Status", interactive=False)
|
222 |
+
|
223 |
+
metrics = gr.Textbox(label="Performance Metrics", interactive=False)
|
224 |
+
|
225 |
+
info = gr.Textbox(label="Generation Info", interactive=False)
|
226 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
# Tips
|
228 |
+
gr.Markdown(
|
229 |
+
"""
|
230 |
### 💡 Quick Tips:
|
231 |
|
232 |
- **Auto-transcription**: Leave reference text empty to auto-transcribe
|
|
|
239 |
- Student Only: ~0.05x (20x faster than real-time)
|
240 |
- Teacher-Guided: ~0.10x (10x faster)
|
241 |
- High Diversity: ~0.20x (5x faster)
|
242 |
+
"""
|
243 |
+
)
|
244 |
+
|
245 |
# Event handler
|
246 |
generate_btn.click(
|
247 |
generate_speech,
|
|
|
254 |
custom_teacher_steps,
|
255 |
custom_teacher_stopping_time,
|
256 |
custom_student_start_step,
|
257 |
+
verbose,
|
258 |
],
|
259 |
+
outputs=[output_audio, status, metrics, info],
|
260 |
)
|
261 |
+
|
262 |
# Update visibility of custom settings based on mode
|
263 |
def update_custom_visibility(mode):
|
264 |
+
is_custom = mode == "Custom"
|
265 |
return gr.update(visible=is_custom)
|
266 |
+
|
267 |
+
mode.change(update_custom_visibility, inputs=[mode], outputs=[custom_settings])
|
|
|
|
|
|
|
|
|
268 |
|
269 |
# Launch the app
|
270 |
if __name__ == "__main__":
|
|
|
272 |
print(f"Warning: Model failed to load - {status_message}")
|
273 |
if not asr_pipe:
|
274 |
print("Warning: ASR pipeline not available - auto-transcription disabled")
|
275 |
+
|
276 |
+
demo.launch()
|
ctcmodel.py
CHANGED
@@ -1,36 +1,24 @@
|
|
1 |
-
from torch import nn
|
2 |
-
import torch
|
3 |
import copy
|
4 |
-
|
5 |
from pathlib import Path
|
6 |
-
from torchaudio.models import Conformer
|
7 |
-
|
8 |
|
9 |
-
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from f5_tts.model.utils import list_str_to_tensor
|
13 |
-
from f5_tts.model.utils import lens_to_mask
|
14 |
-
from f5_tts.model.utils import mask_from_frac_lengths
|
15 |
|
|
|
|
|
16 |
|
17 |
-
from f5_tts.model.utils import (
|
18 |
-
default,
|
19 |
-
exists,
|
20 |
-
list_str_to_idx,
|
21 |
-
list_str_to_tensor,
|
22 |
-
lens_to_mask,
|
23 |
-
mask_from_frac_lengths,
|
24 |
-
)
|
25 |
|
26 |
class ResBlock(nn.Module):
|
27 |
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
|
28 |
super().__init__()
|
29 |
self._n_groups = 8
|
30 |
-
self.blocks = nn.ModuleList(
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
def forward(self, x):
|
36 |
for block in self.blocks:
|
@@ -41,70 +29,71 @@ class ResBlock(nn.Module):
|
|
41 |
|
42 |
def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
|
43 |
layers = [
|
44 |
-
nn.Conv1d(
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
nn.ReLU(),
|
46 |
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
47 |
nn.Dropout(p=dropout_p),
|
48 |
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
49 |
nn.ReLU(),
|
50 |
-
nn.Dropout(p=dropout_p)
|
51 |
]
|
52 |
return nn.Sequential(*layers)
|
53 |
|
54 |
|
55 |
class ConformerCTC(nn.Module):
|
56 |
-
def __init__(self,
|
57 |
-
vocab_size,
|
58 |
-
mel_dim=100,
|
59 |
-
num_heads=8,
|
60 |
-
d_hid=512,
|
61 |
-
nlayers=6):
|
62 |
super().__init__()
|
63 |
-
|
64 |
self.mel_proj = nn.Conv1d(mel_dim, d_hid, kernel_size=3, padding=1)
|
65 |
-
|
66 |
self.d_hid = d_hid
|
67 |
-
|
68 |
self.resblock1 = nn.Sequential(
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
self.resblock2 = nn.Sequential(
|
74 |
-
|
75 |
-
|
76 |
-
)
|
77 |
-
|
78 |
|
79 |
self.conf_pre = torch.nn.ModuleList(
|
80 |
-
[
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
87 |
for _ in range(nlayers // 2)
|
88 |
]
|
89 |
)
|
90 |
-
|
91 |
self.conf_after = torch.nn.ModuleList(
|
92 |
-
[
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
99 |
for _ in range(nlayers // 2)
|
100 |
]
|
101 |
)
|
102 |
|
103 |
-
self.out = nn.Linear(d_hid, 1 + vocab_size)
|
104 |
|
105 |
self.ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True).cuda()
|
106 |
|
107 |
-
|
108 |
def forward(self, latent, text=None, text_lens=None):
|
109 |
layers = []
|
110 |
|
@@ -125,20 +114,24 @@ class ConformerCTC(nn.Module):
|
|
125 |
|
126 |
batch_size, time_steps, _ = x.shape
|
127 |
# Create a dummy lengths tensor (all sequences are assumed to be full length).
|
128 |
-
input_lengths = torch.full(
|
|
|
|
|
129 |
|
130 |
-
for layer in
|
131 |
x, _ = layer(x, input_lengths)
|
132 |
layers.append(x.transpose(1, 2))
|
133 |
|
134 |
-
for layer in
|
135 |
x, _ = layer(x, input_lengths)
|
136 |
layers.append(x.transpose(1, 2))
|
137 |
|
138 |
x = self.out(x)
|
139 |
|
140 |
if text_lens is not None and text is not None:
|
141 |
-
loss = self.ctc_loss(
|
|
|
|
|
142 |
return x, layers, loss
|
143 |
else:
|
144 |
return x, layers
|
@@ -147,9 +140,8 @@ class ConformerCTC(nn.Module):
|
|
147 |
if __name__ == "__main__":
|
148 |
from f5_tts.model.utils import get_tokenizer
|
149 |
|
150 |
-
|
151 |
bsz = 16
|
152 |
-
|
153 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
154 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
155 |
dataset_name = "Emilia_ZH_EN"
|
@@ -158,15 +150,17 @@ if __name__ == "__main__":
|
|
158 |
else:
|
159 |
tokenizer_path = dataset_name
|
160 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
161 |
-
|
162 |
-
model = ConformerCTC(
|
163 |
-
|
|
|
|
|
164 |
text = ["hello world"] * bsz
|
165 |
lens = torch.randint(1, 1000, (bsz,)).cuda()
|
166 |
inp = torch.randn(bsz, lens.max(), 80).cuda()
|
167 |
-
|
168 |
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
|
169 |
-
|
170 |
# handle text as string
|
171 |
text_lens = torch.tensor([len(t) for t in text], device=device)
|
172 |
if isinstance(text, list):
|
@@ -198,7 +192,6 @@ if __name__ == "__main__":
|
|
198 |
|
199 |
char_vocab_map = list(vocab_char_map.keys())
|
200 |
|
201 |
-
|
202 |
for batch in best_path:
|
203 |
decoded_sequence = []
|
204 |
previous_token = None
|
@@ -212,10 +205,15 @@ if __name__ == "__main__":
|
|
212 |
decoded_sequences.append(decoded_sequence)
|
213 |
|
214 |
# Convert token indices to characters
|
215 |
-
decoded_texts = [
|
|
|
|
|
|
|
216 |
gt_texts = []
|
217 |
for i in range(text_lens.size(0)):
|
218 |
-
gt_texts.append(
|
219 |
-
|
|
|
|
|
220 |
print(decoded_texts)
|
221 |
-
print(gt_texts)
|
|
|
|
|
|
|
1 |
import copy
|
|
|
2 |
from pathlib import Path
|
|
|
|
|
3 |
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torchaudio.models import Conformer
|
|
|
|
|
|
|
7 |
|
8 |
+
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
|
9 |
+
list_str_to_tensor, mask_from_frac_lengths)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class ResBlock(nn.Module):
|
13 |
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
|
14 |
super().__init__()
|
15 |
self._n_groups = 8
|
16 |
+
self.blocks = nn.ModuleList(
|
17 |
+
[
|
18 |
+
self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p)
|
19 |
+
for i in range(n_conv)
|
20 |
+
]
|
21 |
+
)
|
22 |
|
23 |
def forward(self, x):
|
24 |
for block in self.blocks:
|
|
|
29 |
|
30 |
def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
|
31 |
layers = [
|
32 |
+
nn.Conv1d(
|
33 |
+
hidden_dim,
|
34 |
+
hidden_dim,
|
35 |
+
kernel_size=3,
|
36 |
+
padding=dilation,
|
37 |
+
dilation=dilation,
|
38 |
+
),
|
39 |
nn.ReLU(),
|
40 |
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
41 |
nn.Dropout(p=dropout_p),
|
42 |
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
43 |
nn.ReLU(),
|
44 |
+
nn.Dropout(p=dropout_p),
|
45 |
]
|
46 |
return nn.Sequential(*layers)
|
47 |
|
48 |
|
49 |
class ConformerCTC(nn.Module):
|
50 |
+
def __init__(self, vocab_size, mel_dim=100, num_heads=8, d_hid=512, nlayers=6):
|
|
|
|
|
|
|
|
|
|
|
51 |
super().__init__()
|
52 |
+
|
53 |
self.mel_proj = nn.Conv1d(mel_dim, d_hid, kernel_size=3, padding=1)
|
54 |
+
|
55 |
self.d_hid = d_hid
|
56 |
+
|
57 |
self.resblock1 = nn.Sequential(
|
58 |
+
ResBlock(d_hid), nn.GroupNorm(num_groups=1, num_channels=d_hid)
|
59 |
+
)
|
60 |
+
|
|
|
61 |
self.resblock2 = nn.Sequential(
|
62 |
+
ResBlock(d_hid), nn.GroupNorm(num_groups=1, num_channels=d_hid)
|
63 |
+
)
|
|
|
|
|
64 |
|
65 |
self.conf_pre = torch.nn.ModuleList(
|
66 |
+
[
|
67 |
+
Conformer(
|
68 |
+
input_dim=d_hid,
|
69 |
+
num_heads=num_heads,
|
70 |
+
ffn_dim=d_hid * 2,
|
71 |
+
num_layers=1,
|
72 |
+
depthwise_conv_kernel_size=15,
|
73 |
+
use_group_norm=True,
|
74 |
+
)
|
75 |
for _ in range(nlayers // 2)
|
76 |
]
|
77 |
)
|
78 |
+
|
79 |
self.conf_after = torch.nn.ModuleList(
|
80 |
+
[
|
81 |
+
Conformer(
|
82 |
+
input_dim=d_hid,
|
83 |
+
num_heads=num_heads,
|
84 |
+
ffn_dim=d_hid * 2,
|
85 |
+
num_layers=1,
|
86 |
+
depthwise_conv_kernel_size=7,
|
87 |
+
use_group_norm=True,
|
88 |
+
)
|
89 |
for _ in range(nlayers // 2)
|
90 |
]
|
91 |
)
|
92 |
|
93 |
+
self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank
|
94 |
|
95 |
self.ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True).cuda()
|
96 |
|
|
|
97 |
def forward(self, latent, text=None, text_lens=None):
|
98 |
layers = []
|
99 |
|
|
|
114 |
|
115 |
batch_size, time_steps, _ = x.shape
|
116 |
# Create a dummy lengths tensor (all sequences are assumed to be full length).
|
117 |
+
input_lengths = torch.full(
|
118 |
+
(batch_size,), time_steps, device=x.device, dtype=torch.int64
|
119 |
+
)
|
120 |
|
121 |
+
for layer in self.conf_pre:
|
122 |
x, _ = layer(x, input_lengths)
|
123 |
layers.append(x.transpose(1, 2))
|
124 |
|
125 |
+
for layer in self.conf_after:
|
126 |
x, _ = layer(x, input_lengths)
|
127 |
layers.append(x.transpose(1, 2))
|
128 |
|
129 |
x = self.out(x)
|
130 |
|
131 |
if text_lens is not None and text is not None:
|
132 |
+
loss = self.ctc_loss(
|
133 |
+
x.log_softmax(dim=2).transpose(0, 1), text, input_lengths, text_lens
|
134 |
+
)
|
135 |
return x, layers, loss
|
136 |
else:
|
137 |
return x, layers
|
|
|
140 |
if __name__ == "__main__":
|
141 |
from f5_tts.model.utils import get_tokenizer
|
142 |
|
|
|
143 |
bsz = 16
|
144 |
+
|
145 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
146 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
147 |
dataset_name = "Emilia_ZH_EN"
|
|
|
150 |
else:
|
151 |
tokenizer_path = dataset_name
|
152 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
153 |
+
|
154 |
+
model = ConformerCTC(
|
155 |
+
vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6
|
156 |
+
).cuda()
|
157 |
+
|
158 |
text = ["hello world"] * bsz
|
159 |
lens = torch.randint(1, 1000, (bsz,)).cuda()
|
160 |
inp = torch.randn(bsz, lens.max(), 80).cuda()
|
161 |
+
|
162 |
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
|
163 |
+
|
164 |
# handle text as string
|
165 |
text_lens = torch.tensor([len(t) for t in text], device=device)
|
166 |
if isinstance(text, list):
|
|
|
192 |
|
193 |
char_vocab_map = list(vocab_char_map.keys())
|
194 |
|
|
|
195 |
for batch in best_path:
|
196 |
decoded_sequence = []
|
197 |
previous_token = None
|
|
|
205 |
decoded_sequences.append(decoded_sequence)
|
206 |
|
207 |
# Convert token indices to characters
|
208 |
+
decoded_texts = [
|
209 |
+
"".join([char_vocab_map[token] for token in sequence])
|
210 |
+
for sequence in decoded_sequences
|
211 |
+
]
|
212 |
gt_texts = []
|
213 |
for i in range(text_lens.size(0)):
|
214 |
+
gt_texts.append(
|
215 |
+
"".join([char_vocab_map[token] for token in text[i, : text_lens[i]]])
|
216 |
+
)
|
217 |
+
|
218 |
print(decoded_texts)
|
219 |
+
print(gt_texts)
|
discriminator_conformer.py
CHANGED
@@ -2,30 +2,28 @@
|
|
2 |
|
3 |
from __future__ import annotations
|
4 |
|
|
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
import torchaudio.transforms as trans
|
9 |
-
from pathlib import Path
|
10 |
from torchaudio.models import Conformer
|
11 |
|
12 |
-
from f5_tts.model.utils import (
|
13 |
-
|
14 |
-
|
15 |
-
list_str_to_idx,
|
16 |
-
list_str_to_tensor,
|
17 |
-
lens_to_mask,
|
18 |
-
mask_from_frac_lengths,
|
19 |
-
)
|
20 |
|
21 |
class ResBlock(nn.Module):
|
22 |
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
|
23 |
super().__init__()
|
24 |
self._n_groups = 8
|
25 |
-
self.blocks = nn.ModuleList(
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
29 |
|
30 |
def forward(self, x):
|
31 |
for block in self.blocks:
|
@@ -36,46 +34,67 @@ class ResBlock(nn.Module):
|
|
36 |
|
37 |
def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
|
38 |
layers = [
|
39 |
-
nn.Conv1d(
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
nn.ReLU(),
|
41 |
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
42 |
nn.Dropout(p=dropout_p),
|
43 |
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
44 |
nn.ReLU(),
|
45 |
-
nn.Dropout(p=dropout_p)
|
46 |
]
|
47 |
return nn.Sequential(*layers)
|
48 |
|
|
|
49 |
class ConformerDiscirminator(nn.Module):
|
50 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
super().__init__()
|
52 |
-
|
53 |
self.input_layer = nn.Conv1d(input_dim, channels, kernel_size=3, padding=1)
|
54 |
|
55 |
self.resblock1 = nn.Sequential(
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
self.resblock2 = nn.Sequential(
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
"
|
67 |
-
"
|
68 |
-
"
|
|
|
69 |
"depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2,
|
70 |
-
"use_group_norm": use_group_norm
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
"depthwise_conv_kernel_size": depthwise_conv_kernel_size,
|
77 |
-
"use_group_norm": use_group_norm
|
78 |
-
|
|
|
|
|
79 |
self.linear = nn.Conv1d(channels, 1, kernel_size=1)
|
80 |
|
81 |
def forward(self, x):
|
@@ -89,12 +108,14 @@ class ConformerDiscirminator(nn.Module):
|
|
89 |
x = nn.functional.avg_pool1d(x, 2)
|
90 |
x = self.resblock2(x)
|
91 |
x = nn.functional.avg_pool1d(x, 2)
|
92 |
-
|
93 |
# Transpose to (B, T, C) for the conformer.
|
94 |
x = x.transpose(1, 2)
|
95 |
batch_size, time_steps, _ = x.shape
|
96 |
# Create a dummy lengths tensor (all sequences are assumed to be full length).
|
97 |
-
lengths = torch.full(
|
|
|
|
|
98 |
# The built-in Conformer returns (output, output_lengths); we discard lengths.
|
99 |
|
100 |
x, _ = self.conformer1(x, lengths)
|
@@ -107,12 +128,13 @@ class ConformerDiscirminator(nn.Module):
|
|
107 |
|
108 |
return out
|
109 |
|
|
|
110 |
if __name__ == "__main__":
|
111 |
-
from f5_tts.model.utils import get_tokenizer
|
112 |
from f5_tts.model import DiT
|
|
|
113 |
|
114 |
bsz = 2
|
115 |
-
|
116 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
117 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
118 |
dataset_name = "Emilia_ZH_EN"
|
@@ -121,20 +143,28 @@ if __name__ == "__main__":
|
|
121 |
else:
|
122 |
tokenizer_path = dataset_name
|
123 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
fake_unet = fake_unet.cuda()
|
129 |
|
130 |
text = ["hello world"] * bsz
|
131 |
lens = torch.randint(1, 1000, (bsz,)).cuda()
|
132 |
inp = torch.randn(bsz, lens.max(), 80).cuda()
|
133 |
-
|
134 |
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
|
135 |
|
136 |
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
|
137 |
-
|
138 |
# handle text as string
|
139 |
if isinstance(text, list):
|
140 |
if exists(vocab_char_map):
|
@@ -147,13 +177,17 @@ if __name__ == "__main__":
|
|
147 |
if not exists(lens):
|
148 |
lens = torch.full((batch,), seq_len, device=device)
|
149 |
|
150 |
-
mask = lens_to_mask(
|
|
|
|
|
151 |
frac_lengths_mask = (0.7, 1.0)
|
152 |
-
|
153 |
# get a random span to mask out for training conditionally
|
154 |
-
frac_lengths =
|
|
|
|
|
155 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
156 |
-
|
157 |
if exists(mask):
|
158 |
rand_span_mask &= mask
|
159 |
|
@@ -163,38 +197,41 @@ if __name__ == "__main__":
|
|
163 |
x1 = inp
|
164 |
x0 = torch.randn_like(x1)
|
165 |
t = time.unsqueeze(-1).unsqueeze(-1)
|
166 |
-
|
167 |
phi = (1 - t) * x0 + t * x1
|
168 |
flow = x1 - x0
|
169 |
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
|
170 |
|
171 |
layers = fake_unet(
|
172 |
-
x=phi,
|
173 |
cond=cond,
|
174 |
-
text=text,
|
175 |
-
time=time,
|
176 |
drop_audio_cond=False,
|
177 |
drop_text=False,
|
178 |
-
classify_mode=True
|
179 |
)
|
180 |
|
181 |
# layers = torch.stack(layers, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
182 |
# print(layers.shape)
|
183 |
|
184 |
from ctcmodel import ConformerCTC
|
185 |
-
|
|
|
|
|
|
|
186 |
real_out, layer = ctcmodel(inp)
|
187 |
-
layer = layer[-3:]
|
188 |
-
layer = [
|
|
|
|
|
|
|
189 |
if layer[0].size(1) < layers[0].size(1):
|
190 |
layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer]
|
191 |
-
|
192 |
layers = layer + layers
|
193 |
|
194 |
-
model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512,
|
195 |
-
channels=512
|
196 |
-
)
|
197 |
-
|
198 |
|
199 |
model = model.cuda()
|
200 |
print(model)
|
|
|
2 |
|
3 |
from __future__ import annotations
|
4 |
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
import torchaudio.transforms as trans
|
|
|
11 |
from torchaudio.models import Conformer
|
12 |
|
13 |
+
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
|
14 |
+
list_str_to_tensor, mask_from_frac_lengths)
|
15 |
+
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class ResBlock(nn.Module):
|
18 |
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
|
19 |
super().__init__()
|
20 |
self._n_groups = 8
|
21 |
+
self.blocks = nn.ModuleList(
|
22 |
+
[
|
23 |
+
self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p)
|
24 |
+
for i in range(n_conv)
|
25 |
+
]
|
26 |
+
)
|
27 |
|
28 |
def forward(self, x):
|
29 |
for block in self.blocks:
|
|
|
34 |
|
35 |
def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
|
36 |
layers = [
|
37 |
+
nn.Conv1d(
|
38 |
+
hidden_dim,
|
39 |
+
hidden_dim,
|
40 |
+
kernel_size=3,
|
41 |
+
padding=dilation,
|
42 |
+
dilation=dilation,
|
43 |
+
),
|
44 |
nn.ReLU(),
|
45 |
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
46 |
nn.Dropout(p=dropout_p),
|
47 |
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
48 |
nn.ReLU(),
|
49 |
+
nn.Dropout(p=dropout_p),
|
50 |
]
|
51 |
return nn.Sequential(*layers)
|
52 |
|
53 |
+
|
54 |
class ConformerDiscirminator(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
input_dim,
|
58 |
+
channels=512,
|
59 |
+
num_layers=3,
|
60 |
+
num_heads=8,
|
61 |
+
depthwise_conv_kernel_size=15,
|
62 |
+
use_group_norm=True,
|
63 |
+
):
|
64 |
super().__init__()
|
65 |
+
|
66 |
self.input_layer = nn.Conv1d(input_dim, channels, kernel_size=3, padding=1)
|
67 |
|
68 |
self.resblock1 = nn.Sequential(
|
69 |
+
ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels)
|
70 |
+
)
|
71 |
+
|
|
|
72 |
self.resblock2 = nn.Sequential(
|
73 |
+
ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels)
|
74 |
+
)
|
75 |
+
|
76 |
+
self.conformer1 = Conformer(
|
77 |
+
**{
|
78 |
+
"input_dim": channels,
|
79 |
+
"num_heads": num_heads,
|
80 |
+
"ffn_dim": channels * 2,
|
81 |
+
"num_layers": 1,
|
82 |
"depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2,
|
83 |
+
"use_group_norm": use_group_norm,
|
84 |
+
}
|
85 |
+
)
|
86 |
+
|
87 |
+
self.conformer2 = Conformer(
|
88 |
+
**{
|
89 |
+
"input_dim": channels,
|
90 |
+
"num_heads": num_heads,
|
91 |
+
"ffn_dim": channels * 2,
|
92 |
+
"num_layers": num_layers - 1,
|
93 |
"depthwise_conv_kernel_size": depthwise_conv_kernel_size,
|
94 |
+
"use_group_norm": use_group_norm,
|
95 |
+
}
|
96 |
+
)
|
97 |
+
|
98 |
self.linear = nn.Conv1d(channels, 1, kernel_size=1)
|
99 |
|
100 |
def forward(self, x):
|
|
|
108 |
x = nn.functional.avg_pool1d(x, 2)
|
109 |
x = self.resblock2(x)
|
110 |
x = nn.functional.avg_pool1d(x, 2)
|
111 |
+
|
112 |
# Transpose to (B, T, C) for the conformer.
|
113 |
x = x.transpose(1, 2)
|
114 |
batch_size, time_steps, _ = x.shape
|
115 |
# Create a dummy lengths tensor (all sequences are assumed to be full length).
|
116 |
+
lengths = torch.full(
|
117 |
+
(batch_size,), time_steps, device=x.device, dtype=torch.int64
|
118 |
+
)
|
119 |
# The built-in Conformer returns (output, output_lengths); we discard lengths.
|
120 |
|
121 |
x, _ = self.conformer1(x, lengths)
|
|
|
128 |
|
129 |
return out
|
130 |
|
131 |
+
|
132 |
if __name__ == "__main__":
|
|
|
133 |
from f5_tts.model import DiT
|
134 |
+
from f5_tts.model.utils import get_tokenizer
|
135 |
|
136 |
bsz = 2
|
137 |
+
|
138 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
139 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
140 |
dataset_name = "Emilia_ZH_EN"
|
|
|
143 |
else:
|
144 |
tokenizer_path = dataset_name
|
145 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
146 |
+
|
147 |
+
fake_unet = DiT(
|
148 |
+
dim=1024,
|
149 |
+
depth=22,
|
150 |
+
heads=16,
|
151 |
+
ff_mult=2,
|
152 |
+
text_dim=512,
|
153 |
+
conv_layers=4,
|
154 |
+
text_num_embeds=vocab_size,
|
155 |
+
mel_dim=80,
|
156 |
+
)
|
157 |
|
158 |
fake_unet = fake_unet.cuda()
|
159 |
|
160 |
text = ["hello world"] * bsz
|
161 |
lens = torch.randint(1, 1000, (bsz,)).cuda()
|
162 |
inp = torch.randn(bsz, lens.max(), 80).cuda()
|
163 |
+
|
164 |
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
|
165 |
|
166 |
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
|
167 |
+
|
168 |
# handle text as string
|
169 |
if isinstance(text, list):
|
170 |
if exists(vocab_char_map):
|
|
|
177 |
if not exists(lens):
|
178 |
lens = torch.full((batch,), seq_len, device=device)
|
179 |
|
180 |
+
mask = lens_to_mask(
|
181 |
+
lens, length=seq_len
|
182 |
+
) # useless here, as collate_fn will pad to max length in batch
|
183 |
frac_lengths_mask = (0.7, 1.0)
|
184 |
+
|
185 |
# get a random span to mask out for training conditionally
|
186 |
+
frac_lengths = (
|
187 |
+
torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask)
|
188 |
+
)
|
189 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
190 |
+
|
191 |
if exists(mask):
|
192 |
rand_span_mask &= mask
|
193 |
|
|
|
197 |
x1 = inp
|
198 |
x0 = torch.randn_like(x1)
|
199 |
t = time.unsqueeze(-1).unsqueeze(-1)
|
200 |
+
|
201 |
phi = (1 - t) * x0 + t * x1
|
202 |
flow = x1 - x0
|
203 |
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
|
204 |
|
205 |
layers = fake_unet(
|
206 |
+
x=phi,
|
207 |
cond=cond,
|
208 |
+
text=text,
|
209 |
+
time=time,
|
210 |
drop_audio_cond=False,
|
211 |
drop_text=False,
|
212 |
+
classify_mode=True,
|
213 |
)
|
214 |
|
215 |
# layers = torch.stack(layers, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
216 |
# print(layers.shape)
|
217 |
|
218 |
from ctcmodel import ConformerCTC
|
219 |
+
|
220 |
+
ctcmodel = ConformerCTC(
|
221 |
+
vocab_size=vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6
|
222 |
+
).cuda()
|
223 |
real_out, layer = ctcmodel(inp)
|
224 |
+
layer = layer[-3:] # only use the last 3 layers
|
225 |
+
layer = [
|
226 |
+
F.interpolate(l, mode="nearest", scale_factor=4).transpose(-1, -2)
|
227 |
+
for l in layer
|
228 |
+
]
|
229 |
if layer[0].size(1) < layers[0].size(1):
|
230 |
layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer]
|
231 |
+
|
232 |
layers = layer + layers
|
233 |
|
234 |
+
model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512, channels=512)
|
|
|
|
|
|
|
235 |
|
236 |
model = model.cuda()
|
237 |
print(model)
|
dmd_trainer.py
CHANGED
@@ -1,28 +1,26 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
import os
|
4 |
import gc
|
5 |
-
|
6 |
-
import
|
7 |
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
-
|
11 |
-
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
12 |
-
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
13 |
-
|
14 |
from accelerate import Accelerator
|
15 |
from accelerate.utils import DistributedDataParallelKwargs
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
from unimodel import UniModel
|
18 |
from f5_tts.model import CFM
|
19 |
-
from f5_tts.model.utils import exists, default
|
20 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
21 |
-
|
|
|
22 |
|
23 |
# trainer
|
24 |
|
25 |
-
import math
|
26 |
|
27 |
class RunningStats:
|
28 |
def __init__(self):
|
@@ -41,7 +39,7 @@ class RunningStats:
|
|
41 |
@property
|
42 |
def variance(self):
|
43 |
"""Return the sample variance. Returns NaN if fewer than two samples."""
|
44 |
-
return self.M2 / (self.count - 1) if self.count > 1 else float(
|
45 |
|
46 |
@property
|
47 |
def std(self):
|
@@ -49,7 +47,6 @@ class RunningStats:
|
|
49 |
return math.sqrt(self.variance)
|
50 |
|
51 |
|
52 |
-
|
53 |
class Trainer:
|
54 |
def __init__(
|
55 |
self,
|
@@ -74,7 +71,6 @@ class Trainer:
|
|
74 |
accelerate_kwargs: dict = dict(),
|
75 |
bnb_optimizer: bool = False,
|
76 |
scale: float = 1.0,
|
77 |
-
|
78 |
# training parameters for DMDSpeech
|
79 |
num_student_step: int = 1,
|
80 |
gen_update_ratio: int = 5,
|
@@ -82,7 +78,6 @@ class Trainer:
|
|
82 |
lambda_generator_loss: float = 1.0,
|
83 |
lambda_ctc_loss: float = 1.0,
|
84 |
lambda_sim_loss: float = 1.0,
|
85 |
-
|
86 |
num_GAN: int = 5000,
|
87 |
num_D: int = 500,
|
88 |
num_ctc: int = 5000,
|
@@ -103,7 +98,13 @@ class Trainer:
|
|
103 |
|
104 |
if logger == "wandb":
|
105 |
if exists(wandb_resume_id):
|
106 |
-
init_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
else:
|
108 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
109 |
self.accelerator.init_trackers(
|
@@ -130,7 +131,9 @@ class Trainer:
|
|
130 |
self.epochs = epochs
|
131 |
self.num_warmup_updates = num_warmup_updates
|
132 |
self.save_per_updates = save_per_updates
|
133 |
-
self.last_per_steps = default(
|
|
|
|
|
134 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
135 |
|
136 |
self.batch_size = batch_size
|
@@ -142,41 +145,56 @@ class Trainer:
|
|
142 |
self.noise_scheduler = noise_scheduler
|
143 |
|
144 |
self.duration_predictor = duration_predictor
|
145 |
-
|
146 |
self.log_step = log_step
|
147 |
|
148 |
-
self.gen_update_ratio = gen_update_ratio
|
149 |
-
self.lambda_discriminator_loss =
|
150 |
-
|
151 |
-
|
152 |
-
self.
|
153 |
-
|
|
|
|
|
|
|
|
|
154 |
# create distillation schedule for student model
|
155 |
-
self.student_steps = (
|
156 |
-
|
157 |
-
|
158 |
-
self.
|
159 |
-
self.
|
160 |
-
self.
|
161 |
-
self.
|
162 |
-
self.
|
163 |
-
self.num_simu = num_simu # number of steps before using simulated data
|
164 |
|
165 |
# Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible
|
166 |
if bnb_optimizer:
|
167 |
import bitsandbytes as bnb
|
168 |
-
|
169 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
170 |
else:
|
171 |
-
self.optimizer_generator = AdamW(
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
|
174 |
-
self.model, self.optimizer_generator, self.optimizer_guidance =
|
|
|
|
|
|
|
|
|
175 |
|
176 |
self.generator_norm = RunningStats()
|
177 |
self.guidance_norm = RunningStats()
|
178 |
|
179 |
-
|
180 |
@property
|
181 |
def is_main(self):
|
182 |
return self.accelerator.is_main_process
|
@@ -186,8 +204,12 @@ class Trainer:
|
|
186 |
if self.is_main:
|
187 |
checkpoint = dict(
|
188 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
189 |
-
optimizer_generator_state_dict=self.accelerator.unwrap_model(
|
190 |
-
|
|
|
|
|
|
|
|
|
191 |
scheduler_generator_state_dict=self.scheduler_generator.state_dict(),
|
192 |
scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(),
|
193 |
step=step,
|
@@ -196,10 +218,14 @@ class Trainer:
|
|
196 |
if not os.path.exists(self.checkpoint_path):
|
197 |
os.makedirs(self.checkpoint_path)
|
198 |
if last:
|
199 |
-
self.accelerator.save(
|
|
|
|
|
200 |
print(f"Saved last checkpoint at step {step}")
|
201 |
else:
|
202 |
-
self.accelerator.save(
|
|
|
|
|
203 |
|
204 |
def load_checkpoint(self):
|
205 |
if (
|
@@ -218,9 +244,15 @@ class Trainer:
|
|
218 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
219 |
)[-1]
|
220 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
221 |
-
checkpoint = torch.load(
|
|
|
|
|
|
|
|
|
222 |
|
223 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
|
|
|
|
224 |
# self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"])
|
225 |
# self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"])
|
226 |
# if self.scheduler_guidance:
|
@@ -232,9 +264,14 @@ class Trainer:
|
|
232 |
del checkpoint
|
233 |
gc.collect()
|
234 |
return step
|
235 |
-
|
236 |
|
237 |
-
def train(
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
if exists(resumable_with_seed):
|
239 |
generator = torch.Generator()
|
240 |
generator.manual_seed(resumable_with_seed)
|
@@ -256,7 +293,11 @@ class Trainer:
|
|
256 |
self.accelerator.even_batches = False
|
257 |
sampler = SequentialSampler(train_dataset)
|
258 |
batch_sampler = DynamicBatchSampler(
|
259 |
-
sampler,
|
|
|
|
|
|
|
|
|
260 |
)
|
261 |
train_dataloader = DataLoader(
|
262 |
train_dataset,
|
@@ -267,29 +308,63 @@ class Trainer:
|
|
267 |
batch_sampler=batch_sampler,
|
268 |
)
|
269 |
else:
|
270 |
-
raise ValueError(
|
|
|
|
|
271 |
|
272 |
# accelerator.prepare() dispatches batches to devices;
|
273 |
# which means the length of dataloader calculated before, should consider the number of devices
|
274 |
-
warmup_steps =
|
275 |
-
|
276 |
-
)
|
277 |
-
|
278 |
# consider a fixed warmup steps while using accelerate multi-gpu ddp
|
279 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
280 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
281 |
decay_steps = total_steps - warmup_steps
|
282 |
-
|
283 |
-
warmup_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps))
|
284 |
-
decay_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps // (self.gen_update_ratio * self.grad_accumulation_steps))
|
285 |
-
self.scheduler_generator = SequentialLR(self.optimizer_generator, schedulers=[warmup_scheduler_generator, decay_scheduler_generator], milestones=[warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)])
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
-
train_dataloader, self.scheduler_generator, self.scheduler_guidance =
|
292 |
-
|
|
|
|
|
293 |
) # actual steps = 1 gpu steps / gpus
|
294 |
start_step = self.load_checkpoint()
|
295 |
global_step = start_step
|
@@ -298,7 +373,9 @@ class Trainer:
|
|
298 |
orig_epoch_step = len(train_dataloader)
|
299 |
skipped_epoch = int(start_step // orig_epoch_step)
|
300 |
skipped_batch = start_step % orig_epoch_step
|
301 |
-
skipped_dataloader = self.accelerator.skip_first_batches(
|
|
|
|
|
302 |
else:
|
303 |
skipped_epoch = 0
|
304 |
|
@@ -323,48 +400,59 @@ class Trainer:
|
|
323 |
|
324 |
for batch in progress_bar:
|
325 |
update_generator = global_step % self.gen_update_ratio == 0
|
326 |
-
|
327 |
with self.accelerator.accumulate(self.model):
|
328 |
metrics = {}
|
329 |
text_inputs = batch["text"]
|
330 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
331 |
mel_lengths = batch["mel_lengths"]
|
332 |
-
|
333 |
mel_spec = mel_spec / self.scale
|
334 |
-
|
335 |
-
guidance_loss_dict, guidance_log_dict = self.model(
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
342 |
|
343 |
# if self.GAN and update_generator:
|
344 |
# # only add discriminator loss if GAN is enabled and generator is being updated
|
345 |
# guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0)
|
346 |
# metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"]
|
347 |
# self.accelerator.backward(guidance_cls_loss, retain_graph=True)
|
348 |
-
|
349 |
# if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
350 |
# metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
351 |
|
352 |
guidance_loss = 0
|
353 |
guidance_loss += guidance_loss_dict["loss_fake_mean"]
|
354 |
-
metrics[
|
355 |
metrics["loss/guidance_loss"] = guidance_loss
|
356 |
|
357 |
if self.GAN and update_generator:
|
358 |
# only add discriminator loss if GAN is enabled and generator is being updated
|
359 |
-
guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
guidance_loss += guidance_cls_loss
|
363 |
-
|
364 |
self.accelerator.backward(guidance_loss)
|
365 |
|
366 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
367 |
-
metrics[
|
|
|
|
|
|
|
|
|
368 |
|
369 |
# if self.guidance_norm.count < 100:
|
370 |
# self.guidance_norm.update(metrics['grad_norm_guidance'])
|
@@ -376,20 +464,20 @@ class Trainer:
|
|
376 |
# elif self.guidance_norm.count >= 100:
|
377 |
# self.guidance_norm.update(metrics['grad_norm_guidance'])
|
378 |
|
379 |
-
|
380 |
self.optimizer_guidance.step()
|
381 |
self.scheduler_guidance.step()
|
382 |
self.optimizer_guidance.zero_grad()
|
383 |
self.optimizer_generator.zero_grad() # zero out the generator's gradient as well
|
384 |
-
|
385 |
if update_generator:
|
386 |
-
generator_loss_dict, generator_log_dict = self.model(
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
393 |
# if self.GAN:
|
394 |
# gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0)
|
395 |
# metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"]
|
@@ -402,32 +490,57 @@ class Trainer:
|
|
402 |
generator_loss = 0
|
403 |
generator_loss += generator_loss_dict["loss_dm"]
|
404 |
if "loss_mse" in generator_loss_dict:
|
405 |
-
generator_loss += generator_loss_dict["loss_mse"]
|
406 |
-
generator_loss += generator_loss_dict["loss_ctc"] * (
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
if self.GAN:
|
410 |
-
gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
generator_loss += gen_cls_loss
|
413 |
|
414 |
-
metrics[
|
415 |
-
metrics[
|
416 |
-
|
417 |
-
metrics[
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
if
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
self.accelerator.backward(generator_loss)
|
426 |
|
427 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
428 |
-
metrics[
|
|
|
|
|
|
|
|
|
429 |
# self.generator_norm.update(metrics['grad_norm_generator'])
|
430 |
-
|
431 |
# if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std:
|
432 |
# self.optimizer_generator.zero_grad()
|
433 |
# self.optimizer_guidance.zero_grad()
|
@@ -440,89 +553,165 @@ class Trainer:
|
|
440 |
self.optimizer_generator.zero_grad()
|
441 |
self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well
|
442 |
|
443 |
-
|
444 |
global_step += 1
|
445 |
|
446 |
if self.accelerator.is_local_main_process:
|
447 |
-
self.accelerator.log(
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
# log the first batch of the epoch
|
455 |
with torch.no_grad():
|
456 |
-
generator_input =
|
|
|
|
|
|
|
|
|
|
|
457 |
generator_input = vocoder.decode(generator_input.float().cpu())
|
458 |
generator_input = wandb.Audio(
|
459 |
generator_input.float().numpy().squeeze(),
|
460 |
sample_rate=24000,
|
461 |
-
caption="time: "
|
|
|
462 |
)
|
463 |
|
464 |
-
generator_output =
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
generator_output = wandb.Audio(
|
467 |
generator_output.float().numpy().squeeze(),
|
468 |
sample_rate=24000,
|
469 |
-
caption="time: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
)
|
471 |
-
|
472 |
-
generator_cond = generator_log_dict['generator_cond'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
|
473 |
generator_cond = vocoder.decode(generator_cond.float().cpu())
|
474 |
generator_cond = wandb.Audio(
|
475 |
generator_cond.float().numpy().squeeze(),
|
476 |
sample_rate=24000,
|
477 |
-
caption="time: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
)
|
479 |
-
|
480 |
-
ground_truth = generator_log_dict['ground_truth'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
|
481 |
ground_truth = vocoder.decode(ground_truth.float().cpu())
|
482 |
ground_truth = wandb.Audio(
|
483 |
ground_truth.float().numpy().squeeze(),
|
484 |
sample_rate=24000,
|
485 |
-
caption="time: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
)
|
487 |
-
|
488 |
-
dmtrain_noisy_inp = generator_log_dict['dmtrain_noisy_inp'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
|
489 |
-
dmtrain_noisy_inp = vocoder.decode(dmtrain_noisy_inp.float().cpu())
|
490 |
dmtrain_noisy_inp = wandb.Audio(
|
491 |
dmtrain_noisy_inp.float().numpy().squeeze(),
|
492 |
sample_rate=24000,
|
493 |
-
caption="dmtrain_time: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
)
|
495 |
-
|
496 |
-
dmtrain_pred_real_image = generator_log_dict['dmtrain_pred_real_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
|
497 |
-
dmtrain_pred_real_image = vocoder.decode(dmtrain_pred_real_image.float().cpu())
|
498 |
dmtrain_pred_real_image = wandb.Audio(
|
499 |
dmtrain_pred_real_image.float().numpy().squeeze(),
|
500 |
sample_rate=24000,
|
501 |
-
caption="dmtrain_time: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
)
|
503 |
-
|
504 |
-
dmtrain_pred_fake_image = generator_log_dict['dmtrain_pred_fake_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
|
505 |
-
dmtrain_pred_fake_image = vocoder.decode(dmtrain_pred_fake_image.float().cpu())
|
506 |
dmtrain_pred_fake_image = wandb.Audio(
|
507 |
dmtrain_pred_fake_image.float().numpy().squeeze(),
|
508 |
sample_rate=24000,
|
509 |
-
caption="dmtrain_time: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
)
|
511 |
-
|
512 |
-
|
513 |
-
self.accelerator.log({"noisy_input": generator_input,
|
514 |
-
"output": generator_output,
|
515 |
-
"cond": generator_cond,
|
516 |
-
"ground_truth": ground_truth,
|
517 |
-
"dmtrain_noisy_inp": dmtrain_noisy_inp,
|
518 |
-
"dmtrain_pred_real_image": dmtrain_pred_real_image,
|
519 |
-
"dmtrain_pred_fake_image": dmtrain_pred_fake_image,
|
520 |
-
|
521 |
-
}, step=global_step)
|
522 |
|
523 |
progress_bar.set_postfix(step=str(global_step), metrics=metrics)
|
524 |
|
525 |
-
if
|
|
|
|
|
|
|
526 |
self.save_checkpoint(global_step)
|
527 |
|
528 |
if global_step % self.last_per_steps == 0:
|
@@ -531,5 +720,3 @@ class Trainer:
|
|
531 |
self.save_checkpoint(global_step, last=True)
|
532 |
|
533 |
self.accelerator.end_training()
|
534 |
-
|
535 |
-
|
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
import gc
|
4 |
+
import math
|
5 |
+
import os
|
6 |
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
+
import wandb
|
|
|
|
|
|
|
10 |
from accelerate import Accelerator
|
11 |
from accelerate.utils import DistributedDataParallelKwargs
|
12 |
+
from torch.optim import AdamW
|
13 |
+
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
14 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
15 |
+
from tqdm import tqdm
|
16 |
|
|
|
17 |
from f5_tts.model import CFM
|
|
|
18 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
19 |
+
from f5_tts.model.utils import default, exists
|
20 |
+
from unimodel import UniModel
|
21 |
|
22 |
# trainer
|
23 |
|
|
|
24 |
|
25 |
class RunningStats:
|
26 |
def __init__(self):
|
|
|
39 |
@property
|
40 |
def variance(self):
|
41 |
"""Return the sample variance. Returns NaN if fewer than two samples."""
|
42 |
+
return self.M2 / (self.count - 1) if self.count > 1 else float("nan")
|
43 |
|
44 |
@property
|
45 |
def std(self):
|
|
|
47 |
return math.sqrt(self.variance)
|
48 |
|
49 |
|
|
|
50 |
class Trainer:
|
51 |
def __init__(
|
52 |
self,
|
|
|
71 |
accelerate_kwargs: dict = dict(),
|
72 |
bnb_optimizer: bool = False,
|
73 |
scale: float = 1.0,
|
|
|
74 |
# training parameters for DMDSpeech
|
75 |
num_student_step: int = 1,
|
76 |
gen_update_ratio: int = 5,
|
|
|
78 |
lambda_generator_loss: float = 1.0,
|
79 |
lambda_ctc_loss: float = 1.0,
|
80 |
lambda_sim_loss: float = 1.0,
|
|
|
81 |
num_GAN: int = 5000,
|
82 |
num_D: int = 500,
|
83 |
num_ctc: int = 5000,
|
|
|
98 |
|
99 |
if logger == "wandb":
|
100 |
if exists(wandb_resume_id):
|
101 |
+
init_kwargs = {
|
102 |
+
"wandb": {
|
103 |
+
"resume": "allow",
|
104 |
+
"name": wandb_run_name,
|
105 |
+
"id": wandb_resume_id,
|
106 |
+
}
|
107 |
+
}
|
108 |
else:
|
109 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
110 |
self.accelerator.init_trackers(
|
|
|
131 |
self.epochs = epochs
|
132 |
self.num_warmup_updates = num_warmup_updates
|
133 |
self.save_per_updates = save_per_updates
|
134 |
+
self.last_per_steps = default(
|
135 |
+
last_per_steps, save_per_updates * grad_accumulation_steps
|
136 |
+
)
|
137 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
138 |
|
139 |
self.batch_size = batch_size
|
|
|
145 |
self.noise_scheduler = noise_scheduler
|
146 |
|
147 |
self.duration_predictor = duration_predictor
|
148 |
+
|
149 |
self.log_step = log_step
|
150 |
|
151 |
+
self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update
|
152 |
+
self.lambda_discriminator_loss = (
|
153 |
+
lambda_discriminator_loss # weight for discriminator loss (L_adv)
|
154 |
+
)
|
155 |
+
self.lambda_generator_loss = (
|
156 |
+
lambda_generator_loss # weight for generator loss (L_adv)
|
157 |
+
)
|
158 |
+
self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss
|
159 |
+
self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss
|
160 |
+
|
161 |
# create distillation schedule for student model
|
162 |
+
self.student_steps = torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]
|
163 |
+
|
164 |
+
self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training
|
165 |
+
self.num_GAN = num_GAN # number of steps before adversarial training
|
166 |
+
self.num_D = num_D # number of steps to train the discriminator before adversarial training
|
167 |
+
self.num_ctc = num_ctc # number of steps before CTC training
|
168 |
+
self.num_sim = num_sim # number of steps before similarity training
|
169 |
+
self.num_simu = num_simu # number of steps before using simulated data
|
|
|
170 |
|
171 |
# Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible
|
172 |
if bnb_optimizer:
|
173 |
import bitsandbytes as bnb
|
174 |
+
|
175 |
+
self.optimizer_generator = bnb.optim.AdamW8bit(
|
176 |
+
self.model.feedforward_model.parameters(), lr=learning_rate
|
177 |
+
)
|
178 |
+
self.optimizer_guidance = bnb.optim.AdamW8bit(
|
179 |
+
self.model.guidance_model.parameters(), lr=learning_rate
|
180 |
+
)
|
181 |
else:
|
182 |
+
self.optimizer_generator = AdamW(
|
183 |
+
self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7
|
184 |
+
)
|
185 |
+
self.optimizer_guidance = AdamW(
|
186 |
+
self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7
|
187 |
+
)
|
188 |
|
189 |
+
self.model, self.optimizer_generator, self.optimizer_guidance = (
|
190 |
+
self.accelerator.prepare(
|
191 |
+
self.model, self.optimizer_generator, self.optimizer_guidance
|
192 |
+
)
|
193 |
+
)
|
194 |
|
195 |
self.generator_norm = RunningStats()
|
196 |
self.guidance_norm = RunningStats()
|
197 |
|
|
|
198 |
@property
|
199 |
def is_main(self):
|
200 |
return self.accelerator.is_main_process
|
|
|
204 |
if self.is_main:
|
205 |
checkpoint = dict(
|
206 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
207 |
+
optimizer_generator_state_dict=self.accelerator.unwrap_model(
|
208 |
+
self.optimizer_generator
|
209 |
+
).state_dict(),
|
210 |
+
optimizer_guidance_state_dict=self.accelerator.unwrap_model(
|
211 |
+
self.optimizer_guidance
|
212 |
+
).state_dict(),
|
213 |
scheduler_generator_state_dict=self.scheduler_generator.state_dict(),
|
214 |
scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(),
|
215 |
step=step,
|
|
|
218 |
if not os.path.exists(self.checkpoint_path):
|
219 |
os.makedirs(self.checkpoint_path)
|
220 |
if last:
|
221 |
+
self.accelerator.save(
|
222 |
+
checkpoint, f"{self.checkpoint_path}/model_last.pt"
|
223 |
+
)
|
224 |
print(f"Saved last checkpoint at step {step}")
|
225 |
else:
|
226 |
+
self.accelerator.save(
|
227 |
+
checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
|
228 |
+
)
|
229 |
|
230 |
def load_checkpoint(self):
|
231 |
if (
|
|
|
244 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
245 |
)[-1]
|
246 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
247 |
+
checkpoint = torch.load(
|
248 |
+
f"{self.checkpoint_path}/{latest_checkpoint}",
|
249 |
+
weights_only=True,
|
250 |
+
map_location="cpu",
|
251 |
+
)
|
252 |
|
253 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
254 |
+
checkpoint["model_state_dict"], strict=False
|
255 |
+
)
|
256 |
# self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"])
|
257 |
# self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"])
|
258 |
# if self.scheduler_guidance:
|
|
|
264 |
del checkpoint
|
265 |
gc.collect()
|
266 |
return step
|
|
|
267 |
|
268 |
+
def train(
|
269 |
+
self,
|
270 |
+
train_dataset: Dataset,
|
271 |
+
num_workers=64,
|
272 |
+
resumable_with_seed: int = None,
|
273 |
+
vocoder: nn.Module = None,
|
274 |
+
):
|
275 |
if exists(resumable_with_seed):
|
276 |
generator = torch.Generator()
|
277 |
generator.manual_seed(resumable_with_seed)
|
|
|
293 |
self.accelerator.even_batches = False
|
294 |
sampler = SequentialSampler(train_dataset)
|
295 |
batch_sampler = DynamicBatchSampler(
|
296 |
+
sampler,
|
297 |
+
self.batch_size,
|
298 |
+
max_samples=self.max_samples,
|
299 |
+
random_seed=resumable_with_seed,
|
300 |
+
drop_last=False,
|
301 |
)
|
302 |
train_dataloader = DataLoader(
|
303 |
train_dataset,
|
|
|
308 |
batch_sampler=batch_sampler,
|
309 |
)
|
310 |
else:
|
311 |
+
raise ValueError(
|
312 |
+
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
|
313 |
+
)
|
314 |
|
315 |
# accelerator.prepare() dispatches batches to devices;
|
316 |
# which means the length of dataloader calculated before, should consider the number of devices
|
317 |
+
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes
|
318 |
+
|
|
|
|
|
319 |
# consider a fixed warmup steps while using accelerate multi-gpu ddp
|
320 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
321 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
322 |
decay_steps = total_steps - warmup_steps
|
|
|
|
|
|
|
|
|
323 |
|
324 |
+
warmup_scheduler_generator = LinearLR(
|
325 |
+
self.optimizer_generator,
|
326 |
+
start_factor=1e-8,
|
327 |
+
end_factor=1.0,
|
328 |
+
total_iters=warmup_steps
|
329 |
+
// (self.gen_update_ratio * self.grad_accumulation_steps),
|
330 |
+
)
|
331 |
+
decay_scheduler_generator = LinearLR(
|
332 |
+
self.optimizer_generator,
|
333 |
+
start_factor=1.0,
|
334 |
+
end_factor=1e-8,
|
335 |
+
total_iters=decay_steps
|
336 |
+
// (self.gen_update_ratio * self.grad_accumulation_steps),
|
337 |
+
)
|
338 |
+
self.scheduler_generator = SequentialLR(
|
339 |
+
self.optimizer_generator,
|
340 |
+
schedulers=[warmup_scheduler_generator, decay_scheduler_generator],
|
341 |
+
milestones=[
|
342 |
+
warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)
|
343 |
+
],
|
344 |
+
)
|
345 |
+
|
346 |
+
warmup_scheduler_guidance = LinearLR(
|
347 |
+
self.optimizer_guidance,
|
348 |
+
start_factor=1e-8,
|
349 |
+
end_factor=1.0,
|
350 |
+
total_iters=warmup_steps,
|
351 |
+
)
|
352 |
+
decay_scheduler_guidance = LinearLR(
|
353 |
+
self.optimizer_guidance,
|
354 |
+
start_factor=1.0,
|
355 |
+
end_factor=1e-8,
|
356 |
+
total_iters=decay_steps,
|
357 |
+
)
|
358 |
+
self.scheduler_guidance = SequentialLR(
|
359 |
+
self.optimizer_guidance,
|
360 |
+
schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance],
|
361 |
+
milestones=[warmup_steps],
|
362 |
+
)
|
363 |
|
364 |
+
train_dataloader, self.scheduler_generator, self.scheduler_guidance = (
|
365 |
+
self.accelerator.prepare(
|
366 |
+
train_dataloader, self.scheduler_generator, self.scheduler_guidance
|
367 |
+
)
|
368 |
) # actual steps = 1 gpu steps / gpus
|
369 |
start_step = self.load_checkpoint()
|
370 |
global_step = start_step
|
|
|
373 |
orig_epoch_step = len(train_dataloader)
|
374 |
skipped_epoch = int(start_step // orig_epoch_step)
|
375 |
skipped_batch = start_step % orig_epoch_step
|
376 |
+
skipped_dataloader = self.accelerator.skip_first_batches(
|
377 |
+
train_dataloader, num_batches=skipped_batch
|
378 |
+
)
|
379 |
else:
|
380 |
skipped_epoch = 0
|
381 |
|
|
|
400 |
|
401 |
for batch in progress_bar:
|
402 |
update_generator = global_step % self.gen_update_ratio == 0
|
403 |
+
|
404 |
with self.accelerator.accumulate(self.model):
|
405 |
metrics = {}
|
406 |
text_inputs = batch["text"]
|
407 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
408 |
mel_lengths = batch["mel_lengths"]
|
409 |
+
|
410 |
mel_spec = mel_spec / self.scale
|
411 |
+
|
412 |
+
guidance_loss_dict, guidance_log_dict = self.model(
|
413 |
+
inp=mel_spec,
|
414 |
+
text=text_inputs,
|
415 |
+
lens=mel_lengths,
|
416 |
+
student_steps=self.student_steps,
|
417 |
+
update_generator=False,
|
418 |
+
use_simulated=global_step >= self.num_simu,
|
419 |
+
)
|
420 |
|
421 |
# if self.GAN and update_generator:
|
422 |
# # only add discriminator loss if GAN is enabled and generator is being updated
|
423 |
# guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0)
|
424 |
# metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"]
|
425 |
# self.accelerator.backward(guidance_cls_loss, retain_graph=True)
|
426 |
+
|
427 |
# if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
428 |
# metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
429 |
|
430 |
guidance_loss = 0
|
431 |
guidance_loss += guidance_loss_dict["loss_fake_mean"]
|
432 |
+
metrics["loss/fake_score"] = guidance_loss_dict["loss_fake_mean"]
|
433 |
metrics["loss/guidance_loss"] = guidance_loss
|
434 |
|
435 |
if self.GAN and update_generator:
|
436 |
# only add discriminator loss if GAN is enabled and generator is being updated
|
437 |
+
guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (
|
438 |
+
self.lambda_discriminator_loss
|
439 |
+
if global_step >= self.num_GAN and update_generator
|
440 |
+
else 0
|
441 |
+
)
|
442 |
+
metrics["loss/discriminator_loss"] = guidance_loss_dict[
|
443 |
+
"guidance_cls_loss"
|
444 |
+
]
|
445 |
|
446 |
guidance_loss += guidance_cls_loss
|
447 |
+
|
448 |
self.accelerator.backward(guidance_loss)
|
449 |
|
450 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
451 |
+
metrics["grad_norm_guidance"] = (
|
452 |
+
self.accelerator.clip_grad_norm_(
|
453 |
+
self.model.parameters(), self.max_grad_norm
|
454 |
+
)
|
455 |
+
)
|
456 |
|
457 |
# if self.guidance_norm.count < 100:
|
458 |
# self.guidance_norm.update(metrics['grad_norm_guidance'])
|
|
|
464 |
# elif self.guidance_norm.count >= 100:
|
465 |
# self.guidance_norm.update(metrics['grad_norm_guidance'])
|
466 |
|
|
|
467 |
self.optimizer_guidance.step()
|
468 |
self.scheduler_guidance.step()
|
469 |
self.optimizer_guidance.zero_grad()
|
470 |
self.optimizer_generator.zero_grad() # zero out the generator's gradient as well
|
471 |
+
|
472 |
if update_generator:
|
473 |
+
generator_loss_dict, generator_log_dict = self.model(
|
474 |
+
inp=mel_spec,
|
475 |
+
text=text_inputs,
|
476 |
+
lens=mel_lengths,
|
477 |
+
student_steps=self.student_steps,
|
478 |
+
update_generator=True,
|
479 |
+
use_simulated=global_step >= self.num_ctc,
|
480 |
+
)
|
481 |
# if self.GAN:
|
482 |
# gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0)
|
483 |
# metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"]
|
|
|
490 |
generator_loss = 0
|
491 |
generator_loss += generator_loss_dict["loss_dm"]
|
492 |
if "loss_mse" in generator_loss_dict:
|
493 |
+
generator_loss += generator_loss_dict["loss_mse"]
|
494 |
+
generator_loss += generator_loss_dict["loss_ctc"] * (
|
495 |
+
self.lambda_ctc_loss if global_step >= self.num_ctc else 0
|
496 |
+
)
|
497 |
+
generator_loss += generator_loss_dict["loss_sim"] * (
|
498 |
+
self.lambda_sim_loss if global_step >= self.num_sim else 0
|
499 |
+
)
|
500 |
+
generator_loss += generator_loss_dict["loss_kl"] * (
|
501 |
+
self.lambda_ctc_loss if global_step >= self.num_ctc else 0
|
502 |
+
)
|
503 |
if self.GAN:
|
504 |
+
gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (
|
505 |
+
self.lambda_generator_loss
|
506 |
+
if global_step >= (self.num_GAN + self.num_D)
|
507 |
+
and update_generator
|
508 |
+
else 0
|
509 |
+
)
|
510 |
+
metrics["loss/gen_cls_loss"] = generator_loss_dict[
|
511 |
+
"gen_cls_loss"
|
512 |
+
]
|
513 |
generator_loss += gen_cls_loss
|
514 |
|
515 |
+
metrics["loss/dm_loss"] = generator_loss_dict["loss_dm"]
|
516 |
+
metrics["loss/ctc_loss"] = generator_loss_dict["loss_ctc"]
|
517 |
+
|
518 |
+
metrics["loss/similarity_loss"] = generator_loss_dict[
|
519 |
+
"loss_sim"
|
520 |
+
]
|
521 |
+
metrics["loss/generator_loss"] = generator_loss
|
522 |
+
|
523 |
+
if (
|
524 |
+
"loss_mse" in generator_loss_dict
|
525 |
+
and generator_loss_dict["loss_mse"] != 0
|
526 |
+
):
|
527 |
+
metrics["loss/mse_loss"] = generator_loss_dict["loss_mse"]
|
528 |
+
if (
|
529 |
+
"loss_kl" in generator_loss_dict
|
530 |
+
and generator_loss_dict["loss_kl"] != 0
|
531 |
+
):
|
532 |
+
metrics["loss/kl_loss"] = generator_loss_dict["loss_kl"]
|
533 |
|
534 |
self.accelerator.backward(generator_loss)
|
535 |
|
536 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
537 |
+
metrics["grad_norm_generator"] = (
|
538 |
+
self.accelerator.clip_grad_norm_(
|
539 |
+
self.model.parameters(), self.max_grad_norm
|
540 |
+
)
|
541 |
+
)
|
542 |
# self.generator_norm.update(metrics['grad_norm_generator'])
|
543 |
+
|
544 |
# if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std:
|
545 |
# self.optimizer_generator.zero_grad()
|
546 |
# self.optimizer_guidance.zero_grad()
|
|
|
553 |
self.optimizer_generator.zero_grad()
|
554 |
self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well
|
555 |
|
|
|
556 |
global_step += 1
|
557 |
|
558 |
if self.accelerator.is_local_main_process:
|
559 |
+
self.accelerator.log(
|
560 |
+
{
|
561 |
+
**metrics,
|
562 |
+
"lr_generator": self.scheduler_generator.get_last_lr()[0],
|
563 |
+
"lr_guidance": self.scheduler_guidance.get_last_lr()[0],
|
564 |
+
},
|
565 |
+
step=global_step,
|
566 |
+
)
|
567 |
+
|
568 |
+
if (
|
569 |
+
global_step % self.log_step == 0
|
570 |
+
and self.accelerator.is_local_main_process
|
571 |
+
and vocoder is not None
|
572 |
+
):
|
573 |
# log the first batch of the epoch
|
574 |
with torch.no_grad():
|
575 |
+
generator_input = (
|
576 |
+
generator_log_dict["generator_input"][0]
|
577 |
+
.unsqueeze(0)
|
578 |
+
.permute(0, 2, 1)
|
579 |
+
* self.scale
|
580 |
+
)
|
581 |
generator_input = vocoder.decode(generator_input.float().cpu())
|
582 |
generator_input = wandb.Audio(
|
583 |
generator_input.float().numpy().squeeze(),
|
584 |
sample_rate=24000,
|
585 |
+
caption="time: "
|
586 |
+
+ str(generator_log_dict["time"][0].float().cpu().numpy()),
|
587 |
)
|
588 |
|
589 |
+
generator_output = (
|
590 |
+
generator_log_dict["generator_output"][0]
|
591 |
+
.unsqueeze(0)
|
592 |
+
.permute(0, 2, 1)
|
593 |
+
* self.scale
|
594 |
+
)
|
595 |
+
generator_output = vocoder.decode(
|
596 |
+
generator_output.float().cpu()
|
597 |
+
)
|
598 |
generator_output = wandb.Audio(
|
599 |
generator_output.float().numpy().squeeze(),
|
600 |
sample_rate=24000,
|
601 |
+
caption="time: "
|
602 |
+
+ str(generator_log_dict["time"][0].float().cpu().numpy()),
|
603 |
+
)
|
604 |
+
|
605 |
+
generator_cond = (
|
606 |
+
generator_log_dict["generator_cond"][0]
|
607 |
+
.unsqueeze(0)
|
608 |
+
.permute(0, 2, 1)
|
609 |
+
* self.scale
|
610 |
)
|
|
|
|
|
611 |
generator_cond = vocoder.decode(generator_cond.float().cpu())
|
612 |
generator_cond = wandb.Audio(
|
613 |
generator_cond.float().numpy().squeeze(),
|
614 |
sample_rate=24000,
|
615 |
+
caption="time: "
|
616 |
+
+ str(generator_log_dict["time"][0].float().cpu().numpy()),
|
617 |
+
)
|
618 |
+
|
619 |
+
ground_truth = (
|
620 |
+
generator_log_dict["ground_truth"][0]
|
621 |
+
.unsqueeze(0)
|
622 |
+
.permute(0, 2, 1)
|
623 |
+
* self.scale
|
624 |
)
|
|
|
|
|
625 |
ground_truth = vocoder.decode(ground_truth.float().cpu())
|
626 |
ground_truth = wandb.Audio(
|
627 |
ground_truth.float().numpy().squeeze(),
|
628 |
sample_rate=24000,
|
629 |
+
caption="time: "
|
630 |
+
+ str(generator_log_dict["time"][0].float().cpu().numpy()),
|
631 |
+
)
|
632 |
+
|
633 |
+
dmtrain_noisy_inp = (
|
634 |
+
generator_log_dict["dmtrain_noisy_inp"][0]
|
635 |
+
.unsqueeze(0)
|
636 |
+
.permute(0, 2, 1)
|
637 |
+
* self.scale
|
638 |
+
)
|
639 |
+
dmtrain_noisy_inp = vocoder.decode(
|
640 |
+
dmtrain_noisy_inp.float().cpu()
|
641 |
)
|
|
|
|
|
|
|
642 |
dmtrain_noisy_inp = wandb.Audio(
|
643 |
dmtrain_noisy_inp.float().numpy().squeeze(),
|
644 |
sample_rate=24000,
|
645 |
+
caption="dmtrain_time: "
|
646 |
+
+ str(
|
647 |
+
generator_log_dict["dmtrain_time"][0]
|
648 |
+
.float()
|
649 |
+
.cpu()
|
650 |
+
.numpy()
|
651 |
+
),
|
652 |
+
)
|
653 |
+
|
654 |
+
dmtrain_pred_real_image = (
|
655 |
+
generator_log_dict["dmtrain_pred_real_image"][0]
|
656 |
+
.unsqueeze(0)
|
657 |
+
.permute(0, 2, 1)
|
658 |
+
* self.scale
|
659 |
+
)
|
660 |
+
dmtrain_pred_real_image = vocoder.decode(
|
661 |
+
dmtrain_pred_real_image.float().cpu()
|
662 |
)
|
|
|
|
|
|
|
663 |
dmtrain_pred_real_image = wandb.Audio(
|
664 |
dmtrain_pred_real_image.float().numpy().squeeze(),
|
665 |
sample_rate=24000,
|
666 |
+
caption="dmtrain_time: "
|
667 |
+
+ str(
|
668 |
+
generator_log_dict["dmtrain_time"][0]
|
669 |
+
.float()
|
670 |
+
.cpu()
|
671 |
+
.numpy()
|
672 |
+
),
|
673 |
+
)
|
674 |
+
|
675 |
+
dmtrain_pred_fake_image = (
|
676 |
+
generator_log_dict["dmtrain_pred_fake_image"][0]
|
677 |
+
.unsqueeze(0)
|
678 |
+
.permute(0, 2, 1)
|
679 |
+
* self.scale
|
680 |
+
)
|
681 |
+
dmtrain_pred_fake_image = vocoder.decode(
|
682 |
+
dmtrain_pred_fake_image.float().cpu()
|
683 |
)
|
|
|
|
|
|
|
684 |
dmtrain_pred_fake_image = wandb.Audio(
|
685 |
dmtrain_pred_fake_image.float().numpy().squeeze(),
|
686 |
sample_rate=24000,
|
687 |
+
caption="dmtrain_time: "
|
688 |
+
+ str(
|
689 |
+
generator_log_dict["dmtrain_time"][0]
|
690 |
+
.float()
|
691 |
+
.cpu()
|
692 |
+
.numpy()
|
693 |
+
),
|
694 |
+
)
|
695 |
+
|
696 |
+
self.accelerator.log(
|
697 |
+
{
|
698 |
+
"noisy_input": generator_input,
|
699 |
+
"output": generator_output,
|
700 |
+
"cond": generator_cond,
|
701 |
+
"ground_truth": ground_truth,
|
702 |
+
"dmtrain_noisy_inp": dmtrain_noisy_inp,
|
703 |
+
"dmtrain_pred_real_image": dmtrain_pred_real_image,
|
704 |
+
"dmtrain_pred_fake_image": dmtrain_pred_fake_image,
|
705 |
+
},
|
706 |
+
step=global_step,
|
707 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
|
709 |
progress_bar.set_postfix(step=str(global_step), metrics=metrics)
|
710 |
|
711 |
+
if (
|
712 |
+
global_step % (self.save_per_updates * self.grad_accumulation_steps)
|
713 |
+
== 0
|
714 |
+
):
|
715 |
self.save_checkpoint(global_step)
|
716 |
|
717 |
if global_step % self.last_per_steps == 0:
|
|
|
720 |
self.save_checkpoint(global_step, last=True)
|
721 |
|
722 |
self.accelerator.end_training()
|
|
|
|
duration_predictor.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
3 |
|
4 |
# from tts_encode import tts_encode
|
5 |
|
|
|
6 |
def calculate_remaining_lengths(mel_lengths):
|
7 |
B = mel_lengths.shape[0]
|
8 |
max_L = mel_lengths.max().item() # Get the maximum length in the batch
|
@@ -21,64 +22,84 @@ class PositionalEncoding(nn.Module):
|
|
21 |
super().__init__()
|
22 |
pe = torch.zeros(max_len, hidden_dim)
|
23 |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
24 |
-
div_term = torch.exp(
|
|
|
|
|
|
|
25 |
pe[:, 0::2] = torch.sin(position * div_term)
|
26 |
pe[:, 1::2] = torch.cos(position * div_term)
|
27 |
self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim)
|
28 |
|
29 |
def forward(self, x):
|
30 |
-
x = x + self.pe[:, :x.size(1)].to(x.device)
|
31 |
return x
|
32 |
|
33 |
|
34 |
class SpeechLengthPredictor(nn.Module):
|
35 |
|
36 |
-
def __init__(
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
output_dim=1,
|
40 |
):
|
41 |
super().__init__()
|
42 |
-
|
43 |
# Text Encoder: Embedding + Transformer Layers
|
44 |
-
self.text_embedder = nn.Embedding(
|
|
|
|
|
45 |
self.text_pe = PositionalEncoding(hidden_dim)
|
46 |
encoder_layer = nn.TransformerEncoderLayer(
|
47 |
-
d_model=hidden_dim,
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
)
|
49 |
-
|
50 |
-
|
51 |
# Mel Spectrogram Embedder
|
52 |
self.mel_embedder = nn.Linear(n_mel, hidden_dim)
|
53 |
self.mel_pe = PositionalEncoding(hidden_dim)
|
54 |
|
55 |
# Transformer Decoder Layers with Cross-Attention in Every Layer
|
56 |
decoder_layer = nn.TransformerDecoderLayer(
|
57 |
-
d_model=hidden_dim,
|
|
|
|
|
|
|
58 |
)
|
59 |
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)
|
60 |
-
|
61 |
# Final Classification Layer
|
62 |
self.predictor = nn.Linear(hidden_dim, output_dim)
|
63 |
|
64 |
def forward(self, text_ids, mel):
|
65 |
# Encode text
|
66 |
text_embedded = self.text_pe(self.text_embedder(text_ids))
|
67 |
-
text_features = self.text_encoder(text_embedded)
|
68 |
-
|
69 |
# Encode Mel spectrogram
|
70 |
mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D)
|
71 |
-
|
72 |
# Causal Masking for Decoder
|
73 |
seq_len = mel_features.size(1)
|
74 |
-
causal_mask =
|
|
|
|
|
75 |
# causal_mask = torch.triu(
|
76 |
# torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
|
77 |
# )
|
78 |
|
79 |
# Transformer Decoder with Cross-Attention in Each Layer
|
80 |
decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)
|
81 |
-
|
82 |
# Length Prediction
|
83 |
length_logits = self.predictor(decoder_out).squeeze(-1)
|
84 |
return length_logits
|
|
|
3 |
|
4 |
# from tts_encode import tts_encode
|
5 |
|
6 |
+
|
7 |
def calculate_remaining_lengths(mel_lengths):
|
8 |
B = mel_lengths.shape[0]
|
9 |
max_L = mel_lengths.max().item() # Get the maximum length in the batch
|
|
|
22 |
super().__init__()
|
23 |
pe = torch.zeros(max_len, hidden_dim)
|
24 |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
25 |
+
div_term = torch.exp(
|
26 |
+
torch.arange(0, hidden_dim, 2).float()
|
27 |
+
* (-torch.log(torch.tensor(10000.0)) / hidden_dim)
|
28 |
+
)
|
29 |
pe[:, 0::2] = torch.sin(position * div_term)
|
30 |
pe[:, 1::2] = torch.cos(position * div_term)
|
31 |
self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim)
|
32 |
|
33 |
def forward(self, x):
|
34 |
+
x = x + self.pe[:, : x.size(1)].to(x.device)
|
35 |
return x
|
36 |
|
37 |
|
38 |
class SpeechLengthPredictor(nn.Module):
|
39 |
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
vocab_size=2545,
|
43 |
+
n_mel=100,
|
44 |
+
hidden_dim=256,
|
45 |
+
n_text_layer=4,
|
46 |
+
n_cross_layer=4,
|
47 |
+
n_head=8,
|
48 |
output_dim=1,
|
49 |
):
|
50 |
super().__init__()
|
51 |
+
|
52 |
# Text Encoder: Embedding + Transformer Layers
|
53 |
+
self.text_embedder = nn.Embedding(
|
54 |
+
vocab_size + 1, hidden_dim, padding_idx=vocab_size
|
55 |
+
)
|
56 |
self.text_pe = PositionalEncoding(hidden_dim)
|
57 |
encoder_layer = nn.TransformerEncoderLayer(
|
58 |
+
d_model=hidden_dim,
|
59 |
+
nhead=n_head,
|
60 |
+
dim_feedforward=hidden_dim * 2,
|
61 |
+
batch_first=True,
|
62 |
+
)
|
63 |
+
self.text_encoder = nn.TransformerEncoder(
|
64 |
+
encoder_layer, num_layers=n_text_layer
|
65 |
)
|
66 |
+
|
|
|
67 |
# Mel Spectrogram Embedder
|
68 |
self.mel_embedder = nn.Linear(n_mel, hidden_dim)
|
69 |
self.mel_pe = PositionalEncoding(hidden_dim)
|
70 |
|
71 |
# Transformer Decoder Layers with Cross-Attention in Every Layer
|
72 |
decoder_layer = nn.TransformerDecoderLayer(
|
73 |
+
d_model=hidden_dim,
|
74 |
+
nhead=n_head,
|
75 |
+
dim_feedforward=hidden_dim * 2,
|
76 |
+
batch_first=True,
|
77 |
)
|
78 |
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)
|
79 |
+
|
80 |
# Final Classification Layer
|
81 |
self.predictor = nn.Linear(hidden_dim, output_dim)
|
82 |
|
83 |
def forward(self, text_ids, mel):
|
84 |
# Encode text
|
85 |
text_embedded = self.text_pe(self.text_embedder(text_ids))
|
86 |
+
text_features = self.text_encoder(text_embedded) # (B, L_text, D)
|
87 |
+
|
88 |
# Encode Mel spectrogram
|
89 |
mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D)
|
90 |
+
|
91 |
# Causal Masking for Decoder
|
92 |
seq_len = mel_features.size(1)
|
93 |
+
causal_mask = (
|
94 |
+
torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device)
|
95 |
+
)
|
96 |
# causal_mask = torch.triu(
|
97 |
# torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
|
98 |
# )
|
99 |
|
100 |
# Transformer Decoder with Cross-Attention in Each Layer
|
101 |
decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)
|
102 |
+
|
103 |
# Length Prediction
|
104 |
length_logits = self.predictor(decoder_out).squeeze(-1)
|
105 |
return length_logits
|
duration_trainer.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
4 |
-
import os
|
5 |
-
|
6 |
import math
|
|
|
7 |
|
8 |
import torch
|
|
|
9 |
import torchaudio
|
10 |
import wandb
|
11 |
from accelerate import Accelerator
|
@@ -13,37 +13,28 @@ from accelerate.utils import DistributedDataParallelKwargs
|
|
13 |
from ema_pytorch import EMA
|
14 |
from torch.optim import AdamW
|
15 |
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
16 |
-
from torch.utils.data import
|
|
|
17 |
from tqdm import tqdm
|
18 |
|
19 |
-
import torch.nn.functional as F
|
20 |
-
|
21 |
-
from f5_tts.model import CFM
|
22 |
-
from f5_tts.model.dataset import collate_fn, DynamicBatchSampler
|
23 |
-
from f5_tts.model.utils import default, exists
|
24 |
-
|
25 |
from duration_predictor import calculate_remaining_lengths
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# trainer
|
28 |
|
29 |
-
from f5_tts.model.utils import (
|
30 |
-
default,
|
31 |
-
exists,
|
32 |
-
list_str_to_idx,
|
33 |
-
list_str_to_tensor,
|
34 |
-
lens_to_mask,
|
35 |
-
mask_from_frac_lengths,
|
36 |
-
)
|
37 |
|
38 |
SAMPLE_RATE = 24_000
|
39 |
|
40 |
|
41 |
def masked_l1_loss(est_lengths, tar_lengths):
|
42 |
-
first_zero_idx = (tar_lengths == 0).int().argmax(dim=1)
|
43 |
B, L = tar_lengths.shape
|
44 |
-
range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L)
|
45 |
mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
|
46 |
-
loss = F.l1_loss(est_lengths, tar_lengths, reduction=
|
47 |
loss = loss * mask # Zero out ignored positions
|
48 |
loss = loss.sum() / mask.sum() # Normalize by valid elements
|
49 |
return loss
|
@@ -55,9 +46,9 @@ def masked_cross_entropy_loss(est_length_logits, tar_length_labels):
|
|
55 |
range_tensor = torch.arange(L, device=tar_length_labels.device).expand(B, L)
|
56 |
mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
|
57 |
loss = F.cross_entropy(
|
58 |
-
est_length_logits.reshape(-1, est_length_logits.size(-1)),
|
59 |
-
tar_length_labels.reshape(-1),
|
60 |
-
reduction=
|
61 |
).reshape(B, L)
|
62 |
loss = loss * mask
|
63 |
loss = loss.sum() / mask.sum()
|
@@ -71,7 +62,7 @@ class Trainer:
|
|
71 |
vocab_size,
|
72 |
vocab_char_map,
|
73 |
process_token_to_id=True,
|
74 |
-
loss_fn=
|
75 |
lambda_L1=1,
|
76 |
gumbel_tau=0.5,
|
77 |
n_class=301,
|
@@ -110,7 +101,13 @@ class Trainer:
|
|
110 |
self.logger = logger
|
111 |
if self.logger == "wandb":
|
112 |
if exists(wandb_resume_id):
|
113 |
-
init_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
else:
|
115 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
116 |
|
@@ -139,7 +136,7 @@ class Trainer:
|
|
139 |
self.vocab_size = vocab_size
|
140 |
self.vocab_char_map = vocab_char_map
|
141 |
self.process_token_to_id = process_token_to_id
|
142 |
-
assert loss_fn in [
|
143 |
self.loss_fn = loss_fn
|
144 |
self.lambda_L1 = lambda_L1
|
145 |
self.n_class = n_class
|
@@ -149,7 +146,9 @@ class Trainer:
|
|
149 |
self.epochs = epochs
|
150 |
self.num_warmup_updates = num_warmup_updates
|
151 |
self.save_per_updates = save_per_updates
|
152 |
-
self.last_per_steps = default(
|
|
|
|
|
153 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
154 |
|
155 |
self.batch_size = batch_size
|
@@ -164,33 +163,44 @@ class Trainer:
|
|
164 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
165 |
else:
|
166 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
167 |
-
self.model, self.optimizer = self.accelerator.prepare(
|
168 |
-
|
|
|
|
|
169 |
@property
|
170 |
def is_main(self):
|
171 |
return self.accelerator.is_main_process
|
172 |
|
173 |
def save_checkpoint(self, step, last=False):
|
174 |
self.accelerator.wait_for_everyone()
|
175 |
-
if self.is_main:
|
176 |
checkpoint = dict(
|
177 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
178 |
-
optimizer_state_dict=self.accelerator.unwrap_model(
|
|
|
|
|
179 |
scheduler_state_dict=self.scheduler.state_dict(),
|
180 |
step=step,
|
181 |
)
|
182 |
if not os.path.exists(self.checkpoint_path):
|
183 |
os.makedirs(self.checkpoint_path)
|
184 |
if last:
|
185 |
-
self.accelerator.save(
|
|
|
|
|
186 |
else:
|
187 |
-
self.accelerator.save(
|
|
|
|
|
188 |
|
189 |
def load_checkpoint(self):
|
190 |
if (
|
191 |
not exists(self.checkpoint_path)
|
192 |
or not os.path.exists(self.checkpoint_path)
|
193 |
-
or not any(
|
|
|
|
|
|
|
194 |
):
|
195 |
return 0
|
196 |
|
@@ -203,21 +213,32 @@ class Trainer:
|
|
203 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
204 |
)[-1]
|
205 |
|
206 |
-
print(f
|
207 |
|
208 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
209 |
-
checkpoint = torch.load(
|
|
|
|
|
|
|
|
|
210 |
|
211 |
-
print(f
|
212 |
|
213 |
if "step" in checkpoint:
|
214 |
# patch for backward compatibility, 305e3ea
|
215 |
-
for key in [
|
|
|
|
|
|
|
216 |
if key in checkpoint["model_state_dict"]:
|
217 |
del checkpoint["model_state_dict"][key]
|
218 |
|
219 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
220 |
-
|
|
|
|
|
|
|
|
|
221 |
if self.scheduler:
|
222 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
223 |
step = checkpoint["step"]
|
@@ -227,17 +248,18 @@ class Trainer:
|
|
227 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
228 |
if k not in ["initted", "step"]
|
229 |
}
|
230 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
|
|
|
|
231 |
step = 0
|
232 |
-
|
233 |
del checkpoint
|
234 |
gc.collect()
|
235 |
|
236 |
-
print(f
|
237 |
|
238 |
return step
|
239 |
|
240 |
-
|
241 |
def validate(self, valid_dataloader, global_step):
|
242 |
"""
|
243 |
Runs evaluation on the validation set, computes the average loss,
|
@@ -251,54 +273,61 @@ class Trainer:
|
|
251 |
with torch.no_grad():
|
252 |
for batch in valid_dataloader:
|
253 |
# Inputs
|
254 |
-
mel = batch[
|
255 |
-
text = batch[
|
256 |
|
257 |
if self.process_token_to_id:
|
258 |
text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device)
|
259 |
-
text_ids = text_ids.masked_fill(text_ids
|
260 |
else:
|
261 |
text_ids = text
|
262 |
|
263 |
# Targets
|
264 |
-
mel_lengths = batch[
|
265 |
tar_lengths = calculate_remaining_lengths(mel_lengths)
|
266 |
predictions = self.model(text_ids=text_ids, mel=mel)
|
267 |
|
268 |
-
if self.loss_fn ==
|
269 |
est_lengths = predictions
|
270 |
loss = masked_l1_loss(
|
271 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
272 |
)
|
273 |
frame_error = loss
|
274 |
|
275 |
-
elif self.loss_fn ==
|
276 |
-
tar_length_labels = (tar_lengths // self.n_frame_per_class)
|
277 |
-
|
|
|
278 |
est_length_logtis = predictions
|
279 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
280 |
loss = masked_cross_entropy_loss(
|
281 |
-
est_length_logits=est_length_logtis,
|
|
|
282 |
)
|
283 |
est_lengths = est_length_labels * self.n_frame_per_class
|
284 |
frame_error = masked_l1_loss(
|
285 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
286 |
)
|
287 |
|
288 |
-
elif self.loss_fn ==
|
289 |
-
tar_length_labels = (tar_lengths // self.n_frame_per_class)
|
290 |
-
|
|
|
291 |
est_length_logtis = predictions
|
292 |
est_length_1hots = F.gumbel_softmax(
|
293 |
est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
|
294 |
)
|
295 |
-
length_values =
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
298 |
est_lengths = (est_length_1hots * length_values).sum(-1)
|
299 |
|
300 |
loss_CE = masked_cross_entropy_loss(
|
301 |
-
est_length_logits=est_length_logtis,
|
|
|
302 |
)
|
303 |
|
304 |
loss_L1 = masked_l1_loss(
|
@@ -321,18 +350,19 @@ class Trainer:
|
|
321 |
avg_valid_sec_error = total_sec_error / count if count > 0 else 0.0
|
322 |
# Log validation metrics
|
323 |
self.accelerator.log(
|
324 |
-
{
|
325 |
-
|
326 |
-
f"valid_sec_error": avg_valid_sec_error
|
327 |
-
},
|
328 |
-
step=global_step
|
329 |
)
|
330 |
-
|
331 |
-
self.model.train()
|
332 |
|
|
|
333 |
|
334 |
-
def train(
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
336 |
if exists(resumable_with_seed):
|
337 |
generator = torch.Generator()
|
338 |
generator.manual_seed(resumable_with_seed)
|
@@ -366,7 +396,11 @@ class Trainer:
|
|
366 |
|
367 |
sampler = SequentialSampler(train_dataset)
|
368 |
batch_sampler = DynamicBatchSampler(
|
369 |
-
sampler,
|
|
|
|
|
|
|
|
|
370 |
)
|
371 |
train_dataloader = DataLoader(
|
372 |
train_dataset,
|
@@ -379,20 +413,26 @@ class Trainer:
|
|
379 |
|
380 |
sampler = SequentialSampler(valid_dataset)
|
381 |
batch_sampler = DynamicBatchSampler(
|
382 |
-
sampler,
|
|
|
|
|
|
|
|
|
383 |
)
|
384 |
# Create validation dataloader (always sequential, no shuffling)
|
385 |
valid_dataloader = DataLoader(
|
386 |
valid_dataset,
|
387 |
collate_fn=collate_fn,
|
388 |
num_workers=num_workers,
|
389 |
-
pin_memory=True,
|
390 |
persistent_workers=True,
|
391 |
batch_sampler=batch_sampler,
|
392 |
)
|
393 |
else:
|
394 |
-
raise ValueError(
|
395 |
-
|
|
|
|
|
396 |
# accelerator.prepare() dispatches batches to devices;
|
397 |
# which means the length of dataloader calculated before, should consider the number of devices
|
398 |
warmup_steps = (
|
@@ -401,10 +441,16 @@ class Trainer:
|
|
401 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
402 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
403 |
decay_steps = total_steps - warmup_steps
|
404 |
-
warmup_scheduler = LinearLR(
|
405 |
-
|
|
|
|
|
|
|
|
|
406 |
self.scheduler = SequentialLR(
|
407 |
-
self.optimizer,
|
|
|
|
|
408 |
)
|
409 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
410 |
train_dataloader, self.scheduler
|
@@ -418,7 +464,9 @@ class Trainer:
|
|
418 |
orig_epoch_step = len(train_dataloader)
|
419 |
skipped_epoch = int(start_step // orig_epoch_step)
|
420 |
skipped_batch = start_step % orig_epoch_step
|
421 |
-
skipped_dataloader = self.accelerator.skip_first_batches(
|
|
|
|
|
422 |
else:
|
423 |
skipped_epoch = 0
|
424 |
|
@@ -444,21 +492,23 @@ class Trainer:
|
|
444 |
for batch in progress_bar:
|
445 |
with self.accelerator.accumulate(self.model):
|
446 |
# Inputs
|
447 |
-
mel = batch[
|
448 |
-
text = batch[
|
449 |
|
450 |
if self.process_token_to_id:
|
451 |
-
text_ids = list_str_to_idx(text, self.vocab_char_map).to(
|
452 |
-
|
|
|
|
|
453 |
else:
|
454 |
text_ids = text
|
455 |
|
456 |
# Targets
|
457 |
-
mel_lengths = batch[
|
458 |
tar_lengths = calculate_remaining_lengths(mel_lengths)
|
459 |
predictions = self.model(text_ids=text_ids, mel=mel)
|
460 |
|
461 |
-
if self.loss_fn ==
|
462 |
est_lengths = predictions
|
463 |
loss = masked_l1_loss(
|
464 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
@@ -469,19 +519,23 @@ class Trainer:
|
|
469 |
sec_error = frame_error * 256 / 24000
|
470 |
|
471 |
log_dict = {
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
elif self.loss_fn ==
|
479 |
-
tar_length_labels = (
|
480 |
-
|
|
|
|
|
|
|
481 |
est_length_logtis = predictions
|
482 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
483 |
loss = masked_cross_entropy_loss(
|
484 |
-
est_length_logits=est_length_logtis,
|
|
|
485 |
)
|
486 |
with torch.no_grad():
|
487 |
est_lengths = est_length_labels * self.n_frame_per_class
|
@@ -491,29 +545,36 @@ class Trainer:
|
|
491 |
sec_error = frame_error * 256 / 24000
|
492 |
|
493 |
log_dict = {
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
elif self.loss_fn ==
|
501 |
-
tar_length_labels = (
|
502 |
-
|
|
|
|
|
|
|
503 |
est_length_logtis = predictions
|
504 |
est_length_1hots = F.gumbel_softmax(
|
505 |
est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
|
506 |
)
|
507 |
-
length_values =
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
510 |
est_lengths = (est_length_1hots * length_values).sum(-1)
|
511 |
|
512 |
loss_CE = masked_cross_entropy_loss(
|
513 |
-
est_length_logits=est_length_logtis,
|
|
|
514 |
)
|
515 |
|
516 |
-
loss_L1 =
|
517 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
518 |
)
|
519 |
|
@@ -524,21 +585,22 @@ class Trainer:
|
|
524 |
sec_error = frame_error * 256 / 24000
|
525 |
|
526 |
log_dict = {
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
}
|
533 |
|
534 |
else:
|
535 |
raise NotImplementedError(self.loss_fn)
|
536 |
|
537 |
-
|
538 |
self.accelerator.backward(loss)
|
539 |
|
540 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
541 |
-
self.accelerator.clip_grad_norm_(
|
|
|
|
|
542 |
|
543 |
self.optimizer.step()
|
544 |
self.scheduler.step()
|
@@ -550,7 +612,10 @@ class Trainer:
|
|
550 |
self.accelerator.log(log_dict, step=global_step)
|
551 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
552 |
|
553 |
-
if
|
|
|
|
|
|
|
554 |
self.save_checkpoint(global_step)
|
555 |
# if self.log_samples and self.accelerator.is_local_main_process:
|
556 |
# Run validation at the end of each epoch (only on the main process)
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
|
|
|
|
4 |
import math
|
5 |
+
import os
|
6 |
|
7 |
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
import torchaudio
|
10 |
import wandb
|
11 |
from accelerate import Accelerator
|
|
|
13 |
from ema_pytorch import EMA
|
14 |
from torch.optim import AdamW
|
15 |
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
16 |
+
from torch.utils.data import Dataset # <-- Added Subset import
|
17 |
+
from torch.utils.data import DataLoader, SequentialSampler, Subset
|
18 |
from tqdm import tqdm
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
from duration_predictor import calculate_remaining_lengths
|
21 |
+
from f5_tts.model import CFM
|
22 |
+
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
23 |
+
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
|
24 |
+
list_str_to_tensor, mask_from_frac_lengths)
|
25 |
|
26 |
# trainer
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
SAMPLE_RATE = 24_000
|
30 |
|
31 |
|
32 |
def masked_l1_loss(est_lengths, tar_lengths):
|
33 |
+
first_zero_idx = (tar_lengths == 0).int().argmax(dim=1)
|
34 |
B, L = tar_lengths.shape
|
35 |
+
range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L)
|
36 |
mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
|
37 |
+
loss = F.l1_loss(est_lengths, tar_lengths, reduction="none") # (B, L)
|
38 |
loss = loss * mask # Zero out ignored positions
|
39 |
loss = loss.sum() / mask.sum() # Normalize by valid elements
|
40 |
return loss
|
|
|
46 |
range_tensor = torch.arange(L, device=tar_length_labels.device).expand(B, L)
|
47 |
mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
|
48 |
loss = F.cross_entropy(
|
49 |
+
est_length_logits.reshape(-1, est_length_logits.size(-1)),
|
50 |
+
tar_length_labels.reshape(-1),
|
51 |
+
reduction="none",
|
52 |
).reshape(B, L)
|
53 |
loss = loss * mask
|
54 |
loss = loss.sum() / mask.sum()
|
|
|
62 |
vocab_size,
|
63 |
vocab_char_map,
|
64 |
process_token_to_id=True,
|
65 |
+
loss_fn="L1",
|
66 |
lambda_L1=1,
|
67 |
gumbel_tau=0.5,
|
68 |
n_class=301,
|
|
|
101 |
self.logger = logger
|
102 |
if self.logger == "wandb":
|
103 |
if exists(wandb_resume_id):
|
104 |
+
init_kwargs = {
|
105 |
+
"wandb": {
|
106 |
+
"resume": "allow",
|
107 |
+
"name": wandb_run_name,
|
108 |
+
"id": wandb_resume_id,
|
109 |
+
}
|
110 |
+
}
|
111 |
else:
|
112 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
113 |
|
|
|
136 |
self.vocab_size = vocab_size
|
137 |
self.vocab_char_map = vocab_char_map
|
138 |
self.process_token_to_id = process_token_to_id
|
139 |
+
assert loss_fn in ["L1", "CE", "L1_and_CE"]
|
140 |
self.loss_fn = loss_fn
|
141 |
self.lambda_L1 = lambda_L1
|
142 |
self.n_class = n_class
|
|
|
146 |
self.epochs = epochs
|
147 |
self.num_warmup_updates = num_warmup_updates
|
148 |
self.save_per_updates = save_per_updates
|
149 |
+
self.last_per_steps = default(
|
150 |
+
last_per_steps, save_per_updates * grad_accumulation_steps
|
151 |
+
)
|
152 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
153 |
|
154 |
self.batch_size = batch_size
|
|
|
163 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
164 |
else:
|
165 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
166 |
+
self.model, self.optimizer = self.accelerator.prepare(
|
167 |
+
self.model, self.optimizer
|
168 |
+
)
|
169 |
+
|
170 |
@property
|
171 |
def is_main(self):
|
172 |
return self.accelerator.is_main_process
|
173 |
|
174 |
def save_checkpoint(self, step, last=False):
|
175 |
self.accelerator.wait_for_everyone()
|
176 |
+
if self.is_main:
|
177 |
checkpoint = dict(
|
178 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
179 |
+
optimizer_state_dict=self.accelerator.unwrap_model(
|
180 |
+
self.optimizer
|
181 |
+
).state_dict(),
|
182 |
scheduler_state_dict=self.scheduler.state_dict(),
|
183 |
step=step,
|
184 |
)
|
185 |
if not os.path.exists(self.checkpoint_path):
|
186 |
os.makedirs(self.checkpoint_path)
|
187 |
if last:
|
188 |
+
self.accelerator.save(
|
189 |
+
checkpoint, f"{self.checkpoint_path}/model_last.pt"
|
190 |
+
)
|
191 |
else:
|
192 |
+
self.accelerator.save(
|
193 |
+
checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
|
194 |
+
)
|
195 |
|
196 |
def load_checkpoint(self):
|
197 |
if (
|
198 |
not exists(self.checkpoint_path)
|
199 |
or not os.path.exists(self.checkpoint_path)
|
200 |
+
or not any(
|
201 |
+
filename.endswith(".pt")
|
202 |
+
for filename in os.listdir(self.checkpoint_path)
|
203 |
+
)
|
204 |
):
|
205 |
return 0
|
206 |
|
|
|
213 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
214 |
)[-1]
|
215 |
|
216 |
+
print(f"To load from {latest_checkpoint}.")
|
217 |
|
218 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
219 |
+
checkpoint = torch.load(
|
220 |
+
f"{self.checkpoint_path}/{latest_checkpoint}",
|
221 |
+
weights_only=True,
|
222 |
+
map_location="cpu",
|
223 |
+
)
|
224 |
|
225 |
+
print(f"Loaded from {latest_checkpoint}.")
|
226 |
|
227 |
if "step" in checkpoint:
|
228 |
# patch for backward compatibility, 305e3ea
|
229 |
+
for key in [
|
230 |
+
"mel_spec.mel_stft.mel_scale.fb",
|
231 |
+
"mel_spec.mel_stft.spectrogram.window",
|
232 |
+
]:
|
233 |
if key in checkpoint["model_state_dict"]:
|
234 |
del checkpoint["model_state_dict"][key]
|
235 |
|
236 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
237 |
+
checkpoint["model_state_dict"]
|
238 |
+
)
|
239 |
+
self.accelerator.unwrap_model(self.optimizer).load_state_dict(
|
240 |
+
checkpoint["optimizer_state_dict"]
|
241 |
+
)
|
242 |
if self.scheduler:
|
243 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
244 |
step = checkpoint["step"]
|
|
|
248 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
249 |
if k not in ["initted", "step"]
|
250 |
}
|
251 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
252 |
+
checkpoint["model_state_dict"]
|
253 |
+
)
|
254 |
step = 0
|
255 |
+
|
256 |
del checkpoint
|
257 |
gc.collect()
|
258 |
|
259 |
+
print(f"Exit load_checkpoint.")
|
260 |
|
261 |
return step
|
262 |
|
|
|
263 |
def validate(self, valid_dataloader, global_step):
|
264 |
"""
|
265 |
Runs evaluation on the validation set, computes the average loss,
|
|
|
273 |
with torch.no_grad():
|
274 |
for batch in valid_dataloader:
|
275 |
# Inputs
|
276 |
+
mel = batch["mel"].permute(0, 2, 1) # (B, L_mel, D)
|
277 |
+
text = batch["text"]
|
278 |
|
279 |
if self.process_token_to_id:
|
280 |
text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device)
|
281 |
+
text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size)
|
282 |
else:
|
283 |
text_ids = text
|
284 |
|
285 |
# Targets
|
286 |
+
mel_lengths = batch["mel_lengths"]
|
287 |
tar_lengths = calculate_remaining_lengths(mel_lengths)
|
288 |
predictions = self.model(text_ids=text_ids, mel=mel)
|
289 |
|
290 |
+
if self.loss_fn == "L1":
|
291 |
est_lengths = predictions
|
292 |
loss = masked_l1_loss(
|
293 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
294 |
)
|
295 |
frame_error = loss
|
296 |
|
297 |
+
elif self.loss_fn == "CE":
|
298 |
+
tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp(
|
299 |
+
min=0, max=self.n_class - 1
|
300 |
+
) # [0, 1, ..., n_class-1]
|
301 |
est_length_logtis = predictions
|
302 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
303 |
loss = masked_cross_entropy_loss(
|
304 |
+
est_length_logits=est_length_logtis,
|
305 |
+
tar_length_labels=tar_length_labels,
|
306 |
)
|
307 |
est_lengths = est_length_labels * self.n_frame_per_class
|
308 |
frame_error = masked_l1_loss(
|
309 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
310 |
)
|
311 |
|
312 |
+
elif self.loss_fn == "L1_and_CE":
|
313 |
+
tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp(
|
314 |
+
min=0, max=self.n_class - 1
|
315 |
+
) # [0, 1, ..., n_class-1]
|
316 |
est_length_logtis = predictions
|
317 |
est_length_1hots = F.gumbel_softmax(
|
318 |
est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
|
319 |
)
|
320 |
+
length_values = (
|
321 |
+
torch.arange(
|
322 |
+
self.n_class, device=est_length_1hots.device
|
323 |
+
).float()
|
324 |
+
* self.n_frame_per_class
|
325 |
+
)
|
326 |
est_lengths = (est_length_1hots * length_values).sum(-1)
|
327 |
|
328 |
loss_CE = masked_cross_entropy_loss(
|
329 |
+
est_length_logits=est_length_logtis,
|
330 |
+
tar_length_labels=tar_length_labels,
|
331 |
)
|
332 |
|
333 |
loss_L1 = masked_l1_loss(
|
|
|
350 |
avg_valid_sec_error = total_sec_error / count if count > 0 else 0.0
|
351 |
# Log validation metrics
|
352 |
self.accelerator.log(
|
353 |
+
{f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error},
|
354 |
+
step=global_step,
|
|
|
|
|
|
|
355 |
)
|
|
|
|
|
356 |
|
357 |
+
self.model.train()
|
358 |
|
359 |
+
def train(
|
360 |
+
self,
|
361 |
+
train_dataset: Dataset,
|
362 |
+
valid_dataset: Dataset,
|
363 |
+
num_workers=64,
|
364 |
+
resumable_with_seed: int = None,
|
365 |
+
):
|
366 |
if exists(resumable_with_seed):
|
367 |
generator = torch.Generator()
|
368 |
generator.manual_seed(resumable_with_seed)
|
|
|
396 |
|
397 |
sampler = SequentialSampler(train_dataset)
|
398 |
batch_sampler = DynamicBatchSampler(
|
399 |
+
sampler,
|
400 |
+
self.batch_size,
|
401 |
+
max_samples=self.max_samples,
|
402 |
+
random_seed=resumable_with_seed,
|
403 |
+
drop_last=False,
|
404 |
)
|
405 |
train_dataloader = DataLoader(
|
406 |
train_dataset,
|
|
|
413 |
|
414 |
sampler = SequentialSampler(valid_dataset)
|
415 |
batch_sampler = DynamicBatchSampler(
|
416 |
+
sampler,
|
417 |
+
self.batch_size,
|
418 |
+
max_samples=self.max_samples,
|
419 |
+
random_seed=resumable_with_seed,
|
420 |
+
drop_last=False,
|
421 |
)
|
422 |
# Create validation dataloader (always sequential, no shuffling)
|
423 |
valid_dataloader = DataLoader(
|
424 |
valid_dataset,
|
425 |
collate_fn=collate_fn,
|
426 |
num_workers=num_workers,
|
427 |
+
pin_memory=True,
|
428 |
persistent_workers=True,
|
429 |
batch_sampler=batch_sampler,
|
430 |
)
|
431 |
else:
|
432 |
+
raise ValueError(
|
433 |
+
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
|
434 |
+
)
|
435 |
+
|
436 |
# accelerator.prepare() dispatches batches to devices;
|
437 |
# which means the length of dataloader calculated before, should consider the number of devices
|
438 |
warmup_steps = (
|
|
|
441 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
442 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
443 |
decay_steps = total_steps - warmup_steps
|
444 |
+
warmup_scheduler = LinearLR(
|
445 |
+
self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps
|
446 |
+
)
|
447 |
+
decay_scheduler = LinearLR(
|
448 |
+
self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps
|
449 |
+
)
|
450 |
self.scheduler = SequentialLR(
|
451 |
+
self.optimizer,
|
452 |
+
schedulers=[warmup_scheduler, decay_scheduler],
|
453 |
+
milestones=[warmup_steps],
|
454 |
)
|
455 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
456 |
train_dataloader, self.scheduler
|
|
|
464 |
orig_epoch_step = len(train_dataloader)
|
465 |
skipped_epoch = int(start_step // orig_epoch_step)
|
466 |
skipped_batch = start_step % orig_epoch_step
|
467 |
+
skipped_dataloader = self.accelerator.skip_first_batches(
|
468 |
+
train_dataloader, num_batches=skipped_batch
|
469 |
+
)
|
470 |
else:
|
471 |
skipped_epoch = 0
|
472 |
|
|
|
492 |
for batch in progress_bar:
|
493 |
with self.accelerator.accumulate(self.model):
|
494 |
# Inputs
|
495 |
+
mel = batch["mel"].permute(0, 2, 1) # (B, L_mel, D)
|
496 |
+
text = batch["text"]
|
497 |
|
498 |
if self.process_token_to_id:
|
499 |
+
text_ids = list_str_to_idx(text, self.vocab_char_map).to(
|
500 |
+
mel.device
|
501 |
+
)
|
502 |
+
text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size)
|
503 |
else:
|
504 |
text_ids = text
|
505 |
|
506 |
# Targets
|
507 |
+
mel_lengths = batch["mel_lengths"]
|
508 |
tar_lengths = calculate_remaining_lengths(mel_lengths)
|
509 |
predictions = self.model(text_ids=text_ids, mel=mel)
|
510 |
|
511 |
+
if self.loss_fn == "L1":
|
512 |
est_lengths = predictions
|
513 |
loss = masked_l1_loss(
|
514 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
|
|
519 |
sec_error = frame_error * 256 / 24000
|
520 |
|
521 |
log_dict = {
|
522 |
+
"loss": loss.item(),
|
523 |
+
"loss_L1": loss.item(),
|
524 |
+
"sec_error": sec_error.item(),
|
525 |
+
"lr": self.scheduler.get_last_lr()[0],
|
526 |
+
}
|
527 |
+
|
528 |
+
elif self.loss_fn == "CE":
|
529 |
+
tar_length_labels = (
|
530 |
+
tar_lengths // self.n_frame_per_class
|
531 |
+
).clamp(
|
532 |
+
min=0, max=self.n_class - 1
|
533 |
+
) # [0, 1, ..., n_class-1]
|
534 |
est_length_logtis = predictions
|
535 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
536 |
loss = masked_cross_entropy_loss(
|
537 |
+
est_length_logits=est_length_logtis,
|
538 |
+
tar_length_labels=tar_length_labels,
|
539 |
)
|
540 |
with torch.no_grad():
|
541 |
est_lengths = est_length_labels * self.n_frame_per_class
|
|
|
545 |
sec_error = frame_error * 256 / 24000
|
546 |
|
547 |
log_dict = {
|
548 |
+
"loss": loss.item(),
|
549 |
+
"loss_CE": loss.item(),
|
550 |
+
"sec_error": sec_error.item(),
|
551 |
+
"lr": self.scheduler.get_last_lr()[0],
|
552 |
+
}
|
553 |
+
|
554 |
+
elif self.loss_fn == "L1_and_CE":
|
555 |
+
tar_length_labels = (
|
556 |
+
tar_lengths // self.n_frame_per_class
|
557 |
+
).clamp(
|
558 |
+
min=0, max=self.n_class - 1
|
559 |
+
) # [0, 1, ..., n_class-1]
|
560 |
est_length_logtis = predictions
|
561 |
est_length_1hots = F.gumbel_softmax(
|
562 |
est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
|
563 |
)
|
564 |
+
length_values = (
|
565 |
+
torch.arange(
|
566 |
+
self.n_class, device=est_length_1hots.device
|
567 |
+
).float()
|
568 |
+
* self.n_frame_per_class
|
569 |
+
)
|
570 |
est_lengths = (est_length_1hots * length_values).sum(-1)
|
571 |
|
572 |
loss_CE = masked_cross_entropy_loss(
|
573 |
+
est_length_logits=est_length_logtis,
|
574 |
+
tar_length_labels=tar_length_labels,
|
575 |
)
|
576 |
|
577 |
+
loss_L1 = masked_l1_loss(
|
578 |
est_lengths=est_lengths, tar_lengths=tar_lengths
|
579 |
)
|
580 |
|
|
|
585 |
sec_error = frame_error * 256 / 24000
|
586 |
|
587 |
log_dict = {
|
588 |
+
"loss": loss.item(),
|
589 |
+
"loss_L1": loss_L1.item(),
|
590 |
+
"loss_CE": loss_CE.item(),
|
591 |
+
"sec_error": sec_error.item(),
|
592 |
+
"lr": self.scheduler.get_last_lr()[0],
|
593 |
}
|
594 |
|
595 |
else:
|
596 |
raise NotImplementedError(self.loss_fn)
|
597 |
|
|
|
598 |
self.accelerator.backward(loss)
|
599 |
|
600 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
601 |
+
self.accelerator.clip_grad_norm_(
|
602 |
+
self.model.parameters(), self.max_grad_norm
|
603 |
+
)
|
604 |
|
605 |
self.optimizer.step()
|
606 |
self.scheduler.step()
|
|
|
612 |
self.accelerator.log(log_dict, step=global_step)
|
613 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
614 |
|
615 |
+
if (
|
616 |
+
global_step % (self.save_per_updates * self.grad_accumulation_steps)
|
617 |
+
== 0
|
618 |
+
):
|
619 |
self.save_checkpoint(global_step)
|
620 |
# if self.log_samples and self.accelerator.is_local_main_process:
|
621 |
# Run validation at the end of each epoch (only on the main process)
|
duration_trainer_with_prompt.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
4 |
-
import os
|
5 |
-
|
6 |
import math
|
|
|
7 |
|
8 |
import torch
|
|
|
9 |
import torchaudio
|
10 |
import wandb
|
11 |
from accelerate import Accelerator
|
@@ -13,25 +13,17 @@ from accelerate.utils import DistributedDataParallelKwargs
|
|
13 |
from ema_pytorch import EMA
|
14 |
from torch.optim import AdamW
|
15 |
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
16 |
-
from torch.utils.data import
|
|
|
17 |
from tqdm import tqdm
|
18 |
|
19 |
-
import torch.nn.functional as F
|
20 |
-
|
21 |
from f5_tts.model import CFM
|
22 |
-
from f5_tts.model.dataset import
|
23 |
-
from f5_tts.model.utils import default, exists
|
|
|
24 |
|
25 |
# trainer
|
26 |
|
27 |
-
from f5_tts.model.utils import (
|
28 |
-
default,
|
29 |
-
exists,
|
30 |
-
list_str_to_idx,
|
31 |
-
list_str_to_tensor,
|
32 |
-
lens_to_mask,
|
33 |
-
mask_from_frac_lengths,
|
34 |
-
)
|
35 |
|
36 |
SAMPLE_RATE = 24_000
|
37 |
|
@@ -43,7 +35,7 @@ class Trainer:
|
|
43 |
vocab_size,
|
44 |
vocab_char_map,
|
45 |
process_token_to_id=True,
|
46 |
-
loss_fn=
|
47 |
lambda_L1=1,
|
48 |
gumbel_tau=0.5,
|
49 |
n_class=301,
|
@@ -83,7 +75,13 @@ class Trainer:
|
|
83 |
self.logger = logger
|
84 |
if self.logger == "wandb":
|
85 |
if exists(wandb_resume_id):
|
86 |
-
init_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
else:
|
88 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
89 |
|
@@ -112,7 +110,7 @@ class Trainer:
|
|
112 |
self.vocab_size = vocab_size
|
113 |
self.vocab_char_map = vocab_char_map
|
114 |
self.process_token_to_id = process_token_to_id
|
115 |
-
assert loss_fn in [
|
116 |
self.loss_fn = loss_fn
|
117 |
self.lambda_L1 = lambda_L1
|
118 |
self.n_class = n_class
|
@@ -122,7 +120,9 @@ class Trainer:
|
|
122 |
self.epochs = epochs
|
123 |
self.num_warmup_updates = num_warmup_updates
|
124 |
self.save_per_updates = save_per_updates
|
125 |
-
self.last_per_steps = default(
|
|
|
|
|
126 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
127 |
|
128 |
self.batch_size = batch_size
|
@@ -137,33 +137,44 @@ class Trainer:
|
|
137 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
138 |
else:
|
139 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
140 |
-
self.model, self.optimizer = self.accelerator.prepare(
|
141 |
-
|
|
|
|
|
142 |
@property
|
143 |
def is_main(self):
|
144 |
return self.accelerator.is_main_process
|
145 |
|
146 |
def save_checkpoint(self, step, last=False):
|
147 |
self.accelerator.wait_for_everyone()
|
148 |
-
if self.is_main:
|
149 |
checkpoint = dict(
|
150 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
151 |
-
optimizer_state_dict=self.accelerator.unwrap_model(
|
|
|
|
|
152 |
scheduler_state_dict=self.scheduler.state_dict(),
|
153 |
step=step,
|
154 |
)
|
155 |
if not os.path.exists(self.checkpoint_path):
|
156 |
os.makedirs(self.checkpoint_path)
|
157 |
if last:
|
158 |
-
self.accelerator.save(
|
|
|
|
|
159 |
else:
|
160 |
-
self.accelerator.save(
|
|
|
|
|
161 |
|
162 |
def load_checkpoint(self):
|
163 |
if (
|
164 |
not exists(self.checkpoint_path)
|
165 |
or not os.path.exists(self.checkpoint_path)
|
166 |
-
or not any(
|
|
|
|
|
|
|
167 |
):
|
168 |
return 0
|
169 |
|
@@ -176,21 +187,32 @@ class Trainer:
|
|
176 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
177 |
)[-1]
|
178 |
|
179 |
-
print(f
|
180 |
|
181 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
182 |
-
checkpoint = torch.load(
|
|
|
|
|
|
|
|
|
183 |
|
184 |
-
print(f
|
185 |
|
186 |
if "step" in checkpoint:
|
187 |
# patch for backward compatibility, 305e3ea
|
188 |
-
for key in [
|
|
|
|
|
|
|
189 |
if key in checkpoint["model_state_dict"]:
|
190 |
del checkpoint["model_state_dict"][key]
|
191 |
|
192 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
193 |
-
|
|
|
|
|
|
|
|
|
194 |
if self.scheduler:
|
195 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
196 |
step = checkpoint["step"]
|
@@ -200,17 +222,18 @@ class Trainer:
|
|
200 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
201 |
if k not in ["initted", "step"]
|
202 |
}
|
203 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
|
|
|
|
204 |
step = 0
|
205 |
-
|
206 |
del checkpoint
|
207 |
gc.collect()
|
208 |
|
209 |
-
print(f
|
210 |
|
211 |
return step
|
212 |
|
213 |
-
|
214 |
def validate(self, valid_dataloader, global_step):
|
215 |
"""
|
216 |
Runs evaluation on the validation set, computes the average loss,
|
@@ -226,31 +249,40 @@ class Trainer:
|
|
226 |
for batch in valid_dataloader:
|
227 |
|
228 |
# Inputs
|
229 |
-
prompt_mel = batch[
|
230 |
-
prompt_text = batch[
|
231 |
-
text = batch[
|
232 |
|
233 |
-
target_ids = list_str_to_idx(text, self.vocab_char_map).to(
|
234 |
-
|
|
|
|
|
235 |
|
236 |
-
prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(
|
237 |
-
|
|
|
|
|
238 |
|
239 |
# Targets
|
240 |
-
tar_lengths = batch[
|
241 |
|
242 |
# Forward
|
243 |
-
predictions = SLP(
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
248 |
est_length_logtis = predictions
|
249 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
250 |
loss = F.cross_entropy(est_length_logtis, tar_length_labels)
|
251 |
-
|
252 |
est_lengths = est_length_labels * self.n_frame_per_class
|
253 |
-
frame_error = (
|
|
|
|
|
254 |
sec_error = frame_error * 256 / 24000
|
255 |
|
256 |
total_sec_error += sec_error.item()
|
@@ -262,18 +294,19 @@ class Trainer:
|
|
262 |
|
263 |
# Log validation metrics
|
264 |
self.accelerator.log(
|
265 |
-
{
|
266 |
-
|
267 |
-
f"valid_sec_error": avg_valid_sec_error
|
268 |
-
},
|
269 |
-
step=global_step
|
270 |
)
|
271 |
-
|
272 |
-
self.model.train()
|
273 |
|
|
|
274 |
|
275 |
-
def train(
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
277 |
if exists(resumable_with_seed):
|
278 |
generator = torch.Generator()
|
279 |
generator.manual_seed(resumable_with_seed)
|
@@ -307,7 +340,11 @@ class Trainer:
|
|
307 |
|
308 |
sampler = SequentialSampler(train_dataset)
|
309 |
batch_sampler = DynamicBatchSampler(
|
310 |
-
sampler,
|
|
|
|
|
|
|
|
|
311 |
)
|
312 |
train_dataloader = DataLoader(
|
313 |
train_dataset,
|
@@ -320,20 +357,26 @@ class Trainer:
|
|
320 |
|
321 |
sampler = SequentialSampler(valid_dataset)
|
322 |
batch_sampler = DynamicBatchSampler(
|
323 |
-
sampler,
|
|
|
|
|
|
|
|
|
324 |
)
|
325 |
# Create validation dataloader (always sequential, no shuffling)
|
326 |
valid_dataloader = DataLoader(
|
327 |
valid_dataset,
|
328 |
collate_fn=collate_fn,
|
329 |
num_workers=num_workers,
|
330 |
-
pin_memory=True,
|
331 |
persistent_workers=True,
|
332 |
batch_sampler=batch_sampler,
|
333 |
)
|
334 |
else:
|
335 |
-
raise ValueError(
|
336 |
-
|
|
|
|
|
337 |
# accelerator.prepare() dispatches batches to devices;
|
338 |
# which means the length of dataloader calculated before, should consider the number of devices
|
339 |
warmup_steps = (
|
@@ -342,10 +385,16 @@ class Trainer:
|
|
342 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
343 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
344 |
decay_steps = total_steps - warmup_steps
|
345 |
-
warmup_scheduler = LinearLR(
|
346 |
-
|
|
|
|
|
|
|
|
|
347 |
self.scheduler = SequentialLR(
|
348 |
-
self.optimizer,
|
|
|
|
|
349 |
)
|
350 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
351 |
train_dataloader, self.scheduler
|
@@ -359,7 +408,9 @@ class Trainer:
|
|
359 |
orig_epoch_step = len(train_dataloader)
|
360 |
skipped_epoch = int(start_step // orig_epoch_step)
|
361 |
skipped_batch = start_step % orig_epoch_step
|
362 |
-
skipped_dataloader = self.accelerator.skip_first_batches(
|
|
|
|
|
363 |
else:
|
364 |
skipped_epoch = 0
|
365 |
|
@@ -385,49 +436,65 @@ class Trainer:
|
|
385 |
for batch in progress_bar:
|
386 |
with self.accelerator.accumulate(self.model):
|
387 |
# Inputs
|
388 |
-
prompt_mel = batch[
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
# Targets
|
399 |
-
tar_lengths = batch[
|
400 |
|
401 |
# Forward
|
402 |
-
predictions = SLP(
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
est_length_logtis = predictions
|
408 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
409 |
loss = F.cross_entropy(est_length_logtis, tar_length_labels)
|
410 |
-
|
411 |
with torch.no_grad():
|
412 |
est_lengths = est_length_labels * self.n_frame_per_class
|
413 |
-
frame_error = (
|
|
|
|
|
414 |
sec_error = frame_error * 256 / 24000
|
415 |
|
416 |
log_dict = {
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
|
423 |
else:
|
424 |
raise NotImplementedError(self.loss_fn)
|
425 |
|
426 |
-
|
427 |
self.accelerator.backward(loss)
|
428 |
|
429 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
430 |
-
self.accelerator.clip_grad_norm_(
|
|
|
|
|
431 |
|
432 |
self.optimizer.step()
|
433 |
self.scheduler.step()
|
@@ -439,7 +506,10 @@ class Trainer:
|
|
439 |
self.accelerator.log(log_dict, step=global_step)
|
440 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
441 |
|
442 |
-
if
|
|
|
|
|
|
|
443 |
self.save_checkpoint(global_step)
|
444 |
# if self.log_samples and self.accelerator.is_local_main_process:
|
445 |
# Run validation at the end of each epoch (only on the main process)
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
|
|
|
|
4 |
import math
|
5 |
+
import os
|
6 |
|
7 |
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
import torchaudio
|
10 |
import wandb
|
11 |
from accelerate import Accelerator
|
|
|
13 |
from ema_pytorch import EMA
|
14 |
from torch.optim import AdamW
|
15 |
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
16 |
+
from torch.utils.data import Dataset # <-- Added Subset import
|
17 |
+
from torch.utils.data import DataLoader, SequentialSampler, Subset
|
18 |
from tqdm import tqdm
|
19 |
|
|
|
|
|
20 |
from f5_tts.model import CFM
|
21 |
+
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
22 |
+
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
|
23 |
+
list_str_to_tensor, mask_from_frac_lengths)
|
24 |
|
25 |
# trainer
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
SAMPLE_RATE = 24_000
|
29 |
|
|
|
35 |
vocab_size,
|
36 |
vocab_char_map,
|
37 |
process_token_to_id=True,
|
38 |
+
loss_fn="L1",
|
39 |
lambda_L1=1,
|
40 |
gumbel_tau=0.5,
|
41 |
n_class=301,
|
|
|
75 |
self.logger = logger
|
76 |
if self.logger == "wandb":
|
77 |
if exists(wandb_resume_id):
|
78 |
+
init_kwargs = {
|
79 |
+
"wandb": {
|
80 |
+
"resume": "allow",
|
81 |
+
"name": wandb_run_name,
|
82 |
+
"id": wandb_resume_id,
|
83 |
+
}
|
84 |
+
}
|
85 |
else:
|
86 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
87 |
|
|
|
110 |
self.vocab_size = vocab_size
|
111 |
self.vocab_char_map = vocab_char_map
|
112 |
self.process_token_to_id = process_token_to_id
|
113 |
+
assert loss_fn in ["L1", "CE", "L1_and_CE"]
|
114 |
self.loss_fn = loss_fn
|
115 |
self.lambda_L1 = lambda_L1
|
116 |
self.n_class = n_class
|
|
|
120 |
self.epochs = epochs
|
121 |
self.num_warmup_updates = num_warmup_updates
|
122 |
self.save_per_updates = save_per_updates
|
123 |
+
self.last_per_steps = default(
|
124 |
+
last_per_steps, save_per_updates * grad_accumulation_steps
|
125 |
+
)
|
126 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
127 |
|
128 |
self.batch_size = batch_size
|
|
|
137 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
138 |
else:
|
139 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
140 |
+
self.model, self.optimizer = self.accelerator.prepare(
|
141 |
+
self.model, self.optimizer
|
142 |
+
)
|
143 |
+
|
144 |
@property
|
145 |
def is_main(self):
|
146 |
return self.accelerator.is_main_process
|
147 |
|
148 |
def save_checkpoint(self, step, last=False):
|
149 |
self.accelerator.wait_for_everyone()
|
150 |
+
if self.is_main:
|
151 |
checkpoint = dict(
|
152 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
153 |
+
optimizer_state_dict=self.accelerator.unwrap_model(
|
154 |
+
self.optimizer
|
155 |
+
).state_dict(),
|
156 |
scheduler_state_dict=self.scheduler.state_dict(),
|
157 |
step=step,
|
158 |
)
|
159 |
if not os.path.exists(self.checkpoint_path):
|
160 |
os.makedirs(self.checkpoint_path)
|
161 |
if last:
|
162 |
+
self.accelerator.save(
|
163 |
+
checkpoint, f"{self.checkpoint_path}/model_last.pt"
|
164 |
+
)
|
165 |
else:
|
166 |
+
self.accelerator.save(
|
167 |
+
checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
|
168 |
+
)
|
169 |
|
170 |
def load_checkpoint(self):
|
171 |
if (
|
172 |
not exists(self.checkpoint_path)
|
173 |
or not os.path.exists(self.checkpoint_path)
|
174 |
+
or not any(
|
175 |
+
filename.endswith(".pt")
|
176 |
+
for filename in os.listdir(self.checkpoint_path)
|
177 |
+
)
|
178 |
):
|
179 |
return 0
|
180 |
|
|
|
187 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
188 |
)[-1]
|
189 |
|
190 |
+
print(f"To load from {latest_checkpoint}.")
|
191 |
|
192 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
193 |
+
checkpoint = torch.load(
|
194 |
+
f"{self.checkpoint_path}/{latest_checkpoint}",
|
195 |
+
weights_only=True,
|
196 |
+
map_location="cpu",
|
197 |
+
)
|
198 |
|
199 |
+
print(f"Loaded from {latest_checkpoint}.")
|
200 |
|
201 |
if "step" in checkpoint:
|
202 |
# patch for backward compatibility, 305e3ea
|
203 |
+
for key in [
|
204 |
+
"mel_spec.mel_stft.mel_scale.fb",
|
205 |
+
"mel_spec.mel_stft.spectrogram.window",
|
206 |
+
]:
|
207 |
if key in checkpoint["model_state_dict"]:
|
208 |
del checkpoint["model_state_dict"][key]
|
209 |
|
210 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
211 |
+
checkpoint["model_state_dict"]
|
212 |
+
)
|
213 |
+
self.accelerator.unwrap_model(self.optimizer).load_state_dict(
|
214 |
+
checkpoint["optimizer_state_dict"]
|
215 |
+
)
|
216 |
if self.scheduler:
|
217 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
218 |
step = checkpoint["step"]
|
|
|
222 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
223 |
if k not in ["initted", "step"]
|
224 |
}
|
225 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
226 |
+
checkpoint["model_state_dict"]
|
227 |
+
)
|
228 |
step = 0
|
229 |
+
|
230 |
del checkpoint
|
231 |
gc.collect()
|
232 |
|
233 |
+
print(f"Exit load_checkpoint.")
|
234 |
|
235 |
return step
|
236 |
|
|
|
237 |
def validate(self, valid_dataloader, global_step):
|
238 |
"""
|
239 |
Runs evaluation on the validation set, computes the average loss,
|
|
|
249 |
for batch in valid_dataloader:
|
250 |
|
251 |
# Inputs
|
252 |
+
prompt_mel = batch["pmt_mel_specs"].permute(0, 2, 1) # (B, L_mel, D)
|
253 |
+
prompt_text = batch["pmt_text"]
|
254 |
+
text = batch["text"]
|
255 |
|
256 |
+
target_ids = list_str_to_idx(text, self.vocab_char_map).to(
|
257 |
+
prompt_mel.device
|
258 |
+
)
|
259 |
+
target_ids = target_ids.masked_fill(target_ids == -1, vocab_size)
|
260 |
|
261 |
+
prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(
|
262 |
+
prompt_mel.device
|
263 |
+
)
|
264 |
+
prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size)
|
265 |
|
266 |
# Targets
|
267 |
+
tar_lengths = batch["mel_lengths"]
|
268 |
|
269 |
# Forward
|
270 |
+
predictions = SLP(
|
271 |
+
target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel
|
272 |
+
) # (B, C)
|
273 |
+
|
274 |
+
if self.loss_fn == "CE":
|
275 |
+
tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp(
|
276 |
+
min=0, max=self.n_class - 1
|
277 |
+
) # [0, 1, ..., n_class-1]
|
278 |
est_length_logtis = predictions
|
279 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
280 |
loss = F.cross_entropy(est_length_logtis, tar_length_labels)
|
281 |
+
|
282 |
est_lengths = est_length_labels * self.n_frame_per_class
|
283 |
+
frame_error = (
|
284 |
+
(est_lengths.float() - tar_lengths.float()).abs().mean()
|
285 |
+
)
|
286 |
sec_error = frame_error * 256 / 24000
|
287 |
|
288 |
total_sec_error += sec_error.item()
|
|
|
294 |
|
295 |
# Log validation metrics
|
296 |
self.accelerator.log(
|
297 |
+
{f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error},
|
298 |
+
step=global_step,
|
|
|
|
|
|
|
299 |
)
|
|
|
|
|
300 |
|
301 |
+
self.model.train()
|
302 |
|
303 |
+
def train(
|
304 |
+
self,
|
305 |
+
train_dataset: Dataset,
|
306 |
+
valid_dataset: Dataset,
|
307 |
+
num_workers=64,
|
308 |
+
resumable_with_seed: int = None,
|
309 |
+
):
|
310 |
if exists(resumable_with_seed):
|
311 |
generator = torch.Generator()
|
312 |
generator.manual_seed(resumable_with_seed)
|
|
|
340 |
|
341 |
sampler = SequentialSampler(train_dataset)
|
342 |
batch_sampler = DynamicBatchSampler(
|
343 |
+
sampler,
|
344 |
+
self.batch_size,
|
345 |
+
max_samples=self.max_samples,
|
346 |
+
random_seed=resumable_with_seed,
|
347 |
+
drop_last=False,
|
348 |
)
|
349 |
train_dataloader = DataLoader(
|
350 |
train_dataset,
|
|
|
357 |
|
358 |
sampler = SequentialSampler(valid_dataset)
|
359 |
batch_sampler = DynamicBatchSampler(
|
360 |
+
sampler,
|
361 |
+
self.batch_size,
|
362 |
+
max_samples=self.max_samples,
|
363 |
+
random_seed=resumable_with_seed,
|
364 |
+
drop_last=False,
|
365 |
)
|
366 |
# Create validation dataloader (always sequential, no shuffling)
|
367 |
valid_dataloader = DataLoader(
|
368 |
valid_dataset,
|
369 |
collate_fn=collate_fn,
|
370 |
num_workers=num_workers,
|
371 |
+
pin_memory=True,
|
372 |
persistent_workers=True,
|
373 |
batch_sampler=batch_sampler,
|
374 |
)
|
375 |
else:
|
376 |
+
raise ValueError(
|
377 |
+
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
|
378 |
+
)
|
379 |
+
|
380 |
# accelerator.prepare() dispatches batches to devices;
|
381 |
# which means the length of dataloader calculated before, should consider the number of devices
|
382 |
warmup_steps = (
|
|
|
385 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
386 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
387 |
decay_steps = total_steps - warmup_steps
|
388 |
+
warmup_scheduler = LinearLR(
|
389 |
+
self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps
|
390 |
+
)
|
391 |
+
decay_scheduler = LinearLR(
|
392 |
+
self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps
|
393 |
+
)
|
394 |
self.scheduler = SequentialLR(
|
395 |
+
self.optimizer,
|
396 |
+
schedulers=[warmup_scheduler, decay_scheduler],
|
397 |
+
milestones=[warmup_steps],
|
398 |
)
|
399 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
400 |
train_dataloader, self.scheduler
|
|
|
408 |
orig_epoch_step = len(train_dataloader)
|
409 |
skipped_epoch = int(start_step // orig_epoch_step)
|
410 |
skipped_batch = start_step % orig_epoch_step
|
411 |
+
skipped_dataloader = self.accelerator.skip_first_batches(
|
412 |
+
train_dataloader, num_batches=skipped_batch
|
413 |
+
)
|
414 |
else:
|
415 |
skipped_epoch = 0
|
416 |
|
|
|
436 |
for batch in progress_bar:
|
437 |
with self.accelerator.accumulate(self.model):
|
438 |
# Inputs
|
439 |
+
prompt_mel = batch["pmt_mel_specs"].permute(
|
440 |
+
0, 2, 1
|
441 |
+
) # (B, L_mel, D)
|
442 |
+
prompt_text = batch["pmt_text"]
|
443 |
+
text = batch["text"]
|
444 |
+
|
445 |
+
target_ids = list_str_to_idx(text, self.vocab_char_map).to(
|
446 |
+
prompt_mel.device
|
447 |
+
)
|
448 |
+
target_ids = target_ids.masked_fill(target_ids == -1, vocab_size)
|
449 |
+
|
450 |
+
prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(
|
451 |
+
prompt_mel.device
|
452 |
+
)
|
453 |
+
prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size)
|
454 |
|
455 |
# Targets
|
456 |
+
tar_lengths = batch["mel_lengths"]
|
457 |
|
458 |
# Forward
|
459 |
+
predictions = SLP(
|
460 |
+
target_ids=target_ids,
|
461 |
+
prompt_ids=prompt_ids,
|
462 |
+
prompt_mel=prompt_mel,
|
463 |
+
) # (B, C)
|
464 |
+
|
465 |
+
if self.loss_fn == "CE":
|
466 |
+
tar_length_labels = (
|
467 |
+
tar_lengths // self.n_frame_per_class
|
468 |
+
).clamp(
|
469 |
+
min=0, max=self.n_class - 1
|
470 |
+
) # [0, 1, ..., n_class-1]
|
471 |
est_length_logtis = predictions
|
472 |
est_length_labels = torch.argmax(est_length_logtis, dim=-1)
|
473 |
loss = F.cross_entropy(est_length_logtis, tar_length_labels)
|
474 |
+
|
475 |
with torch.no_grad():
|
476 |
est_lengths = est_length_labels * self.n_frame_per_class
|
477 |
+
frame_error = (
|
478 |
+
(est_lengths.float() - tar_lengths.float()).abs().mean()
|
479 |
+
)
|
480 |
sec_error = frame_error * 256 / 24000
|
481 |
|
482 |
log_dict = {
|
483 |
+
"loss": loss.item(),
|
484 |
+
"loss_CE": loss.item(),
|
485 |
+
"sec_error": sec_error.item(),
|
486 |
+
"lr": self.scheduler.get_last_lr()[0],
|
487 |
+
}
|
488 |
|
489 |
else:
|
490 |
raise NotImplementedError(self.loss_fn)
|
491 |
|
|
|
492 |
self.accelerator.backward(loss)
|
493 |
|
494 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
495 |
+
self.accelerator.clip_grad_norm_(
|
496 |
+
self.model.parameters(), self.max_grad_norm
|
497 |
+
)
|
498 |
|
499 |
self.optimizer.step()
|
500 |
self.scheduler.step()
|
|
|
506 |
self.accelerator.log(log_dict, step=global_step)
|
507 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
508 |
|
509 |
+
if (
|
510 |
+
global_step % (self.save_per_updates * self.grad_accumulation_steps)
|
511 |
+
== 0
|
512 |
+
):
|
513 |
self.save_checkpoint(global_step)
|
514 |
# if self.log_samples and self.accelerator.is_local_main_process:
|
515 |
# Run validation at the end of each epoch (only on the main process)
|
ecapa_tdnn.py
CHANGED
@@ -1,23 +1,34 @@
|
|
1 |
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
2 |
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
import torchaudio.transforms as trans
|
|
|
7 |
from ctcmodel import ConformerCTC
|
8 |
-
# from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool
|
9 |
-
from pathlib import Path
|
10 |
|
11 |
-
|
12 |
-
|
13 |
|
14 |
|
15 |
class Res2Conv1dReluBn(nn.Module):
|
16 |
-
|
17 |
in_channels == out_channels == channels
|
18 |
-
|
19 |
-
|
20 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
super().__init__()
|
22 |
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
23 |
self.scale = scale
|
@@ -27,7 +38,17 @@ class Res2Conv1dReluBn(nn.Module):
|
|
27 |
self.convs = []
|
28 |
self.bns = []
|
29 |
for i in range(self.nums):
|
30 |
-
self.convs.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
self.bns.append(nn.BatchNorm1d(self.width))
|
32 |
self.convs = nn.ModuleList(self.convs)
|
33 |
self.bns = nn.ModuleList(self.bns)
|
@@ -51,22 +72,33 @@ class Res2Conv1dReluBn(nn.Module):
|
|
51 |
return out
|
52 |
|
53 |
|
54 |
-
|
55 |
-
|
56 |
|
57 |
|
58 |
class Conv1dReluBn(nn.Module):
|
59 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
super().__init__()
|
61 |
-
self.conv = nn.Conv1d(
|
|
|
|
|
62 |
self.bn = nn.BatchNorm1d(out_channels)
|
63 |
|
64 |
def forward(self, x):
|
65 |
return self.bn(F.relu(self.conv(x)))
|
66 |
|
67 |
|
68 |
-
|
69 |
-
|
70 |
|
71 |
|
72 |
class SE_Connect(nn.Module):
|
@@ -84,15 +116,32 @@ class SE_Connect(nn.Module):
|
|
84 |
return out
|
85 |
|
86 |
|
87 |
-
|
88 |
-
|
|
|
89 |
|
90 |
class SE_Res2Block(nn.Module):
|
91 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
super().__init__()
|
93 |
-
self.Conv1dReluBn1 = Conv1dReluBn(
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
97 |
|
98 |
self.shortcut = None
|
@@ -116,8 +165,9 @@ class SE_Res2Block(nn.Module):
|
|
116 |
return x + residual
|
117 |
|
118 |
|
119 |
-
|
120 |
-
|
|
|
121 |
|
122 |
class AttentiveStatsPool(nn.Module):
|
123 |
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
@@ -126,16 +176,24 @@ class AttentiveStatsPool(nn.Module):
|
|
126 |
|
127 |
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
128 |
if global_context_att:
|
129 |
-
self.linear1 = nn.Conv1d(
|
|
|
|
|
130 |
else:
|
131 |
-
self.linear1 = nn.Conv1d(
|
132 |
-
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def forward(self, x):
|
135 |
|
136 |
if self.global_context_att:
|
137 |
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
138 |
-
context_std = torch.sqrt(
|
|
|
|
|
139 |
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
140 |
else:
|
141 |
x_in = x
|
@@ -145,42 +203,52 @@ class AttentiveStatsPool(nn.Module):
|
|
145 |
# alpha = F.relu(self.linear1(x_in))
|
146 |
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
147 |
mean = torch.sum(alpha * x, dim=2)
|
148 |
-
residuals = torch.sum(alpha * (x
|
149 |
std = torch.sqrt(residuals.clamp(min=1e-9))
|
150 |
return torch.cat([mean, std], dim=1)
|
151 |
|
152 |
|
153 |
class ECAPA_TDNN(nn.Module):
|
154 |
-
def __init__(
|
155 |
-
|
|
|
|
|
|
|
|
|
156 |
ctc_cls=ConformerCTC,
|
157 |
-
ctc_path=
|
158 |
-
ctc_args={
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
):
|
161 |
super().__init__()
|
162 |
if ctc_path != None:
|
163 |
ctc_path = Path(ctc_path)
|
164 |
model = ctc_cls(**ctc_args)
|
165 |
-
state_dict = torch.load(ctc_path, map_location=
|
166 |
-
model.load_state_dict(state_dict[
|
167 |
print(f"Initialized pretrained ConformerCTC backbone from {ctc_path}.")
|
168 |
else:
|
169 |
raise ValueError(ctc_path)
|
170 |
|
171 |
self.ctc_model = model
|
172 |
self.ctc_model.out.requires_grad_(False)
|
173 |
-
|
174 |
if ctc_cls == ConformerCTC:
|
175 |
-
self.feat_num = ctc_args[
|
176 |
# elif ctc_cls == ConformerCTCNoPool:
|
177 |
# self.feat_num = ctc_args['nlayers'] + 1
|
178 |
else:
|
179 |
raise ValueError(ctc_cls)
|
180 |
-
feat_dim = ctc_args[
|
181 |
|
182 |
self.emb_dim = emb_dim
|
183 |
-
|
184 |
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
185 |
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
186 |
|
@@ -188,14 +256,45 @@ class ECAPA_TDNN(nn.Module):
|
|
188 |
self.channels = [channels] * 4 + [1536]
|
189 |
|
190 |
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
191 |
-
self.layer2 = SE_Res2Block(
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
196 |
cat_channels = channels * 3
|
197 |
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
198 |
-
self.pooling = AttentiveStatsPool(
|
|
|
|
|
|
|
|
|
199 |
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
200 |
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
201 |
|
@@ -206,21 +305,26 @@ class ECAPA_TDNN(nn.Module):
|
|
206 |
else:
|
207 |
self.ctc_model = self.ctc_model.train()
|
208 |
self.ctc_no_grad = ctc_no_grad
|
209 |
-
print(
|
210 |
|
211 |
-
def forward(self, latent, input_lengths,
|
212 |
if self.ctc_no_grad:
|
213 |
with torch.no_grad():
|
214 |
asr, h = self.ctc_model(latent, input_lengths)
|
215 |
else:
|
216 |
asr, h = self.ctc_model(latent, input_lengths)
|
217 |
-
|
218 |
x = torch.stack(h, dim=0)
|
219 |
-
norm_weights =
|
|
|
|
|
|
|
|
|
|
|
220 |
x = (norm_weights * x).sum(dim=0)
|
221 |
x = x + 1e-6
|
222 |
# x = torch.transpose(x, 1, 2) + 1e-6
|
223 |
-
|
224 |
x = self.instance_norm(x)
|
225 |
# x = torch.transpose(x, 1, 2)
|
226 |
|
@@ -238,9 +342,10 @@ class ECAPA_TDNN(nn.Module):
|
|
238 |
return out, asr
|
239 |
return out
|
240 |
|
|
|
241 |
if __name__ == "__main__":
|
242 |
-
from diffspeech.ldm.model import DiT
|
243 |
from diffspeech.data.collate import get_mask_from_lengths
|
|
|
244 |
from diffspeech.tools.text.vocab import IPA
|
245 |
|
246 |
bsz = 3
|
@@ -265,4 +370,4 @@ if __name__ == "__main__":
|
|
265 |
|
266 |
emb = model(latent, latent_mask.sum(axis=-1))
|
267 |
|
268 |
-
print(emb.shape)
|
|
|
1 |
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
2 |
|
3 |
+
# from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
import torchaudio.transforms as trans
|
10 |
+
|
11 |
from ctcmodel import ConformerCTC
|
|
|
|
|
12 |
|
13 |
+
""" Res2Conv1d + BatchNorm1d + ReLU
|
14 |
+
"""
|
15 |
|
16 |
|
17 |
class Res2Conv1dReluBn(nn.Module):
|
18 |
+
"""
|
19 |
in_channels == out_channels == channels
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
channels,
|
25 |
+
kernel_size=1,
|
26 |
+
stride=1,
|
27 |
+
padding=0,
|
28 |
+
dilation=1,
|
29 |
+
bias=True,
|
30 |
+
scale=4,
|
31 |
+
):
|
32 |
super().__init__()
|
33 |
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
34 |
self.scale = scale
|
|
|
38 |
self.convs = []
|
39 |
self.bns = []
|
40 |
for i in range(self.nums):
|
41 |
+
self.convs.append(
|
42 |
+
nn.Conv1d(
|
43 |
+
self.width,
|
44 |
+
self.width,
|
45 |
+
kernel_size,
|
46 |
+
stride,
|
47 |
+
padding,
|
48 |
+
dilation,
|
49 |
+
bias=bias,
|
50 |
+
)
|
51 |
+
)
|
52 |
self.bns.append(nn.BatchNorm1d(self.width))
|
53 |
self.convs = nn.ModuleList(self.convs)
|
54 |
self.bns = nn.ModuleList(self.bns)
|
|
|
72 |
return out
|
73 |
|
74 |
|
75 |
+
""" Conv1d + BatchNorm1d + ReLU
|
76 |
+
"""
|
77 |
|
78 |
|
79 |
class Conv1dReluBn(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
kernel_size=1,
|
85 |
+
stride=1,
|
86 |
+
padding=0,
|
87 |
+
dilation=1,
|
88 |
+
bias=True,
|
89 |
+
):
|
90 |
super().__init__()
|
91 |
+
self.conv = nn.Conv1d(
|
92 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
93 |
+
)
|
94 |
self.bn = nn.BatchNorm1d(out_channels)
|
95 |
|
96 |
def forward(self, x):
|
97 |
return self.bn(F.relu(self.conv(x)))
|
98 |
|
99 |
|
100 |
+
""" The SE connection of 1D case.
|
101 |
+
"""
|
102 |
|
103 |
|
104 |
class SE_Connect(nn.Module):
|
|
|
116 |
return out
|
117 |
|
118 |
|
119 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
120 |
+
"""
|
121 |
+
|
122 |
|
123 |
class SE_Res2Block(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
in_channels,
|
127 |
+
out_channels,
|
128 |
+
kernel_size,
|
129 |
+
stride,
|
130 |
+
padding,
|
131 |
+
dilation,
|
132 |
+
scale,
|
133 |
+
se_bottleneck_dim,
|
134 |
+
):
|
135 |
super().__init__()
|
136 |
+
self.Conv1dReluBn1 = Conv1dReluBn(
|
137 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
138 |
+
)
|
139 |
+
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
140 |
+
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
141 |
+
)
|
142 |
+
self.Conv1dReluBn2 = Conv1dReluBn(
|
143 |
+
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
144 |
+
)
|
145 |
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
146 |
|
147 |
self.shortcut = None
|
|
|
165 |
return x + residual
|
166 |
|
167 |
|
168 |
+
""" Attentive weighted mean and standard deviation pooling.
|
169 |
+
"""
|
170 |
+
|
171 |
|
172 |
class AttentiveStatsPool(nn.Module):
|
173 |
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
|
|
176 |
|
177 |
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
178 |
if global_context_att:
|
179 |
+
self.linear1 = nn.Conv1d(
|
180 |
+
in_dim * 3, attention_channels, kernel_size=1
|
181 |
+
) # equals W and b in the paper
|
182 |
else:
|
183 |
+
self.linear1 = nn.Conv1d(
|
184 |
+
in_dim, attention_channels, kernel_size=1
|
185 |
+
) # equals W and b in the paper
|
186 |
+
self.linear2 = nn.Conv1d(
|
187 |
+
attention_channels, in_dim, kernel_size=1
|
188 |
+
) # equals V and k in the paper
|
189 |
|
190 |
def forward(self, x):
|
191 |
|
192 |
if self.global_context_att:
|
193 |
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
194 |
+
context_std = torch.sqrt(
|
195 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
196 |
+
).expand_as(x)
|
197 |
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
198 |
else:
|
199 |
x_in = x
|
|
|
203 |
# alpha = F.relu(self.linear1(x_in))
|
204 |
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
205 |
mean = torch.sum(alpha * x, dim=2)
|
206 |
+
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
207 |
std = torch.sqrt(residuals.clamp(min=1e-9))
|
208 |
return torch.cat([mean, std], dim=1)
|
209 |
|
210 |
|
211 |
class ECAPA_TDNN(nn.Module):
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
channels=512,
|
215 |
+
emb_dim=512,
|
216 |
+
global_context_att=False,
|
217 |
+
use_fp16=True,
|
218 |
ctc_cls=ConformerCTC,
|
219 |
+
ctc_path="/data4/F5TTS/ckpts/F5TTS_norm_ASR_vocos_pinyin_Emilia_ZH_EN/model_last.pt",
|
220 |
+
ctc_args={
|
221 |
+
"vocab_size": 2545,
|
222 |
+
"mel_dim": 100,
|
223 |
+
"num_heads": 8,
|
224 |
+
"d_hid": 512,
|
225 |
+
"nlayers": 6,
|
226 |
+
},
|
227 |
+
ctc_no_grad=False,
|
228 |
):
|
229 |
super().__init__()
|
230 |
if ctc_path != None:
|
231 |
ctc_path = Path(ctc_path)
|
232 |
model = ctc_cls(**ctc_args)
|
233 |
+
state_dict = torch.load(ctc_path, map_location="cpu")
|
234 |
+
model.load_state_dict(state_dict["model_state_dict"])
|
235 |
print(f"Initialized pretrained ConformerCTC backbone from {ctc_path}.")
|
236 |
else:
|
237 |
raise ValueError(ctc_path)
|
238 |
|
239 |
self.ctc_model = model
|
240 |
self.ctc_model.out.requires_grad_(False)
|
241 |
+
|
242 |
if ctc_cls == ConformerCTC:
|
243 |
+
self.feat_num = ctc_args["nlayers"] + 2 + 1
|
244 |
# elif ctc_cls == ConformerCTCNoPool:
|
245 |
# self.feat_num = ctc_args['nlayers'] + 1
|
246 |
else:
|
247 |
raise ValueError(ctc_cls)
|
248 |
+
feat_dim = ctc_args["d_hid"]
|
249 |
|
250 |
self.emb_dim = emb_dim
|
251 |
+
|
252 |
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
253 |
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
254 |
|
|
|
256 |
self.channels = [channels] * 4 + [1536]
|
257 |
|
258 |
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
259 |
+
self.layer2 = SE_Res2Block(
|
260 |
+
self.channels[0],
|
261 |
+
self.channels[1],
|
262 |
+
kernel_size=3,
|
263 |
+
stride=1,
|
264 |
+
padding=2,
|
265 |
+
dilation=2,
|
266 |
+
scale=8,
|
267 |
+
se_bottleneck_dim=128,
|
268 |
+
)
|
269 |
+
self.layer3 = SE_Res2Block(
|
270 |
+
self.channels[1],
|
271 |
+
self.channels[2],
|
272 |
+
kernel_size=3,
|
273 |
+
stride=1,
|
274 |
+
padding=3,
|
275 |
+
dilation=3,
|
276 |
+
scale=8,
|
277 |
+
se_bottleneck_dim=128,
|
278 |
+
)
|
279 |
+
self.layer4 = SE_Res2Block(
|
280 |
+
self.channels[2],
|
281 |
+
self.channels[3],
|
282 |
+
kernel_size=3,
|
283 |
+
stride=1,
|
284 |
+
padding=4,
|
285 |
+
dilation=4,
|
286 |
+
scale=8,
|
287 |
+
se_bottleneck_dim=128,
|
288 |
+
)
|
289 |
|
290 |
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
291 |
cat_channels = channels * 3
|
292 |
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
293 |
+
self.pooling = AttentiveStatsPool(
|
294 |
+
self.channels[-1],
|
295 |
+
attention_channels=128,
|
296 |
+
global_context_att=global_context_att,
|
297 |
+
)
|
298 |
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
299 |
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
300 |
|
|
|
305 |
else:
|
306 |
self.ctc_model = self.ctc_model.train()
|
307 |
self.ctc_no_grad = ctc_no_grad
|
308 |
+
print("ctc_no_grad: ", self.ctc_no_grad)
|
309 |
|
310 |
+
def forward(self, latent, input_lengths, return_asr=False):
|
311 |
if self.ctc_no_grad:
|
312 |
with torch.no_grad():
|
313 |
asr, h = self.ctc_model(latent, input_lengths)
|
314 |
else:
|
315 |
asr, h = self.ctc_model(latent, input_lengths)
|
316 |
+
|
317 |
x = torch.stack(h, dim=0)
|
318 |
+
norm_weights = (
|
319 |
+
F.softmax(self.feature_weight, dim=-1)
|
320 |
+
.unsqueeze(-1)
|
321 |
+
.unsqueeze(-1)
|
322 |
+
.unsqueeze(-1)
|
323 |
+
)
|
324 |
x = (norm_weights * x).sum(dim=0)
|
325 |
x = x + 1e-6
|
326 |
# x = torch.transpose(x, 1, 2) + 1e-6
|
327 |
+
|
328 |
x = self.instance_norm(x)
|
329 |
# x = torch.transpose(x, 1, 2)
|
330 |
|
|
|
342 |
return out, asr
|
343 |
return out
|
344 |
|
345 |
+
|
346 |
if __name__ == "__main__":
|
|
|
347 |
from diffspeech.data.collate import get_mask_from_lengths
|
348 |
+
from diffspeech.ldm.model import DiT
|
349 |
from diffspeech.tools.text.vocab import IPA
|
350 |
|
351 |
bsz = 3
|
|
|
370 |
|
371 |
emb = model(latent, latent_mask.sum(axis=-1))
|
372 |
|
373 |
+
print(emb.shape)
|
f5_tts/api.py
CHANGED
@@ -8,15 +8,10 @@ from cached_path import cached_path
|
|
8 |
from hydra.utils import get_class
|
9 |
from omegaconf import OmegaConf
|
10 |
|
11 |
-
from f5_tts.infer.utils_infer import (
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
preprocess_ref_audio_text,
|
16 |
-
remove_silence_for_generated_wav,
|
17 |
-
save_spectrogram,
|
18 |
-
transcribe,
|
19 |
-
)
|
20 |
from f5_tts.model.utils import seed_everything
|
21 |
|
22 |
|
@@ -32,7 +27,9 @@ class F5TTS:
|
|
32 |
device=None,
|
33 |
hf_cache_dir=None,
|
34 |
):
|
35 |
-
model_cfg = OmegaConf.load(
|
|
|
|
|
36 |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
37 |
model_arc = model_cfg.model.arch
|
38 |
|
@@ -50,16 +47,20 @@ class F5TTS:
|
|
50 |
self.device = (
|
51 |
"cuda"
|
52 |
if torch.cuda.is_available()
|
53 |
-
else
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
)
|
59 |
|
60 |
# Load models
|
61 |
self.vocoder = load_vocoder(
|
62 |
-
self.mel_spec_type,
|
|
|
|
|
|
|
|
|
63 |
)
|
64 |
|
65 |
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
@@ -77,10 +78,20 @@ class F5TTS:
|
|
77 |
|
78 |
if not ckpt_file:
|
79 |
ckpt_file = str(
|
80 |
-
cached_path(
|
|
|
|
|
|
|
81 |
)
|
82 |
self.ema_model = load_model(
|
83 |
-
model_cls,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
)
|
85 |
|
86 |
def transcribe(self, ref_audio, language=None):
|
|
|
8 |
from hydra.utils import get_class
|
9 |
from omegaconf import OmegaConf
|
10 |
|
11 |
+
from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
|
12 |
+
preprocess_ref_audio_text,
|
13 |
+
remove_silence_for_generated_wav,
|
14 |
+
save_spectrogram, transcribe)
|
|
|
|
|
|
|
|
|
|
|
15 |
from f5_tts.model.utils import seed_everything
|
16 |
|
17 |
|
|
|
27 |
device=None,
|
28 |
hf_cache_dir=None,
|
29 |
):
|
30 |
+
model_cfg = OmegaConf.load(
|
31 |
+
str(files("f5_tts").joinpath(f"configs/{model}.yaml"))
|
32 |
+
)
|
33 |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
34 |
model_arc = model_cfg.model.arch
|
35 |
|
|
|
47 |
self.device = (
|
48 |
"cuda"
|
49 |
if torch.cuda.is_available()
|
50 |
+
else (
|
51 |
+
"xpu"
|
52 |
+
if torch.xpu.is_available()
|
53 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
54 |
+
)
|
55 |
)
|
56 |
|
57 |
# Load models
|
58 |
self.vocoder = load_vocoder(
|
59 |
+
self.mel_spec_type,
|
60 |
+
vocoder_local_path is not None,
|
61 |
+
vocoder_local_path,
|
62 |
+
self.device,
|
63 |
+
hf_cache_dir,
|
64 |
)
|
65 |
|
66 |
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
|
|
78 |
|
79 |
if not ckpt_file:
|
80 |
ckpt_file = str(
|
81 |
+
cached_path(
|
82 |
+
f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}",
|
83 |
+
cache_dir=hf_cache_dir,
|
84 |
+
)
|
85 |
)
|
86 |
self.ema_model = load_model(
|
87 |
+
model_cls,
|
88 |
+
model_arc,
|
89 |
+
ckpt_file,
|
90 |
+
self.mel_spec_type,
|
91 |
+
vocab_file,
|
92 |
+
self.ode_method,
|
93 |
+
self.use_ema,
|
94 |
+
self.device,
|
95 |
)
|
96 |
|
97 |
def transcribe(self, ref_audio, language=None):
|
f5_tts/eval/ecapa_tdnn.py
CHANGED
@@ -9,7 +9,6 @@ import torch
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
-
|
13 |
""" Res2Conv1d + BatchNorm1d + ReLU
|
14 |
"""
|
15 |
|
@@ -19,7 +18,16 @@ class Res2Conv1dReluBn(nn.Module):
|
|
19 |
in_channels == out_channels == channels
|
20 |
"""
|
21 |
|
22 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
super().__init__()
|
24 |
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
25 |
self.scale = scale
|
@@ -29,7 +37,17 @@ class Res2Conv1dReluBn(nn.Module):
|
|
29 |
self.convs = []
|
30 |
self.bns = []
|
31 |
for i in range(self.nums):
|
32 |
-
self.convs.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
self.bns.append(nn.BatchNorm1d(self.width))
|
34 |
self.convs = nn.ModuleList(self.convs)
|
35 |
self.bns = nn.ModuleList(self.bns)
|
@@ -58,9 +76,20 @@ class Res2Conv1dReluBn(nn.Module):
|
|
58 |
|
59 |
|
60 |
class Conv1dReluBn(nn.Module):
|
61 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
super().__init__()
|
63 |
-
self.conv = nn.Conv1d(
|
|
|
|
|
64 |
self.bn = nn.BatchNorm1d(out_channels)
|
65 |
|
66 |
def forward(self, x):
|
@@ -99,11 +128,27 @@ class SE_Connect(nn.Module):
|
|
99 |
|
100 |
|
101 |
class SE_Res2Block(nn.Module):
|
102 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
super().__init__()
|
104 |
-
self.Conv1dReluBn1 = Conv1dReluBn(
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
108 |
|
109 |
self.shortcut = None
|
@@ -138,15 +183,23 @@ class AttentiveStatsPool(nn.Module):
|
|
138 |
|
139 |
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
140 |
if global_context_att:
|
141 |
-
self.linear1 = nn.Conv1d(
|
|
|
|
|
142 |
else:
|
143 |
-
self.linear1 = nn.Conv1d(
|
144 |
-
|
|
|
|
|
|
|
|
|
145 |
|
146 |
def forward(self, x):
|
147 |
if self.global_context_att:
|
148 |
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
149 |
-
context_std = torch.sqrt(
|
|
|
|
|
150 |
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
151 |
else:
|
152 |
x_in = x
|
@@ -184,24 +237,36 @@ class ECAPA_TDNN(nn.Module):
|
|
184 |
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
185 |
try:
|
186 |
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
187 |
-
self.feature_extract = torch.hub.load(
|
|
|
|
|
188 |
except: # noqa: E722
|
189 |
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
190 |
|
191 |
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
192 |
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
193 |
):
|
194 |
-
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention =
|
|
|
|
|
195 |
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
196 |
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
197 |
):
|
198 |
-
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention =
|
|
|
|
|
199 |
|
200 |
self.feat_num = self.get_feat_num()
|
201 |
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
202 |
|
203 |
if feat_type != "fbank" and feat_type != "mfcc":
|
204 |
-
freeze_list = [
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
for name, param in self.feature_extract.named_parameters():
|
206 |
for freeze_val in freeze_list:
|
207 |
if freeze_val in name:
|
@@ -252,7 +317,9 @@ class ECAPA_TDNN(nn.Module):
|
|
252 |
cat_channels = channels * 3
|
253 |
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
254 |
self.pooling = AttentiveStatsPool(
|
255 |
-
self.channels[-1],
|
|
|
|
|
256 |
)
|
257 |
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
258 |
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
@@ -287,7 +354,12 @@ class ECAPA_TDNN(nn.Module):
|
|
287 |
x = torch.stack(x, dim=0)
|
288 |
else:
|
289 |
x = x.unsqueeze(0)
|
290 |
-
norm_weights =
|
|
|
|
|
|
|
|
|
|
|
291 |
x = (norm_weights * x).sum(dim=0)
|
292 |
x = torch.transpose(x, 1, 2) + 1e-6
|
293 |
|
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
|
|
|
12 |
""" Res2Conv1d + BatchNorm1d + ReLU
|
13 |
"""
|
14 |
|
|
|
18 |
in_channels == out_channels == channels
|
19 |
"""
|
20 |
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
channels,
|
24 |
+
kernel_size=1,
|
25 |
+
stride=1,
|
26 |
+
padding=0,
|
27 |
+
dilation=1,
|
28 |
+
bias=True,
|
29 |
+
scale=4,
|
30 |
+
):
|
31 |
super().__init__()
|
32 |
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
33 |
self.scale = scale
|
|
|
37 |
self.convs = []
|
38 |
self.bns = []
|
39 |
for i in range(self.nums):
|
40 |
+
self.convs.append(
|
41 |
+
nn.Conv1d(
|
42 |
+
self.width,
|
43 |
+
self.width,
|
44 |
+
kernel_size,
|
45 |
+
stride,
|
46 |
+
padding,
|
47 |
+
dilation,
|
48 |
+
bias=bias,
|
49 |
+
)
|
50 |
+
)
|
51 |
self.bns.append(nn.BatchNorm1d(self.width))
|
52 |
self.convs = nn.ModuleList(self.convs)
|
53 |
self.bns = nn.ModuleList(self.bns)
|
|
|
76 |
|
77 |
|
78 |
class Conv1dReluBn(nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
in_channels,
|
82 |
+
out_channels,
|
83 |
+
kernel_size=1,
|
84 |
+
stride=1,
|
85 |
+
padding=0,
|
86 |
+
dilation=1,
|
87 |
+
bias=True,
|
88 |
+
):
|
89 |
super().__init__()
|
90 |
+
self.conv = nn.Conv1d(
|
91 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
92 |
+
)
|
93 |
self.bn = nn.BatchNorm1d(out_channels)
|
94 |
|
95 |
def forward(self, x):
|
|
|
128 |
|
129 |
|
130 |
class SE_Res2Block(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
in_channels,
|
134 |
+
out_channels,
|
135 |
+
kernel_size,
|
136 |
+
stride,
|
137 |
+
padding,
|
138 |
+
dilation,
|
139 |
+
scale,
|
140 |
+
se_bottleneck_dim,
|
141 |
+
):
|
142 |
super().__init__()
|
143 |
+
self.Conv1dReluBn1 = Conv1dReluBn(
|
144 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
145 |
+
)
|
146 |
+
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
147 |
+
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
148 |
+
)
|
149 |
+
self.Conv1dReluBn2 = Conv1dReluBn(
|
150 |
+
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
151 |
+
)
|
152 |
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
153 |
|
154 |
self.shortcut = None
|
|
|
183 |
|
184 |
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
185 |
if global_context_att:
|
186 |
+
self.linear1 = nn.Conv1d(
|
187 |
+
in_dim * 3, attention_channels, kernel_size=1
|
188 |
+
) # equals W and b in the paper
|
189 |
else:
|
190 |
+
self.linear1 = nn.Conv1d(
|
191 |
+
in_dim, attention_channels, kernel_size=1
|
192 |
+
) # equals W and b in the paper
|
193 |
+
self.linear2 = nn.Conv1d(
|
194 |
+
attention_channels, in_dim, kernel_size=1
|
195 |
+
) # equals V and k in the paper
|
196 |
|
197 |
def forward(self, x):
|
198 |
if self.global_context_att:
|
199 |
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
200 |
+
context_std = torch.sqrt(
|
201 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
202 |
+
).expand_as(x)
|
203 |
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
204 |
else:
|
205 |
x_in = x
|
|
|
237 |
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
238 |
try:
|
239 |
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
240 |
+
self.feature_extract = torch.hub.load(
|
241 |
+
local_s3prl_path, feat_type, source="local", config_path=config_path
|
242 |
+
)
|
243 |
except: # noqa: E722
|
244 |
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
245 |
|
246 |
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
247 |
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
248 |
):
|
249 |
+
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = (
|
250 |
+
False
|
251 |
+
)
|
252 |
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
253 |
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
254 |
):
|
255 |
+
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = (
|
256 |
+
False
|
257 |
+
)
|
258 |
|
259 |
self.feat_num = self.get_feat_num()
|
260 |
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
261 |
|
262 |
if feat_type != "fbank" and feat_type != "mfcc":
|
263 |
+
freeze_list = [
|
264 |
+
"final_proj",
|
265 |
+
"label_embs_concat",
|
266 |
+
"mask_emb",
|
267 |
+
"project_q",
|
268 |
+
"quantizer",
|
269 |
+
]
|
270 |
for name, param in self.feature_extract.named_parameters():
|
271 |
for freeze_val in freeze_list:
|
272 |
if freeze_val in name:
|
|
|
317 |
cat_channels = channels * 3
|
318 |
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
319 |
self.pooling = AttentiveStatsPool(
|
320 |
+
self.channels[-1],
|
321 |
+
attention_channels=128,
|
322 |
+
global_context_att=global_context_att,
|
323 |
)
|
324 |
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
325 |
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
|
|
354 |
x = torch.stack(x, dim=0)
|
355 |
else:
|
356 |
x = x.unsqueeze(0)
|
357 |
+
norm_weights = (
|
358 |
+
F.softmax(self.feature_weight, dim=-1)
|
359 |
+
.unsqueeze(-1)
|
360 |
+
.unsqueeze(-1)
|
361 |
+
.unsqueeze(-1)
|
362 |
+
)
|
363 |
x = (norm_weights * x).sum(dim=0)
|
364 |
x = torch.transpose(x, 1, 2) + 1e-6
|
365 |
|
f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
|
4 |
-
|
5 |
sys.path.append(os.getcwd())
|
6 |
|
7 |
import argparse
|
@@ -15,16 +14,13 @@ from hydra.utils import get_class
|
|
15 |
from omegaconf import OmegaConf
|
16 |
from tqdm import tqdm
|
17 |
|
18 |
-
from f5_tts.eval.utils_eval import (
|
19 |
-
|
20 |
-
|
21 |
-
get_seedtts_testset_metainfo,
|
22 |
-
)
|
23 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
24 |
from f5_tts.model import CFM
|
25 |
from f5_tts.model.utils import get_tokenizer
|
26 |
|
27 |
-
|
28 |
accelerator = Accelerator()
|
29 |
device = f"cuda:{accelerator.process_index}"
|
30 |
|
@@ -67,7 +63,9 @@ def main():
|
|
67 |
use_truth_duration = False
|
68 |
no_ref_audio = False
|
69 |
|
70 |
-
model_cfg = OmegaConf.load(
|
|
|
|
|
71 |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
72 |
model_arc = model_cfg.model.arch
|
73 |
|
@@ -83,8 +81,12 @@ def main():
|
|
83 |
|
84 |
if testset == "ls_pc_test_clean":
|
85 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
86 |
-
librispeech_test_clean_path =
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
|
89 |
elif testset == "seedtts_test_zh":
|
90 |
metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
|
@@ -126,14 +128,18 @@ def main():
|
|
126 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
127 |
elif mel_spec_type == "bigvgan":
|
128 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
129 |
-
vocoder = load_vocoder(
|
|
|
|
|
130 |
|
131 |
# Tokenizer
|
132 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
133 |
|
134 |
# Model
|
135 |
model = CFM(
|
136 |
-
transformer=model_cls(
|
|
|
|
|
137 |
mel_spec_kwargs=dict(
|
138 |
n_fft=n_fft,
|
139 |
hop_length=hop_length,
|
@@ -154,7 +160,9 @@ def main():
|
|
154 |
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
155 |
ckpt_path = ckpt_prefix + ".safetensors"
|
156 |
else:
|
157 |
-
print(
|
|
|
|
|
158 |
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
159 |
|
160 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
@@ -169,7 +177,14 @@ def main():
|
|
169 |
|
170 |
with accelerator.split_between_processes(prompts_all) as prompts:
|
171 |
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
ref_mels = ref_mels.to(device)
|
174 |
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
175 |
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
@@ -198,7 +213,11 @@ def main():
|
|
198 |
|
199 |
if ref_rms_list[i] < target_rms:
|
200 |
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
201 |
-
torchaudio.save(
|
|
|
|
|
|
|
|
|
202 |
|
203 |
accelerator.wait_for_everyone()
|
204 |
if accelerator.is_main_process:
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
|
|
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import argparse
|
|
|
14 |
from omegaconf import OmegaConf
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
+
from f5_tts.eval.utils_eval import (get_inference_prompt,
|
18 |
+
get_librispeech_test_clean_metainfo,
|
19 |
+
get_seedtts_testset_metainfo)
|
|
|
|
|
20 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
21 |
from f5_tts.model import CFM
|
22 |
from f5_tts.model.utils import get_tokenizer
|
23 |
|
|
|
24 |
accelerator = Accelerator()
|
25 |
device = f"cuda:{accelerator.process_index}"
|
26 |
|
|
|
63 |
use_truth_duration = False
|
64 |
no_ref_audio = False
|
65 |
|
66 |
+
model_cfg = OmegaConf.load(
|
67 |
+
str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))
|
68 |
+
)
|
69 |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
70 |
model_arc = model_cfg.model.arch
|
71 |
|
|
|
81 |
|
82 |
if testset == "ls_pc_test_clean":
|
83 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
84 |
+
librispeech_test_clean_path = (
|
85 |
+
"<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
86 |
+
)
|
87 |
+
metainfo = get_librispeech_test_clean_metainfo(
|
88 |
+
metalst, librispeech_test_clean_path
|
89 |
+
)
|
90 |
|
91 |
elif testset == "seedtts_test_zh":
|
92 |
metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
|
|
|
128 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
129 |
elif mel_spec_type == "bigvgan":
|
130 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
131 |
+
vocoder = load_vocoder(
|
132 |
+
vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path
|
133 |
+
)
|
134 |
|
135 |
# Tokenizer
|
136 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
137 |
|
138 |
# Model
|
139 |
model = CFM(
|
140 |
+
transformer=model_cls(
|
141 |
+
**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
142 |
+
),
|
143 |
mel_spec_kwargs=dict(
|
144 |
n_fft=n_fft,
|
145 |
hop_length=hop_length,
|
|
|
160 |
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
161 |
ckpt_path = ckpt_prefix + ".safetensors"
|
162 |
else:
|
163 |
+
print(
|
164 |
+
"Loading from self-organized training checkpoints rather than released pretrained."
|
165 |
+
)
|
166 |
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
167 |
|
168 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
|
|
177 |
|
178 |
with accelerator.split_between_processes(prompts_all) as prompts:
|
179 |
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
180 |
+
(
|
181 |
+
utts,
|
182 |
+
ref_rms_list,
|
183 |
+
ref_mels,
|
184 |
+
ref_mel_lens,
|
185 |
+
total_mel_lens,
|
186 |
+
final_text_list,
|
187 |
+
) = prompt
|
188 |
ref_mels = ref_mels.to(device)
|
189 |
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
190 |
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
|
|
213 |
|
214 |
if ref_rms_list[i] < target_rms:
|
215 |
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
216 |
+
torchaudio.save(
|
217 |
+
f"{output_dir}/{utts[i]}.wav",
|
218 |
+
generated_wave,
|
219 |
+
target_sample_rate,
|
220 |
+
)
|
221 |
|
222 |
accelerator.wait_for_everyone()
|
223 |
if accelerator.is_main_process:
|
f5_tts/eval/eval_librispeech_test_clean.py
CHANGED
@@ -5,7 +5,6 @@ import json
|
|
5 |
import os
|
6 |
import sys
|
7 |
|
8 |
-
|
9 |
sys.path.append(os.getcwd())
|
10 |
|
11 |
import multiprocessing as mp
|
@@ -15,18 +14,23 @@ import numpy as np
|
|
15 |
|
16 |
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
|
17 |
|
18 |
-
|
19 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
20 |
|
21 |
|
22 |
def get_args():
|
23 |
parser = argparse.ArgumentParser()
|
24 |
-
parser.add_argument(
|
|
|
|
|
25 |
parser.add_argument("-l", "--lang", type=str, default="en")
|
26 |
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
27 |
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
28 |
-
parser.add_argument(
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
return parser.parse_args()
|
31 |
|
32 |
|
@@ -39,7 +43,9 @@ def main():
|
|
39 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
40 |
|
41 |
gpus = list(range(args.gpu_nums))
|
42 |
-
test_set = get_librispeech_test(
|
|
|
|
|
43 |
|
44 |
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
45 |
## leading to a low similarity for the ground truth in some cases.
|
@@ -59,13 +65,19 @@ def main():
|
|
59 |
|
60 |
if eval_task == "wer":
|
61 |
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
-
args = [
|
|
|
|
|
|
|
63 |
results = pool.map(run_asr_wer, args)
|
64 |
for r in results:
|
65 |
full_results.extend(r)
|
66 |
elif eval_task == "sim":
|
67 |
with mp.Pool(processes=len(gpus)) as pool:
|
68 |
-
args = [
|
|
|
|
|
|
|
69 |
results = pool.map(run_sim, args)
|
70 |
for r in results:
|
71 |
full_results.extend(r)
|
|
|
5 |
import os
|
6 |
import sys
|
7 |
|
|
|
8 |
sys.path.append(os.getcwd())
|
9 |
|
10 |
import multiprocessing as mp
|
|
|
14 |
|
15 |
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
|
16 |
|
|
|
17 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
18 |
|
19 |
|
20 |
def get_args():
|
21 |
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
"-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]
|
24 |
+
)
|
25 |
parser.add_argument("-l", "--lang", type=str, default="en")
|
26 |
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
27 |
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
28 |
+
parser.add_argument(
|
29 |
+
"-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--local", action="store_true", help="Use local custom checkpoint directory"
|
33 |
+
)
|
34 |
return parser.parse_args()
|
35 |
|
36 |
|
|
|
43 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
44 |
|
45 |
gpus = list(range(args.gpu_nums))
|
46 |
+
test_set = get_librispeech_test(
|
47 |
+
metalst, gen_wav_dir, gpus, librispeech_test_clean_path
|
48 |
+
)
|
49 |
|
50 |
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
51 |
## leading to a low similarity for the ground truth in some cases.
|
|
|
65 |
|
66 |
if eval_task == "wer":
|
67 |
with mp.Pool(processes=len(gpus)) as pool:
|
68 |
+
args = [
|
69 |
+
(rank, lang, sub_test_set, asr_ckpt_dir)
|
70 |
+
for (rank, sub_test_set) in test_set
|
71 |
+
]
|
72 |
results = pool.map(run_asr_wer, args)
|
73 |
for r in results:
|
74 |
full_results.extend(r)
|
75 |
elif eval_task == "sim":
|
76 |
with mp.Pool(processes=len(gpus)) as pool:
|
77 |
+
args = [
|
78 |
+
(rank, sub_test_set, wavlm_ckpt_dir)
|
79 |
+
for (rank, sub_test_set) in test_set
|
80 |
+
]
|
81 |
results = pool.map(run_sim, args)
|
82 |
for r in results:
|
83 |
full_results.extend(r)
|
f5_tts/eval/eval_seedtts_testset.py
CHANGED
@@ -5,7 +5,6 @@ import json
|
|
5 |
import os
|
6 |
import sys
|
7 |
|
8 |
-
|
9 |
sys.path.append(os.getcwd())
|
10 |
|
11 |
import multiprocessing as mp
|
@@ -15,17 +14,22 @@ import numpy as np
|
|
15 |
|
16 |
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
|
17 |
|
18 |
-
|
19 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
20 |
|
21 |
|
22 |
def get_args():
|
23 |
parser = argparse.ArgumentParser()
|
24 |
-
parser.add_argument(
|
|
|
|
|
25 |
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
26 |
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
27 |
-
parser.add_argument(
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
return parser.parse_args()
|
30 |
|
31 |
|
@@ -58,13 +62,19 @@ def main():
|
|
58 |
|
59 |
if eval_task == "wer":
|
60 |
with mp.Pool(processes=len(gpus)) as pool:
|
61 |
-
args = [
|
|
|
|
|
|
|
62 |
results = pool.map(run_asr_wer, args)
|
63 |
for r in results:
|
64 |
full_results.extend(r)
|
65 |
elif eval_task == "sim":
|
66 |
with mp.Pool(processes=len(gpus)) as pool:
|
67 |
-
args = [
|
|
|
|
|
|
|
68 |
results = pool.map(run_sim, args)
|
69 |
for r in results:
|
70 |
full_results.extend(r)
|
|
|
5 |
import os
|
6 |
import sys
|
7 |
|
|
|
8 |
sys.path.append(os.getcwd())
|
9 |
|
10 |
import multiprocessing as mp
|
|
|
14 |
|
15 |
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
|
16 |
|
|
|
17 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
18 |
|
19 |
|
20 |
def get_args():
|
21 |
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
"-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]
|
24 |
+
)
|
25 |
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
26 |
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
27 |
+
parser.add_argument(
|
28 |
+
"-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--local", action="store_true", help="Use local custom checkpoint directory"
|
32 |
+
)
|
33 |
return parser.parse_args()
|
34 |
|
35 |
|
|
|
62 |
|
63 |
if eval_task == "wer":
|
64 |
with mp.Pool(processes=len(gpus)) as pool:
|
65 |
+
args = [
|
66 |
+
(rank, lang, sub_test_set, asr_ckpt_dir)
|
67 |
+
for (rank, sub_test_set) in test_set
|
68 |
+
]
|
69 |
results = pool.map(run_asr_wer, args)
|
70 |
for r in results:
|
71 |
full_results.extend(r)
|
72 |
elif eval_task == "sim":
|
73 |
with mp.Pool(processes=len(gpus)) as pool:
|
74 |
+
args = [
|
75 |
+
(rank, sub_test_set, wavlm_ckpt_dir)
|
76 |
+
for (rank, sub_test_set) in test_set
|
77 |
+
]
|
78 |
results = pool.map(run_sim, args)
|
79 |
for r in results:
|
80 |
full_results.extend(r)
|
f5_tts/eval/eval_utmos.py
CHANGED
@@ -13,9 +13,15 @@ def main():
|
|
13 |
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
-
device =
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
predictor = predictor.to(device)
|
20 |
|
21 |
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
|
|
13 |
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
+
device = (
|
17 |
+
"cuda"
|
18 |
+
if torch.cuda.is_available()
|
19 |
+
else "xpu" if torch.xpu.is_available() else "cpu"
|
20 |
+
)
|
21 |
+
|
22 |
+
predictor = torch.hub.load(
|
23 |
+
"tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True
|
24 |
+
)
|
25 |
predictor = predictor.to(device)
|
26 |
|
27 |
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
f5_tts/eval/utils_eval.py
CHANGED
@@ -43,11 +43,15 @@ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
|
43 |
|
44 |
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
45 |
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
46 |
-
ref_wav = os.path.join(
|
|
|
|
|
47 |
|
48 |
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
49 |
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
50 |
-
gen_wav = os.path.join(
|
|
|
|
|
51 |
|
52 |
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
53 |
|
@@ -106,13 +110,17 @@ def get_inference_prompt(
|
|
106 |
mel_spec_type=mel_spec_type,
|
107 |
)
|
108 |
|
109 |
-
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
|
|
|
|
110 |
# Audio
|
111 |
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
112 |
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
113 |
if ref_rms < target_rms:
|
114 |
ref_audio = ref_audio * target_rms / ref_rms
|
115 |
-
assert
|
|
|
|
|
116 |
if ref_sr != target_sample_rate:
|
117 |
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
118 |
ref_audio = resampler(ref_audio)
|
@@ -145,14 +153,18 @@ def get_inference_prompt(
|
|
145 |
else:
|
146 |
ref_text_len = len(prompt_text.encode("utf-8"))
|
147 |
gen_text_len = len(gt_text.encode("utf-8"))
|
148 |
-
total_mel_len = ref_mel_len + int(
|
|
|
|
|
149 |
|
150 |
# deal with batch
|
151 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
152 |
-
assert
|
153 |
-
|
|
|
|
|
|
|
154 |
)
|
155 |
-
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
156 |
|
157 |
utts[bucket_i].append(utt)
|
158 |
ref_rms_list[bucket_i].append(ref_rms)
|
@@ -183,7 +195,14 @@ def get_inference_prompt(
|
|
183 |
ref_mel_lens[bucket_i],
|
184 |
total_mel_lens[bucket_i],
|
185 |
final_text_list[bucket_i],
|
186 |
-
) =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
# add residual
|
189 |
for bucket_i, bucket_frames in enumerate(batch_accum):
|
@@ -244,7 +263,9 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
|
244 |
# get librispeech test-clean cross sentence test
|
245 |
|
246 |
|
247 |
-
def get_librispeech_test(
|
|
|
|
|
248 |
f = open(metalst)
|
249 |
lines = f.readlines()
|
250 |
f.close()
|
@@ -255,14 +276,21 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path
|
|
255 |
|
256 |
if eval_ground_truth:
|
257 |
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
258 |
-
gen_wav = os.path.join(
|
|
|
|
|
|
|
|
|
|
|
259 |
else:
|
260 |
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
|
261 |
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
262 |
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
|
263 |
|
264 |
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
265 |
-
ref_wav = os.path.join(
|
|
|
|
|
266 |
|
267 |
test_set_.append((gen_wav, ref_wav, gen_txt))
|
268 |
|
@@ -382,7 +410,9 @@ def run_sim(args):
|
|
382 |
device = f"cuda:{rank}"
|
383 |
|
384 |
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
|
385 |
-
state_dict = torch.load(
|
|
|
|
|
386 |
model.load_state_dict(state_dict["model"], strict=False)
|
387 |
|
388 |
use_gpu = True if torch.cuda.is_available() else False
|
|
|
43 |
|
44 |
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
45 |
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
46 |
+
ref_wav = os.path.join(
|
47 |
+
librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac"
|
48 |
+
)
|
49 |
|
50 |
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
51 |
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
52 |
+
gen_wav = os.path.join(
|
53 |
+
librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac"
|
54 |
+
)
|
55 |
|
56 |
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
57 |
|
|
|
110 |
mel_spec_type=mel_spec_type,
|
111 |
)
|
112 |
|
113 |
+
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
114 |
+
metainfo, desc="Processing prompts..."
|
115 |
+
):
|
116 |
# Audio
|
117 |
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
118 |
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
119 |
if ref_rms < target_rms:
|
120 |
ref_audio = ref_audio * target_rms / ref_rms
|
121 |
+
assert (
|
122 |
+
ref_audio.shape[-1] > 5000
|
123 |
+
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
124 |
if ref_sr != target_sample_rate:
|
125 |
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
126 |
ref_audio = resampler(ref_audio)
|
|
|
153 |
else:
|
154 |
ref_text_len = len(prompt_text.encode("utf-8"))
|
155 |
gen_text_len = len(gt_text.encode("utf-8"))
|
156 |
+
total_mel_len = ref_mel_len + int(
|
157 |
+
ref_mel_len / ref_text_len * gen_text_len / speed
|
158 |
+
)
|
159 |
|
160 |
# deal with batch
|
161 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
162 |
+
assert (
|
163 |
+
min_tokens <= total_mel_len <= max_tokens
|
164 |
+
), f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
165 |
+
bucket_i = math.floor(
|
166 |
+
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
167 |
)
|
|
|
168 |
|
169 |
utts[bucket_i].append(utt)
|
170 |
ref_rms_list[bucket_i].append(ref_rms)
|
|
|
195 |
ref_mel_lens[bucket_i],
|
196 |
total_mel_lens[bucket_i],
|
197 |
final_text_list[bucket_i],
|
198 |
+
) = (
|
199 |
+
[],
|
200 |
+
[],
|
201 |
+
[],
|
202 |
+
[],
|
203 |
+
[],
|
204 |
+
[],
|
205 |
+
)
|
206 |
|
207 |
# add residual
|
208 |
for bucket_i, bucket_frames in enumerate(batch_accum):
|
|
|
263 |
# get librispeech test-clean cross sentence test
|
264 |
|
265 |
|
266 |
+
def get_librispeech_test(
|
267 |
+
metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False
|
268 |
+
):
|
269 |
f = open(metalst)
|
270 |
lines = f.readlines()
|
271 |
f.close()
|
|
|
276 |
|
277 |
if eval_ground_truth:
|
278 |
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
279 |
+
gen_wav = os.path.join(
|
280 |
+
librispeech_test_clean_path,
|
281 |
+
gen_spk_id,
|
282 |
+
gen_chaptr_id,
|
283 |
+
gen_utt + ".flac",
|
284 |
+
)
|
285 |
else:
|
286 |
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
|
287 |
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
288 |
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
|
289 |
|
290 |
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
291 |
+
ref_wav = os.path.join(
|
292 |
+
librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac"
|
293 |
+
)
|
294 |
|
295 |
test_set_.append((gen_wav, ref_wav, gen_txt))
|
296 |
|
|
|
410 |
device = f"cuda:{rank}"
|
411 |
|
412 |
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
|
413 |
+
state_dict = torch.load(
|
414 |
+
ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage
|
415 |
+
)
|
416 |
model.load_state_dict(state_dict["model"], strict=False)
|
417 |
|
418 |
use_gpu = True if torch.cuda.is_available() else False
|
f5_tts/infer/infer_cli.py
CHANGED
@@ -14,23 +14,12 @@ from hydra.utils import get_class
|
|
14 |
from omegaconf import OmegaConf
|
15 |
from unidecode import unidecode
|
16 |
|
17 |
-
from f5_tts.infer.utils_infer import (
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
load_model,
|
24 |
-
load_vocoder,
|
25 |
-
mel_spec_type,
|
26 |
-
nfe_step,
|
27 |
-
preprocess_ref_audio_text,
|
28 |
-
remove_silence_for_generated_wav,
|
29 |
-
speed,
|
30 |
-
sway_sampling_coef,
|
31 |
-
target_rms,
|
32 |
-
)
|
33 |
-
|
34 |
|
35 |
parser = argparse.ArgumentParser(
|
36 |
prog="python3 infer-cli.py",
|
@@ -41,7 +30,9 @@ parser.add_argument(
|
|
41 |
"-c",
|
42 |
"--config",
|
43 |
type=str,
|
44 |
-
default=os.path.join(
|
|
|
|
|
45 |
help="The configuration file, default see infer/examples/basic/basic.toml",
|
46 |
)
|
47 |
|
@@ -188,13 +179,17 @@ model = args.model or config.get("model", "F5TTS_v1_Base")
|
|
188 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
189 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
190 |
|
191 |
-
ref_audio = args.ref_audio or config.get(
|
|
|
|
|
192 |
ref_text = (
|
193 |
args.ref_text
|
194 |
if args.ref_text is not None
|
195 |
else config.get("ref_text", "Some call me nature, others call me mother nature.")
|
196 |
)
|
197 |
-
gen_text = args.gen_text or config.get(
|
|
|
|
|
198 |
gen_file = args.gen_file or config.get("gen_file", "")
|
199 |
|
200 |
output_dir = args.output_dir or config.get("output_dir", "tests")
|
@@ -203,21 +198,29 @@ output_file = args.output_file or config.get(
|
|
203 |
)
|
204 |
|
205 |
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
206 |
-
use_legacy_text = args.no_legacy_text or config.get(
|
|
|
|
|
207 |
if save_chunk and use_legacy_text:
|
208 |
print(
|
209 |
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
|
210 |
)
|
211 |
|
212 |
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
213 |
-
load_vocoder_from_local = args.load_vocoder_from_local or config.get(
|
|
|
|
|
214 |
|
215 |
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
216 |
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
217 |
-
cross_fade_duration = args.cross_fade_duration or config.get(
|
|
|
|
|
218 |
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
|
219 |
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
220 |
-
sway_sampling_coef = args.sway_sampling_coef or config.get(
|
|
|
|
|
221 |
speed = args.speed or config.get("speed", speed)
|
222 |
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
223 |
device = args.device or config.get("device", device)
|
@@ -232,7 +235,9 @@ if "voices" in config:
|
|
232 |
for voice in config["voices"]:
|
233 |
voice_ref_audio = config["voices"][voice]["ref_audio"]
|
234 |
if "infer/examples/" in voice_ref_audio:
|
235 |
-
config["voices"][voice]["ref_audio"] = str(
|
|
|
|
|
236 |
|
237 |
|
238 |
# ignore gen_text if gen_file provided
|
@@ -259,14 +264,18 @@ elif vocoder_name == "bigvgan":
|
|
259 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
260 |
|
261 |
vocoder = load_vocoder(
|
262 |
-
vocoder_name=vocoder_name,
|
|
|
|
|
|
|
263 |
)
|
264 |
|
265 |
|
266 |
# load TTS model
|
267 |
|
268 |
model_cfg = OmegaConf.load(
|
269 |
-
args.model_cfg
|
|
|
270 |
)
|
271 |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
272 |
model_arc = model_cfg.model.arch
|
@@ -288,11 +297,18 @@ elif model == "E2TTS_Base":
|
|
288 |
ckpt_step = 1200000
|
289 |
|
290 |
if not ckpt_file:
|
291 |
-
ckpt_file = str(
|
|
|
|
|
292 |
|
293 |
print(f"Using {model}...")
|
294 |
ema_model = load_model(
|
295 |
-
model_cls,
|
|
|
|
|
|
|
|
|
|
|
296 |
)
|
297 |
|
298 |
|
@@ -309,8 +325,10 @@ def main():
|
|
309 |
for voice in voices:
|
310 |
print("Voice:", voice)
|
311 |
print("ref_audio ", voices[voice]["ref_audio"])
|
312 |
-
voices[voice]["ref_audio"], voices[voice]["ref_text"] =
|
313 |
-
|
|
|
|
|
314 |
)
|
315 |
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
|
316 |
|
@@ -360,7 +378,10 @@ def main():
|
|
360 |
if use_legacy_text:
|
361 |
gen_text_ = unidecode(gen_text_)
|
362 |
sf.write(
|
363 |
-
os.path.join(
|
|
|
|
|
|
|
364 |
audio_segment,
|
365 |
final_sample_rate,
|
366 |
)
|
|
|
14 |
from omegaconf import OmegaConf
|
15 |
from unidecode import unidecode
|
16 |
|
17 |
+
from f5_tts.infer.utils_infer import (cfg_strength, cross_fade_duration,
|
18 |
+
device, fix_duration, infer_process,
|
19 |
+
load_model, load_vocoder, mel_spec_type,
|
20 |
+
nfe_step, preprocess_ref_audio_text,
|
21 |
+
remove_silence_for_generated_wav, speed,
|
22 |
+
sway_sampling_coef, target_rms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
parser = argparse.ArgumentParser(
|
25 |
prog="python3 infer-cli.py",
|
|
|
30 |
"-c",
|
31 |
"--config",
|
32 |
type=str,
|
33 |
+
default=os.path.join(
|
34 |
+
files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"
|
35 |
+
),
|
36 |
help="The configuration file, default see infer/examples/basic/basic.toml",
|
37 |
)
|
38 |
|
|
|
179 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
180 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
181 |
|
182 |
+
ref_audio = args.ref_audio or config.get(
|
183 |
+
"ref_audio", "infer/examples/basic/basic_ref_en.wav"
|
184 |
+
)
|
185 |
ref_text = (
|
186 |
args.ref_text
|
187 |
if args.ref_text is not None
|
188 |
else config.get("ref_text", "Some call me nature, others call me mother nature.")
|
189 |
)
|
190 |
+
gen_text = args.gen_text or config.get(
|
191 |
+
"gen_text", "Here we generate something just for test."
|
192 |
+
)
|
193 |
gen_file = args.gen_file or config.get("gen_file", "")
|
194 |
|
195 |
output_dir = args.output_dir or config.get("output_dir", "tests")
|
|
|
198 |
)
|
199 |
|
200 |
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
201 |
+
use_legacy_text = args.no_legacy_text or config.get(
|
202 |
+
"no_legacy_text", False
|
203 |
+
) # no_legacy_text is a store_false arg
|
204 |
if save_chunk and use_legacy_text:
|
205 |
print(
|
206 |
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
|
207 |
)
|
208 |
|
209 |
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
210 |
+
load_vocoder_from_local = args.load_vocoder_from_local or config.get(
|
211 |
+
"load_vocoder_from_local", False
|
212 |
+
)
|
213 |
|
214 |
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
215 |
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
216 |
+
cross_fade_duration = args.cross_fade_duration or config.get(
|
217 |
+
"cross_fade_duration", cross_fade_duration
|
218 |
+
)
|
219 |
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
|
220 |
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
221 |
+
sway_sampling_coef = args.sway_sampling_coef or config.get(
|
222 |
+
"sway_sampling_coef", sway_sampling_coef
|
223 |
+
)
|
224 |
speed = args.speed or config.get("speed", speed)
|
225 |
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
226 |
device = args.device or config.get("device", device)
|
|
|
235 |
for voice in config["voices"]:
|
236 |
voice_ref_audio = config["voices"][voice]["ref_audio"]
|
237 |
if "infer/examples/" in voice_ref_audio:
|
238 |
+
config["voices"][voice]["ref_audio"] = str(
|
239 |
+
files("f5_tts").joinpath(f"{voice_ref_audio}")
|
240 |
+
)
|
241 |
|
242 |
|
243 |
# ignore gen_text if gen_file provided
|
|
|
264 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
265 |
|
266 |
vocoder = load_vocoder(
|
267 |
+
vocoder_name=vocoder_name,
|
268 |
+
is_local=load_vocoder_from_local,
|
269 |
+
local_path=vocoder_local_path,
|
270 |
+
device=device,
|
271 |
)
|
272 |
|
273 |
|
274 |
# load TTS model
|
275 |
|
276 |
model_cfg = OmegaConf.load(
|
277 |
+
args.model_cfg
|
278 |
+
or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
279 |
)
|
280 |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
281 |
model_arc = model_cfg.model.arch
|
|
|
297 |
ckpt_step = 1200000
|
298 |
|
299 |
if not ckpt_file:
|
300 |
+
ckpt_file = str(
|
301 |
+
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")
|
302 |
+
)
|
303 |
|
304 |
print(f"Using {model}...")
|
305 |
ema_model = load_model(
|
306 |
+
model_cls,
|
307 |
+
model_arc,
|
308 |
+
ckpt_file,
|
309 |
+
mel_spec_type=vocoder_name,
|
310 |
+
vocab_file=vocab_file,
|
311 |
+
device=device,
|
312 |
)
|
313 |
|
314 |
|
|
|
325 |
for voice in voices:
|
326 |
print("Voice:", voice)
|
327 |
print("ref_audio ", voices[voice]["ref_audio"])
|
328 |
+
voices[voice]["ref_audio"], voices[voice]["ref_text"] = (
|
329 |
+
preprocess_ref_audio_text(
|
330 |
+
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
331 |
+
)
|
332 |
)
|
333 |
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
|
334 |
|
|
|
378 |
if use_legacy_text:
|
379 |
gen_text_ = unidecode(gen_text_)
|
380 |
sf.write(
|
381 |
+
os.path.join(
|
382 |
+
output_chunk_dir,
|
383 |
+
f"{len(generated_audio_segments) - 1}_{gen_text_}.wav",
|
384 |
+
),
|
385 |
audio_segment,
|
386 |
final_sample_rate,
|
387 |
)
|
f5_tts/infer/infer_gradio.py
CHANGED
@@ -19,7 +19,6 @@ import torchaudio
|
|
19 |
from cached_path import cached_path
|
20 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
21 |
|
22 |
-
|
23 |
try:
|
24 |
import spaces
|
25 |
|
@@ -35,25 +34,21 @@ def gpu_decorator(func):
|
|
35 |
return func
|
36 |
|
37 |
|
38 |
-
from f5_tts.infer.utils_infer import (
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
preprocess_ref_audio_text,
|
43 |
-
remove_silence_for_generated_wav,
|
44 |
-
save_spectrogram,
|
45 |
-
tempfile_kwargs,
|
46 |
-
)
|
47 |
from f5_tts.model import DiT, UNetT
|
48 |
|
49 |
-
|
50 |
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
51 |
tts_model_choice = DEFAULT_TTS_MODEL
|
52 |
|
53 |
DEFAULT_TTS_MODEL_CFG = [
|
54 |
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
|
55 |
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
|
56 |
-
json.dumps(
|
|
|
|
|
57 |
]
|
58 |
|
59 |
|
@@ -69,8 +64,12 @@ def load_f5tts():
|
|
69 |
|
70 |
|
71 |
def load_e2tts():
|
72 |
-
ckpt_path = str(
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
75 |
|
76 |
|
@@ -113,7 +112,8 @@ def chat_model_inference(messages, model, tokenizer):
|
|
113 |
)
|
114 |
|
115 |
generated_ids = [
|
116 |
-
output_ids[len(input_ids) :]
|
|
|
117 |
]
|
118 |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
119 |
|
@@ -157,7 +157,9 @@ def infer(
|
|
157 |
gr.Warning("Please enter text to generate or upload a text file.")
|
158 |
return gr.update(), gr.update(), ref_text
|
159 |
|
160 |
-
ref_audio, ref_text = preprocess_ref_audio_text(
|
|
|
|
|
161 |
|
162 |
if model == DEFAULT_TTS_MODEL:
|
163 |
ema_model = F5TTS_ema_model
|
@@ -172,7 +174,9 @@ def infer(
|
|
172 |
global custom_ema_model, pre_custom_path
|
173 |
if pre_custom_path != model[1]:
|
174 |
show_info("Loading Custom TTS model...")
|
175 |
-
custom_ema_model = load_custom(
|
|
|
|
|
176 |
pre_custom_path = model[1]
|
177 |
ema_model = custom_ema_model
|
178 |
|
@@ -202,7 +206,9 @@ def infer(
|
|
202 |
final_wave = final_wave.squeeze().cpu().numpy()
|
203 |
|
204 |
# Save the spectrogram
|
205 |
-
with tempfile.NamedTemporaryFile(
|
|
|
|
|
206 |
spectrogram_path = tmp_spectrogram.name
|
207 |
save_spectrogram(combined_spectrogram, spectrogram_path)
|
208 |
|
@@ -219,7 +225,9 @@ with gr.Blocks() as app_tts:
|
|
219 |
max_lines=40,
|
220 |
scale=4,
|
221 |
)
|
222 |
-
gen_text_file = gr.File(
|
|
|
|
|
223 |
generate_btn = gr.Button("Synthesize", variant="primary")
|
224 |
with gr.Accordion("Advanced Settings", open=False):
|
225 |
with gr.Row():
|
@@ -229,7 +237,11 @@ with gr.Blocks() as app_tts:
|
|
229 |
lines=2,
|
230 |
scale=4,
|
231 |
)
|
232 |
-
ref_text_file = gr.File(
|
|
|
|
|
|
|
|
|
233 |
with gr.Row():
|
234 |
randomize_seed = gr.Checkbox(
|
235 |
label="Randomize Seed",
|
@@ -417,13 +429,25 @@ with gr.Blocks() as app_multistyle:
|
|
417 |
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
|
418 |
with gr.Row():
|
419 |
regular_seed_slider = gr.Slider(
|
420 |
-
show_label=False,
|
|
|
|
|
|
|
|
|
|
|
421 |
)
|
422 |
regular_speed_slider = gr.Slider(
|
423 |
-
show_label=False,
|
|
|
|
|
|
|
|
|
|
|
424 |
)
|
425 |
with gr.Column(scale=1, min_width=160):
|
426 |
-
regular_ref_text_file = gr.File(
|
|
|
|
|
427 |
|
428 |
# Regular speech type (max 100)
|
429 |
max_speech_types = 100
|
@@ -450,13 +474,25 @@ with gr.Blocks() as app_multistyle:
|
|
450 |
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
|
451 |
with gr.Row():
|
452 |
seed_input = gr.Slider(
|
453 |
-
show_label=False,
|
|
|
|
|
|
|
|
|
|
|
454 |
)
|
455 |
speed_input = gr.Slider(
|
456 |
-
show_label=False,
|
|
|
|
|
|
|
|
|
|
|
457 |
)
|
458 |
with gr.Column(scale=1, min_width=160):
|
459 |
-
ref_text_file_input = gr.File(
|
|
|
|
|
460 |
speech_type_rows.append(row)
|
461 |
speech_type_names.append(name_input)
|
462 |
speech_type_audios.append(audio_input)
|
@@ -494,7 +530,9 @@ with gr.Blocks() as app_multistyle:
|
|
494 |
row_updates[speech_type_count] = gr.update(visible=True)
|
495 |
speech_type_count += 1
|
496 |
else:
|
497 |
-
gr.Warning(
|
|
|
|
|
498 |
return row_updates
|
499 |
|
500 |
add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
|
@@ -525,10 +563,14 @@ with gr.Blocks() as app_multistyle:
|
|
525 |
scale=4,
|
526 |
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
|
527 |
)
|
528 |
-
gen_text_file_multistyle = gr.File(
|
|
|
|
|
529 |
|
530 |
def make_insert_speech_type_fn(index):
|
531 |
-
def insert_speech_type_fn(
|
|
|
|
|
532 |
current_text = current_text or ""
|
533 |
if not speech_type_name:
|
534 |
gr.Warning("Please enter speech type name before insert.")
|
@@ -547,7 +589,12 @@ with gr.Blocks() as app_multistyle:
|
|
547 |
insert_fn = make_insert_speech_type_fn(i)
|
548 |
insert_btn.click(
|
549 |
insert_fn,
|
550 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
551 |
outputs=gen_text_input_multistyle,
|
552 |
)
|
553 |
|
@@ -567,7 +614,9 @@ with gr.Blocks() as app_multistyle:
|
|
567 |
)
|
568 |
|
569 |
# Generate button
|
570 |
-
generate_multistyle_btn = gr.Button(
|
|
|
|
|
571 |
|
572 |
# Output audio
|
573 |
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
|
@@ -613,7 +662,10 @@ with gr.Blocks() as app_multistyle:
|
|
613 |
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
|
614 |
):
|
615 |
if name_input and audio_input:
|
616 |
-
speech_types[name_input] = {
|
|
|
|
|
|
|
617 |
else:
|
618 |
speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
|
619 |
ref_text_idx += 1
|
@@ -635,14 +687,22 @@ with gr.Blocks() as app_multistyle:
|
|
635 |
if name in speech_types:
|
636 |
current_type_name = name
|
637 |
else:
|
638 |
-
gr.Warning(
|
|
|
|
|
639 |
current_type_name = "Regular"
|
640 |
|
641 |
try:
|
642 |
ref_audio = speech_types[current_type_name]["audio"]
|
643 |
except KeyError:
|
644 |
-
gr.Warning(
|
645 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
ref_text = speech_types[current_type_name].get("ref_text", "")
|
647 |
|
648 |
if seed_input == -1:
|
@@ -664,7 +724,9 @@ with gr.Blocks() as app_multistyle:
|
|
664 |
|
665 |
generated_audio_segments.append(audio_data)
|
666 |
speech_types[current_type_name]["ref_text"] = ref_text_out
|
667 |
-
inference_meta_data +=
|
|
|
|
|
668 |
|
669 |
# Concatenate all audio segments
|
670 |
if generated_audio_segments:
|
@@ -676,7 +738,11 @@ with gr.Blocks() as app_multistyle:
|
|
676 |
)
|
677 |
else:
|
678 |
gr.Warning("No audio generated.")
|
679 |
-
return
|
|
|
|
|
|
|
|
|
680 |
|
681 |
generate_multistyle_btn.click(
|
682 |
generate_multistyle_speech,
|
@@ -689,7 +755,9 @@ with gr.Blocks() as app_multistyle:
|
|
689 |
+ [
|
690 |
remove_silence_multistyle,
|
691 |
],
|
692 |
-
outputs=[audio_output_multistyle]
|
|
|
|
|
693 |
)
|
694 |
|
695 |
# Validation function to disable Generate button if speech types are missing
|
@@ -753,7 +821,9 @@ Have a conversation with an AI using your reference voice!
|
|
753 |
torch.cuda.empty_cache()
|
754 |
|
755 |
show_info(f"Loading chat model: {chat_model_name}")
|
756 |
-
chat_model_state = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
757 |
chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
|
758 |
show_info(f"Chat model {chat_model_name} loaded successfully!")
|
759 |
|
@@ -769,7 +839,9 @@ Have a conversation with an AI using your reference voice!
|
|
769 |
info="Enter the name of a HuggingFace chat model",
|
770 |
allow_custom_value=not USING_SPACES,
|
771 |
)
|
772 |
-
load_chat_model_btn = gr.Button(
|
|
|
|
|
773 |
chat_interface_container = gr.Column(visible=USING_SPACES)
|
774 |
|
775 |
chat_model_name_input.change(
|
@@ -779,7 +851,9 @@ Have a conversation with an AI using your reference voice!
|
|
779 |
show_progress="hidden",
|
780 |
)
|
781 |
load_chat_model_btn.click(
|
782 |
-
load_chat_model,
|
|
|
|
|
783 |
)
|
784 |
|
785 |
with chat_interface_container:
|
@@ -796,7 +870,9 @@ Have a conversation with an AI using your reference voice!
|
|
796 |
scale=3,
|
797 |
)
|
798 |
ref_text_file_chat = gr.File(
|
799 |
-
label="Load Reference Text from File (.txt)",
|
|
|
|
|
800 |
)
|
801 |
with gr.Row():
|
802 |
randomize_seed_chat = gr.Checkbox(
|
@@ -805,7 +881,9 @@ Have a conversation with an AI using your reference voice!
|
|
805 |
info="Uncheck to use the seed specified.",
|
806 |
scale=3,
|
807 |
)
|
808 |
-
seed_input_chat = gr.Number(
|
|
|
|
|
809 |
remove_silence_chat = gr.Checkbox(
|
810 |
label="Remove Silences",
|
811 |
value=True,
|
@@ -855,13 +933,17 @@ Have a conversation with an AI using your reference voice!
|
|
855 |
"""Generate text response from AI"""
|
856 |
|
857 |
system_prompt_state = [{"role": "system", "content": system_prompt}]
|
858 |
-
response = chat_model_inference(
|
|
|
|
|
859 |
|
860 |
conv_state.append({"role": "assistant", "content": response})
|
861 |
return conv_state
|
862 |
|
863 |
@gpu_decorator
|
864 |
-
def generate_audio_response(
|
|
|
|
|
865 |
"""Generate TTS audio for AI response"""
|
866 |
if not conv_state or not ref_audio:
|
867 |
return None, ref_text, seed_input
|
@@ -896,7 +978,11 @@ Have a conversation with an AI using your reference voice!
|
|
896 |
outputs=[ref_text_chat],
|
897 |
)
|
898 |
|
899 |
-
for user_operation in [
|
|
|
|
|
|
|
|
|
900 |
user_operation(
|
901 |
process_audio_input,
|
902 |
inputs=[chatbot_interface, audio_input_chat, text_input_chat],
|
@@ -923,7 +1009,11 @@ Have a conversation with an AI using your reference voice!
|
|
923 |
)
|
924 |
|
925 |
# Handle clear button or system prompt change and reset conversation
|
926 |
-
for user_operation in [
|
|
|
|
|
|
|
|
|
927 |
user_operation(
|
928 |
clear_conversation,
|
929 |
outputs=[chatbot_interface, audio_output_chat],
|
@@ -931,13 +1021,15 @@ Have a conversation with an AI using your reference voice!
|
|
931 |
|
932 |
|
933 |
with gr.Blocks() as app_credits:
|
934 |
-
gr.Markdown(
|
|
|
935 |
# Credits
|
936 |
|
937 |
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
938 |
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
939 |
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
940 |
-
"""
|
|
|
941 |
|
942 |
|
943 |
with gr.Blocks() as app:
|
@@ -958,7 +1050,9 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
958 |
"""
|
959 |
)
|
960 |
|
961 |
-
last_used_custom = files("f5_tts").joinpath(
|
|
|
|
|
962 |
|
963 |
def load_last_used_custom():
|
964 |
try:
|
@@ -974,8 +1068,15 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
974 |
def switch_tts_model(new_choice):
|
975 |
global tts_model_choice
|
976 |
if new_choice == "Custom": # override in case webpage is refreshed
|
977 |
-
custom_ckpt_path, custom_vocab_path, custom_model_cfg =
|
978 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
979 |
return (
|
980 |
gr.update(visible=True, value=custom_ckpt_path),
|
981 |
gr.update(visible=True, value=custom_vocab_path),
|
@@ -983,22 +1084,42 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
983 |
)
|
984 |
else:
|
985 |
tts_model_choice = new_choice
|
986 |
-
return
|
|
|
|
|
|
|
|
|
987 |
|
988 |
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
|
989 |
global tts_model_choice
|
990 |
-
tts_model_choice = (
|
|
|
|
|
|
|
|
|
|
|
991 |
with open(last_used_custom, "w", encoding="utf-8") as f:
|
992 |
-
f.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
993 |
|
994 |
with gr.Row():
|
995 |
if not USING_SPACES:
|
996 |
choose_tts_model = gr.Radio(
|
997 |
-
choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"],
|
|
|
|
|
998 |
)
|
999 |
else:
|
1000 |
choose_tts_model = gr.Radio(
|
1001 |
-
choices=[DEFAULT_TTS_MODEL, "E2-TTS"],
|
|
|
|
|
1002 |
)
|
1003 |
custom_ckpt_path = gr.Dropdown(
|
1004 |
choices=[DEFAULT_TTS_MODEL_CFG[0]],
|
|
|
19 |
from cached_path import cached_path
|
20 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
21 |
|
|
|
22 |
try:
|
23 |
import spaces
|
24 |
|
|
|
34 |
return func
|
35 |
|
36 |
|
37 |
+
from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
|
38 |
+
preprocess_ref_audio_text,
|
39 |
+
remove_silence_for_generated_wav,
|
40 |
+
save_spectrogram, tempfile_kwargs)
|
|
|
|
|
|
|
|
|
|
|
41 |
from f5_tts.model import DiT, UNetT
|
42 |
|
|
|
43 |
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
44 |
tts_model_choice = DEFAULT_TTS_MODEL
|
45 |
|
46 |
DEFAULT_TTS_MODEL_CFG = [
|
47 |
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
|
48 |
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
|
49 |
+
json.dumps(
|
50 |
+
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
51 |
+
),
|
52 |
]
|
53 |
|
54 |
|
|
|
64 |
|
65 |
|
66 |
def load_e2tts():
|
67 |
+
ckpt_path = str(
|
68 |
+
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")
|
69 |
+
)
|
70 |
+
E2TTS_model_cfg = dict(
|
71 |
+
dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1
|
72 |
+
)
|
73 |
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
74 |
|
75 |
|
|
|
112 |
)
|
113 |
|
114 |
generated_ids = [
|
115 |
+
output_ids[len(input_ids) :]
|
116 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
117 |
]
|
118 |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
119 |
|
|
|
157 |
gr.Warning("Please enter text to generate or upload a text file.")
|
158 |
return gr.update(), gr.update(), ref_text
|
159 |
|
160 |
+
ref_audio, ref_text = preprocess_ref_audio_text(
|
161 |
+
ref_audio_orig, ref_text, show_info=show_info
|
162 |
+
)
|
163 |
|
164 |
if model == DEFAULT_TTS_MODEL:
|
165 |
ema_model = F5TTS_ema_model
|
|
|
174 |
global custom_ema_model, pre_custom_path
|
175 |
if pre_custom_path != model[1]:
|
176 |
show_info("Loading Custom TTS model...")
|
177 |
+
custom_ema_model = load_custom(
|
178 |
+
model[1], vocab_path=model[2], model_cfg=model[3]
|
179 |
+
)
|
180 |
pre_custom_path = model[1]
|
181 |
ema_model = custom_ema_model
|
182 |
|
|
|
206 |
final_wave = final_wave.squeeze().cpu().numpy()
|
207 |
|
208 |
# Save the spectrogram
|
209 |
+
with tempfile.NamedTemporaryFile(
|
210 |
+
suffix=".png", **tempfile_kwargs
|
211 |
+
) as tmp_spectrogram:
|
212 |
spectrogram_path = tmp_spectrogram.name
|
213 |
save_spectrogram(combined_spectrogram, spectrogram_path)
|
214 |
|
|
|
225 |
max_lines=40,
|
226 |
scale=4,
|
227 |
)
|
228 |
+
gen_text_file = gr.File(
|
229 |
+
label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1
|
230 |
+
)
|
231 |
generate_btn = gr.Button("Synthesize", variant="primary")
|
232 |
with gr.Accordion("Advanced Settings", open=False):
|
233 |
with gr.Row():
|
|
|
237 |
lines=2,
|
238 |
scale=4,
|
239 |
)
|
240 |
+
ref_text_file = gr.File(
|
241 |
+
label="Load Reference Text from File (.txt)",
|
242 |
+
file_types=[".txt"],
|
243 |
+
scale=1,
|
244 |
+
)
|
245 |
with gr.Row():
|
246 |
randomize_seed = gr.Checkbox(
|
247 |
label="Randomize Seed",
|
|
|
429 |
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
|
430 |
with gr.Row():
|
431 |
regular_seed_slider = gr.Slider(
|
432 |
+
show_label=False,
|
433 |
+
minimum=-1,
|
434 |
+
maximum=999,
|
435 |
+
value=-1,
|
436 |
+
step=1,
|
437 |
+
info="Seed, -1 for random",
|
438 |
)
|
439 |
regular_speed_slider = gr.Slider(
|
440 |
+
show_label=False,
|
441 |
+
minimum=0.3,
|
442 |
+
maximum=2.0,
|
443 |
+
value=1.0,
|
444 |
+
step=0.1,
|
445 |
+
info="Adjust the speed",
|
446 |
)
|
447 |
with gr.Column(scale=1, min_width=160):
|
448 |
+
regular_ref_text_file = gr.File(
|
449 |
+
label="Load Reference Text from File (.txt)", file_types=[".txt"]
|
450 |
+
)
|
451 |
|
452 |
# Regular speech type (max 100)
|
453 |
max_speech_types = 100
|
|
|
474 |
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
|
475 |
with gr.Row():
|
476 |
seed_input = gr.Slider(
|
477 |
+
show_label=False,
|
478 |
+
minimum=-1,
|
479 |
+
maximum=999,
|
480 |
+
value=-1,
|
481 |
+
step=1,
|
482 |
+
info="Seed. -1 for random",
|
483 |
)
|
484 |
speed_input = gr.Slider(
|
485 |
+
show_label=False,
|
486 |
+
minimum=0.3,
|
487 |
+
maximum=2.0,
|
488 |
+
value=1.0,
|
489 |
+
step=0.1,
|
490 |
+
info="Adjust the speed",
|
491 |
)
|
492 |
with gr.Column(scale=1, min_width=160):
|
493 |
+
ref_text_file_input = gr.File(
|
494 |
+
label="Load Reference Text from File (.txt)", file_types=[".txt"]
|
495 |
+
)
|
496 |
speech_type_rows.append(row)
|
497 |
speech_type_names.append(name_input)
|
498 |
speech_type_audios.append(audio_input)
|
|
|
530 |
row_updates[speech_type_count] = gr.update(visible=True)
|
531 |
speech_type_count += 1
|
532 |
else:
|
533 |
+
gr.Warning(
|
534 |
+
"Exhausted maximum number of speech types. Consider restart the app."
|
535 |
+
)
|
536 |
return row_updates
|
537 |
|
538 |
add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
|
|
|
563 |
scale=4,
|
564 |
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
|
565 |
)
|
566 |
+
gen_text_file_multistyle = gr.File(
|
567 |
+
label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1
|
568 |
+
)
|
569 |
|
570 |
def make_insert_speech_type_fn(index):
|
571 |
+
def insert_speech_type_fn(
|
572 |
+
current_text, speech_type_name, speech_type_seed, speech_type_speed
|
573 |
+
):
|
574 |
current_text = current_text or ""
|
575 |
if not speech_type_name:
|
576 |
gr.Warning("Please enter speech type name before insert.")
|
|
|
589 |
insert_fn = make_insert_speech_type_fn(i)
|
590 |
insert_btn.click(
|
591 |
insert_fn,
|
592 |
+
inputs=[
|
593 |
+
gen_text_input_multistyle,
|
594 |
+
speech_type_names[i],
|
595 |
+
speech_type_seeds[i],
|
596 |
+
speech_type_speeds[i],
|
597 |
+
],
|
598 |
outputs=gen_text_input_multistyle,
|
599 |
)
|
600 |
|
|
|
614 |
)
|
615 |
|
616 |
# Generate button
|
617 |
+
generate_multistyle_btn = gr.Button(
|
618 |
+
"Generate Multi-Style Speech", variant="primary"
|
619 |
+
)
|
620 |
|
621 |
# Output audio
|
622 |
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
|
|
|
662 |
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
|
663 |
):
|
664 |
if name_input and audio_input:
|
665 |
+
speech_types[name_input] = {
|
666 |
+
"audio": audio_input,
|
667 |
+
"ref_text": ref_text_input,
|
668 |
+
}
|
669 |
else:
|
670 |
speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
|
671 |
ref_text_idx += 1
|
|
|
687 |
if name in speech_types:
|
688 |
current_type_name = name
|
689 |
else:
|
690 |
+
gr.Warning(
|
691 |
+
f"Type {name} is not available, will use Regular as default."
|
692 |
+
)
|
693 |
current_type_name = "Regular"
|
694 |
|
695 |
try:
|
696 |
ref_audio = speech_types[current_type_name]["audio"]
|
697 |
except KeyError:
|
698 |
+
gr.Warning(
|
699 |
+
f"Please provide reference audio for type {current_type_name}."
|
700 |
+
)
|
701 |
+
return (
|
702 |
+
[None]
|
703 |
+
+ [speech_types[name]["ref_text"] for name in speech_types]
|
704 |
+
+ [None]
|
705 |
+
)
|
706 |
ref_text = speech_types[current_type_name].get("ref_text", "")
|
707 |
|
708 |
if seed_input == -1:
|
|
|
724 |
|
725 |
generated_audio_segments.append(audio_data)
|
726 |
speech_types[current_type_name]["ref_text"] = ref_text_out
|
727 |
+
inference_meta_data += (
|
728 |
+
json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
|
729 |
+
)
|
730 |
|
731 |
# Concatenate all audio segments
|
732 |
if generated_audio_segments:
|
|
|
738 |
)
|
739 |
else:
|
740 |
gr.Warning("No audio generated.")
|
741 |
+
return (
|
742 |
+
[None]
|
743 |
+
+ [speech_types[name]["ref_text"] for name in speech_types]
|
744 |
+
+ [None]
|
745 |
+
)
|
746 |
|
747 |
generate_multistyle_btn.click(
|
748 |
generate_multistyle_speech,
|
|
|
755 |
+ [
|
756 |
remove_silence_multistyle,
|
757 |
],
|
758 |
+
outputs=[audio_output_multistyle]
|
759 |
+
+ speech_type_ref_texts
|
760 |
+
+ [cherrypick_interface_multistyle],
|
761 |
)
|
762 |
|
763 |
# Validation function to disable Generate button if speech types are missing
|
|
|
821 |
torch.cuda.empty_cache()
|
822 |
|
823 |
show_info(f"Loading chat model: {chat_model_name}")
|
824 |
+
chat_model_state = AutoModelForCausalLM.from_pretrained(
|
825 |
+
chat_model_name, torch_dtype="auto", device_map="auto"
|
826 |
+
)
|
827 |
chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
|
828 |
show_info(f"Chat model {chat_model_name} loaded successfully!")
|
829 |
|
|
|
839 |
info="Enter the name of a HuggingFace chat model",
|
840 |
allow_custom_value=not USING_SPACES,
|
841 |
)
|
842 |
+
load_chat_model_btn = gr.Button(
|
843 |
+
"Load Chat Model", variant="primary", visible=not USING_SPACES
|
844 |
+
)
|
845 |
chat_interface_container = gr.Column(visible=USING_SPACES)
|
846 |
|
847 |
chat_model_name_input.change(
|
|
|
851 |
show_progress="hidden",
|
852 |
)
|
853 |
load_chat_model_btn.click(
|
854 |
+
load_chat_model,
|
855 |
+
inputs=[chat_model_name_input],
|
856 |
+
outputs=[load_chat_model_btn, chat_interface_container],
|
857 |
)
|
858 |
|
859 |
with chat_interface_container:
|
|
|
870 |
scale=3,
|
871 |
)
|
872 |
ref_text_file_chat = gr.File(
|
873 |
+
label="Load Reference Text from File (.txt)",
|
874 |
+
file_types=[".txt"],
|
875 |
+
scale=1,
|
876 |
)
|
877 |
with gr.Row():
|
878 |
randomize_seed_chat = gr.Checkbox(
|
|
|
881 |
info="Uncheck to use the seed specified.",
|
882 |
scale=3,
|
883 |
)
|
884 |
+
seed_input_chat = gr.Number(
|
885 |
+
show_label=False, value=0, precision=0, scale=1
|
886 |
+
)
|
887 |
remove_silence_chat = gr.Checkbox(
|
888 |
label="Remove Silences",
|
889 |
value=True,
|
|
|
933 |
"""Generate text response from AI"""
|
934 |
|
935 |
system_prompt_state = [{"role": "system", "content": system_prompt}]
|
936 |
+
response = chat_model_inference(
|
937 |
+
system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state
|
938 |
+
)
|
939 |
|
940 |
conv_state.append({"role": "assistant", "content": response})
|
941 |
return conv_state
|
942 |
|
943 |
@gpu_decorator
|
944 |
+
def generate_audio_response(
|
945 |
+
conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input
|
946 |
+
):
|
947 |
"""Generate TTS audio for AI response"""
|
948 |
if not conv_state or not ref_audio:
|
949 |
return None, ref_text, seed_input
|
|
|
978 |
outputs=[ref_text_chat],
|
979 |
)
|
980 |
|
981 |
+
for user_operation in [
|
982 |
+
audio_input_chat.stop_recording,
|
983 |
+
text_input_chat.submit,
|
984 |
+
send_btn_chat.click,
|
985 |
+
]:
|
986 |
user_operation(
|
987 |
process_audio_input,
|
988 |
inputs=[chatbot_interface, audio_input_chat, text_input_chat],
|
|
|
1009 |
)
|
1010 |
|
1011 |
# Handle clear button or system prompt change and reset conversation
|
1012 |
+
for user_operation in [
|
1013 |
+
clear_btn_chat.click,
|
1014 |
+
system_prompt_chat.change,
|
1015 |
+
chatbot_interface.clear,
|
1016 |
+
]:
|
1017 |
user_operation(
|
1018 |
clear_conversation,
|
1019 |
outputs=[chatbot_interface, audio_output_chat],
|
|
|
1021 |
|
1022 |
|
1023 |
with gr.Blocks() as app_credits:
|
1024 |
+
gr.Markdown(
|
1025 |
+
"""
|
1026 |
# Credits
|
1027 |
|
1028 |
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
1029 |
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
1030 |
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
1031 |
+
"""
|
1032 |
+
)
|
1033 |
|
1034 |
|
1035 |
with gr.Blocks() as app:
|
|
|
1050 |
"""
|
1051 |
)
|
1052 |
|
1053 |
+
last_used_custom = files("f5_tts").joinpath(
|
1054 |
+
"infer/.cache/last_used_custom_model_info_v1.txt"
|
1055 |
+
)
|
1056 |
|
1057 |
def load_last_used_custom():
|
1058 |
try:
|
|
|
1068 |
def switch_tts_model(new_choice):
|
1069 |
global tts_model_choice
|
1070 |
if new_choice == "Custom": # override in case webpage is refreshed
|
1071 |
+
custom_ckpt_path, custom_vocab_path, custom_model_cfg = (
|
1072 |
+
load_last_used_custom()
|
1073 |
+
)
|
1074 |
+
tts_model_choice = (
|
1075 |
+
"Custom",
|
1076 |
+
custom_ckpt_path,
|
1077 |
+
custom_vocab_path,
|
1078 |
+
custom_model_cfg,
|
1079 |
+
)
|
1080 |
return (
|
1081 |
gr.update(visible=True, value=custom_ckpt_path),
|
1082 |
gr.update(visible=True, value=custom_vocab_path),
|
|
|
1084 |
)
|
1085 |
else:
|
1086 |
tts_model_choice = new_choice
|
1087 |
+
return (
|
1088 |
+
gr.update(visible=False),
|
1089 |
+
gr.update(visible=False),
|
1090 |
+
gr.update(visible=False),
|
1091 |
+
)
|
1092 |
|
1093 |
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
|
1094 |
global tts_model_choice
|
1095 |
+
tts_model_choice = (
|
1096 |
+
"Custom",
|
1097 |
+
custom_ckpt_path,
|
1098 |
+
custom_vocab_path,
|
1099 |
+
custom_model_cfg,
|
1100 |
+
)
|
1101 |
with open(last_used_custom, "w", encoding="utf-8") as f:
|
1102 |
+
f.write(
|
1103 |
+
custom_ckpt_path
|
1104 |
+
+ "\n"
|
1105 |
+
+ custom_vocab_path
|
1106 |
+
+ "\n"
|
1107 |
+
+ custom_model_cfg
|
1108 |
+
+ "\n"
|
1109 |
+
)
|
1110 |
|
1111 |
with gr.Row():
|
1112 |
if not USING_SPACES:
|
1113 |
choose_tts_model = gr.Radio(
|
1114 |
+
choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"],
|
1115 |
+
label="Choose TTS Model",
|
1116 |
+
value=DEFAULT_TTS_MODEL,
|
1117 |
)
|
1118 |
else:
|
1119 |
choose_tts_model = gr.Radio(
|
1120 |
+
choices=[DEFAULT_TTS_MODEL, "E2-TTS"],
|
1121 |
+
label="Choose TTS Model",
|
1122 |
+
value=DEFAULT_TTS_MODEL,
|
1123 |
)
|
1124 |
custom_ckpt_path = gr.Dropdown(
|
1125 |
choices=[DEFAULT_TTS_MODEL_CFG[0]],
|
f5_tts/infer/speech_edit.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import os
|
2 |
|
3 |
-
|
4 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
5 |
|
6 |
from importlib.resources import files
|
@@ -12,19 +11,19 @@ from cached_path import cached_path
|
|
12 |
from hydra.utils import get_class
|
13 |
from omegaconf import OmegaConf
|
14 |
|
15 |
-
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder,
|
|
|
16 |
from f5_tts.model import CFM
|
17 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
18 |
|
19 |
-
|
20 |
device = (
|
21 |
"cuda"
|
22 |
if torch.cuda.is_available()
|
23 |
-
else
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
)
|
29 |
|
30 |
|
@@ -59,7 +58,9 @@ n_fft = model_cfg.model.mel_spec.n_fft
|
|
59 |
|
60 |
|
61 |
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
62 |
-
ckpt_path = str(
|
|
|
|
|
63 |
output_dir = "tests"
|
64 |
|
65 |
|
@@ -103,14 +104,18 @@ if mel_spec_type == "vocos":
|
|
103 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
104 |
elif mel_spec_type == "bigvgan":
|
105 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
106 |
-
vocoder = load_vocoder(
|
|
|
|
|
107 |
|
108 |
# Tokenizer
|
109 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
110 |
|
111 |
# Model
|
112 |
model = CFM(
|
113 |
-
transformer=model_cls(
|
|
|
|
|
114 |
mel_spec_kwargs=dict(
|
115 |
n_fft=n_fft,
|
116 |
hop_length=hop_length,
|
@@ -146,7 +151,14 @@ for part in parts_to_edit:
|
|
146 |
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
147 |
part_dur = part_dur * target_sample_rate
|
148 |
start = start * target_sample_rate
|
149 |
-
audio_ = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
edit_mask = torch.cat(
|
151 |
(
|
152 |
edit_mask,
|
@@ -157,7 +169,9 @@ for part in parts_to_edit:
|
|
157 |
)
|
158 |
offset = end * target_sample_rate
|
159 |
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
160 |
-
edit_mask = F.pad(
|
|
|
|
|
161 |
audio = audio.to(device)
|
162 |
edit_mask = edit_mask.to(device)
|
163 |
|
@@ -201,5 +215,7 @@ with torch.inference_mode():
|
|
201 |
generated_wave = generated_wave * rms / target_rms
|
202 |
|
203 |
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
204 |
-
torchaudio.save(
|
|
|
|
|
205 |
print(f"Generated wav: {generated_wave.shape}")
|
|
|
1 |
import os
|
2 |
|
|
|
3 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
4 |
|
5 |
from importlib.resources import files
|
|
|
11 |
from hydra.utils import get_class
|
12 |
from omegaconf import OmegaConf
|
13 |
|
14 |
+
from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder,
|
15 |
+
save_spectrogram)
|
16 |
from f5_tts.model import CFM
|
17 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
18 |
|
|
|
19 |
device = (
|
20 |
"cuda"
|
21 |
if torch.cuda.is_available()
|
22 |
+
else (
|
23 |
+
"xpu"
|
24 |
+
if torch.xpu.is_available()
|
25 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
26 |
+
)
|
27 |
)
|
28 |
|
29 |
|
|
|
58 |
|
59 |
|
60 |
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
61 |
+
ckpt_path = str(
|
62 |
+
cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors")
|
63 |
+
)
|
64 |
output_dir = "tests"
|
65 |
|
66 |
|
|
|
104 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
105 |
elif mel_spec_type == "bigvgan":
|
106 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
107 |
+
vocoder = load_vocoder(
|
108 |
+
vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path
|
109 |
+
)
|
110 |
|
111 |
# Tokenizer
|
112 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
113 |
|
114 |
# Model
|
115 |
model = CFM(
|
116 |
+
transformer=model_cls(
|
117 |
+
**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
118 |
+
),
|
119 |
mel_spec_kwargs=dict(
|
120 |
n_fft=n_fft,
|
121 |
hop_length=hop_length,
|
|
|
151 |
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
152 |
part_dur = part_dur * target_sample_rate
|
153 |
start = start * target_sample_rate
|
154 |
+
audio_ = torch.cat(
|
155 |
+
(
|
156 |
+
audio_,
|
157 |
+
audio[:, round(offset) : round(start)],
|
158 |
+
torch.zeros(1, round(part_dur)),
|
159 |
+
),
|
160 |
+
dim=-1,
|
161 |
+
)
|
162 |
edit_mask = torch.cat(
|
163 |
(
|
164 |
edit_mask,
|
|
|
169 |
)
|
170 |
offset = end * target_sample_rate
|
171 |
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
172 |
+
edit_mask = F.pad(
|
173 |
+
edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True
|
174 |
+
)
|
175 |
audio = audio.to(device)
|
176 |
edit_mask = edit_mask.to(device)
|
177 |
|
|
|
215 |
generated_wave = generated_wave * rms / target_rms
|
216 |
|
217 |
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
218 |
+
torchaudio.save(
|
219 |
+
f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate
|
220 |
+
)
|
221 |
print(f"Generated wav: {generated_wave.shape}")
|
f5_tts/infer/utils_infer.py
CHANGED
@@ -4,9 +4,10 @@ import os
|
|
4 |
import sys
|
5 |
from concurrent.futures import ThreadPoolExecutor
|
6 |
|
7 |
-
|
8 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
9 |
-
sys.path.append(
|
|
|
|
|
10 |
|
11 |
import hashlib
|
12 |
import re
|
@@ -15,7 +16,6 @@ from importlib.resources import files
|
|
15 |
|
16 |
import matplotlib
|
17 |
|
18 |
-
|
19 |
matplotlib.use("Agg")
|
20 |
|
21 |
import matplotlib.pylab as plt
|
@@ -31,21 +31,22 @@ from vocos import Vocos
|
|
31 |
from f5_tts.model import CFM
|
32 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
33 |
|
34 |
-
|
35 |
_ref_audio_cache = {}
|
36 |
_ref_text_cache = {}
|
37 |
|
38 |
device = (
|
39 |
"cuda"
|
40 |
if torch.cuda.is_available()
|
41 |
-
else
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
)
|
47 |
|
48 |
-
tempfile_kwargs =
|
|
|
|
|
49 |
|
50 |
# -----------------------------------------
|
51 |
|
@@ -87,12 +88,23 @@ def chunk_text(text, max_chars=135):
|
|
87 |
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
88 |
|
89 |
for sentence in sentences:
|
90 |
-
if
|
91 |
-
current_chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
else:
|
93 |
if current_chunk:
|
94 |
chunks.append(current_chunk.strip())
|
95 |
-
current_chunk =
|
|
|
|
|
|
|
|
|
96 |
|
97 |
if current_chunk:
|
98 |
chunks.append(current_chunk.strip())
|
@@ -101,7 +113,13 @@ def chunk_text(text, max_chars=135):
|
|
101 |
|
102 |
|
103 |
# load vocoder
|
104 |
-
def load_vocoder(
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if vocoder_name == "vocos":
|
106 |
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
107 |
if is_local:
|
@@ -111,8 +129,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
111 |
else:
|
112 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
113 |
repo_id = "charactr/vocos-mel-24khz"
|
114 |
-
config_path = hf_hub_download(
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
vocoder = Vocos.from_hparams(config_path)
|
117 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
118 |
from vocos.feature_extractors import EncodecFeatures
|
@@ -129,13 +151,17 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
129 |
try:
|
130 |
from third_party.BigVGAN import bigvgan
|
131 |
except ImportError:
|
132 |
-
print(
|
|
|
|
|
133 |
if is_local:
|
134 |
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
|
135 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
136 |
else:
|
137 |
vocoder = bigvgan.BigVGAN.from_pretrained(
|
138 |
-
"nvidia/bigvgan_v2_24khz_100band_256x",
|
|
|
|
|
139 |
)
|
140 |
|
141 |
vocoder.remove_weight_norm()
|
@@ -177,7 +203,11 @@ def transcribe(ref_audio, language=None):
|
|
177 |
ref_audio,
|
178 |
chunk_length_s=30,
|
179 |
batch_size=128,
|
180 |
-
generate_kwargs=
|
|
|
|
|
|
|
|
|
181 |
return_timestamps=False,
|
182 |
)["text"].strip()
|
183 |
|
@@ -214,7 +244,10 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
|
214 |
}
|
215 |
|
216 |
# patch for backward compatibility, 305e3ea
|
217 |
-
for key in [
|
|
|
|
|
|
|
218 |
if key in checkpoint["model_state_dict"]:
|
219 |
del checkpoint["model_state_dict"][key]
|
220 |
|
@@ -253,7 +286,9 @@ def load_model(
|
|
253 |
|
254 |
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
255 |
model = CFM(
|
256 |
-
transformer=model_cls(
|
|
|
|
|
257 |
mel_spec_kwargs=dict(
|
258 |
n_fft=n_fft,
|
259 |
hop_length=hop_length,
|
@@ -276,7 +311,9 @@ def load_model(
|
|
276 |
|
277 |
def remove_silence_edges(audio, silence_threshold=-42):
|
278 |
# Remove silence from the start
|
279 |
-
non_silent_start_idx = silence.detect_leading_silence(
|
|
|
|
|
280 |
audio = audio[non_silent_start_idx:]
|
281 |
|
282 |
# Remove silence from the end
|
@@ -315,11 +352,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
|
315 |
|
316 |
# 1. try to find long silence for clipping
|
317 |
non_silent_segs = silence.split_on_silence(
|
318 |
-
aseg,
|
|
|
|
|
|
|
|
|
319 |
)
|
320 |
non_silent_wave = AudioSegment.silent(duration=0)
|
321 |
for non_silent_seg in non_silent_segs:
|
322 |
-
if
|
|
|
|
|
|
|
323 |
show_info("Audio is over 12s, clipping short. (1)")
|
324 |
break
|
325 |
non_silent_wave += non_silent_seg
|
@@ -327,11 +371,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
|
327 |
# 2. try to find short silence for clipping if 1. failed
|
328 |
if len(non_silent_wave) > 12000:
|
329 |
non_silent_segs = silence.split_on_silence(
|
330 |
-
aseg,
|
|
|
|
|
|
|
|
|
331 |
)
|
332 |
non_silent_wave = AudioSegment.silent(duration=0)
|
333 |
for non_silent_seg in non_silent_segs:
|
334 |
-
if
|
|
|
|
|
|
|
335 |
show_info("Audio is over 12s, clipping short. (2)")
|
336 |
break
|
337 |
non_silent_wave += non_silent_seg
|
@@ -399,7 +450,12 @@ def infer_process(
|
|
399 |
):
|
400 |
# Split the input text into batches
|
401 |
audio, sr = torchaudio.load(ref_audio)
|
402 |
-
max_chars = int(
|
|
|
|
|
|
|
|
|
|
|
403 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
404 |
for i, gen_text in enumerate(gen_text_batches):
|
405 |
print(f"gen_text {i}", gen_text)
|
@@ -483,7 +539,9 @@ def infer_batch_process(
|
|
483 |
# Calculate duration
|
484 |
ref_text_len = len(ref_text.encode("utf-8"))
|
485 |
gen_text_len = len(gen_text.encode("utf-8"))
|
486 |
-
duration = ref_audio_len + int(
|
|
|
|
|
487 |
|
488 |
# inference
|
489 |
with torch.inference_mode():
|
@@ -519,12 +577,19 @@ def infer_batch_process(
|
|
519 |
yield generated_wave, generated_cpu
|
520 |
|
521 |
if streaming:
|
522 |
-
for gen_text in
|
|
|
|
|
|
|
|
|
523 |
for chunk in process_batch(gen_text):
|
524 |
yield chunk
|
525 |
else:
|
526 |
with ThreadPoolExecutor() as executor:
|
527 |
-
futures = [
|
|
|
|
|
|
|
528 |
for future in progress.tqdm(futures) if progress is not None else futures:
|
529 |
result = future.result()
|
530 |
if result:
|
@@ -545,7 +610,9 @@ def infer_batch_process(
|
|
545 |
|
546 |
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
547 |
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
548 |
-
cross_fade_samples = min(
|
|
|
|
|
549 |
|
550 |
if cross_fade_samples <= 0:
|
551 |
# No overlap possible, concatenate
|
@@ -561,11 +628,17 @@ def infer_batch_process(
|
|
561 |
fade_in = np.linspace(0, 1, cross_fade_samples)
|
562 |
|
563 |
# Cross-faded overlap
|
564 |
-
cross_faded_overlap =
|
|
|
|
|
565 |
|
566 |
# Combine
|
567 |
new_wave = np.concatenate(
|
568 |
-
[
|
|
|
|
|
|
|
|
|
569 |
)
|
570 |
|
571 |
final_wave = new_wave
|
|
|
4 |
import sys
|
5 |
from concurrent.futures import ThreadPoolExecutor
|
6 |
|
|
|
7 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
8 |
+
sys.path.append(
|
9 |
+
f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/"
|
10 |
+
)
|
11 |
|
12 |
import hashlib
|
13 |
import re
|
|
|
16 |
|
17 |
import matplotlib
|
18 |
|
|
|
19 |
matplotlib.use("Agg")
|
20 |
|
21 |
import matplotlib.pylab as plt
|
|
|
31 |
from f5_tts.model import CFM
|
32 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
33 |
|
|
|
34 |
_ref_audio_cache = {}
|
35 |
_ref_text_cache = {}
|
36 |
|
37 |
device = (
|
38 |
"cuda"
|
39 |
if torch.cuda.is_available()
|
40 |
+
else (
|
41 |
+
"xpu"
|
42 |
+
if torch.xpu.is_available()
|
43 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
44 |
+
)
|
45 |
)
|
46 |
|
47 |
+
tempfile_kwargs = (
|
48 |
+
{"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
|
49 |
+
)
|
50 |
|
51 |
# -----------------------------------------
|
52 |
|
|
|
88 |
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
89 |
|
90 |
for sentence in sentences:
|
91 |
+
if (
|
92 |
+
len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8"))
|
93 |
+
<= max_chars
|
94 |
+
):
|
95 |
+
current_chunk += (
|
96 |
+
sentence + " "
|
97 |
+
if sentence and len(sentence[-1].encode("utf-8")) == 1
|
98 |
+
else sentence
|
99 |
+
)
|
100 |
else:
|
101 |
if current_chunk:
|
102 |
chunks.append(current_chunk.strip())
|
103 |
+
current_chunk = (
|
104 |
+
sentence + " "
|
105 |
+
if sentence and len(sentence[-1].encode("utf-8")) == 1
|
106 |
+
else sentence
|
107 |
+
)
|
108 |
|
109 |
if current_chunk:
|
110 |
chunks.append(current_chunk.strip())
|
|
|
113 |
|
114 |
|
115 |
# load vocoder
|
116 |
+
def load_vocoder(
|
117 |
+
vocoder_name="vocos",
|
118 |
+
is_local=False,
|
119 |
+
local_path="",
|
120 |
+
device=device,
|
121 |
+
hf_cache_dir=None,
|
122 |
+
):
|
123 |
if vocoder_name == "vocos":
|
124 |
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
125 |
if is_local:
|
|
|
129 |
else:
|
130 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
131 |
repo_id = "charactr/vocos-mel-24khz"
|
132 |
+
config_path = hf_hub_download(
|
133 |
+
repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml"
|
134 |
+
)
|
135 |
+
model_path = hf_hub_download(
|
136 |
+
repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin"
|
137 |
+
)
|
138 |
vocoder = Vocos.from_hparams(config_path)
|
139 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
140 |
from vocos.feature_extractors import EncodecFeatures
|
|
|
151 |
try:
|
152 |
from third_party.BigVGAN import bigvgan
|
153 |
except ImportError:
|
154 |
+
print(
|
155 |
+
"You need to follow the README to init submodule and change the BigVGAN source code."
|
156 |
+
)
|
157 |
if is_local:
|
158 |
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
|
159 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
160 |
else:
|
161 |
vocoder = bigvgan.BigVGAN.from_pretrained(
|
162 |
+
"nvidia/bigvgan_v2_24khz_100band_256x",
|
163 |
+
use_cuda_kernel=False,
|
164 |
+
cache_dir=hf_cache_dir,
|
165 |
)
|
166 |
|
167 |
vocoder.remove_weight_norm()
|
|
|
203 |
ref_audio,
|
204 |
chunk_length_s=30,
|
205 |
batch_size=128,
|
206 |
+
generate_kwargs=(
|
207 |
+
{"task": "transcribe", "language": language}
|
208 |
+
if language
|
209 |
+
else {"task": "transcribe"}
|
210 |
+
),
|
211 |
return_timestamps=False,
|
212 |
)["text"].strip()
|
213 |
|
|
|
244 |
}
|
245 |
|
246 |
# patch for backward compatibility, 305e3ea
|
247 |
+
for key in [
|
248 |
+
"mel_spec.mel_stft.mel_scale.fb",
|
249 |
+
"mel_spec.mel_stft.spectrogram.window",
|
250 |
+
]:
|
251 |
if key in checkpoint["model_state_dict"]:
|
252 |
del checkpoint["model_state_dict"][key]
|
253 |
|
|
|
286 |
|
287 |
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
288 |
model = CFM(
|
289 |
+
transformer=model_cls(
|
290 |
+
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
291 |
+
),
|
292 |
mel_spec_kwargs=dict(
|
293 |
n_fft=n_fft,
|
294 |
hop_length=hop_length,
|
|
|
311 |
|
312 |
def remove_silence_edges(audio, silence_threshold=-42):
|
313 |
# Remove silence from the start
|
314 |
+
non_silent_start_idx = silence.detect_leading_silence(
|
315 |
+
audio, silence_threshold=silence_threshold
|
316 |
+
)
|
317 |
audio = audio[non_silent_start_idx:]
|
318 |
|
319 |
# Remove silence from the end
|
|
|
352 |
|
353 |
# 1. try to find long silence for clipping
|
354 |
non_silent_segs = silence.split_on_silence(
|
355 |
+
aseg,
|
356 |
+
min_silence_len=1000,
|
357 |
+
silence_thresh=-50,
|
358 |
+
keep_silence=1000,
|
359 |
+
seek_step=10,
|
360 |
)
|
361 |
non_silent_wave = AudioSegment.silent(duration=0)
|
362 |
for non_silent_seg in non_silent_segs:
|
363 |
+
if (
|
364 |
+
len(non_silent_wave) > 6000
|
365 |
+
and len(non_silent_wave + non_silent_seg) > 12000
|
366 |
+
):
|
367 |
show_info("Audio is over 12s, clipping short. (1)")
|
368 |
break
|
369 |
non_silent_wave += non_silent_seg
|
|
|
371 |
# 2. try to find short silence for clipping if 1. failed
|
372 |
if len(non_silent_wave) > 12000:
|
373 |
non_silent_segs = silence.split_on_silence(
|
374 |
+
aseg,
|
375 |
+
min_silence_len=100,
|
376 |
+
silence_thresh=-40,
|
377 |
+
keep_silence=1000,
|
378 |
+
seek_step=10,
|
379 |
)
|
380 |
non_silent_wave = AudioSegment.silent(duration=0)
|
381 |
for non_silent_seg in non_silent_segs:
|
382 |
+
if (
|
383 |
+
len(non_silent_wave) > 6000
|
384 |
+
and len(non_silent_wave + non_silent_seg) > 12000
|
385 |
+
):
|
386 |
show_info("Audio is over 12s, clipping short. (2)")
|
387 |
break
|
388 |
non_silent_wave += non_silent_seg
|
|
|
450 |
):
|
451 |
# Split the input text into batches
|
452 |
audio, sr = torchaudio.load(ref_audio)
|
453 |
+
max_chars = int(
|
454 |
+
len(ref_text.encode("utf-8"))
|
455 |
+
/ (audio.shape[-1] / sr)
|
456 |
+
* (22 - audio.shape[-1] / sr)
|
457 |
+
* speed
|
458 |
+
)
|
459 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
460 |
for i, gen_text in enumerate(gen_text_batches):
|
461 |
print(f"gen_text {i}", gen_text)
|
|
|
539 |
# Calculate duration
|
540 |
ref_text_len = len(ref_text.encode("utf-8"))
|
541 |
gen_text_len = len(gen_text.encode("utf-8"))
|
542 |
+
duration = ref_audio_len + int(
|
543 |
+
ref_audio_len / ref_text_len * gen_text_len / local_speed
|
544 |
+
)
|
545 |
|
546 |
# inference
|
547 |
with torch.inference_mode():
|
|
|
577 |
yield generated_wave, generated_cpu
|
578 |
|
579 |
if streaming:
|
580 |
+
for gen_text in (
|
581 |
+
progress.tqdm(gen_text_batches)
|
582 |
+
if progress is not None
|
583 |
+
else gen_text_batches
|
584 |
+
):
|
585 |
for chunk in process_batch(gen_text):
|
586 |
yield chunk
|
587 |
else:
|
588 |
with ThreadPoolExecutor() as executor:
|
589 |
+
futures = [
|
590 |
+
executor.submit(process_batch, gen_text)
|
591 |
+
for gen_text in gen_text_batches
|
592 |
+
]
|
593 |
for future in progress.tqdm(futures) if progress is not None else futures:
|
594 |
result = future.result()
|
595 |
if result:
|
|
|
610 |
|
611 |
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
612 |
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
613 |
+
cross_fade_samples = min(
|
614 |
+
cross_fade_samples, len(prev_wave), len(next_wave)
|
615 |
+
)
|
616 |
|
617 |
if cross_fade_samples <= 0:
|
618 |
# No overlap possible, concatenate
|
|
|
628 |
fade_in = np.linspace(0, 1, cross_fade_samples)
|
629 |
|
630 |
# Cross-faded overlap
|
631 |
+
cross_faded_overlap = (
|
632 |
+
prev_overlap * fade_out + next_overlap * fade_in
|
633 |
+
)
|
634 |
|
635 |
# Combine
|
636 |
new_wave = np.concatenate(
|
637 |
+
[
|
638 |
+
prev_wave[:-cross_fade_samples],
|
639 |
+
cross_faded_overlap,
|
640 |
+
next_wave[cross_fade_samples:],
|
641 |
+
]
|
642 |
)
|
643 |
|
644 |
final_wave = new_wave
|
f5_tts/model/__init__.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
-
from f5_tts.model.cfm import CFM
|
2 |
-
|
3 |
-
from f5_tts.model.backbones.unett import UNetT
|
4 |
from f5_tts.model.backbones.dit import DiT
|
5 |
from f5_tts.model.backbones.mmdit import MMDiT
|
6 |
-
|
|
|
7 |
from f5_tts.model.trainer import Trainer
|
8 |
|
9 |
-
|
10 |
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
|
|
|
|
|
|
|
|
1 |
from f5_tts.model.backbones.dit import DiT
|
2 |
from f5_tts.model.backbones.mmdit import MMDiT
|
3 |
+
from f5_tts.model.backbones.unett import UNetT
|
4 |
+
from f5_tts.model.cfm import CFM
|
5 |
from f5_tts.model.trainer import Trainer
|
6 |
|
|
|
7 |
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
f5_tts/model/backbones/dit.py
CHANGED
@@ -10,21 +10,14 @@ d - dimension
|
|
10 |
from __future__ import annotations
|
11 |
|
12 |
import torch
|
13 |
-
from torch import nn
|
14 |
import torch.nn.functional as F
|
15 |
-
|
16 |
from x_transformers.x_transformers import RotaryEmbedding
|
17 |
|
18 |
-
from f5_tts.model.modules import (
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
DiTBlock,
|
23 |
-
AdaLayerNormZero_Final,
|
24 |
-
precompute_freqs_cis,
|
25 |
-
get_pos_embed_indices,
|
26 |
-
)
|
27 |
-
|
28 |
|
29 |
# Text embedding
|
30 |
|
@@ -32,34 +25,49 @@ from f5_tts.model.modules import (
|
|
32 |
class TextEmbedding(nn.Module):
|
33 |
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
34 |
super().__init__()
|
35 |
-
self.text_embed = nn.Embedding(
|
|
|
|
|
36 |
|
37 |
if conv_layers > 0:
|
38 |
self.extra_modeling = True
|
39 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
40 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
41 |
self.text_blocks = nn.Sequential(
|
42 |
-
*[
|
|
|
|
|
|
|
43 |
)
|
44 |
else:
|
45 |
self.extra_modeling = False
|
46 |
|
47 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
48 |
-
text =
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
batch, text_len = text.shape[0], text.shape[1]
|
51 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
52 |
|
53 |
if drop_text: # cfg for text
|
54 |
text = torch.zeros_like(text)
|
55 |
-
|
56 |
text = self.text_embed(text) # b n -> b n d
|
57 |
|
58 |
# possible extra modeling
|
59 |
if self.extra_modeling:
|
60 |
# sinus pos emb
|
61 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
62 |
-
pos_idx = get_pos_embed_indices(
|
|
|
|
|
63 |
text_pos_embed = self.freqs_cis[pos_idx]
|
64 |
text = text + text_pos_embed
|
65 |
|
@@ -78,7 +86,13 @@ class InputEmbedding(nn.Module):
|
|
78 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
79 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
80 |
|
81 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
if drop_audio_cond: # cfg for cond audio
|
83 |
cond = torch.zeros_like(cond)
|
84 |
|
@@ -114,17 +128,23 @@ class DiT(nn.Module):
|
|
114 |
if second_time:
|
115 |
self.time_embed2 = TimestepEmbedding(dim)
|
116 |
# Zero-init the weights and biases of the first and last Linear layers in time_mlp
|
117 |
-
nn.init.zeros_(
|
118 |
-
|
119 |
-
|
120 |
-
nn.init.zeros_(self.time_embed2.time_mlp[
|
|
|
|
|
|
|
|
|
121 |
else:
|
122 |
self.time_embed2 = None
|
123 |
-
|
124 |
if text_dim is None:
|
125 |
text_dim = mel_dim
|
126 |
self.vocab_size = text_num_embeds
|
127 |
-
self.text_embed = TextEmbedding(
|
|
|
|
|
128 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
129 |
|
130 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -133,9 +153,20 @@ class DiT(nn.Module):
|
|
133 |
self.depth = depth
|
134 |
|
135 |
self.transformer_blocks = nn.ModuleList(
|
136 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
)
|
138 |
-
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
139 |
|
140 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
141 |
self.proj_out = nn.Linear(dim, mel_dim)
|
@@ -171,7 +202,7 @@ class DiT(nn.Module):
|
|
171 |
if second_time is not None and self.time_embed2 is not None:
|
172 |
t2 = self.time_embed2(second_time)
|
173 |
t = t + t2
|
174 |
-
|
175 |
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
176 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
177 |
|
@@ -185,7 +216,9 @@ class DiT(nn.Module):
|
|
185 |
|
186 |
for block in self.transformer_blocks:
|
187 |
if self.checkpoint_activations:
|
188 |
-
x = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
189 |
else:
|
190 |
x = block(x, t, mask=mask, rope=rope)
|
191 |
|
|
|
10 |
from __future__ import annotations
|
11 |
|
12 |
import torch
|
|
|
13 |
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
from x_transformers.x_transformers import RotaryEmbedding
|
16 |
|
17 |
+
from f5_tts.model.modules import (AdaLayerNormZero_Final, ConvNeXtV2Block,
|
18 |
+
ConvPositionEmbedding, DiTBlock,
|
19 |
+
TimestepEmbedding, get_pos_embed_indices,
|
20 |
+
precompute_freqs_cis)
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Text embedding
|
23 |
|
|
|
25 |
class TextEmbedding(nn.Module):
|
26 |
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
27 |
super().__init__()
|
28 |
+
self.text_embed = nn.Embedding(
|
29 |
+
text_num_embeds + 1, text_dim
|
30 |
+
) # use 0 as filler token
|
31 |
|
32 |
if conv_layers > 0:
|
33 |
self.extra_modeling = True
|
34 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
35 |
+
self.register_buffer(
|
36 |
+
"freqs_cis",
|
37 |
+
precompute_freqs_cis(text_dim, self.precompute_max_pos),
|
38 |
+
persistent=False,
|
39 |
+
)
|
40 |
self.text_blocks = nn.Sequential(
|
41 |
+
*[
|
42 |
+
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
|
43 |
+
for _ in range(conv_layers)
|
44 |
+
]
|
45 |
)
|
46 |
else:
|
47 |
self.extra_modeling = False
|
48 |
|
49 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
50 |
+
text = (
|
51 |
+
text + 1
|
52 |
+
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
53 |
+
text = text[
|
54 |
+
:, :seq_len
|
55 |
+
] # curtail if character tokens are more than the mel spec tokens
|
56 |
batch, text_len = text.shape[0], text.shape[1]
|
57 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
58 |
|
59 |
if drop_text: # cfg for text
|
60 |
text = torch.zeros_like(text)
|
61 |
+
|
62 |
text = self.text_embed(text) # b n -> b n d
|
63 |
|
64 |
# possible extra modeling
|
65 |
if self.extra_modeling:
|
66 |
# sinus pos emb
|
67 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
68 |
+
pos_idx = get_pos_embed_indices(
|
69 |
+
batch_start, seq_len, max_pos=self.precompute_max_pos
|
70 |
+
)
|
71 |
text_pos_embed = self.freqs_cis[pos_idx]
|
72 |
text = text + text_pos_embed
|
73 |
|
|
|
86 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
87 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
88 |
|
89 |
+
def forward(
|
90 |
+
self,
|
91 |
+
x: float["b n d"],
|
92 |
+
cond: float["b n d"],
|
93 |
+
text_embed: float["b n d"],
|
94 |
+
drop_audio_cond=False,
|
95 |
+
): # noqa: F722
|
96 |
if drop_audio_cond: # cfg for cond audio
|
97 |
cond = torch.zeros_like(cond)
|
98 |
|
|
|
128 |
if second_time:
|
129 |
self.time_embed2 = TimestepEmbedding(dim)
|
130 |
# Zero-init the weights and biases of the first and last Linear layers in time_mlp
|
131 |
+
nn.init.zeros_(
|
132 |
+
self.time_embed2.time_mlp[0].weight
|
133 |
+
) # First Linear layer weights
|
134 |
+
nn.init.zeros_(self.time_embed2.time_mlp[0].bias) # First Linear layer bias
|
135 |
+
nn.init.zeros_(
|
136 |
+
self.time_embed2.time_mlp[-1].weight
|
137 |
+
) # Last Linear layer weights
|
138 |
+
nn.init.zeros_(self.time_embed2.time_mlp[-1].bias) # Last Linear layer bias
|
139 |
else:
|
140 |
self.time_embed2 = None
|
141 |
+
|
142 |
if text_dim is None:
|
143 |
text_dim = mel_dim
|
144 |
self.vocab_size = text_num_embeds
|
145 |
+
self.text_embed = TextEmbedding(
|
146 |
+
text_num_embeds, text_dim, conv_layers=conv_layers
|
147 |
+
)
|
148 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
149 |
|
150 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
153 |
self.depth = depth
|
154 |
|
155 |
self.transformer_blocks = nn.ModuleList(
|
156 |
+
[
|
157 |
+
DiTBlock(
|
158 |
+
dim=dim,
|
159 |
+
heads=heads,
|
160 |
+
dim_head=dim_head,
|
161 |
+
ff_mult=ff_mult,
|
162 |
+
dropout=dropout,
|
163 |
+
)
|
164 |
+
for _ in range(depth)
|
165 |
+
]
|
166 |
+
)
|
167 |
+
self.long_skip_connection = (
|
168 |
+
nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
169 |
)
|
|
|
170 |
|
171 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
172 |
self.proj_out = nn.Linear(dim, mel_dim)
|
|
|
202 |
if second_time is not None and self.time_embed2 is not None:
|
203 |
t2 = self.time_embed2(second_time)
|
204 |
t = t + t2
|
205 |
+
|
206 |
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
207 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
208 |
|
|
|
216 |
|
217 |
for block in self.transformer_blocks:
|
218 |
if self.checkpoint_activations:
|
219 |
+
x = torch.utils.checkpoint.checkpoint(
|
220 |
+
self.ckpt_wrapper(block), x, t, mask, rope
|
221 |
+
)
|
222 |
else:
|
223 |
x = block(x, t, mask=mask, rope=rope)
|
224 |
|
f5_tts/model/backbones/mmdit.py
CHANGED
@@ -10,41 +10,37 @@ d - dimension
|
|
10 |
from __future__ import annotations
|
11 |
|
12 |
import torch
|
13 |
-
from torch import nn
|
14 |
import torch.nn.functional as F
|
15 |
-
|
16 |
from x_transformers.x_transformers import RotaryEmbedding
|
17 |
|
18 |
-
from f5_tts.model.modules import (
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
get_pos_embed_indices,
|
26 |
-
)
|
27 |
-
|
28 |
-
from f5_tts.model.utils import (
|
29 |
-
default,
|
30 |
-
exists,
|
31 |
-
lens_to_mask,
|
32 |
-
list_str_to_idx,
|
33 |
-
list_str_to_tensor,
|
34 |
-
mask_from_frac_lengths,
|
35 |
-
)
|
36 |
# text embedding
|
37 |
|
38 |
|
39 |
class TextEmbedding(nn.Module):
|
40 |
def __init__(self, out_dim, text_num_embeds):
|
41 |
super().__init__()
|
42 |
-
self.text_embed = nn.Embedding(
|
|
|
|
|
43 |
|
44 |
self.precompute_max_pos = 1024
|
45 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
def forward(
|
|
|
|
|
48 |
text = text + 1
|
49 |
if drop_text:
|
50 |
text = torch.zeros_like(text)
|
@@ -53,7 +49,9 @@ class TextEmbedding(nn.Module):
|
|
53 |
# sinus pos emb
|
54 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
55 |
batch_text_len = text.shape[1]
|
56 |
-
pos_idx = get_pos_embed_indices(
|
|
|
|
|
57 |
text_pos_embed = self.freqs_cis[pos_idx]
|
58 |
|
59 |
text = text + text_pos_embed
|
@@ -70,7 +68,9 @@ class AudioEmbedding(nn.Module):
|
|
70 |
self.linear = nn.Linear(2 * in_dim, out_dim)
|
71 |
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
72 |
|
73 |
-
def forward(
|
|
|
|
|
74 |
if drop_audio_cond:
|
75 |
cond = torch.zeros_like(cond)
|
76 |
x = torch.cat((x, cond), dim=-1)
|
@@ -97,23 +97,24 @@ class MMDiT(nn.Module):
|
|
97 |
mel_dim=100,
|
98 |
checkpoint_activations=False,
|
99 |
text_encoder=True,
|
100 |
-
|
101 |
):
|
102 |
super().__init__()
|
103 |
|
104 |
self.time_embed = TimestepEmbedding(dim)
|
105 |
if text_encoder:
|
106 |
-
self.text_encoder = TextEncoder(
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
113 |
else:
|
114 |
self.text_encoder = None
|
115 |
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
116 |
-
|
117 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
118 |
|
119 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -136,9 +137,8 @@ class MMDiT(nn.Module):
|
|
136 |
)
|
137 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
138 |
self.proj_out = nn.Linear(dim, mel_dim)
|
139 |
-
|
140 |
-
self.checkpoint_activations = checkpoint_activations
|
141 |
|
|
|
142 |
|
143 |
def forward(
|
144 |
self,
|
@@ -161,45 +161,53 @@ class MMDiT(nn.Module):
|
|
161 |
c = self.text_encoder(text, t, mask=text_mask, drop_text=drop_text)
|
162 |
else:
|
163 |
c = self.text_embed(text, drop_text=drop_text)
|
164 |
-
|
165 |
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
166 |
|
167 |
seq_len = x.shape[1]
|
168 |
text_len = text.shape[1]
|
169 |
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
170 |
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
171 |
-
|
172 |
# if mask is not None:
|
173 |
# rope_audio = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
174 |
-
|
175 |
# dummy_token = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype)
|
176 |
# x = torch.cat([x, dummy_token], dim=1) # shape is now [b, nw+1, d]
|
177 |
-
|
178 |
# # pad the mask so that new dummy token is always masked out
|
179 |
# # mask: [b, nw] -> [b, nw+1]
|
180 |
# false_col = torch.zeros((x.shape[0], 1), dtype=torch.bool, device=x.device)
|
181 |
# mask = torch.cat([mask, false_col], dim=1)
|
182 |
-
|
183 |
# if text_mask is not None:
|
184 |
# rope_text = self.rotary_embed.forward_from_seq_len(text_len + 1)
|
185 |
|
186 |
# dummy_token = torch.zeros((c.shape[0], 1, c.shape[-1]), device=c.device, dtype=c.dtype)
|
187 |
# c = torch.cat([c, dummy_token], dim=1) # shape is now [b, nt+1, d]
|
188 |
-
|
189 |
# # pad the text mask so that new dummy token is always masked out
|
190 |
# # text_mask: [b, nt] -> [b, nt+1]
|
191 |
# false_col = torch.zeros((c.shape[0], 1), dtype=torch.bool, device=c.device)
|
192 |
# text_mask = torch.cat([text_mask, false_col], dim=1)
|
193 |
-
|
194 |
for block in self.transformer_blocks:
|
195 |
-
c, x = block(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
x = self.norm_out(x, t)
|
198 |
output = self.proj_out(x)
|
199 |
-
|
200 |
|
201 |
return output
|
202 |
|
|
|
203 |
class TextEncoder(nn.Module):
|
204 |
def __init__(
|
205 |
self,
|
@@ -219,7 +227,7 @@ class TextEncoder(nn.Module):
|
|
219 |
# Embeddings
|
220 |
self.text_embed = TextEmbedding(text_dim, text_num_embeds)
|
221 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
222 |
-
|
223 |
# Example stack of DiTBlocks or any custom blocks
|
224 |
self.transformer_blocks = nn.ModuleList(
|
225 |
[
|
@@ -239,7 +247,7 @@ class TextEncoder(nn.Module):
|
|
239 |
text: int["b nt"], # noqa: F821
|
240 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
241 |
mask: bool["b nt"] | None = None, # noqa: F821 F722
|
242 |
-
drop_text: bool = False
|
243 |
):
|
244 |
"""
|
245 |
Encode text into hidden states of shape [b, nt, d].
|
@@ -251,7 +259,7 @@ class TextEncoder(nn.Module):
|
|
251 |
|
252 |
# Basic embedding
|
253 |
hidden_states = self.text_embed(text, seq_len) # [b, nt, d]
|
254 |
-
|
255 |
# lens and mask
|
256 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
257 |
|
@@ -260,17 +268,18 @@ class TextEncoder(nn.Module):
|
|
260 |
# Here, you likely want standard self-attn, so no cross-attn
|
261 |
hidden_states = block(
|
262 |
x=hidden_states,
|
263 |
-
t=time,
|
264 |
-
mask=mask,
|
265 |
-
rope=rope
|
266 |
)
|
267 |
return hidden_states
|
268 |
|
|
|
269 |
if __name__ == "__main__":
|
270 |
from f5_tts.model.utils import get_tokenizer
|
271 |
|
272 |
bsz = 16
|
273 |
-
|
274 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
275 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
276 |
dataset_name = "Emilia_ZH_EN"
|
@@ -279,23 +288,22 @@ if __name__ == "__main__":
|
|
279 |
else:
|
280 |
tokenizer_path = dataset_name
|
281 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
282 |
-
|
283 |
text = ["hello world"] * bsz
|
284 |
-
text_lens = torch.ones((bsz,
|
285 |
text_lens[-1] = 5
|
286 |
device = "cuda"
|
287 |
batch = bsz
|
288 |
time_embed = TimestepEmbedding(512).to(device)
|
289 |
-
|
290 |
-
|
291 |
# handle text as string
|
292 |
if isinstance(text, list):
|
293 |
if exists(vocab_char_map):
|
294 |
text = list_str_to_idx(text, vocab_char_map).to(device)
|
295 |
else:
|
296 |
text = list_str_to_tensor(text).to(device)
|
297 |
-
assert text.shape[0] == batch
|
298 |
-
|
299 |
time = torch.rand((batch,), device=device)
|
300 |
text_mask = lens_to_mask(text_lens).to(device)
|
301 |
|
@@ -311,7 +319,7 @@ if __name__ == "__main__":
|
|
311 |
# ).to('cuda')
|
312 |
# hidden_states = text_encoder(text, time_embed(time), mask)
|
313 |
# print(hidden_states.shape) # [bsz, seq_len, text_dim]
|
314 |
-
|
315 |
# test MMDiT
|
316 |
mel_dim = 80
|
317 |
model = MMDiT(
|
@@ -323,14 +331,23 @@ if __name__ == "__main__":
|
|
323 |
dropout=0.1,
|
324 |
ff_mult=4,
|
325 |
text_num_embeds=vocab_size,
|
326 |
-
mel_dim=mel_dim
|
327 |
).to(device)
|
328 |
-
|
329 |
x = torch.rand((batch, 100, mel_dim), device=device)
|
330 |
cond = torch.rand((batch, 100, mel_dim), device=device)
|
331 |
lens = torch.ones((batch,), dtype=torch.long) * 100
|
332 |
mask = lens_to_mask(lens).to(device)
|
333 |
-
|
334 |
-
output = model(
|
335 |
-
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from __future__ import annotations
|
11 |
|
12 |
import torch
|
|
|
13 |
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
from x_transformers.x_transformers import RotaryEmbedding
|
16 |
|
17 |
+
from f5_tts.model.modules import (AdaLayerNormZero_Final,
|
18 |
+
ConvPositionEmbedding, DiTBlock, MMDiTBlock,
|
19 |
+
TimestepEmbedding, get_pos_embed_indices,
|
20 |
+
precompute_freqs_cis)
|
21 |
+
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
|
22 |
+
list_str_to_tensor, mask_from_frac_lengths)
|
23 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# text embedding
|
25 |
|
26 |
|
27 |
class TextEmbedding(nn.Module):
|
28 |
def __init__(self, out_dim, text_num_embeds):
|
29 |
super().__init__()
|
30 |
+
self.text_embed = nn.Embedding(
|
31 |
+
text_num_embeds + 1, out_dim
|
32 |
+
) # will use 0 as filler token
|
33 |
|
34 |
self.precompute_max_pos = 1024
|
35 |
+
self.register_buffer(
|
36 |
+
"freqs_cis",
|
37 |
+
precompute_freqs_cis(out_dim, self.precompute_max_pos),
|
38 |
+
persistent=False,
|
39 |
+
)
|
40 |
|
41 |
+
def forward(
|
42 |
+
self, text: int["b nt"], drop_text=False
|
43 |
+
) -> int["b nt d"]: # noqa: F722
|
44 |
text = text + 1
|
45 |
if drop_text:
|
46 |
text = torch.zeros_like(text)
|
|
|
49 |
# sinus pos emb
|
50 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
51 |
batch_text_len = text.shape[1]
|
52 |
+
pos_idx = get_pos_embed_indices(
|
53 |
+
batch_start, batch_text_len, max_pos=self.precompute_max_pos
|
54 |
+
)
|
55 |
text_pos_embed = self.freqs_cis[pos_idx]
|
56 |
|
57 |
text = text + text_pos_embed
|
|
|
68 |
self.linear = nn.Linear(2 * in_dim, out_dim)
|
69 |
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
70 |
|
71 |
+
def forward(
|
72 |
+
self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False
|
73 |
+
): # noqa: F722
|
74 |
if drop_audio_cond:
|
75 |
cond = torch.zeros_like(cond)
|
76 |
x = torch.cat((x, cond), dim=-1)
|
|
|
97 |
mel_dim=100,
|
98 |
checkpoint_activations=False,
|
99 |
text_encoder=True,
|
|
|
100 |
):
|
101 |
super().__init__()
|
102 |
|
103 |
self.time_embed = TimestepEmbedding(dim)
|
104 |
if text_encoder:
|
105 |
+
self.text_encoder = TextEncoder(
|
106 |
+
text_num_embeds=text_num_embeds,
|
107 |
+
text_dim=dim,
|
108 |
+
depth=text_depth,
|
109 |
+
heads=heads,
|
110 |
+
dim_head=dim_head,
|
111 |
+
ff_mult=ff_mult,
|
112 |
+
dropout=dropout,
|
113 |
+
)
|
114 |
else:
|
115 |
self.text_encoder = None
|
116 |
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
117 |
+
|
118 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
119 |
|
120 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
137 |
)
|
138 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
139 |
self.proj_out = nn.Linear(dim, mel_dim)
|
|
|
|
|
140 |
|
141 |
+
self.checkpoint_activations = checkpoint_activations
|
142 |
|
143 |
def forward(
|
144 |
self,
|
|
|
161 |
c = self.text_encoder(text, t, mask=text_mask, drop_text=drop_text)
|
162 |
else:
|
163 |
c = self.text_embed(text, drop_text=drop_text)
|
164 |
+
|
165 |
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
166 |
|
167 |
seq_len = x.shape[1]
|
168 |
text_len = text.shape[1]
|
169 |
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
170 |
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
171 |
+
|
172 |
# if mask is not None:
|
173 |
# rope_audio = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
174 |
+
|
175 |
# dummy_token = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype)
|
176 |
# x = torch.cat([x, dummy_token], dim=1) # shape is now [b, nw+1, d]
|
177 |
+
|
178 |
# # pad the mask so that new dummy token is always masked out
|
179 |
# # mask: [b, nw] -> [b, nw+1]
|
180 |
# false_col = torch.zeros((x.shape[0], 1), dtype=torch.bool, device=x.device)
|
181 |
# mask = torch.cat([mask, false_col], dim=1)
|
182 |
+
|
183 |
# if text_mask is not None:
|
184 |
# rope_text = self.rotary_embed.forward_from_seq_len(text_len + 1)
|
185 |
|
186 |
# dummy_token = torch.zeros((c.shape[0], 1, c.shape[-1]), device=c.device, dtype=c.dtype)
|
187 |
# c = torch.cat([c, dummy_token], dim=1) # shape is now [b, nt+1, d]
|
188 |
+
|
189 |
# # pad the text mask so that new dummy token is always masked out
|
190 |
# # text_mask: [b, nt] -> [b, nt+1]
|
191 |
# false_col = torch.zeros((c.shape[0], 1), dtype=torch.bool, device=c.device)
|
192 |
# text_mask = torch.cat([text_mask, false_col], dim=1)
|
193 |
+
|
194 |
for block in self.transformer_blocks:
|
195 |
+
c, x = block(
|
196 |
+
x,
|
197 |
+
c,
|
198 |
+
t,
|
199 |
+
mask=mask,
|
200 |
+
src_mask=text_mask,
|
201 |
+
rope=rope_audio,
|
202 |
+
c_rope=rope_text,
|
203 |
+
)
|
204 |
|
205 |
x = self.norm_out(x, t)
|
206 |
output = self.proj_out(x)
|
|
|
207 |
|
208 |
return output
|
209 |
|
210 |
+
|
211 |
class TextEncoder(nn.Module):
|
212 |
def __init__(
|
213 |
self,
|
|
|
227 |
# Embeddings
|
228 |
self.text_embed = TextEmbedding(text_dim, text_num_embeds)
|
229 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
230 |
+
|
231 |
# Example stack of DiTBlocks or any custom blocks
|
232 |
self.transformer_blocks = nn.ModuleList(
|
233 |
[
|
|
|
247 |
text: int["b nt"], # noqa: F821
|
248 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
249 |
mask: bool["b nt"] | None = None, # noqa: F821 F722
|
250 |
+
drop_text: bool = False,
|
251 |
):
|
252 |
"""
|
253 |
Encode text into hidden states of shape [b, nt, d].
|
|
|
259 |
|
260 |
# Basic embedding
|
261 |
hidden_states = self.text_embed(text, seq_len) # [b, nt, d]
|
262 |
+
|
263 |
# lens and mask
|
264 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
265 |
|
|
|
268 |
# Here, you likely want standard self-attn, so no cross-attn
|
269 |
hidden_states = block(
|
270 |
x=hidden_states,
|
271 |
+
t=time, # no time embedding for the text encoder by default
|
272 |
+
mask=mask, # or pass a text mask if needed
|
273 |
+
rope=rope, # pass a rope if you want rotary embeddings for text
|
274 |
)
|
275 |
return hidden_states
|
276 |
|
277 |
+
|
278 |
if __name__ == "__main__":
|
279 |
from f5_tts.model.utils import get_tokenizer
|
280 |
|
281 |
bsz = 16
|
282 |
+
|
283 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
284 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
285 |
dataset_name = "Emilia_ZH_EN"
|
|
|
288 |
else:
|
289 |
tokenizer_path = dataset_name
|
290 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
291 |
+
|
292 |
text = ["hello world"] * bsz
|
293 |
+
text_lens = torch.ones((bsz,), dtype=torch.long) * len("hello world")
|
294 |
text_lens[-1] = 5
|
295 |
device = "cuda"
|
296 |
batch = bsz
|
297 |
time_embed = TimestepEmbedding(512).to(device)
|
298 |
+
|
|
|
299 |
# handle text as string
|
300 |
if isinstance(text, list):
|
301 |
if exists(vocab_char_map):
|
302 |
text = list_str_to_idx(text, vocab_char_map).to(device)
|
303 |
else:
|
304 |
text = list_str_to_tensor(text).to(device)
|
305 |
+
assert text.shape[0] == batch
|
306 |
+
|
307 |
time = torch.rand((batch,), device=device)
|
308 |
text_mask = lens_to_mask(text_lens).to(device)
|
309 |
|
|
|
319 |
# ).to('cuda')
|
320 |
# hidden_states = text_encoder(text, time_embed(time), mask)
|
321 |
# print(hidden_states.shape) # [bsz, seq_len, text_dim]
|
322 |
+
|
323 |
# test MMDiT
|
324 |
mel_dim = 80
|
325 |
model = MMDiT(
|
|
|
331 |
dropout=0.1,
|
332 |
ff_mult=4,
|
333 |
text_num_embeds=vocab_size,
|
334 |
+
mel_dim=mel_dim,
|
335 |
).to(device)
|
336 |
+
|
337 |
x = torch.rand((batch, 100, mel_dim), device=device)
|
338 |
cond = torch.rand((batch, 100, mel_dim), device=device)
|
339 |
lens = torch.ones((batch,), dtype=torch.long) * 100
|
340 |
mask = lens_to_mask(lens).to(device)
|
341 |
+
|
342 |
+
output = model(
|
343 |
+
x,
|
344 |
+
cond,
|
345 |
+
text,
|
346 |
+
time,
|
347 |
+
drop_audio_cond=False,
|
348 |
+
drop_text=False,
|
349 |
+
mask=mask,
|
350 |
+
text_mask=text_mask,
|
351 |
+
)
|
352 |
+
|
353 |
+
print(output.shape) # [bsz, seq_len, mel_dim]
|
f5_tts/model/backbones/unett.py
CHANGED
@@ -8,26 +8,19 @@ d - dimension
|
|
8 |
"""
|
9 |
|
10 |
from __future__ import annotations
|
|
|
11 |
from typing import Literal
|
12 |
|
13 |
import torch
|
14 |
-
from torch import nn
|
15 |
import torch.nn.functional as F
|
16 |
-
|
17 |
from x_transformers import RMSNorm
|
18 |
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
|
20 |
-
from f5_tts.model.modules import (
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
Attention,
|
25 |
-
AttnProcessor,
|
26 |
-
FeedForward,
|
27 |
-
precompute_freqs_cis,
|
28 |
-
get_pos_embed_indices,
|
29 |
-
)
|
30 |
-
|
31 |
|
32 |
# Text embedding
|
33 |
|
@@ -35,21 +28,34 @@ from f5_tts.model.modules import (
|
|
35 |
class TextEmbedding(nn.Module):
|
36 |
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
37 |
super().__init__()
|
38 |
-
self.text_embed = nn.Embedding(
|
|
|
|
|
39 |
|
40 |
if conv_layers > 0:
|
41 |
self.extra_modeling = True
|
42 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
43 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
44 |
self.text_blocks = nn.Sequential(
|
45 |
-
*[
|
|
|
|
|
|
|
46 |
)
|
47 |
else:
|
48 |
self.extra_modeling = False
|
49 |
|
50 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
51 |
-
text =
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
batch, text_len = text.shape[0], text.shape[1]
|
54 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
55 |
|
@@ -62,7 +68,9 @@ class TextEmbedding(nn.Module):
|
|
62 |
if self.extra_modeling:
|
63 |
# sinus pos emb
|
64 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
65 |
-
pos_idx = get_pos_embed_indices(
|
|
|
|
|
66 |
text_pos_embed = self.freqs_cis[pos_idx]
|
67 |
text = text + text_pos_embed
|
68 |
|
@@ -81,7 +89,13 @@ class InputEmbedding(nn.Module):
|
|
81 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
82 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
83 |
|
84 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
if drop_audio_cond: # cfg for cond audio
|
86 |
cond = torch.zeros_like(cond)
|
87 |
|
@@ -115,7 +129,9 @@ class UNetT(nn.Module):
|
|
115 |
self.time_embed = TimestepEmbedding(dim)
|
116 |
if text_dim is None:
|
117 |
text_dim = mel_dim
|
118 |
-
self.text_embed = TextEmbedding(
|
|
|
|
|
119 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
120 |
|
121 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -144,7 +160,11 @@ class UNetT(nn.Module):
|
|
144 |
ff_norm = RMSNorm(dim)
|
145 |
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
146 |
|
147 |
-
skip_proj =
|
|
|
|
|
|
|
|
|
148 |
|
149 |
self.layers.append(
|
150 |
nn.ModuleList(
|
@@ -190,7 +210,9 @@ class UNetT(nn.Module):
|
|
190 |
# flat unet transformer
|
191 |
skip_connect_type = self.skip_connect_type
|
192 |
skips = []
|
193 |
-
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(
|
|
|
|
|
194 |
layer = idx + 1
|
195 |
|
196 |
# skip connection logic
|
|
|
8 |
"""
|
9 |
|
10 |
from __future__ import annotations
|
11 |
+
|
12 |
from typing import Literal
|
13 |
|
14 |
import torch
|
|
|
15 |
import torch.nn.functional as F
|
16 |
+
from torch import nn
|
17 |
from x_transformers import RMSNorm
|
18 |
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
|
20 |
+
from f5_tts.model.modules import (Attention, AttnProcessor, ConvNeXtV2Block,
|
21 |
+
ConvPositionEmbedding, FeedForward,
|
22 |
+
TimestepEmbedding, get_pos_embed_indices,
|
23 |
+
precompute_freqs_cis)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Text embedding
|
26 |
|
|
|
28 |
class TextEmbedding(nn.Module):
|
29 |
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
30 |
super().__init__()
|
31 |
+
self.text_embed = nn.Embedding(
|
32 |
+
text_num_embeds + 1, text_dim
|
33 |
+
) # use 0 as filler token
|
34 |
|
35 |
if conv_layers > 0:
|
36 |
self.extra_modeling = True
|
37 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
38 |
+
self.register_buffer(
|
39 |
+
"freqs_cis",
|
40 |
+
precompute_freqs_cis(text_dim, self.precompute_max_pos),
|
41 |
+
persistent=False,
|
42 |
+
)
|
43 |
self.text_blocks = nn.Sequential(
|
44 |
+
*[
|
45 |
+
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
|
46 |
+
for _ in range(conv_layers)
|
47 |
+
]
|
48 |
)
|
49 |
else:
|
50 |
self.extra_modeling = False
|
51 |
|
52 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
53 |
+
text = (
|
54 |
+
text + 1
|
55 |
+
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
56 |
+
text = text[
|
57 |
+
:, :seq_len
|
58 |
+
] # curtail if character tokens are more than the mel spec tokens
|
59 |
batch, text_len = text.shape[0], text.shape[1]
|
60 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
61 |
|
|
|
68 |
if self.extra_modeling:
|
69 |
# sinus pos emb
|
70 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
71 |
+
pos_idx = get_pos_embed_indices(
|
72 |
+
batch_start, seq_len, max_pos=self.precompute_max_pos
|
73 |
+
)
|
74 |
text_pos_embed = self.freqs_cis[pos_idx]
|
75 |
text = text + text_pos_embed
|
76 |
|
|
|
89 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
90 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
91 |
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
x: float["b n d"],
|
95 |
+
cond: float["b n d"],
|
96 |
+
text_embed: float["b n d"],
|
97 |
+
drop_audio_cond=False,
|
98 |
+
): # noqa: F722
|
99 |
if drop_audio_cond: # cfg for cond audio
|
100 |
cond = torch.zeros_like(cond)
|
101 |
|
|
|
129 |
self.time_embed = TimestepEmbedding(dim)
|
130 |
if text_dim is None:
|
131 |
text_dim = mel_dim
|
132 |
+
self.text_embed = TextEmbedding(
|
133 |
+
text_num_embeds, text_dim, conv_layers=conv_layers
|
134 |
+
)
|
135 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
136 |
|
137 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
160 |
ff_norm = RMSNorm(dim)
|
161 |
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
162 |
|
163 |
+
skip_proj = (
|
164 |
+
nn.Linear(dim * 2, dim, bias=False)
|
165 |
+
if needs_skip_proj and is_later_half
|
166 |
+
else None
|
167 |
+
)
|
168 |
|
169 |
self.layers.append(
|
170 |
nn.ModuleList(
|
|
|
210 |
# flat unet transformer
|
211 |
skip_connect_type = self.skip_connect_type
|
212 |
skips = []
|
213 |
+
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(
|
214 |
+
self.layers
|
215 |
+
):
|
216 |
layer = idx + 1
|
217 |
|
218 |
# skip connection logic
|
f5_tts/model/cfm.py
CHANGED
@@ -19,14 +19,8 @@ from torch.nn.utils.rnn import pad_sequence
|
|
19 |
from torchdiffeq import odeint
|
20 |
|
21 |
from f5_tts.model.modules import MelSpec
|
22 |
-
from f5_tts.model.utils import (
|
23 |
-
|
24 |
-
exists,
|
25 |
-
lens_to_mask,
|
26 |
-
list_str_to_idx,
|
27 |
-
list_str_to_tensor,
|
28 |
-
mask_from_frac_lengths,
|
29 |
-
)
|
30 |
|
31 |
|
32 |
class CFM(nn.Module):
|
@@ -74,7 +68,7 @@ class CFM(nn.Module):
|
|
74 |
|
75 |
# vocab map for tokenization
|
76 |
self.vocab_char_map = vocab_char_map
|
77 |
-
|
78 |
self.scale = scale
|
79 |
|
80 |
@property
|
@@ -109,11 +103,11 @@ class CFM(nn.Module):
|
|
109 |
assert cond.shape[-1] == self.num_channels
|
110 |
|
111 |
cond = cond.to(next(self.parameters()).dtype)
|
112 |
-
|
113 |
print(self.scale)
|
114 |
|
115 |
cond = cond / self.scale
|
116 |
-
|
117 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
118 |
if not exists(lens):
|
119 |
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
@@ -129,7 +123,9 @@ class CFM(nn.Module):
|
|
129 |
|
130 |
if exists(text):
|
131 |
text_lens = (text != -1).sum(dim=-1)
|
132 |
-
lens = torch.maximum(
|
|
|
|
|
133 |
|
134 |
# duration
|
135 |
|
@@ -140,19 +136,25 @@ class CFM(nn.Module):
|
|
140 |
if isinstance(duration, int):
|
141 |
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
142 |
|
143 |
-
duration = torch.maximum(
|
|
|
|
|
144 |
duration = duration.clamp(max=max_duration)
|
145 |
max_duration = duration.amax()
|
146 |
|
147 |
# duplicate test corner for inner time step oberservation
|
148 |
if duplicate_test:
|
149 |
-
test_cond = F.pad(
|
|
|
|
|
150 |
|
151 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
152 |
if no_ref_audio:
|
153 |
cond = torch.zeros_like(cond)
|
154 |
|
155 |
-
cond_mask = F.pad(
|
|
|
|
|
156 |
cond_mask = cond_mask.unsqueeze(-1)
|
157 |
step_cond = torch.where(
|
158 |
cond_mask, cond, torch.zeros_like(cond)
|
@@ -171,13 +173,25 @@ class CFM(nn.Module):
|
|
171 |
|
172 |
# predict flow
|
173 |
pred = self.transformer(
|
174 |
-
x=x,
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
)
|
176 |
if cfg_strength < 1e-5:
|
177 |
return pred
|
178 |
|
179 |
null_pred = self.transformer(
|
180 |
-
x=x,
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
)
|
182 |
return pred + (pred - null_pred) * cfg_strength
|
183 |
|
@@ -188,7 +202,11 @@ class CFM(nn.Module):
|
|
188 |
for dur in duration:
|
189 |
if exists(seed):
|
190 |
torch.manual_seed(seed)
|
191 |
-
y0.append(
|
|
|
|
|
|
|
|
|
192 |
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
193 |
|
194 |
t_start = 0
|
@@ -199,7 +217,9 @@ class CFM(nn.Module):
|
|
199 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
200 |
steps = int(steps * (1 - t_start))
|
201 |
|
202 |
-
t = torch.linspace(
|
|
|
|
|
203 |
if sway_sampling_coef is not None:
|
204 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
205 |
|
@@ -210,7 +230,7 @@ class CFM(nn.Module):
|
|
210 |
out = torch.where(cond_mask, cond, out)
|
211 |
|
212 |
out = out * self.scale
|
213 |
-
|
214 |
if exists(vocoder):
|
215 |
out = out.permute(0, 2, 1)
|
216 |
out = vocoder(out)
|
@@ -231,7 +251,12 @@ class CFM(nn.Module):
|
|
231 |
inp = inp.permute(0, 2, 1)
|
232 |
assert inp.shape[-1] == self.num_channels
|
233 |
|
234 |
-
batch, seq_len, dtype, device, _σ1 =
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
# handle text as string
|
237 |
if isinstance(text, list):
|
@@ -245,10 +270,16 @@ class CFM(nn.Module):
|
|
245 |
if not exists(lens):
|
246 |
lens = torch.full((batch,), seq_len, device=device)
|
247 |
|
248 |
-
mask = lens_to_mask(
|
|
|
|
|
249 |
|
250 |
# get a random span to mask out for training conditionally
|
251 |
-
frac_lengths =
|
|
|
|
|
|
|
|
|
252 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
253 |
|
254 |
if exists(mask):
|
@@ -283,11 +314,16 @@ class CFM(nn.Module):
|
|
283 |
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
284 |
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
285 |
pred = self.transformer(
|
286 |
-
x=φ,
|
|
|
|
|
|
|
|
|
|
|
287 |
)
|
288 |
|
289 |
# flow matching loss
|
290 |
loss = F.mse_loss(pred, flow, reduction="none")
|
291 |
loss = loss[rand_span_mask]
|
292 |
|
293 |
-
return loss.mean(), cond, pred, t
|
|
|
19 |
from torchdiffeq import odeint
|
20 |
|
21 |
from f5_tts.model.modules import MelSpec
|
22 |
+
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
|
23 |
+
list_str_to_tensor, mask_from_frac_lengths)
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class CFM(nn.Module):
|
|
|
68 |
|
69 |
# vocab map for tokenization
|
70 |
self.vocab_char_map = vocab_char_map
|
71 |
+
|
72 |
self.scale = scale
|
73 |
|
74 |
@property
|
|
|
103 |
assert cond.shape[-1] == self.num_channels
|
104 |
|
105 |
cond = cond.to(next(self.parameters()).dtype)
|
106 |
+
|
107 |
print(self.scale)
|
108 |
|
109 |
cond = cond / self.scale
|
110 |
+
|
111 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
112 |
if not exists(lens):
|
113 |
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
|
|
123 |
|
124 |
if exists(text):
|
125 |
text_lens = (text != -1).sum(dim=-1)
|
126 |
+
lens = torch.maximum(
|
127 |
+
text_lens, lens
|
128 |
+
) # make sure lengths are at least those of the text characters
|
129 |
|
130 |
# duration
|
131 |
|
|
|
136 |
if isinstance(duration, int):
|
137 |
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
138 |
|
139 |
+
duration = torch.maximum(
|
140 |
+
lens + 1, duration
|
141 |
+
) # just add one token so something is generated
|
142 |
duration = duration.clamp(max=max_duration)
|
143 |
max_duration = duration.amax()
|
144 |
|
145 |
# duplicate test corner for inner time step oberservation
|
146 |
if duplicate_test:
|
147 |
+
test_cond = F.pad(
|
148 |
+
cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
|
149 |
+
)
|
150 |
|
151 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
152 |
if no_ref_audio:
|
153 |
cond = torch.zeros_like(cond)
|
154 |
|
155 |
+
cond_mask = F.pad(
|
156 |
+
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
|
157 |
+
)
|
158 |
cond_mask = cond_mask.unsqueeze(-1)
|
159 |
step_cond = torch.where(
|
160 |
cond_mask, cond, torch.zeros_like(cond)
|
|
|
173 |
|
174 |
# predict flow
|
175 |
pred = self.transformer(
|
176 |
+
x=x,
|
177 |
+
cond=step_cond,
|
178 |
+
text=text,
|
179 |
+
time=t,
|
180 |
+
mask=mask,
|
181 |
+
drop_audio_cond=False,
|
182 |
+
drop_text=False,
|
183 |
)
|
184 |
if cfg_strength < 1e-5:
|
185 |
return pred
|
186 |
|
187 |
null_pred = self.transformer(
|
188 |
+
x=x,
|
189 |
+
cond=step_cond,
|
190 |
+
text=text,
|
191 |
+
time=t,
|
192 |
+
mask=mask,
|
193 |
+
drop_audio_cond=True,
|
194 |
+
drop_text=True,
|
195 |
)
|
196 |
return pred + (pred - null_pred) * cfg_strength
|
197 |
|
|
|
202 |
for dur in duration:
|
203 |
if exists(seed):
|
204 |
torch.manual_seed(seed)
|
205 |
+
y0.append(
|
206 |
+
torch.randn(
|
207 |
+
dur, self.num_channels, device=self.device, dtype=step_cond.dtype
|
208 |
+
)
|
209 |
+
)
|
210 |
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
211 |
|
212 |
t_start = 0
|
|
|
217 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
218 |
steps = int(steps * (1 - t_start))
|
219 |
|
220 |
+
t = torch.linspace(
|
221 |
+
t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
|
222 |
+
)
|
223 |
if sway_sampling_coef is not None:
|
224 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
225 |
|
|
|
230 |
out = torch.where(cond_mask, cond, out)
|
231 |
|
232 |
out = out * self.scale
|
233 |
+
|
234 |
if exists(vocoder):
|
235 |
out = out.permute(0, 2, 1)
|
236 |
out = vocoder(out)
|
|
|
251 |
inp = inp.permute(0, 2, 1)
|
252 |
assert inp.shape[-1] == self.num_channels
|
253 |
|
254 |
+
batch, seq_len, dtype, device, _σ1 = (
|
255 |
+
*inp.shape[:2],
|
256 |
+
inp.dtype,
|
257 |
+
self.device,
|
258 |
+
self.sigma,
|
259 |
+
)
|
260 |
|
261 |
# handle text as string
|
262 |
if isinstance(text, list):
|
|
|
270 |
if not exists(lens):
|
271 |
lens = torch.full((batch,), seq_len, device=device)
|
272 |
|
273 |
+
mask = lens_to_mask(
|
274 |
+
lens, length=seq_len
|
275 |
+
) # useless here, as collate_fn will pad to max length in batch
|
276 |
|
277 |
# get a random span to mask out for training conditionally
|
278 |
+
frac_lengths = (
|
279 |
+
torch.zeros((batch,), device=self.device)
|
280 |
+
.float()
|
281 |
+
.uniform_(*self.frac_lengths_mask)
|
282 |
+
)
|
283 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
284 |
|
285 |
if exists(mask):
|
|
|
314 |
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
315 |
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
316 |
pred = self.transformer(
|
317 |
+
x=φ,
|
318 |
+
cond=cond,
|
319 |
+
text=text,
|
320 |
+
time=time,
|
321 |
+
drop_audio_cond=drop_audio_cond,
|
322 |
+
drop_text=drop_text,
|
323 |
)
|
324 |
|
325 |
# flow matching loss
|
326 |
loss = F.mse_loss(pred, flow, reduction="none")
|
327 |
loss = loss[rand_span_mask]
|
328 |
|
329 |
+
return loss.mean(), cond, pred, t
|
f5_tts/model/dataset.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import re
|
2 |
import json
|
3 |
import random
|
|
|
4 |
from importlib.resources import files
|
5 |
|
6 |
import torch
|
@@ -15,8 +15,9 @@ from tqdm import tqdm
|
|
15 |
from f5_tts.model.modules import MelSpec
|
16 |
from f5_tts.model.utils import default
|
17 |
|
|
|
18 |
def get_speaker_id(path):
|
19 |
-
parts = path.split(
|
20 |
speaker_id = parts[-3]
|
21 |
return speaker_id
|
22 |
|
@@ -40,7 +41,7 @@ class CustomDataset(Dataset):
|
|
40 |
return_wavform=False,
|
41 |
remove_starting_space=True,
|
42 |
need_prompt_speech=False,
|
43 |
-
prompt_repository: dict=None,
|
44 |
):
|
45 |
self.data = custom_dataset
|
46 |
self.durations = durations
|
@@ -63,42 +64,32 @@ class CustomDataset(Dataset):
|
|
63 |
mel_spec_type=mel_spec_type,
|
64 |
),
|
65 |
)
|
66 |
-
|
67 |
self.validation = validation
|
68 |
self.validation_num = validation_num
|
69 |
|
70 |
if (not validation) and data_augmentation:
|
71 |
-
print(
|
72 |
-
self.augment = Compose(
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
min_amplitude=0.001,
|
82 |
-
|
83 |
-
p=0
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
ApplyImpulseResponse(ir_path="/data5/Audio", p=1.0),
|
91 |
-
Aliasing(min_sample_rate=4000, max_sample_rate=30000, p=0.3),
|
92 |
-
BandPassFilter(min_center_freq=100.0, max_center_freq=6000, p=0.2),
|
93 |
-
SevenBandParametricEQ(p=0.2),
|
94 |
-
TanhDistortion(
|
95 |
-
min_distortion=0.01,
|
96 |
-
max_distortion=0.7,
|
97 |
-
p=0.2
|
98 |
-
),
|
99 |
-
])
|
100 |
else:
|
101 |
-
print(
|
102 |
self.augment = None
|
103 |
|
104 |
self.return_wavform = return_wavform
|
@@ -112,7 +103,7 @@ class CustomDataset(Dataset):
|
|
112 |
text = row["text"]
|
113 |
duration = row["duration"]
|
114 |
spk_id = get_speaker_id(audio_path)
|
115 |
-
assert spk_id != None and spk_id !=
|
116 |
if spk_id not in self.prompt_repository:
|
117 |
self.prompt_repository[spk_id] = [row]
|
118 |
else:
|
@@ -120,13 +111,14 @@ class CustomDataset(Dataset):
|
|
120 |
else:
|
121 |
self.prompt_repository = prompt_repository
|
122 |
|
123 |
-
print(
|
|
|
|
|
124 |
self.need_prompt_speech = True
|
125 |
|
126 |
else:
|
127 |
self.need_prompt_speech = False
|
128 |
|
129 |
-
|
130 |
def get_frame_len(self, index):
|
131 |
if self.validation:
|
132 |
index += len(self.data) - self.validation_num
|
@@ -164,9 +156,9 @@ class CustomDataset(Dataset):
|
|
164 |
index = (index + 1) % len(self.data)
|
165 |
|
166 |
if self.remove_starting_space:
|
167 |
-
while len(text) > 1 and text[0] ==
|
168 |
text = text[1:]
|
169 |
-
|
170 |
if self.preprocessed_mel:
|
171 |
mel_spec = torch.tensor(row["mel_spec"])
|
172 |
else:
|
@@ -178,31 +170,37 @@ class CustomDataset(Dataset):
|
|
178 |
|
179 |
# resample if necessary
|
180 |
if source_sample_rate != self.target_sample_rate:
|
181 |
-
resampler = torchaudio.transforms.Resample(
|
|
|
|
|
182 |
audio = resampler(audio)
|
183 |
|
184 |
if not self.validation:
|
185 |
if self.augment != None:
|
186 |
-
audio = self.augment(
|
|
|
|
|
187 |
audio = torch.from_numpy(audio).float().unsqueeze(0)
|
188 |
|
189 |
# to mel spectrogram
|
190 |
mel_spec = self.mel_spectrogram(audio)
|
191 |
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
192 |
|
193 |
-
out[
|
194 |
-
out[
|
195 |
-
out[
|
196 |
-
out[
|
|
|
|
|
197 |
|
198 |
if self.return_wavform:
|
199 |
-
out[
|
200 |
|
201 |
if return_path:
|
202 |
-
out[
|
203 |
|
204 |
if return_row:
|
205 |
-
out[
|
206 |
|
207 |
# Sample a prompt speech of the same speaker
|
208 |
# From prompt_repository
|
@@ -212,9 +210,9 @@ class CustomDataset(Dataset):
|
|
212 |
_count = 100
|
213 |
while True:
|
214 |
pmt_row = random.choice(spk_repository)
|
215 |
-
pmt_audio_path = pmt_row[
|
216 |
-
pmt_text = pmt_row[
|
217 |
-
pmt_duration = pmt_row[
|
218 |
|
219 |
if not isinstance(pmt_text, list):
|
220 |
pmt_text = list(pmt_text)
|
@@ -223,14 +221,14 @@ class CustomDataset(Dataset):
|
|
223 |
if 0.3 <= pmt_duration <= 30 and (0 < len(pmt_text) < 2048):
|
224 |
if pmt_text != text:
|
225 |
break
|
226 |
-
_count =
|
227 |
if _count <= 0:
|
228 |
break
|
229 |
|
230 |
if self.remove_starting_space:
|
231 |
-
while len(pmt_text) > 1 and pmt_text[0] ==
|
232 |
pmt_text = pmt_text[1:]
|
233 |
-
|
234 |
if self.preprocessed_mel:
|
235 |
pmt_mel_spec = torch.tensor(pmt_row["mel_spec"])
|
236 |
else:
|
@@ -242,30 +240,35 @@ class CustomDataset(Dataset):
|
|
242 |
|
243 |
# resample if necessary
|
244 |
if source_sample_rate != self.target_sample_rate:
|
245 |
-
resampler = torchaudio.transforms.Resample(
|
|
|
|
|
246 |
pmt_audio = resampler(pmt_audio)
|
247 |
|
248 |
if not self.validation:
|
249 |
if self.augment != None:
|
250 |
-
pmt_audio = self.augment(
|
|
|
|
|
|
|
251 |
pmt_audio = torch.from_numpy(pmt_audio).float().unsqueeze(0)
|
252 |
|
253 |
# to mel spectrogram
|
254 |
pmt_mel_spec = self.mel_spectrogram(pmt_audio)
|
255 |
pmt_mel_spec = pmt_mel_spec.squeeze(0) # '1 d t -> d t'
|
256 |
|
257 |
-
out[
|
258 |
-
out[
|
259 |
-
out[
|
260 |
|
261 |
if self.return_wavform:
|
262 |
-
out[
|
263 |
|
264 |
if return_path:
|
265 |
-
out[
|
266 |
|
267 |
if return_row:
|
268 |
-
out[
|
269 |
|
270 |
return out
|
271 |
|
@@ -280,7 +283,12 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
280 |
"""
|
281 |
|
282 |
def __init__(
|
283 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
284 |
):
|
285 |
self.sampler = sampler
|
286 |
self.frames_threshold = frames_threshold
|
@@ -302,7 +310,9 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
302 |
# indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
|
303 |
# ):
|
304 |
for idx, frame_len in indices:
|
305 |
-
if batch_frames + frame_len <= self.frames_threshold and (
|
|
|
|
|
306 |
batch.append(idx)
|
307 |
batch_frames += frame_len
|
308 |
else:
|
@@ -337,6 +347,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
337 |
|
338 |
# Load dataset
|
339 |
|
|
|
340 |
def load_dataset(
|
341 |
dataset_name: str,
|
342 |
tokenizer: str = "pinyin",
|
@@ -349,7 +360,7 @@ def load_dataset(
|
|
349 |
return_wavform: bool = False,
|
350 |
remove_starting_space: bool = True,
|
351 |
need_prompt_speech: bool = False,
|
352 |
-
prompt_repository: dict = None
|
353 |
) -> CustomDataset:
|
354 |
"""
|
355 |
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
@@ -359,9 +370,13 @@ def load_dataset(
|
|
359 |
print("Loading dataset ...")
|
360 |
|
361 |
if dataset_type == "CustomDataset":
|
362 |
-
rel_data_path = str(
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
365 |
if audio_type == "raw":
|
366 |
try:
|
367 |
train_dataset = load_from_disk(f"{rel_data_path}/raw")
|
@@ -385,7 +400,7 @@ def load_dataset(
|
|
385 |
return_wavform=return_wavform,
|
386 |
remove_starting_space=remove_starting_space,
|
387 |
need_prompt_speech=need_prompt_speech,
|
388 |
-
prompt_repository=prompt_repository
|
389 |
)
|
390 |
|
391 |
elif dataset_type == "CustomDatasetPath":
|
@@ -398,7 +413,10 @@ def load_dataset(
|
|
398 |
data_dict = json.load(f)
|
399 |
durations = data_dict["duration"]
|
400 |
train_dataset = CustomDataset(
|
401 |
-
train_dataset,
|
|
|
|
|
|
|
402 |
)
|
403 |
|
404 |
return train_dataset
|
@@ -410,7 +428,7 @@ def collate_fn(batch):
|
|
410 |
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
|
411 |
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
412 |
max_mel_length = mel_lengths.amax()
|
413 |
-
|
414 |
# Pad mel_specs
|
415 |
padded_mel_specs = []
|
416 |
for spec in mel_specs: # TODO. maybe records mask for attention here
|
@@ -419,8 +437,8 @@ def collate_fn(batch):
|
|
419 |
padded_mel_specs.append(padded_spec)
|
420 |
mel_specs = torch.stack(padded_mel_specs)
|
421 |
|
422 |
-
text = [item[
|
423 |
-
target_text = [item[
|
424 |
|
425 |
text_lengths = torch.LongTensor([len(item) for item in text])
|
426 |
|
@@ -432,26 +450,26 @@ def collate_fn(batch):
|
|
432 |
target_text=target_text,
|
433 |
)
|
434 |
|
435 |
-
if
|
436 |
pmt_mel_specs = [item["pmt_mel_spec"].squeeze(0) for item in batch]
|
437 |
pmt_mel_lengths = torch.LongTensor([spec.shape[-1] for spec in pmt_mel_specs])
|
438 |
max_pmt_mel_length = pmt_mel_lengths.amax()
|
439 |
-
|
440 |
# Pad mel_specs
|
441 |
padded_pmt_mel_specs = []
|
442 |
-
for spec in pmt_mel_specs:
|
443 |
padding = (0, max_pmt_mel_length - spec.size(-1))
|
444 |
padded_spec = F.pad(spec, padding, value=0)
|
445 |
padded_pmt_mel_specs.append(padded_spec)
|
446 |
pmt_mel_specs = torch.stack(padded_pmt_mel_specs)
|
447 |
|
448 |
-
out[
|
449 |
|
450 |
-
if
|
451 |
-
pmt_text = [item[
|
452 |
pmt_text_lengths = torch.LongTensor([len(item) for item in pmt_text])
|
453 |
|
454 |
-
out[
|
455 |
-
out[
|
456 |
|
457 |
-
return out
|
|
|
|
|
1 |
import json
|
2 |
import random
|
3 |
+
import re
|
4 |
from importlib.resources import files
|
5 |
|
6 |
import torch
|
|
|
15 |
from f5_tts.model.modules import MelSpec
|
16 |
from f5_tts.model.utils import default
|
17 |
|
18 |
+
|
19 |
def get_speaker_id(path):
|
20 |
+
parts = path.split("/")
|
21 |
speaker_id = parts[-3]
|
22 |
return speaker_id
|
23 |
|
|
|
41 |
return_wavform=False,
|
42 |
remove_starting_space=True,
|
43 |
need_prompt_speech=False,
|
44 |
+
prompt_repository: dict = None,
|
45 |
):
|
46 |
self.data = custom_dataset
|
47 |
self.durations = durations
|
|
|
64 |
mel_spec_type=mel_spec_type,
|
65 |
),
|
66 |
)
|
67 |
+
|
68 |
self.validation = validation
|
69 |
self.validation_num = validation_num
|
70 |
|
71 |
if (not validation) and data_augmentation:
|
72 |
+
print("Using data augmentation.")
|
73 |
+
self.augment = Compose(
|
74 |
+
[
|
75 |
+
AddBackgroundNoise(
|
76 |
+
sounds_path="/data5/ESC-50-master",
|
77 |
+
min_snr_db=3.0,
|
78 |
+
max_snr_db=30.0,
|
79 |
+
noise_transform=PolarityInversion(),
|
80 |
+
p=0.5,
|
81 |
+
),
|
82 |
+
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
|
83 |
+
PitchShift(min_semitones=-12.0, max_semitones=12.0, p=0.8),
|
84 |
+
ApplyImpulseResponse(ir_path="/data5/Audio", p=1.0),
|
85 |
+
Aliasing(min_sample_rate=4000, max_sample_rate=30000, p=0.3),
|
86 |
+
BandPassFilter(min_center_freq=100.0, max_center_freq=6000, p=0.2),
|
87 |
+
SevenBandParametricEQ(p=0.2),
|
88 |
+
TanhDistortion(min_distortion=0.01, max_distortion=0.7, p=0.2),
|
89 |
+
]
|
90 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
else:
|
92 |
+
print("No data augmentation.")
|
93 |
self.augment = None
|
94 |
|
95 |
self.return_wavform = return_wavform
|
|
|
103 |
text = row["text"]
|
104 |
duration = row["duration"]
|
105 |
spk_id = get_speaker_id(audio_path)
|
106 |
+
assert spk_id != None and spk_id != "mp3"
|
107 |
if spk_id not in self.prompt_repository:
|
108 |
self.prompt_repository[spk_id] = [row]
|
109 |
else:
|
|
|
111 |
else:
|
112 |
self.prompt_repository = prompt_repository
|
113 |
|
114 |
+
print(
|
115 |
+
f"Grouped samples into {len(self.prompt_repository.keys())} speakers."
|
116 |
+
)
|
117 |
self.need_prompt_speech = True
|
118 |
|
119 |
else:
|
120 |
self.need_prompt_speech = False
|
121 |
|
|
|
122 |
def get_frame_len(self, index):
|
123 |
if self.validation:
|
124 |
index += len(self.data) - self.validation_num
|
|
|
156 |
index = (index + 1) % len(self.data)
|
157 |
|
158 |
if self.remove_starting_space:
|
159 |
+
while len(text) > 1 and text[0] == " ":
|
160 |
text = text[1:]
|
161 |
+
|
162 |
if self.preprocessed_mel:
|
163 |
mel_spec = torch.tensor(row["mel_spec"])
|
164 |
else:
|
|
|
170 |
|
171 |
# resample if necessary
|
172 |
if source_sample_rate != self.target_sample_rate:
|
173 |
+
resampler = torchaudio.transforms.Resample(
|
174 |
+
source_sample_rate, self.target_sample_rate
|
175 |
+
)
|
176 |
audio = resampler(audio)
|
177 |
|
178 |
if not self.validation:
|
179 |
if self.augment != None:
|
180 |
+
audio = self.augment(
|
181 |
+
audio.squeeze().numpy(), sample_rate=self.target_sample_rate
|
182 |
+
)
|
183 |
audio = torch.from_numpy(audio).float().unsqueeze(0)
|
184 |
|
185 |
# to mel spectrogram
|
186 |
mel_spec = self.mel_spectrogram(audio)
|
187 |
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
188 |
|
189 |
+
out["mel_spec"] = mel_spec
|
190 |
+
out["text"] = text
|
191 |
+
out["duration"] = duration
|
192 |
+
out["target_text"] = self.data[(index + len(self.data) // 2) % len(self.data)][
|
193 |
+
"text"
|
194 |
+
]
|
195 |
|
196 |
if self.return_wavform:
|
197 |
+
out["wav"] = audio
|
198 |
|
199 |
if return_path:
|
200 |
+
out["path"] = audio_path
|
201 |
|
202 |
if return_row:
|
203 |
+
out["row"] = row
|
204 |
|
205 |
# Sample a prompt speech of the same speaker
|
206 |
# From prompt_repository
|
|
|
210 |
_count = 100
|
211 |
while True:
|
212 |
pmt_row = random.choice(spk_repository)
|
213 |
+
pmt_audio_path = pmt_row["audio_path"]
|
214 |
+
pmt_text = pmt_row["text"]
|
215 |
+
pmt_duration = pmt_row["duration"]
|
216 |
|
217 |
if not isinstance(pmt_text, list):
|
218 |
pmt_text = list(pmt_text)
|
|
|
221 |
if 0.3 <= pmt_duration <= 30 and (0 < len(pmt_text) < 2048):
|
222 |
if pmt_text != text:
|
223 |
break
|
224 |
+
_count = _count - 1
|
225 |
if _count <= 0:
|
226 |
break
|
227 |
|
228 |
if self.remove_starting_space:
|
229 |
+
while len(pmt_text) > 1 and pmt_text[0] == " ":
|
230 |
pmt_text = pmt_text[1:]
|
231 |
+
|
232 |
if self.preprocessed_mel:
|
233 |
pmt_mel_spec = torch.tensor(pmt_row["mel_spec"])
|
234 |
else:
|
|
|
240 |
|
241 |
# resample if necessary
|
242 |
if source_sample_rate != self.target_sample_rate:
|
243 |
+
resampler = torchaudio.transforms.Resample(
|
244 |
+
source_sample_rate, self.target_sample_rate
|
245 |
+
)
|
246 |
pmt_audio = resampler(pmt_audio)
|
247 |
|
248 |
if not self.validation:
|
249 |
if self.augment != None:
|
250 |
+
pmt_audio = self.augment(
|
251 |
+
pmt_audio.squeeze().numpy(),
|
252 |
+
sample_rate=self.target_sample_rate,
|
253 |
+
)
|
254 |
pmt_audio = torch.from_numpy(pmt_audio).float().unsqueeze(0)
|
255 |
|
256 |
# to mel spectrogram
|
257 |
pmt_mel_spec = self.mel_spectrogram(pmt_audio)
|
258 |
pmt_mel_spec = pmt_mel_spec.squeeze(0) # '1 d t -> d t'
|
259 |
|
260 |
+
out["pmt_mel_spec"] = pmt_mel_spec
|
261 |
+
out["pmt_text"] = pmt_text
|
262 |
+
out["pmt_duration"] = pmt_duration
|
263 |
|
264 |
if self.return_wavform:
|
265 |
+
out["pmt_wav"] = pmt_audio
|
266 |
|
267 |
if return_path:
|
268 |
+
out["pmt_path"] = pmt_audio_path
|
269 |
|
270 |
if return_row:
|
271 |
+
out["pmt_row"] = pmt_row
|
272 |
|
273 |
return out
|
274 |
|
|
|
283 |
"""
|
284 |
|
285 |
def __init__(
|
286 |
+
self,
|
287 |
+
sampler: Sampler[int],
|
288 |
+
frames_threshold: int,
|
289 |
+
max_samples=0,
|
290 |
+
random_seed=None,
|
291 |
+
drop_last: bool = False,
|
292 |
):
|
293 |
self.sampler = sampler
|
294 |
self.frames_threshold = frames_threshold
|
|
|
310 |
# indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
|
311 |
# ):
|
312 |
for idx, frame_len in indices:
|
313 |
+
if batch_frames + frame_len <= self.frames_threshold and (
|
314 |
+
max_samples == 0 or len(batch) < max_samples
|
315 |
+
):
|
316 |
batch.append(idx)
|
317 |
batch_frames += frame_len
|
318 |
else:
|
|
|
347 |
|
348 |
# Load dataset
|
349 |
|
350 |
+
|
351 |
def load_dataset(
|
352 |
dataset_name: str,
|
353 |
tokenizer: str = "pinyin",
|
|
|
360 |
return_wavform: bool = False,
|
361 |
remove_starting_space: bool = True,
|
362 |
need_prompt_speech: bool = False,
|
363 |
+
prompt_repository: dict = None,
|
364 |
) -> CustomDataset:
|
365 |
"""
|
366 |
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
|
|
370 |
print("Loading dataset ...")
|
371 |
|
372 |
if dataset_type == "CustomDataset":
|
373 |
+
rel_data_path = str(
|
374 |
+
f"/home/yl4579/F5-TTS-diff/F5-TTS-DMD-flow-ds/data/{dataset_name}_{tokenizer}"
|
375 |
+
)
|
376 |
+
if "LibriTTS_100_360_500_char_pinyin" in rel_data_path:
|
377 |
+
rel_data_path = rel_data_path.replace(
|
378 |
+
"LibriTTS_100_360_500_char_pinyin", "LibriTTS_100_360_500_char"
|
379 |
+
)
|
380 |
if audio_type == "raw":
|
381 |
try:
|
382 |
train_dataset = load_from_disk(f"{rel_data_path}/raw")
|
|
|
400 |
return_wavform=return_wavform,
|
401 |
remove_starting_space=remove_starting_space,
|
402 |
need_prompt_speech=need_prompt_speech,
|
403 |
+
prompt_repository=prompt_repository,
|
404 |
)
|
405 |
|
406 |
elif dataset_type == "CustomDatasetPath":
|
|
|
413 |
data_dict = json.load(f)
|
414 |
durations = data_dict["duration"]
|
415 |
train_dataset = CustomDataset(
|
416 |
+
train_dataset,
|
417 |
+
durations=durations,
|
418 |
+
preprocessed_mel=preprocessed_mel,
|
419 |
+
**mel_spec_kwargs,
|
420 |
)
|
421 |
|
422 |
return train_dataset
|
|
|
428 |
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
|
429 |
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
430 |
max_mel_length = mel_lengths.amax()
|
431 |
+
|
432 |
# Pad mel_specs
|
433 |
padded_mel_specs = []
|
434 |
for spec in mel_specs: # TODO. maybe records mask for attention here
|
|
|
437 |
padded_mel_specs.append(padded_spec)
|
438 |
mel_specs = torch.stack(padded_mel_specs)
|
439 |
|
440 |
+
text = [item["text"] for item in batch]
|
441 |
+
target_text = [item["target_text"] for item in batch]
|
442 |
|
443 |
text_lengths = torch.LongTensor([len(item) for item in text])
|
444 |
|
|
|
450 |
target_text=target_text,
|
451 |
)
|
452 |
|
453 |
+
if "pmt_mel_spec" in batch[0]:
|
454 |
pmt_mel_specs = [item["pmt_mel_spec"].squeeze(0) for item in batch]
|
455 |
pmt_mel_lengths = torch.LongTensor([spec.shape[-1] for spec in pmt_mel_specs])
|
456 |
max_pmt_mel_length = pmt_mel_lengths.amax()
|
457 |
+
|
458 |
# Pad mel_specs
|
459 |
padded_pmt_mel_specs = []
|
460 |
+
for spec in pmt_mel_specs:
|
461 |
padding = (0, max_pmt_mel_length - spec.size(-1))
|
462 |
padded_spec = F.pad(spec, padding, value=0)
|
463 |
padded_pmt_mel_specs.append(padded_spec)
|
464 |
pmt_mel_specs = torch.stack(padded_pmt_mel_specs)
|
465 |
|
466 |
+
out["pmt_mel_specs"] = pmt_mel_specs
|
467 |
|
468 |
+
if "pmt_text" in batch[0]:
|
469 |
+
pmt_text = [item["pmt_text"] for item in batch]
|
470 |
pmt_text_lengths = torch.LongTensor([len(item) for item in pmt_text])
|
471 |
|
472 |
+
out["pmt_text"] = pmt_text
|
473 |
+
out["pmt_text_lengths"] = pmt_text_lengths
|
474 |
|
475 |
+
return out
|
f5_tts/model/modules.py
CHANGED
@@ -19,7 +19,6 @@ from librosa.filters import mel as librosa_mel_fn
|
|
19 |
from torch import nn
|
20 |
from x_transformers.x_transformers import apply_rotary_pos_emb
|
21 |
|
22 |
-
|
23 |
# raw wav to mel spec
|
24 |
|
25 |
|
@@ -42,15 +41,25 @@ def get_bigvgan_mel_spectrogram(
|
|
42 |
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
43 |
|
44 |
if key not in mel_basis_cache:
|
45 |
-
mel = librosa_mel_fn(
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
48 |
|
49 |
mel_basis = mel_basis_cache[key]
|
50 |
hann_window = hann_window_cache[key]
|
51 |
|
52 |
padding = (n_fft - hop_length) // 2
|
53 |
-
waveform = torch.nn.functional.pad(
|
|
|
|
|
54 |
|
55 |
spec = torch.stft(
|
56 |
waveform,
|
@@ -112,7 +121,9 @@ class MelSpec(nn.Module):
|
|
112 |
mel_spec_type="vocos",
|
113 |
):
|
114 |
super().__init__()
|
115 |
-
assert mel_spec_type in ["vocos", "bigvgan"], print(
|
|
|
|
|
116 |
|
117 |
self.n_fft = n_fft
|
118 |
self.hop_length = hop_length
|
@@ -193,7 +204,9 @@ class ConvPositionEmbedding(nn.Module):
|
|
193 |
# rotary positional embedding related
|
194 |
|
195 |
|
196 |
-
def precompute_freqs_cis(
|
|
|
|
|
197 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
198 |
# has some connection to NTK literature
|
199 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
@@ -209,10 +222,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
|
209 |
|
210 |
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
211 |
# length = length if isinstance(length, int) else length.max()
|
212 |
-
scale = scale * torch.ones_like(
|
|
|
|
|
213 |
pos = (
|
214 |
start.unsqueeze(1)
|
215 |
-
+ (
|
|
|
|
|
|
|
216 |
)
|
217 |
# avoid extra long error.
|
218 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
@@ -251,7 +269,9 @@ class ConvNeXtV2Block(nn.Module):
|
|
251 |
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
252 |
) # depthwise conv
|
253 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
254 |
-
self.pwconv1 = nn.Linear(
|
|
|
|
|
255 |
self.act = nn.GELU()
|
256 |
self.grn = GRN(intermediate_dim)
|
257 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
@@ -284,7 +304,9 @@ class AdaLayerNormZero(nn.Module):
|
|
284 |
|
285 |
def forward(self, x, emb=None):
|
286 |
emb = self.linear(self.silu(emb))
|
287 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
|
|
|
|
|
288 |
|
289 |
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
290 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
@@ -315,14 +337,18 @@ class AdaLayerNormZero_Final(nn.Module):
|
|
315 |
|
316 |
|
317 |
class FeedForward(nn.Module):
|
318 |
-
def __init__(
|
|
|
|
|
319 |
super().__init__()
|
320 |
inner_dim = int(dim * mult)
|
321 |
dim_out = dim_out if dim_out is not None else dim
|
322 |
|
323 |
activation = nn.GELU(approximate=approximate)
|
324 |
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
325 |
-
self.ff = nn.Sequential(
|
|
|
|
|
326 |
|
327 |
def forward(self, x):
|
328 |
return self.ff(x)
|
@@ -346,7 +372,9 @@ class Attention(nn.Module):
|
|
346 |
super().__init__()
|
347 |
|
348 |
if not hasattr(F, "scaled_dot_product_attention"):
|
349 |
-
raise ImportError(
|
|
|
|
|
350 |
|
351 |
self.processor = processor
|
352 |
|
@@ -385,7 +413,9 @@ class Attention(nn.Module):
|
|
385 |
c_rope=None, # rotary position embedding for c
|
386 |
) -> torch.Tensor:
|
387 |
if c is not None:
|
388 |
-
return self.processor(
|
|
|
|
|
389 |
else:
|
390 |
return self.processor(self, x, mask=mask, rope=rope)
|
391 |
|
@@ -414,7 +444,9 @@ class AttnProcessor:
|
|
414 |
# apply rotary position embedding
|
415 |
if rope is not None:
|
416 |
freqs, xpos_scale = rope
|
417 |
-
q_xpos_scale, k_xpos_scale = (
|
|
|
|
|
418 |
|
419 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
420 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
@@ -430,11 +462,15 @@ class AttnProcessor:
|
|
430 |
if mask is not None:
|
431 |
attn_mask = mask
|
432 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
433 |
-
attn_mask = attn_mask.expand(
|
|
|
|
|
434 |
else:
|
435 |
attn_mask = None
|
436 |
|
437 |
-
x = F.scaled_dot_product_attention(
|
|
|
|
|
438 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
439 |
x = x.to(query.dtype)
|
440 |
|
@@ -461,12 +497,12 @@ class JointAttnProcessor:
|
|
461 |
def __call__(
|
462 |
self,
|
463 |
attn: Attention,
|
464 |
-
x: float["b n d"],
|
465 |
-
c: float["b nt d"] = None,
|
466 |
mask: bool["b n"] | None = None,
|
467 |
src_mask: bool["b nt"] | None = None,
|
468 |
-
rope=None,
|
469 |
-
c_rope=None,
|
470 |
) -> torch.FloatTensor:
|
471 |
residual = x
|
472 |
batch_size = c.shape[0]
|
@@ -484,14 +520,18 @@ class JointAttnProcessor:
|
|
484 |
# apply rope for x
|
485 |
if rope is not None:
|
486 |
freqs, xpos_scale = rope
|
487 |
-
q_xpos_scale, k_xpos_scale = (
|
|
|
|
|
488 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
489 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
490 |
|
491 |
# apply rope for c
|
492 |
if c_rope is not None:
|
493 |
freqs, xpos_scale = c_rope
|
494 |
-
q_xpos_scale, k_xpos_scale = (
|
|
|
|
|
495 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
496 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
497 |
|
@@ -515,17 +555,23 @@ class JointAttnProcessor:
|
|
515 |
attn_mask_c = F.pad(src_mask, (x.shape[1], 0), value=True)
|
516 |
attn_mask = attn_mask & attn_mask_c
|
517 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
518 |
-
attn_mask = attn_mask.expand(
|
|
|
|
|
519 |
else:
|
520 |
if src_mask is not None:
|
521 |
# if there's no mask for x but there's src_mask
|
522 |
attn_mask = F.pad(src_mask, (x.shape[1], 0), value=True)
|
523 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
524 |
-
attn_mask = attn_mask.expand(
|
|
|
|
|
525 |
else:
|
526 |
attn_mask = None
|
527 |
|
528 |
-
x = F.scaled_dot_product_attention(
|
|
|
|
|
529 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
530 |
x = x.to(query.dtype)
|
531 |
|
@@ -546,7 +592,6 @@ class JointAttnProcessor:
|
|
546 |
return x, c
|
547 |
|
548 |
|
549 |
-
|
550 |
# DiT Block
|
551 |
|
552 |
|
@@ -564,7 +609,9 @@ class DiTBlock(nn.Module):
|
|
564 |
)
|
565 |
|
566 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
567 |
-
self.ff = FeedForward(
|
|
|
|
|
568 |
|
569 |
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
570 |
# pre-norm & modulation for attention input
|
@@ -596,12 +643,16 @@ class MMDiTBlock(nn.Module):
|
|
596 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
597 |
"""
|
598 |
|
599 |
-
def __init__(
|
|
|
|
|
600 |
super().__init__()
|
601 |
|
602 |
self.context_pre_only = context_pre_only
|
603 |
|
604 |
-
self.attn_norm_c =
|
|
|
|
|
605 |
self.attn_norm_x = AdaLayerNormZero(dim)
|
606 |
self.attn = Attention(
|
607 |
processor=JointAttnProcessor(),
|
@@ -615,23 +666,35 @@ class MMDiTBlock(nn.Module):
|
|
615 |
|
616 |
if not context_pre_only:
|
617 |
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
618 |
-
self.ff_c = FeedForward(
|
|
|
|
|
619 |
else:
|
620 |
self.ff_norm_c = None
|
621 |
self.ff_c = None
|
622 |
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
623 |
-
self.ff_x = FeedForward(
|
|
|
|
|
624 |
|
625 |
-
def forward(
|
|
|
|
|
626 |
# pre-norm & modulation for attention input
|
627 |
if self.context_pre_only:
|
628 |
norm_c = self.attn_norm_c(c, t)
|
629 |
else:
|
630 |
-
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
|
631 |
-
|
|
|
|
|
|
|
|
|
632 |
|
633 |
# attention
|
634 |
-
x_attn_output, c_attn_output = self.attn(
|
|
|
|
|
635 |
|
636 |
# process attention output for context c
|
637 |
if self.context_pre_only:
|
@@ -639,7 +702,9 @@ class MMDiTBlock(nn.Module):
|
|
639 |
else: # if not last layer
|
640 |
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
641 |
|
642 |
-
norm_c =
|
|
|
|
|
643 |
c_ff_output = self.ff_c(norm_c)
|
644 |
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
645 |
|
@@ -660,7 +725,9 @@ class TimestepEmbedding(nn.Module):
|
|
660 |
def __init__(self, dim, freq_embed_dim=256):
|
661 |
super().__init__()
|
662 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
663 |
-
self.time_mlp = nn.Sequential(
|
|
|
|
|
664 |
|
665 |
def forward(self, timestep: float["b"]): # noqa: F821
|
666 |
time_hidden = self.time_embed(timestep)
|
|
|
19 |
from torch import nn
|
20 |
from x_transformers.x_transformers import apply_rotary_pos_emb
|
21 |
|
|
|
22 |
# raw wav to mel spec
|
23 |
|
24 |
|
|
|
41 |
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
42 |
|
43 |
if key not in mel_basis_cache:
|
44 |
+
mel = librosa_mel_fn(
|
45 |
+
sr=target_sample_rate,
|
46 |
+
n_fft=n_fft,
|
47 |
+
n_mels=n_mel_channels,
|
48 |
+
fmin=fmin,
|
49 |
+
fmax=fmax,
|
50 |
+
)
|
51 |
+
mel_basis_cache[key] = (
|
52 |
+
torch.from_numpy(mel).float().to(device)
|
53 |
+
) # TODO: why they need .float()?
|
54 |
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
55 |
|
56 |
mel_basis = mel_basis_cache[key]
|
57 |
hann_window = hann_window_cache[key]
|
58 |
|
59 |
padding = (n_fft - hop_length) // 2
|
60 |
+
waveform = torch.nn.functional.pad(
|
61 |
+
waveform.unsqueeze(1), (padding, padding), mode="reflect"
|
62 |
+
).squeeze(1)
|
63 |
|
64 |
spec = torch.stft(
|
65 |
waveform,
|
|
|
121 |
mel_spec_type="vocos",
|
122 |
):
|
123 |
super().__init__()
|
124 |
+
assert mel_spec_type in ["vocos", "bigvgan"], print(
|
125 |
+
"We only support two extract mel backend: vocos or bigvgan"
|
126 |
+
)
|
127 |
|
128 |
self.n_fft = n_fft
|
129 |
self.hop_length = hop_length
|
|
|
204 |
# rotary positional embedding related
|
205 |
|
206 |
|
207 |
+
def precompute_freqs_cis(
|
208 |
+
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
|
209 |
+
):
|
210 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
211 |
# has some connection to NTK literature
|
212 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
|
222 |
|
223 |
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
224 |
# length = length if isinstance(length, int) else length.max()
|
225 |
+
scale = scale * torch.ones_like(
|
226 |
+
start, dtype=torch.float32
|
227 |
+
) # in case scale is a scalar
|
228 |
pos = (
|
229 |
start.unsqueeze(1)
|
230 |
+
+ (
|
231 |
+
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
|
232 |
+
* scale.unsqueeze(1)
|
233 |
+
).long()
|
234 |
)
|
235 |
# avoid extra long error.
|
236 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
|
|
269 |
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
270 |
) # depthwise conv
|
271 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
272 |
+
self.pwconv1 = nn.Linear(
|
273 |
+
dim, intermediate_dim
|
274 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
275 |
self.act = nn.GELU()
|
276 |
self.grn = GRN(intermediate_dim)
|
277 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
|
304 |
|
305 |
def forward(self, x, emb=None):
|
306 |
emb = self.linear(self.silu(emb))
|
307 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
|
308 |
+
emb, 6, dim=1
|
309 |
+
)
|
310 |
|
311 |
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
312 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
|
337 |
|
338 |
|
339 |
class FeedForward(nn.Module):
|
340 |
+
def __init__(
|
341 |
+
self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
|
342 |
+
):
|
343 |
super().__init__()
|
344 |
inner_dim = int(dim * mult)
|
345 |
dim_out = dim_out if dim_out is not None else dim
|
346 |
|
347 |
activation = nn.GELU(approximate=approximate)
|
348 |
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
349 |
+
self.ff = nn.Sequential(
|
350 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
351 |
+
)
|
352 |
|
353 |
def forward(self, x):
|
354 |
return self.ff(x)
|
|
|
372 |
super().__init__()
|
373 |
|
374 |
if not hasattr(F, "scaled_dot_product_attention"):
|
375 |
+
raise ImportError(
|
376 |
+
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
377 |
+
)
|
378 |
|
379 |
self.processor = processor
|
380 |
|
|
|
413 |
c_rope=None, # rotary position embedding for c
|
414 |
) -> torch.Tensor:
|
415 |
if c is not None:
|
416 |
+
return self.processor(
|
417 |
+
self, x, c=c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope
|
418 |
+
)
|
419 |
else:
|
420 |
return self.processor(self, x, mask=mask, rope=rope)
|
421 |
|
|
|
444 |
# apply rotary position embedding
|
445 |
if rope is not None:
|
446 |
freqs, xpos_scale = rope
|
447 |
+
q_xpos_scale, k_xpos_scale = (
|
448 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
449 |
+
)
|
450 |
|
451 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
452 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
|
|
462 |
if mask is not None:
|
463 |
attn_mask = mask
|
464 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
465 |
+
attn_mask = attn_mask.expand(
|
466 |
+
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
467 |
+
)
|
468 |
else:
|
469 |
attn_mask = None
|
470 |
|
471 |
+
x = F.scaled_dot_product_attention(
|
472 |
+
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
473 |
+
)
|
474 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
475 |
x = x.to(query.dtype)
|
476 |
|
|
|
497 |
def __call__(
|
498 |
self,
|
499 |
attn: Attention,
|
500 |
+
x: float["b n d"], # noised input x
|
501 |
+
c: float["b nt d"] = None, # context c, here text
|
502 |
mask: bool["b n"] | None = None,
|
503 |
src_mask: bool["b nt"] | None = None,
|
504 |
+
rope=None, # rotary position embedding for x
|
505 |
+
c_rope=None, # rotary position embedding for c
|
506 |
) -> torch.FloatTensor:
|
507 |
residual = x
|
508 |
batch_size = c.shape[0]
|
|
|
520 |
# apply rope for x
|
521 |
if rope is not None:
|
522 |
freqs, xpos_scale = rope
|
523 |
+
q_xpos_scale, k_xpos_scale = (
|
524 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
525 |
+
)
|
526 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
527 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
528 |
|
529 |
# apply rope for c
|
530 |
if c_rope is not None:
|
531 |
freqs, xpos_scale = c_rope
|
532 |
+
q_xpos_scale, k_xpos_scale = (
|
533 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
534 |
+
)
|
535 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
536 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
537 |
|
|
|
555 |
attn_mask_c = F.pad(src_mask, (x.shape[1], 0), value=True)
|
556 |
attn_mask = attn_mask & attn_mask_c
|
557 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
558 |
+
attn_mask = attn_mask.expand(
|
559 |
+
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
560 |
+
)
|
561 |
else:
|
562 |
if src_mask is not None:
|
563 |
# if there's no mask for x but there's src_mask
|
564 |
attn_mask = F.pad(src_mask, (x.shape[1], 0), value=True)
|
565 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
566 |
+
attn_mask = attn_mask.expand(
|
567 |
+
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
568 |
+
)
|
569 |
else:
|
570 |
attn_mask = None
|
571 |
|
572 |
+
x = F.scaled_dot_product_attention(
|
573 |
+
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
574 |
+
)
|
575 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
576 |
x = x.to(query.dtype)
|
577 |
|
|
|
592 |
return x, c
|
593 |
|
594 |
|
|
|
595 |
# DiT Block
|
596 |
|
597 |
|
|
|
609 |
)
|
610 |
|
611 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
612 |
+
self.ff = FeedForward(
|
613 |
+
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
614 |
+
)
|
615 |
|
616 |
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
617 |
# pre-norm & modulation for attention input
|
|
|
643 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
644 |
"""
|
645 |
|
646 |
+
def __init__(
|
647 |
+
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False
|
648 |
+
):
|
649 |
super().__init__()
|
650 |
|
651 |
self.context_pre_only = context_pre_only
|
652 |
|
653 |
+
self.attn_norm_c = (
|
654 |
+
AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
655 |
+
)
|
656 |
self.attn_norm_x = AdaLayerNormZero(dim)
|
657 |
self.attn = Attention(
|
658 |
processor=JointAttnProcessor(),
|
|
|
666 |
|
667 |
if not context_pre_only:
|
668 |
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
669 |
+
self.ff_c = FeedForward(
|
670 |
+
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
671 |
+
)
|
672 |
else:
|
673 |
self.ff_norm_c = None
|
674 |
self.ff_c = None
|
675 |
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
676 |
+
self.ff_x = FeedForward(
|
677 |
+
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
678 |
+
)
|
679 |
|
680 |
+
def forward(
|
681 |
+
self, x, c, t, mask=None, src_mask=None, rope=None, c_rope=None
|
682 |
+
): # x: noised input, c: context, t: time embedding
|
683 |
# pre-norm & modulation for attention input
|
684 |
if self.context_pre_only:
|
685 |
norm_c = self.attn_norm_c(c, t)
|
686 |
else:
|
687 |
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
|
688 |
+
c, emb=t
|
689 |
+
)
|
690 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
|
691 |
+
x, emb=t
|
692 |
+
)
|
693 |
|
694 |
# attention
|
695 |
+
x_attn_output, c_attn_output = self.attn(
|
696 |
+
x=norm_x, c=norm_c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope
|
697 |
+
)
|
698 |
|
699 |
# process attention output for context c
|
700 |
if self.context_pre_only:
|
|
|
702 |
else: # if not last layer
|
703 |
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
704 |
|
705 |
+
norm_c = (
|
706 |
+
self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
707 |
+
)
|
708 |
c_ff_output = self.ff_c(norm_c)
|
709 |
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
710 |
|
|
|
725 |
def __init__(self, dim, freq_embed_dim=256):
|
726 |
super().__init__()
|
727 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
728 |
+
self.time_mlp = nn.Sequential(
|
729 |
+
nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
730 |
+
)
|
731 |
|
732 |
def forward(self, timestep: float["b"]): # noqa: F821
|
733 |
time_hidden = self.time_embed(timestep)
|
f5_tts/model/trainer.py
CHANGED
@@ -67,7 +67,13 @@ class Trainer:
|
|
67 |
self.logger = logger
|
68 |
if self.logger == "wandb":
|
69 |
if exists(wandb_resume_id):
|
70 |
-
init_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
else:
|
72 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
73 |
|
@@ -102,7 +108,9 @@ class Trainer:
|
|
102 |
self.epochs = epochs
|
103 |
self.num_warmup_updates = num_warmup_updates
|
104 |
self.save_per_updates = save_per_updates
|
105 |
-
self.last_per_steps = default(
|
|
|
|
|
106 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
107 |
|
108 |
self.batch_size = batch_size
|
@@ -126,8 +134,10 @@ class Trainer:
|
|
126 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
127 |
else:
|
128 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
129 |
-
self.model, self.optimizer = self.accelerator.prepare(
|
130 |
-
|
|
|
|
|
131 |
self.scale = None
|
132 |
self.count = 0
|
133 |
|
@@ -137,10 +147,12 @@ class Trainer:
|
|
137 |
|
138 |
def save_checkpoint(self, step, last=False):
|
139 |
self.accelerator.wait_for_everyone()
|
140 |
-
if self.is_main:
|
141 |
checkpoint = dict(
|
142 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
143 |
-
optimizer_state_dict=self.accelerator.unwrap_model(
|
|
|
|
|
144 |
ema_model_state_dict=self.ema_model.state_dict(),
|
145 |
scheduler_state_dict=self.scheduler.state_dict(),
|
146 |
step=step,
|
@@ -150,16 +162,23 @@ class Trainer:
|
|
150 |
if not os.path.exists(self.checkpoint_path):
|
151 |
os.makedirs(self.checkpoint_path)
|
152 |
if last:
|
153 |
-
self.accelerator.save(
|
|
|
|
|
154 |
print(f"Saved last checkpoint at step {step}")
|
155 |
else:
|
156 |
-
self.accelerator.save(
|
|
|
|
|
157 |
|
158 |
def load_checkpoint(self):
|
159 |
if (
|
160 |
not exists(self.checkpoint_path)
|
161 |
or not os.path.exists(self.checkpoint_path)
|
162 |
-
or not any(
|
|
|
|
|
|
|
163 |
):
|
164 |
return 0
|
165 |
|
@@ -172,10 +191,17 @@ class Trainer:
|
|
172 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
173 |
)[-1]
|
174 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
175 |
-
checkpoint = torch.load(
|
|
|
|
|
|
|
|
|
176 |
|
177 |
# patch for backward compatibility, 305e3ea
|
178 |
-
for key in [
|
|
|
|
|
|
|
179 |
if key in checkpoint["ema_model_state_dict"]:
|
180 |
del checkpoint["ema_model_state_dict"][key]
|
181 |
|
@@ -184,12 +210,19 @@ class Trainer:
|
|
184 |
|
185 |
if "step" in checkpoint:
|
186 |
# patch for backward compatibility, 305e3ea
|
187 |
-
for key in [
|
|
|
|
|
|
|
188 |
if key in checkpoint["model_state_dict"]:
|
189 |
del checkpoint["model_state_dict"][key]
|
190 |
|
191 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
192 |
-
|
|
|
|
|
|
|
|
|
193 |
if self.scheduler:
|
194 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
195 |
step = checkpoint["step"]
|
@@ -199,28 +232,37 @@ class Trainer:
|
|
199 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
200 |
if k not in ["initted", "step"]
|
201 |
}
|
202 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
|
|
|
|
203 |
step = 0
|
204 |
|
205 |
if "scale" in checkpoint:
|
206 |
self.scale = float(checkpoint["scale"])
|
207 |
self.model.scale = self.scale
|
208 |
-
|
209 |
if "count" in checkpoint:
|
210 |
self.count = int(checkpoint["count"])
|
211 |
-
|
212 |
del checkpoint
|
213 |
gc.collect()
|
214 |
return step
|
215 |
|
216 |
-
def train(
|
|
|
|
|
217 |
if self.log_samples:
|
218 |
-
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder,
|
|
|
219 |
|
220 |
vocoder = load_vocoder(
|
221 |
-
vocoder_name=self.vocoder_name,
|
|
|
|
|
222 |
)
|
223 |
-
target_sample_rate = self.accelerator.unwrap_model(
|
|
|
|
|
224 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
225 |
os.makedirs(log_samples_path, exist_ok=True)
|
226 |
|
@@ -245,7 +287,11 @@ class Trainer:
|
|
245 |
self.accelerator.even_batches = False
|
246 |
sampler = SequentialSampler(train_dataset)
|
247 |
batch_sampler = DynamicBatchSampler(
|
248 |
-
sampler,
|
|
|
|
|
|
|
|
|
249 |
)
|
250 |
train_dataloader = DataLoader(
|
251 |
train_dataset,
|
@@ -256,7 +302,9 @@ class Trainer:
|
|
256 |
batch_sampler=batch_sampler,
|
257 |
)
|
258 |
else:
|
259 |
-
raise ValueError(
|
|
|
|
|
260 |
|
261 |
# accelerator.prepare() dispatches batches to devices;
|
262 |
# which means the length of dataloader calculated before, should consider the number of devices
|
@@ -266,10 +314,16 @@ class Trainer:
|
|
266 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
267 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
268 |
decay_steps = total_steps - warmup_steps
|
269 |
-
warmup_scheduler = LinearLR(
|
270 |
-
|
|
|
|
|
|
|
|
|
271 |
self.scheduler = SequentialLR(
|
272 |
-
self.optimizer,
|
|
|
|
|
273 |
)
|
274 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
275 |
train_dataloader, self.scheduler
|
@@ -281,7 +335,9 @@ class Trainer:
|
|
281 |
orig_epoch_step = len(train_dataloader)
|
282 |
skipped_epoch = int(start_step // orig_epoch_step)
|
283 |
skipped_batch = start_step % orig_epoch_step
|
284 |
-
skipped_dataloader = self.accelerator.skip_first_batches(
|
|
|
|
|
285 |
else:
|
286 |
skipped_epoch = 0
|
287 |
|
@@ -309,28 +365,40 @@ class Trainer:
|
|
309 |
text_inputs = batch["text"]
|
310 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
311 |
mel_lengths = batch["mel_lengths"]
|
312 |
-
|
313 |
self.count += 1
|
314 |
-
|
315 |
if self.scale is None:
|
316 |
self.scale = mel_spec.std()
|
317 |
else:
|
318 |
self.scale += (mel_spec.std() - self.scale) / self.count
|
319 |
-
|
320 |
-
mel_spec = mel_spec / self.scale
|
321 |
-
|
322 |
# TODO. add duration predictor training
|
323 |
-
if
|
324 |
-
|
325 |
-
self.accelerator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
loss, cond, pred, t = self.model(
|
328 |
-
mel_spec,
|
|
|
|
|
|
|
329 |
)
|
330 |
self.accelerator.backward(loss)
|
331 |
|
332 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
333 |
-
self.accelerator.clip_grad_norm_(
|
|
|
|
|
334 |
|
335 |
self.optimizer.step()
|
336 |
self.scheduler.step()
|
@@ -342,18 +410,30 @@ class Trainer:
|
|
342 |
global_step += 1
|
343 |
|
344 |
if self.accelerator.is_local_main_process:
|
345 |
-
self.accelerator.log(
|
|
|
|
|
|
|
346 |
if self.logger == "tensorboard":
|
347 |
self.writer.add_scalar("loss", loss.item(), global_step)
|
348 |
-
self.writer.add_scalar(
|
|
|
|
|
349 |
|
350 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
351 |
|
352 |
-
if
|
|
|
|
|
|
|
353 |
self.save_checkpoint(global_step)
|
354 |
if self.log_samples and self.accelerator.is_local_main_process:
|
355 |
-
gen_mel_spec =
|
356 |
-
|
|
|
|
|
|
|
|
|
357 |
with torch.inference_mode():
|
358 |
if self.vocoder_name == "vocos":
|
359 |
gen_audio = vocoder.decode(gen_mel_spec).cpu()
|
@@ -361,51 +441,56 @@ class Trainer:
|
|
361 |
elif self.vocoder_name == "bigvgan":
|
362 |
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
363 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
364 |
-
|
365 |
gen_audio = wandb.Audio(
|
366 |
gen_audio.float().numpy().squeeze(),
|
367 |
sample_rate=24000,
|
368 |
-
caption="time: "
|
|
|
369 |
)
|
370 |
ref_audio = wandb.Audio(
|
371 |
ref_audio.float().numpy().squeeze(),
|
372 |
sample_rate=24000,
|
373 |
-
caption="time: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
)
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
#
|
382 |
-
#
|
383 |
-
#
|
384 |
-
#
|
385 |
-
#
|
386 |
-
#
|
387 |
-
# #
|
388 |
-
# #
|
389 |
-
# #
|
390 |
-
# #
|
391 |
-
# #
|
392 |
-
# #
|
393 |
-
#
|
394 |
-
#
|
395 |
-
#
|
396 |
-
#
|
397 |
-
#
|
398 |
-
#
|
399 |
-
#
|
400 |
-
#
|
401 |
-
|
402 |
-
#
|
403 |
-
#
|
404 |
-
# gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
405 |
-
# ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
406 |
-
|
407 |
-
# torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
|
408 |
-
# torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
409 |
|
410 |
if global_step % self.last_per_steps == 0:
|
411 |
self.save_checkpoint(global_step, last=True)
|
|
|
67 |
self.logger = logger
|
68 |
if self.logger == "wandb":
|
69 |
if exists(wandb_resume_id):
|
70 |
+
init_kwargs = {
|
71 |
+
"wandb": {
|
72 |
+
"resume": "allow",
|
73 |
+
"name": wandb_run_name,
|
74 |
+
"id": wandb_resume_id,
|
75 |
+
}
|
76 |
+
}
|
77 |
else:
|
78 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
79 |
|
|
|
108 |
self.epochs = epochs
|
109 |
self.num_warmup_updates = num_warmup_updates
|
110 |
self.save_per_updates = save_per_updates
|
111 |
+
self.last_per_steps = default(
|
112 |
+
last_per_steps, save_per_updates * grad_accumulation_steps
|
113 |
+
)
|
114 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
115 |
|
116 |
self.batch_size = batch_size
|
|
|
134 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
135 |
else:
|
136 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
137 |
+
self.model, self.optimizer = self.accelerator.prepare(
|
138 |
+
self.model, self.optimizer
|
139 |
+
)
|
140 |
+
|
141 |
self.scale = None
|
142 |
self.count = 0
|
143 |
|
|
|
147 |
|
148 |
def save_checkpoint(self, step, last=False):
|
149 |
self.accelerator.wait_for_everyone()
|
150 |
+
if self.is_main:
|
151 |
checkpoint = dict(
|
152 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
153 |
+
optimizer_state_dict=self.accelerator.unwrap_model(
|
154 |
+
self.optimizer
|
155 |
+
).state_dict(),
|
156 |
ema_model_state_dict=self.ema_model.state_dict(),
|
157 |
scheduler_state_dict=self.scheduler.state_dict(),
|
158 |
step=step,
|
|
|
162 |
if not os.path.exists(self.checkpoint_path):
|
163 |
os.makedirs(self.checkpoint_path)
|
164 |
if last:
|
165 |
+
self.accelerator.save(
|
166 |
+
checkpoint, f"{self.checkpoint_path}/model_last.pt"
|
167 |
+
)
|
168 |
print(f"Saved last checkpoint at step {step}")
|
169 |
else:
|
170 |
+
self.accelerator.save(
|
171 |
+
checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
|
172 |
+
)
|
173 |
|
174 |
def load_checkpoint(self):
|
175 |
if (
|
176 |
not exists(self.checkpoint_path)
|
177 |
or not os.path.exists(self.checkpoint_path)
|
178 |
+
or not any(
|
179 |
+
filename.endswith(".pt")
|
180 |
+
for filename in os.listdir(self.checkpoint_path)
|
181 |
+
)
|
182 |
):
|
183 |
return 0
|
184 |
|
|
|
191 |
key=lambda x: int("".join(filter(str.isdigit, x))),
|
192 |
)[-1]
|
193 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
194 |
+
checkpoint = torch.load(
|
195 |
+
f"{self.checkpoint_path}/{latest_checkpoint}",
|
196 |
+
weights_only=True,
|
197 |
+
map_location="cpu",
|
198 |
+
)
|
199 |
|
200 |
# patch for backward compatibility, 305e3ea
|
201 |
+
for key in [
|
202 |
+
"ema_model.mel_spec.mel_stft.mel_scale.fb",
|
203 |
+
"ema_model.mel_spec.mel_stft.spectrogram.window",
|
204 |
+
]:
|
205 |
if key in checkpoint["ema_model_state_dict"]:
|
206 |
del checkpoint["ema_model_state_dict"][key]
|
207 |
|
|
|
210 |
|
211 |
if "step" in checkpoint:
|
212 |
# patch for backward compatibility, 305e3ea
|
213 |
+
for key in [
|
214 |
+
"mel_spec.mel_stft.mel_scale.fb",
|
215 |
+
"mel_spec.mel_stft.spectrogram.window",
|
216 |
+
]:
|
217 |
if key in checkpoint["model_state_dict"]:
|
218 |
del checkpoint["model_state_dict"][key]
|
219 |
|
220 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
221 |
+
checkpoint["model_state_dict"]
|
222 |
+
)
|
223 |
+
self.accelerator.unwrap_model(self.optimizer).load_state_dict(
|
224 |
+
checkpoint["optimizer_state_dict"]
|
225 |
+
)
|
226 |
if self.scheduler:
|
227 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
228 |
step = checkpoint["step"]
|
|
|
232 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
233 |
if k not in ["initted", "step"]
|
234 |
}
|
235 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
236 |
+
checkpoint["model_state_dict"]
|
237 |
+
)
|
238 |
step = 0
|
239 |
|
240 |
if "scale" in checkpoint:
|
241 |
self.scale = float(checkpoint["scale"])
|
242 |
self.model.scale = self.scale
|
243 |
+
|
244 |
if "count" in checkpoint:
|
245 |
self.count = int(checkpoint["count"])
|
246 |
+
|
247 |
del checkpoint
|
248 |
gc.collect()
|
249 |
return step
|
250 |
|
251 |
+
def train(
|
252 |
+
self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None
|
253 |
+
):
|
254 |
if self.log_samples:
|
255 |
+
from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
|
256 |
+
nfe_step, sway_sampling_coef)
|
257 |
|
258 |
vocoder = load_vocoder(
|
259 |
+
vocoder_name=self.vocoder_name,
|
260 |
+
is_local=self.is_local_vocoder,
|
261 |
+
local_path=self.local_vocoder_path,
|
262 |
)
|
263 |
+
target_sample_rate = self.accelerator.unwrap_model(
|
264 |
+
self.model
|
265 |
+
).mel_spec.target_sample_rate
|
266 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
267 |
os.makedirs(log_samples_path, exist_ok=True)
|
268 |
|
|
|
287 |
self.accelerator.even_batches = False
|
288 |
sampler = SequentialSampler(train_dataset)
|
289 |
batch_sampler = DynamicBatchSampler(
|
290 |
+
sampler,
|
291 |
+
self.batch_size,
|
292 |
+
max_samples=self.max_samples,
|
293 |
+
random_seed=resumable_with_seed,
|
294 |
+
drop_last=False,
|
295 |
)
|
296 |
train_dataloader = DataLoader(
|
297 |
train_dataset,
|
|
|
302 |
batch_sampler=batch_sampler,
|
303 |
)
|
304 |
else:
|
305 |
+
raise ValueError(
|
306 |
+
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
|
307 |
+
)
|
308 |
|
309 |
# accelerator.prepare() dispatches batches to devices;
|
310 |
# which means the length of dataloader calculated before, should consider the number of devices
|
|
|
314 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
315 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
316 |
decay_steps = total_steps - warmup_steps
|
317 |
+
warmup_scheduler = LinearLR(
|
318 |
+
self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps
|
319 |
+
)
|
320 |
+
decay_scheduler = LinearLR(
|
321 |
+
self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps
|
322 |
+
)
|
323 |
self.scheduler = SequentialLR(
|
324 |
+
self.optimizer,
|
325 |
+
schedulers=[warmup_scheduler, decay_scheduler],
|
326 |
+
milestones=[warmup_steps],
|
327 |
)
|
328 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
329 |
train_dataloader, self.scheduler
|
|
|
335 |
orig_epoch_step = len(train_dataloader)
|
336 |
skipped_epoch = int(start_step // orig_epoch_step)
|
337 |
skipped_batch = start_step % orig_epoch_step
|
338 |
+
skipped_dataloader = self.accelerator.skip_first_batches(
|
339 |
+
train_dataloader, num_batches=skipped_batch
|
340 |
+
)
|
341 |
else:
|
342 |
skipped_epoch = 0
|
343 |
|
|
|
365 |
text_inputs = batch["text"]
|
366 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
367 |
mel_lengths = batch["mel_lengths"]
|
368 |
+
|
369 |
self.count += 1
|
370 |
+
|
371 |
if self.scale is None:
|
372 |
self.scale = mel_spec.std()
|
373 |
else:
|
374 |
self.scale += (mel_spec.std() - self.scale) / self.count
|
375 |
+
|
376 |
+
mel_spec = mel_spec / self.scale # normalize mel spectrogram
|
377 |
+
|
378 |
# TODO. add duration predictor training
|
379 |
+
if (
|
380 |
+
self.duration_predictor is not None
|
381 |
+
and self.accelerator.is_local_main_process
|
382 |
+
):
|
383 |
+
dur_loss = self.duration_predictor(
|
384 |
+
mel_spec, lens=batch.get("durations")
|
385 |
+
)
|
386 |
+
self.accelerator.log(
|
387 |
+
{"duration loss": dur_loss.item()}, step=global_step
|
388 |
+
)
|
389 |
|
390 |
loss, cond, pred, t = self.model(
|
391 |
+
mel_spec,
|
392 |
+
text=text_inputs,
|
393 |
+
lens=mel_lengths,
|
394 |
+
noise_scheduler=self.noise_scheduler,
|
395 |
)
|
396 |
self.accelerator.backward(loss)
|
397 |
|
398 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
399 |
+
self.accelerator.clip_grad_norm_(
|
400 |
+
self.model.parameters(), self.max_grad_norm
|
401 |
+
)
|
402 |
|
403 |
self.optimizer.step()
|
404 |
self.scheduler.step()
|
|
|
410 |
global_step += 1
|
411 |
|
412 |
if self.accelerator.is_local_main_process:
|
413 |
+
self.accelerator.log(
|
414 |
+
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]},
|
415 |
+
step=global_step,
|
416 |
+
)
|
417 |
if self.logger == "tensorboard":
|
418 |
self.writer.add_scalar("loss", loss.item(), global_step)
|
419 |
+
self.writer.add_scalar(
|
420 |
+
"lr", self.scheduler.get_last_lr()[0], global_step
|
421 |
+
)
|
422 |
|
423 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
424 |
|
425 |
+
if (
|
426 |
+
global_step % (self.save_per_updates * self.grad_accumulation_steps)
|
427 |
+
== 0
|
428 |
+
):
|
429 |
self.save_checkpoint(global_step)
|
430 |
if self.log_samples and self.accelerator.is_local_main_process:
|
431 |
+
gen_mel_spec = (
|
432 |
+
pred[0].unsqueeze(0).permute(0, 2, 1) * self.scale
|
433 |
+
)
|
434 |
+
ref_mel_spec = (
|
435 |
+
cond[0].unsqueeze(0).permute(0, 2, 1) * self.scale
|
436 |
+
)
|
437 |
with torch.inference_mode():
|
438 |
if self.vocoder_name == "vocos":
|
439 |
gen_audio = vocoder.decode(gen_mel_spec).cpu()
|
|
|
441 |
elif self.vocoder_name == "bigvgan":
|
442 |
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
443 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
444 |
+
|
445 |
gen_audio = wandb.Audio(
|
446 |
gen_audio.float().numpy().squeeze(),
|
447 |
sample_rate=24000,
|
448 |
+
caption="time: "
|
449 |
+
+ str(t[0].squeeze().float().cpu().numpy()),
|
450 |
)
|
451 |
ref_audio = wandb.Audio(
|
452 |
ref_audio.float().numpy().squeeze(),
|
453 |
sample_rate=24000,
|
454 |
+
caption="time: "
|
455 |
+
+ str(t[0].squeeze().float().cpu().numpy()),
|
456 |
+
)
|
457 |
+
|
458 |
+
self.accelerator.log(
|
459 |
+
{
|
460 |
+
"gen_audio": gen_audio,
|
461 |
+
"ref_audio": ref_audio,
|
462 |
+
},
|
463 |
+
step=global_step,
|
464 |
)
|
465 |
|
466 |
+
# if self.log_samples and self.accelerator.is_local_main_process:
|
467 |
+
# ref_audio_len = mel_lengths[0]
|
468 |
+
# infer_text = [
|
469 |
+
# text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
|
470 |
+
# ]
|
471 |
+
# with torch.inference_mode():
|
472 |
+
# # generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
473 |
+
# # cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
474 |
+
# # text=infer_text,
|
475 |
+
# # duration=ref_audio_len * 2,
|
476 |
+
# # steps=nfe_step,
|
477 |
+
# # cfg_strength=cfg_strength,
|
478 |
+
# # sway_sampling_coef=sway_sampling_coef,
|
479 |
+
# # )
|
480 |
+
# # generated = generated.to(torch.float32)
|
481 |
+
# # gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
|
482 |
+
# # ref_mel_spec = batch["mel"][0].unsqueeze(0)
|
483 |
+
# gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1)
|
484 |
+
# ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1)
|
485 |
+
# if self.vocoder_name == "vocos":
|
486 |
+
# gen_audio = vocoder.decode(gen_mel_spec).cpu()
|
487 |
+
# ref_audio = vocoder.decode(ref_mel_spec).cpu()
|
488 |
+
# elif self.vocoder_name == "bigvgan":
|
489 |
+
# gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
490 |
+
# ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
491 |
+
|
492 |
+
# torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
|
493 |
+
# torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
|
|
|
|
|
|
|
|
|
|
494 |
|
495 |
if global_step % self.last_per_steps == 0:
|
496 |
self.save_checkpoint(global_step, last=True)
|
f5_tts/model/utils.py
CHANGED
@@ -5,13 +5,11 @@ import random
|
|
5 |
from collections import defaultdict
|
6 |
from importlib.resources import files
|
7 |
|
|
|
8 |
import torch
|
|
|
9 |
from torch.nn.utils.rnn import pad_sequence
|
10 |
|
11 |
-
import jieba
|
12 |
-
from pypinyin import lazy_pinyin, Style
|
13 |
-
|
14 |
-
|
15 |
# seed everything
|
16 |
|
17 |
|
@@ -39,7 +37,9 @@ def default(v, d):
|
|
39 |
# tensor helpers
|
40 |
|
41 |
|
42 |
-
def lens_to_mask(
|
|
|
|
|
43 |
if not exists(length):
|
44 |
length = t.amax()
|
45 |
|
@@ -47,7 +47,9 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa
|
|
47 |
return seq[None, :] < t[:, None]
|
48 |
|
49 |
|
50 |
-
def mask_from_start_end_indices(
|
|
|
|
|
51 |
max_seq_len = seq_len.max().item()
|
52 |
seq = torch.arange(max_seq_len, device=start.device).long()
|
53 |
start_mask = seq[None, :] >= start[:, None]
|
@@ -55,7 +57,9 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"
|
|
55 |
return start_mask & end_mask
|
56 |
|
57 |
|
58 |
-
def mask_from_frac_lengths(
|
|
|
|
|
59 |
lengths = (frac_lengths * seq_len).long()
|
60 |
max_start = seq_len - lengths
|
61 |
|
@@ -66,7 +70,9 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa
|
|
66 |
return mask_from_start_end_indices(seq_len, start, end)
|
67 |
|
68 |
|
69 |
-
def maybe_masked_mean(
|
|
|
|
|
70 |
if not exists(mask):
|
71 |
return t.mean(dim=1)
|
72 |
|
@@ -90,7 +96,9 @@ def list_str_to_idx(
|
|
90 |
vocab_char_map: dict[str, int], # {char: idx}
|
91 |
padding_value=-1,
|
92 |
) -> int["b nt"]: # noqa: F722
|
93 |
-
list_idx_tensors = [
|
|
|
|
|
94 |
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
95 |
return text
|
96 |
|
@@ -109,13 +117,17 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
109 |
- if use "byte", set to 256 (unicode byte range)
|
110 |
"""
|
111 |
if tokenizer in ["pinyin", "char"]:
|
112 |
-
tokenizer_path = os.path.join(
|
|
|
|
|
113 |
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
114 |
vocab_char_map = {}
|
115 |
for i, char in enumerate(f):
|
116 |
vocab_char_map[char[:-1]] = i
|
117 |
vocab_size = len(vocab_char_map)
|
118 |
-
assert
|
|
|
|
|
119 |
|
120 |
elif tokenizer == "byte":
|
121 |
vocab_char_map = None
|
@@ -131,7 +143,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
131 |
return vocab_char_map, vocab_size
|
132 |
|
133 |
|
134 |
-
|
135 |
# convert char to pinyin
|
136 |
|
137 |
jieba.initialize()
|
@@ -145,9 +156,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
|
145 |
) # add custom trans here, to address oov
|
146 |
|
147 |
def is_chinese(c):
|
148 |
-
return
|
149 |
-
"\u3100" <= c <= "\u9fff" # common chinese characters
|
150 |
-
)
|
151 |
|
152 |
for text in text_list:
|
153 |
char_list = []
|
@@ -158,7 +167,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
|
158 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
159 |
char_list.append(" ")
|
160 |
char_list.extend(seg)
|
161 |
-
elif polyphone and seg_byte_len == 3 * len(
|
|
|
|
|
162 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
163 |
for i, c in enumerate(seg):
|
164 |
if is_chinese(c):
|
@@ -170,7 +181,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
|
170 |
char_list.extend(c)
|
171 |
elif is_chinese(c):
|
172 |
char_list.append(" ")
|
173 |
-
char_list.extend(
|
|
|
|
|
174 |
else:
|
175 |
char_list.append(c)
|
176 |
final_text_list.append(char_list)
|
@@ -224,7 +237,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
|
224 |
def sample_consecutive_steps(float_list):
|
225 |
idx = torch.randint(0, len(float_list), size=(1,))
|
226 |
next_idx = idx - 1
|
227 |
-
|
228 |
if next_idx < 0:
|
229 |
next_idx = 0
|
230 |
else:
|
|
|
5 |
from collections import defaultdict
|
6 |
from importlib.resources import files
|
7 |
|
8 |
+
import jieba
|
9 |
import torch
|
10 |
+
from pypinyin import Style, lazy_pinyin
|
11 |
from torch.nn.utils.rnn import pad_sequence
|
12 |
|
|
|
|
|
|
|
|
|
13 |
# seed everything
|
14 |
|
15 |
|
|
|
37 |
# tensor helpers
|
38 |
|
39 |
|
40 |
+
def lens_to_mask(
|
41 |
+
t: int["b"], length: int | None = None
|
42 |
+
) -> bool["b n"]: # noqa: F722 F821
|
43 |
if not exists(length):
|
44 |
length = t.amax()
|
45 |
|
|
|
47 |
return seq[None, :] < t[:, None]
|
48 |
|
49 |
|
50 |
+
def mask_from_start_end_indices(
|
51 |
+
seq_len: int["b"], start: int["b"], end: int["b"]
|
52 |
+
): # noqa: F722 F821
|
53 |
max_seq_len = seq_len.max().item()
|
54 |
seq = torch.arange(max_seq_len, device=start.device).long()
|
55 |
start_mask = seq[None, :] >= start[:, None]
|
|
|
57 |
return start_mask & end_mask
|
58 |
|
59 |
|
60 |
+
def mask_from_frac_lengths(
|
61 |
+
seq_len: int["b"], frac_lengths: float["b"]
|
62 |
+
): # noqa: F722 F821
|
63 |
lengths = (frac_lengths * seq_len).long()
|
64 |
max_start = seq_len - lengths
|
65 |
|
|
|
70 |
return mask_from_start_end_indices(seq_len, start, end)
|
71 |
|
72 |
|
73 |
+
def maybe_masked_mean(
|
74 |
+
t: float["b n d"], mask: bool["b n"] = None
|
75 |
+
) -> float["b d"]: # noqa: F722
|
76 |
if not exists(mask):
|
77 |
return t.mean(dim=1)
|
78 |
|
|
|
96 |
vocab_char_map: dict[str, int], # {char: idx}
|
97 |
padding_value=-1,
|
98 |
) -> int["b nt"]: # noqa: F722
|
99 |
+
list_idx_tensors = [
|
100 |
+
torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
|
101 |
+
] # pinyin or char style
|
102 |
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
103 |
return text
|
104 |
|
|
|
117 |
- if use "byte", set to 256 (unicode byte range)
|
118 |
"""
|
119 |
if tokenizer in ["pinyin", "char"]:
|
120 |
+
tokenizer_path = os.path.join(
|
121 |
+
files("f5_tts").joinpath("../data"), f"{dataset_name}_{tokenizer}/vocab.txt"
|
122 |
+
)
|
123 |
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
124 |
vocab_char_map = {}
|
125 |
for i, char in enumerate(f):
|
126 |
vocab_char_map[char[:-1]] = i
|
127 |
vocab_size = len(vocab_char_map)
|
128 |
+
assert (
|
129 |
+
vocab_char_map[" "] == 0
|
130 |
+
), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
131 |
|
132 |
elif tokenizer == "byte":
|
133 |
vocab_char_map = None
|
|
|
143 |
return vocab_char_map, vocab_size
|
144 |
|
145 |
|
|
|
146 |
# convert char to pinyin
|
147 |
|
148 |
jieba.initialize()
|
|
|
156 |
) # add custom trans here, to address oov
|
157 |
|
158 |
def is_chinese(c):
|
159 |
+
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
|
|
|
|
160 |
|
161 |
for text in text_list:
|
162 |
char_list = []
|
|
|
167 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
168 |
char_list.append(" ")
|
169 |
char_list.extend(seg)
|
170 |
+
elif polyphone and seg_byte_len == 3 * len(
|
171 |
+
seg
|
172 |
+
): # if pure east asian characters
|
173 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
174 |
for i, c in enumerate(seg):
|
175 |
if is_chinese(c):
|
|
|
181 |
char_list.extend(c)
|
182 |
elif is_chinese(c):
|
183 |
char_list.append(" ")
|
184 |
+
char_list.extend(
|
185 |
+
lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
|
186 |
+
)
|
187 |
else:
|
188 |
char_list.append(c)
|
189 |
final_text_list.append(char_list)
|
|
|
237 |
def sample_consecutive_steps(float_list):
|
238 |
idx = torch.randint(0, len(float_list), size=(1,))
|
239 |
next_idx = idx - 1
|
240 |
+
|
241 |
if next_idx < 0:
|
242 |
next_idx = 0
|
243 |
else:
|
f5_tts/model_new/__init__.py
CHANGED
@@ -4,5 +4,4 @@ from f5_tts.model_new.backbones.unett import UNetT
|
|
4 |
from f5_tts.model_new.cfm import CFM
|
5 |
from f5_tts.model_new.trainer import Trainer
|
6 |
|
7 |
-
|
8 |
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
|
|
4 |
from f5_tts.model_new.cfm import CFM
|
5 |
from f5_tts.model_new.trainer import Trainer
|
6 |
|
|
|
7 |
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
f5_tts/model_new/backbones/dit.py
CHANGED
@@ -14,40 +14,49 @@ import torch.nn.functional as F
|
|
14 |
from torch import nn
|
15 |
from x_transformers.x_transformers import RotaryEmbedding
|
16 |
|
17 |
-
from f5_tts.model_new.modules import (
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
DiTBlock,
|
22 |
-
TimestepEmbedding,
|
23 |
-
get_pos_embed_indices,
|
24 |
-
precompute_freqs_cis,
|
25 |
-
)
|
26 |
-
|
27 |
|
28 |
# Text embedding
|
29 |
|
30 |
|
31 |
class TextEmbedding(nn.Module):
|
32 |
-
def __init__(
|
|
|
|
|
33 |
super().__init__()
|
34 |
-
self.text_embed = nn.Embedding(
|
|
|
|
|
35 |
|
36 |
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
37 |
|
38 |
if conv_layers > 0:
|
39 |
self.extra_modeling = True
|
40 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
41 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
42 |
self.text_blocks = nn.Sequential(
|
43 |
-
*[
|
|
|
|
|
|
|
44 |
)
|
45 |
else:
|
46 |
self.extra_modeling = False
|
47 |
|
48 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
49 |
-
text =
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
batch, text_len = text.shape[0], text.shape[1]
|
52 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
53 |
if self.mask_padding:
|
@@ -62,16 +71,22 @@ class TextEmbedding(nn.Module):
|
|
62 |
if self.extra_modeling:
|
63 |
# sinus pos emb
|
64 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
65 |
-
pos_idx = get_pos_embed_indices(
|
|
|
|
|
66 |
text_pos_embed = self.freqs_cis[pos_idx]
|
67 |
text = text + text_pos_embed
|
68 |
|
69 |
# convnextv2 blocks
|
70 |
if self.mask_padding:
|
71 |
-
text = text.masked_fill(
|
|
|
|
|
72 |
for block in self.text_blocks:
|
73 |
text = block(text)
|
74 |
-
text = text.masked_fill(
|
|
|
|
|
75 |
else:
|
76 |
text = self.text_blocks(text)
|
77 |
|
@@ -87,7 +102,13 @@ class InputEmbedding(nn.Module):
|
|
87 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
88 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
89 |
|
90 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
if drop_audio_cond: # cfg for cond audio
|
92 |
cond = torch.zeros_like(cond)
|
93 |
|
@@ -127,7 +148,10 @@ class DiT(nn.Module):
|
|
127 |
if text_dim is None:
|
128 |
text_dim = mel_dim
|
129 |
self.text_embed = TextEmbedding(
|
130 |
-
text_num_embeds,
|
|
|
|
|
|
|
131 |
)
|
132 |
self.text_cond, self.text_uncond = None, None # text cache
|
133 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
@@ -153,7 +177,9 @@ class DiT(nn.Module):
|
|
153 |
for _ in range(depth)
|
154 |
]
|
155 |
)
|
156 |
-
self.long_skip_connection =
|
|
|
|
|
157 |
|
158 |
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
159 |
self.proj_out = nn.Linear(dim, mel_dim)
|
@@ -230,13 +256,24 @@ class DiT(nn.Module):
|
|
230 |
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
231 |
t = self.time_embed(time)
|
232 |
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
233 |
-
x_cond = self.get_input_embed(
|
234 |
-
|
|
|
|
|
|
|
|
|
235 |
x = torch.cat((x_cond, x_uncond), dim=0)
|
236 |
t = torch.cat((t, t), dim=0)
|
237 |
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
238 |
else:
|
239 |
-
x = self.get_input_embed(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
242 |
|
@@ -246,7 +283,9 @@ class DiT(nn.Module):
|
|
246 |
for block in self.transformer_blocks:
|
247 |
if self.checkpoint_activations:
|
248 |
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
|
249 |
-
x = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
250 |
else:
|
251 |
x = block(x, t, mask=mask, rope=rope)
|
252 |
|
|
|
14 |
from torch import nn
|
15 |
from x_transformers.x_transformers import RotaryEmbedding
|
16 |
|
17 |
+
from f5_tts.model_new.modules import (AdaLayerNorm_Final, ConvNeXtV2Block,
|
18 |
+
ConvPositionEmbedding, DiTBlock,
|
19 |
+
TimestepEmbedding, get_pos_embed_indices,
|
20 |
+
precompute_freqs_cis)
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Text embedding
|
23 |
|
24 |
|
25 |
class TextEmbedding(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2
|
28 |
+
):
|
29 |
super().__init__()
|
30 |
+
self.text_embed = nn.Embedding(
|
31 |
+
text_num_embeds + 1, text_dim
|
32 |
+
) # use 0 as filler token
|
33 |
|
34 |
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
35 |
|
36 |
if conv_layers > 0:
|
37 |
self.extra_modeling = True
|
38 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
39 |
+
self.register_buffer(
|
40 |
+
"freqs_cis",
|
41 |
+
precompute_freqs_cis(text_dim, self.precompute_max_pos),
|
42 |
+
persistent=False,
|
43 |
+
)
|
44 |
self.text_blocks = nn.Sequential(
|
45 |
+
*[
|
46 |
+
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
|
47 |
+
for _ in range(conv_layers)
|
48 |
+
]
|
49 |
)
|
50 |
else:
|
51 |
self.extra_modeling = False
|
52 |
|
53 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
54 |
+
text = (
|
55 |
+
text + 1
|
56 |
+
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
57 |
+
text = text[
|
58 |
+
:, :seq_len
|
59 |
+
] # curtail if character tokens are more than the mel spec tokens
|
60 |
batch, text_len = text.shape[0], text.shape[1]
|
61 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
62 |
if self.mask_padding:
|
|
|
71 |
if self.extra_modeling:
|
72 |
# sinus pos emb
|
73 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
74 |
+
pos_idx = get_pos_embed_indices(
|
75 |
+
batch_start, seq_len, max_pos=self.precompute_max_pos
|
76 |
+
)
|
77 |
text_pos_embed = self.freqs_cis[pos_idx]
|
78 |
text = text + text_pos_embed
|
79 |
|
80 |
# convnextv2 blocks
|
81 |
if self.mask_padding:
|
82 |
+
text = text.masked_fill(
|
83 |
+
text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
|
84 |
+
)
|
85 |
for block in self.text_blocks:
|
86 |
text = block(text)
|
87 |
+
text = text.masked_fill(
|
88 |
+
text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
|
89 |
+
)
|
90 |
else:
|
91 |
text = self.text_blocks(text)
|
92 |
|
|
|
102 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
103 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
104 |
|
105 |
+
def forward(
|
106 |
+
self,
|
107 |
+
x: float["b n d"],
|
108 |
+
cond: float["b n d"],
|
109 |
+
text_embed: float["b n d"],
|
110 |
+
drop_audio_cond=False,
|
111 |
+
): # noqa: F722
|
112 |
if drop_audio_cond: # cfg for cond audio
|
113 |
cond = torch.zeros_like(cond)
|
114 |
|
|
|
148 |
if text_dim is None:
|
149 |
text_dim = mel_dim
|
150 |
self.text_embed = TextEmbedding(
|
151 |
+
text_num_embeds,
|
152 |
+
text_dim,
|
153 |
+
mask_padding=text_mask_padding,
|
154 |
+
conv_layers=conv_layers,
|
155 |
)
|
156 |
self.text_cond, self.text_uncond = None, None # text cache
|
157 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
|
|
177 |
for _ in range(depth)
|
178 |
]
|
179 |
)
|
180 |
+
self.long_skip_connection = (
|
181 |
+
nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
182 |
+
)
|
183 |
|
184 |
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
185 |
self.proj_out = nn.Linear(dim, mel_dim)
|
|
|
256 |
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
257 |
t = self.time_embed(time)
|
258 |
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
259 |
+
x_cond = self.get_input_embed(
|
260 |
+
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache
|
261 |
+
)
|
262 |
+
x_uncond = self.get_input_embed(
|
263 |
+
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache
|
264 |
+
)
|
265 |
x = torch.cat((x_cond, x_uncond), dim=0)
|
266 |
t = torch.cat((t, t), dim=0)
|
267 |
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
268 |
else:
|
269 |
+
x = self.get_input_embed(
|
270 |
+
x,
|
271 |
+
cond,
|
272 |
+
text,
|
273 |
+
drop_audio_cond=drop_audio_cond,
|
274 |
+
drop_text=drop_text,
|
275 |
+
cache=cache,
|
276 |
+
)
|
277 |
|
278 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
279 |
|
|
|
283 |
for block in self.transformer_blocks:
|
284 |
if self.checkpoint_activations:
|
285 |
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
|
286 |
+
x = torch.utils.checkpoint.checkpoint(
|
287 |
+
self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False
|
288 |
+
)
|
289 |
else:
|
290 |
x = block(x, t, mask=mask, rope=rope)
|
291 |
|
f5_tts/model_new/backbones/mmdit.py
CHANGED
@@ -13,15 +13,10 @@ import torch
|
|
13 |
from torch import nn
|
14 |
from x_transformers.x_transformers import RotaryEmbedding
|
15 |
|
16 |
-
from f5_tts.model_new.modules import (
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
TimestepEmbedding,
|
21 |
-
get_pos_embed_indices,
|
22 |
-
precompute_freqs_cis,
|
23 |
-
)
|
24 |
-
|
25 |
|
26 |
# text embedding
|
27 |
|
@@ -29,15 +24,25 @@ from f5_tts.model_new.modules import (
|
|
29 |
class TextEmbedding(nn.Module):
|
30 |
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
|
31 |
super().__init__()
|
32 |
-
self.text_embed = nn.Embedding(
|
|
|
|
|
33 |
|
34 |
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
35 |
|
36 |
self.precompute_max_pos = 1024
|
37 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
def forward(
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
if self.mask_padding:
|
42 |
text_mask = text == 0
|
43 |
|
@@ -49,13 +54,17 @@ class TextEmbedding(nn.Module):
|
|
49 |
# sinus pos emb
|
50 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
51 |
batch_text_len = text.shape[1]
|
52 |
-
pos_idx = get_pos_embed_indices(
|
|
|
|
|
53 |
text_pos_embed = self.freqs_cis[pos_idx]
|
54 |
|
55 |
text = text + text_pos_embed
|
56 |
|
57 |
if self.mask_padding:
|
58 |
-
text = text.masked_fill(
|
|
|
|
|
59 |
|
60 |
return text
|
61 |
|
@@ -69,7 +78,9 @@ class AudioEmbedding(nn.Module):
|
|
69 |
self.linear = nn.Linear(2 * in_dim, out_dim)
|
70 |
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
71 |
|
72 |
-
def forward(
|
|
|
|
|
73 |
if drop_audio_cond:
|
74 |
cond = torch.zeros_like(cond)
|
75 |
x = torch.cat((x, cond), dim=-1)
|
@@ -99,7 +110,9 @@ class MMDiT(nn.Module):
|
|
99 |
super().__init__()
|
100 |
|
101 |
self.time_embed = TimestepEmbedding(dim)
|
102 |
-
self.text_embed = TextEmbedding(
|
|
|
|
|
103 |
self.text_cond, self.text_uncond = None, None # text cache
|
104 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
105 |
|
@@ -187,15 +200,24 @@ class MMDiT(nn.Module):
|
|
187 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
188 |
t = self.time_embed(time)
|
189 |
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
190 |
-
x_cond, c_cond = self.get_input_embed(
|
191 |
-
|
|
|
|
|
|
|
|
|
192 |
x = torch.cat((x_cond, x_uncond), dim=0)
|
193 |
c = torch.cat((c_cond, c_uncond), dim=0)
|
194 |
t = torch.cat((t, t), dim=0)
|
195 |
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
196 |
else:
|
197 |
x, c = self.get_input_embed(
|
198 |
-
x,
|
|
|
|
|
|
|
|
|
|
|
199 |
)
|
200 |
|
201 |
seq_len = x.shape[1]
|
|
|
13 |
from torch import nn
|
14 |
from x_transformers.x_transformers import RotaryEmbedding
|
15 |
|
16 |
+
from f5_tts.model_new.modules import (AdaLayerNorm_Final,
|
17 |
+
ConvPositionEmbedding, MMDiTBlock,
|
18 |
+
TimestepEmbedding, get_pos_embed_indices,
|
19 |
+
precompute_freqs_cis)
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# text embedding
|
22 |
|
|
|
24 |
class TextEmbedding(nn.Module):
|
25 |
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
|
26 |
super().__init__()
|
27 |
+
self.text_embed = nn.Embedding(
|
28 |
+
text_num_embeds + 1, out_dim
|
29 |
+
) # will use 0 as filler token
|
30 |
|
31 |
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
32 |
|
33 |
self.precompute_max_pos = 1024
|
34 |
+
self.register_buffer(
|
35 |
+
"freqs_cis",
|
36 |
+
precompute_freqs_cis(out_dim, self.precompute_max_pos),
|
37 |
+
persistent=False,
|
38 |
+
)
|
39 |
|
40 |
+
def forward(
|
41 |
+
self, text: int["b nt"], drop_text=False
|
42 |
+
) -> int["b nt d"]: # noqa: F722
|
43 |
+
text = (
|
44 |
+
text + 1
|
45 |
+
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
46 |
if self.mask_padding:
|
47 |
text_mask = text == 0
|
48 |
|
|
|
54 |
# sinus pos emb
|
55 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
56 |
batch_text_len = text.shape[1]
|
57 |
+
pos_idx = get_pos_embed_indices(
|
58 |
+
batch_start, batch_text_len, max_pos=self.precompute_max_pos
|
59 |
+
)
|
60 |
text_pos_embed = self.freqs_cis[pos_idx]
|
61 |
|
62 |
text = text + text_pos_embed
|
63 |
|
64 |
if self.mask_padding:
|
65 |
+
text = text.masked_fill(
|
66 |
+
text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
|
67 |
+
)
|
68 |
|
69 |
return text
|
70 |
|
|
|
78 |
self.linear = nn.Linear(2 * in_dim, out_dim)
|
79 |
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
80 |
|
81 |
+
def forward(
|
82 |
+
self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False
|
83 |
+
): # noqa: F722
|
84 |
if drop_audio_cond:
|
85 |
cond = torch.zeros_like(cond)
|
86 |
x = torch.cat((x, cond), dim=-1)
|
|
|
110 |
super().__init__()
|
111 |
|
112 |
self.time_embed = TimestepEmbedding(dim)
|
113 |
+
self.text_embed = TextEmbedding(
|
114 |
+
dim, text_num_embeds, mask_padding=text_mask_padding
|
115 |
+
)
|
116 |
self.text_cond, self.text_uncond = None, None # text cache
|
117 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
118 |
|
|
|
200 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
201 |
t = self.time_embed(time)
|
202 |
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
203 |
+
x_cond, c_cond = self.get_input_embed(
|
204 |
+
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache
|
205 |
+
)
|
206 |
+
x_uncond, c_uncond = self.get_input_embed(
|
207 |
+
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache
|
208 |
+
)
|
209 |
x = torch.cat((x_cond, x_uncond), dim=0)
|
210 |
c = torch.cat((c_cond, c_uncond), dim=0)
|
211 |
t = torch.cat((t, t), dim=0)
|
212 |
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
213 |
else:
|
214 |
x, c = self.get_input_embed(
|
215 |
+
x,
|
216 |
+
cond,
|
217 |
+
text,
|
218 |
+
drop_audio_cond=drop_audio_cond,
|
219 |
+
drop_text=drop_text,
|
220 |
+
cache=cache,
|
221 |
)
|
222 |
|
223 |
seq_len = x.shape[1]
|
f5_tts/model_new/backbones/unett.py
CHANGED
@@ -17,41 +17,50 @@ from torch import nn
|
|
17 |
from x_transformers import RMSNorm
|
18 |
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
|
20 |
-
from f5_tts.model_new.modules import (
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
FeedForward,
|
26 |
-
TimestepEmbedding,
|
27 |
-
get_pos_embed_indices,
|
28 |
-
precompute_freqs_cis,
|
29 |
-
)
|
30 |
-
|
31 |
|
32 |
# Text embedding
|
33 |
|
34 |
|
35 |
class TextEmbedding(nn.Module):
|
36 |
-
def __init__(
|
|
|
|
|
37 |
super().__init__()
|
38 |
-
self.text_embed = nn.Embedding(
|
|
|
|
|
39 |
|
40 |
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
41 |
|
42 |
if conv_layers > 0:
|
43 |
self.extra_modeling = True
|
44 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
45 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
46 |
self.text_blocks = nn.Sequential(
|
47 |
-
*[
|
|
|
|
|
|
|
48 |
)
|
49 |
else:
|
50 |
self.extra_modeling = False
|
51 |
|
52 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
53 |
-
text =
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
batch, text_len = text.shape[0], text.shape[1]
|
56 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
57 |
if self.mask_padding:
|
@@ -66,16 +75,22 @@ class TextEmbedding(nn.Module):
|
|
66 |
if self.extra_modeling:
|
67 |
# sinus pos emb
|
68 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
69 |
-
pos_idx = get_pos_embed_indices(
|
|
|
|
|
70 |
text_pos_embed = self.freqs_cis[pos_idx]
|
71 |
text = text + text_pos_embed
|
72 |
|
73 |
# convnextv2 blocks
|
74 |
if self.mask_padding:
|
75 |
-
text = text.masked_fill(
|
|
|
|
|
76 |
for block in self.text_blocks:
|
77 |
text = block(text)
|
78 |
-
text = text.masked_fill(
|
|
|
|
|
79 |
else:
|
80 |
text = self.text_blocks(text)
|
81 |
|
@@ -91,7 +106,13 @@ class InputEmbedding(nn.Module):
|
|
91 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
92 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
93 |
|
94 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if drop_audio_cond: # cfg for cond audio
|
96 |
cond = torch.zeros_like(cond)
|
97 |
|
@@ -129,7 +150,10 @@ class UNetT(nn.Module):
|
|
129 |
if text_dim is None:
|
130 |
text_dim = mel_dim
|
131 |
self.text_embed = TextEmbedding(
|
132 |
-
text_num_embeds,
|
|
|
|
|
|
|
133 |
)
|
134 |
self.text_cond, self.text_uncond = None, None # text cache
|
135 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
@@ -161,7 +185,11 @@ class UNetT(nn.Module):
|
|
161 |
ff_norm = RMSNorm(dim)
|
162 |
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
163 |
|
164 |
-
skip_proj =
|
|
|
|
|
|
|
|
|
165 |
|
166 |
self.layers.append(
|
167 |
nn.ModuleList(
|
@@ -226,13 +254,24 @@ class UNetT(nn.Module):
|
|
226 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
227 |
t = self.time_embed(time)
|
228 |
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
229 |
-
x_cond = self.get_input_embed(
|
230 |
-
|
|
|
|
|
|
|
|
|
231 |
x = torch.cat((x_cond, x_uncond), dim=0)
|
232 |
t = torch.cat((t, t), dim=0)
|
233 |
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
234 |
else:
|
235 |
-
x = self.get_input_embed(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
238 |
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
@@ -244,7 +283,9 @@ class UNetT(nn.Module):
|
|
244 |
# flat unet transformer
|
245 |
skip_connect_type = self.skip_connect_type
|
246 |
skips = []
|
247 |
-
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(
|
|
|
|
|
248 |
layer = idx + 1
|
249 |
|
250 |
# skip connection logic
|
|
|
17 |
from x_transformers import RMSNorm
|
18 |
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
|
20 |
+
from f5_tts.model_new.modules import (Attention, AttnProcessor,
|
21 |
+
ConvNeXtV2Block, ConvPositionEmbedding,
|
22 |
+
FeedForward, TimestepEmbedding,
|
23 |
+
get_pos_embed_indices,
|
24 |
+
precompute_freqs_cis)
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# Text embedding
|
27 |
|
28 |
|
29 |
class TextEmbedding(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2
|
32 |
+
):
|
33 |
super().__init__()
|
34 |
+
self.text_embed = nn.Embedding(
|
35 |
+
text_num_embeds + 1, text_dim
|
36 |
+
) # use 0 as filler token
|
37 |
|
38 |
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
39 |
|
40 |
if conv_layers > 0:
|
41 |
self.extra_modeling = True
|
42 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
43 |
+
self.register_buffer(
|
44 |
+
"freqs_cis",
|
45 |
+
precompute_freqs_cis(text_dim, self.precompute_max_pos),
|
46 |
+
persistent=False,
|
47 |
+
)
|
48 |
self.text_blocks = nn.Sequential(
|
49 |
+
*[
|
50 |
+
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
|
51 |
+
for _ in range(conv_layers)
|
52 |
+
]
|
53 |
)
|
54 |
else:
|
55 |
self.extra_modeling = False
|
56 |
|
57 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
58 |
+
text = (
|
59 |
+
text + 1
|
60 |
+
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
61 |
+
text = text[
|
62 |
+
:, :seq_len
|
63 |
+
] # curtail if character tokens are more than the mel spec tokens
|
64 |
batch, text_len = text.shape[0], text.shape[1]
|
65 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
66 |
if self.mask_padding:
|
|
|
75 |
if self.extra_modeling:
|
76 |
# sinus pos emb
|
77 |
batch_start = torch.zeros((batch,), dtype=torch.long)
|
78 |
+
pos_idx = get_pos_embed_indices(
|
79 |
+
batch_start, seq_len, max_pos=self.precompute_max_pos
|
80 |
+
)
|
81 |
text_pos_embed = self.freqs_cis[pos_idx]
|
82 |
text = text + text_pos_embed
|
83 |
|
84 |
# convnextv2 blocks
|
85 |
if self.mask_padding:
|
86 |
+
text = text.masked_fill(
|
87 |
+
text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
|
88 |
+
)
|
89 |
for block in self.text_blocks:
|
90 |
text = block(text)
|
91 |
+
text = text.masked_fill(
|
92 |
+
text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
|
93 |
+
)
|
94 |
else:
|
95 |
text = self.text_blocks(text)
|
96 |
|
|
|
106 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
107 |
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
108 |
|
109 |
+
def forward(
|
110 |
+
self,
|
111 |
+
x: float["b n d"],
|
112 |
+
cond: float["b n d"],
|
113 |
+
text_embed: float["b n d"],
|
114 |
+
drop_audio_cond=False,
|
115 |
+
): # noqa: F722
|
116 |
if drop_audio_cond: # cfg for cond audio
|
117 |
cond = torch.zeros_like(cond)
|
118 |
|
|
|
150 |
if text_dim is None:
|
151 |
text_dim = mel_dim
|
152 |
self.text_embed = TextEmbedding(
|
153 |
+
text_num_embeds,
|
154 |
+
text_dim,
|
155 |
+
mask_padding=text_mask_padding,
|
156 |
+
conv_layers=conv_layers,
|
157 |
)
|
158 |
self.text_cond, self.text_uncond = None, None # text cache
|
159 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
|
|
185 |
ff_norm = RMSNorm(dim)
|
186 |
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
187 |
|
188 |
+
skip_proj = (
|
189 |
+
nn.Linear(dim * 2, dim, bias=False)
|
190 |
+
if needs_skip_proj and is_later_half
|
191 |
+
else None
|
192 |
+
)
|
193 |
|
194 |
self.layers.append(
|
195 |
nn.ModuleList(
|
|
|
254 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
255 |
t = self.time_embed(time)
|
256 |
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
257 |
+
x_cond = self.get_input_embed(
|
258 |
+
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache
|
259 |
+
)
|
260 |
+
x_uncond = self.get_input_embed(
|
261 |
+
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache
|
262 |
+
)
|
263 |
x = torch.cat((x_cond, x_uncond), dim=0)
|
264 |
t = torch.cat((t, t), dim=0)
|
265 |
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
266 |
else:
|
267 |
+
x = self.get_input_embed(
|
268 |
+
x,
|
269 |
+
cond,
|
270 |
+
text,
|
271 |
+
drop_audio_cond=drop_audio_cond,
|
272 |
+
drop_text=drop_text,
|
273 |
+
cache=cache,
|
274 |
+
)
|
275 |
|
276 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
277 |
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
|
|
283 |
# flat unet transformer
|
284 |
skip_connect_type = self.skip_connect_type
|
285 |
skips = []
|
286 |
+
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(
|
287 |
+
self.layers
|
288 |
+
):
|
289 |
layer = idx + 1
|
290 |
|
291 |
# skip connection logic
|
f5_tts/model_new/cfm.py
CHANGED
@@ -19,15 +19,9 @@ from torch.nn.utils.rnn import pad_sequence
|
|
19 |
from torchdiffeq import odeint
|
20 |
|
21 |
from f5_tts.model_new.modules import MelSpec
|
22 |
-
from f5_tts.model_new.utils import (
|
23 |
-
|
24 |
-
|
25 |
-
get_epss_timesteps,
|
26 |
-
lens_to_mask,
|
27 |
-
list_str_to_idx,
|
28 |
-
list_str_to_tensor,
|
29 |
-
mask_from_frac_lengths,
|
30 |
-
)
|
31 |
|
32 |
|
33 |
class CFM(nn.Module):
|
@@ -139,13 +133,17 @@ class CFM(nn.Module):
|
|
139 |
|
140 |
# duplicate test corner for inner time step oberservation
|
141 |
if duplicate_test:
|
142 |
-
test_cond = F.pad(
|
|
|
|
|
143 |
|
144 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
145 |
if no_ref_audio:
|
146 |
cond = torch.zeros_like(cond)
|
147 |
|
148 |
-
cond_mask = F.pad(
|
|
|
|
|
149 |
cond_mask = cond_mask.unsqueeze(-1)
|
150 |
step_cond = torch.where(
|
151 |
cond_mask, cond, torch.zeros_like(cond)
|
@@ -196,7 +194,11 @@ class CFM(nn.Module):
|
|
196 |
for dur in duration:
|
197 |
if exists(seed):
|
198 |
torch.manual_seed(seed)
|
199 |
-
y0.append(
|
|
|
|
|
|
|
|
|
200 |
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
201 |
|
202 |
t_start = 0
|
@@ -207,10 +209,14 @@ class CFM(nn.Module):
|
|
207 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
208 |
steps = int(steps * (1 - t_start))
|
209 |
|
210 |
-
if
|
|
|
|
|
211 |
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
212 |
else:
|
213 |
-
t = torch.linspace(
|
|
|
|
|
214 |
if sway_sampling_coef is not None:
|
215 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
216 |
|
@@ -241,7 +247,12 @@ class CFM(nn.Module):
|
|
241 |
inp = inp.permute(0, 2, 1)
|
242 |
assert inp.shape[-1] == self.num_channels
|
243 |
|
244 |
-
batch, seq_len, dtype, device, _σ1 =
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
# handle text as string
|
247 |
if isinstance(text, list):
|
@@ -255,10 +266,16 @@ class CFM(nn.Module):
|
|
255 |
if not exists(lens):
|
256 |
lens = torch.full((batch,), seq_len, device=device)
|
257 |
|
258 |
-
mask = lens_to_mask(
|
|
|
|
|
259 |
|
260 |
# get a random span to mask out for training conditionally
|
261 |
-
frac_lengths =
|
|
|
|
|
|
|
|
|
262 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
263 |
|
264 |
if exists(mask):
|
@@ -292,7 +309,13 @@ class CFM(nn.Module):
|
|
292 |
|
293 |
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
|
294 |
pred = self.transformer(
|
295 |
-
x=φ,
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
)
|
297 |
|
298 |
# flow matching loss
|
|
|
19 |
from torchdiffeq import odeint
|
20 |
|
21 |
from f5_tts.model_new.modules import MelSpec
|
22 |
+
from f5_tts.model_new.utils import (default, exists, get_epss_timesteps,
|
23 |
+
lens_to_mask, list_str_to_idx,
|
24 |
+
list_str_to_tensor, mask_from_frac_lengths)
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
class CFM(nn.Module):
|
|
|
133 |
|
134 |
# duplicate test corner for inner time step oberservation
|
135 |
if duplicate_test:
|
136 |
+
test_cond = F.pad(
|
137 |
+
cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
|
138 |
+
)
|
139 |
|
140 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
141 |
if no_ref_audio:
|
142 |
cond = torch.zeros_like(cond)
|
143 |
|
144 |
+
cond_mask = F.pad(
|
145 |
+
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
|
146 |
+
)
|
147 |
cond_mask = cond_mask.unsqueeze(-1)
|
148 |
step_cond = torch.where(
|
149 |
cond_mask, cond, torch.zeros_like(cond)
|
|
|
194 |
for dur in duration:
|
195 |
if exists(seed):
|
196 |
torch.manual_seed(seed)
|
197 |
+
y0.append(
|
198 |
+
torch.randn(
|
199 |
+
dur, self.num_channels, device=self.device, dtype=step_cond.dtype
|
200 |
+
)
|
201 |
+
)
|
202 |
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
203 |
|
204 |
t_start = 0
|
|
|
209 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
210 |
steps = int(steps * (1 - t_start))
|
211 |
|
212 |
+
if (
|
213 |
+
t_start == 0 and use_epss
|
214 |
+
): # use Empirically Pruned Step Sampling for low NFE
|
215 |
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
216 |
else:
|
217 |
+
t = torch.linspace(
|
218 |
+
t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
|
219 |
+
)
|
220 |
if sway_sampling_coef is not None:
|
221 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
222 |
|
|
|
247 |
inp = inp.permute(0, 2, 1)
|
248 |
assert inp.shape[-1] == self.num_channels
|
249 |
|
250 |
+
batch, seq_len, dtype, device, _σ1 = (
|
251 |
+
*inp.shape[:2],
|
252 |
+
inp.dtype,
|
253 |
+
self.device,
|
254 |
+
self.sigma,
|
255 |
+
)
|
256 |
|
257 |
# handle text as string
|
258 |
if isinstance(text, list):
|
|
|
266 |
if not exists(lens):
|
267 |
lens = torch.full((batch,), seq_len, device=device)
|
268 |
|
269 |
+
mask = lens_to_mask(
|
270 |
+
lens, length=seq_len
|
271 |
+
) # useless here, as collate_fn will pad to max length in batch
|
272 |
|
273 |
# get a random span to mask out for training conditionally
|
274 |
+
frac_lengths = (
|
275 |
+
torch.zeros((batch,), device=self.device)
|
276 |
+
.float()
|
277 |
+
.uniform_(*self.frac_lengths_mask)
|
278 |
+
)
|
279 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
280 |
|
281 |
if exists(mask):
|
|
|
309 |
|
310 |
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
|
311 |
pred = self.transformer(
|
312 |
+
x=φ,
|
313 |
+
cond=cond,
|
314 |
+
text=text,
|
315 |
+
time=time,
|
316 |
+
drop_audio_cond=drop_audio_cond,
|
317 |
+
drop_text=drop_text,
|
318 |
+
mask=mask,
|
319 |
)
|
320 |
|
321 |
# flow matching loss
|
f5_tts/model_new/dataset.py
CHANGED
@@ -62,7 +62,9 @@ class HFDataset(Dataset):
|
|
62 |
audio_tensor = torch.from_numpy(audio).float()
|
63 |
|
64 |
if sample_rate != self.target_sample_rate:
|
65 |
-
resampler = torchaudio.transforms.Resample(
|
|
|
|
|
66 |
audio_tensor = resampler(audio_tensor)
|
67 |
|
68 |
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
@@ -149,7 +151,9 @@ class CustomDataset(Dataset):
|
|
149 |
|
150 |
# resample if necessary
|
151 |
if source_sample_rate != self.target_sample_rate:
|
152 |
-
resampler = torchaudio.transforms.Resample(
|
|
|
|
|
153 |
audio = resampler(audio)
|
154 |
|
155 |
# to mel spectrogram
|
@@ -173,7 +177,12 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
173 |
"""
|
174 |
|
175 |
def __init__(
|
176 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
177 |
):
|
178 |
self.sampler = sampler
|
179 |
self.frames_threshold = frames_threshold
|
@@ -185,7 +194,8 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
185 |
data_source = self.sampler.data_source
|
186 |
|
187 |
for idx in tqdm(
|
188 |
-
self.sampler,
|
|
|
189 |
):
|
190 |
indices.append((idx, data_source.get_frame_len(idx)))
|
191 |
indices.sort(key=lambda elem: elem[1])
|
@@ -193,9 +203,12 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
193 |
batch = []
|
194 |
batch_frames = 0
|
195 |
for idx, frame_len in tqdm(
|
196 |
-
indices,
|
|
|
197 |
):
|
198 |
-
if batch_frames + frame_len <= self.frames_threshold and (
|
|
|
|
|
199 |
batch.append(idx)
|
200 |
batch_frames += frame_len
|
201 |
else:
|
@@ -256,7 +269,9 @@ def load_dataset(
|
|
256 |
print("Loading dataset ...")
|
257 |
|
258 |
if dataset_type == "CustomDataset":
|
259 |
-
rel_data_path = str(
|
|
|
|
|
260 |
if audio_type == "raw":
|
261 |
try:
|
262 |
train_dataset = load_from_disk(f"{rel_data_path}/raw")
|
@@ -287,7 +302,10 @@ def load_dataset(
|
|
287 |
data_dict = json.load(f)
|
288 |
durations = data_dict["duration"]
|
289 |
train_dataset = CustomDataset(
|
290 |
-
train_dataset,
|
|
|
|
|
|
|
291 |
)
|
292 |
|
293 |
elif dataset_type == "HFDataset":
|
@@ -297,7 +315,11 @@ def load_dataset(
|
|
297 |
)
|
298 |
pre, post = dataset_name.split("_")
|
299 |
train_dataset = HFDataset(
|
300 |
-
load_dataset(
|
|
|
|
|
|
|
|
|
301 |
)
|
302 |
|
303 |
return train_dataset
|
|
|
62 |
audio_tensor = torch.from_numpy(audio).float()
|
63 |
|
64 |
if sample_rate != self.target_sample_rate:
|
65 |
+
resampler = torchaudio.transforms.Resample(
|
66 |
+
sample_rate, self.target_sample_rate
|
67 |
+
)
|
68 |
audio_tensor = resampler(audio_tensor)
|
69 |
|
70 |
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
|
|
151 |
|
152 |
# resample if necessary
|
153 |
if source_sample_rate != self.target_sample_rate:
|
154 |
+
resampler = torchaudio.transforms.Resample(
|
155 |
+
source_sample_rate, self.target_sample_rate
|
156 |
+
)
|
157 |
audio = resampler(audio)
|
158 |
|
159 |
# to mel spectrogram
|
|
|
177 |
"""
|
178 |
|
179 |
def __init__(
|
180 |
+
self,
|
181 |
+
sampler: Sampler[int],
|
182 |
+
frames_threshold: int,
|
183 |
+
max_samples=0,
|
184 |
+
random_seed=None,
|
185 |
+
drop_residual: bool = False,
|
186 |
):
|
187 |
self.sampler = sampler
|
188 |
self.frames_threshold = frames_threshold
|
|
|
194 |
data_source = self.sampler.data_source
|
195 |
|
196 |
for idx in tqdm(
|
197 |
+
self.sampler,
|
198 |
+
desc="Sorting with sampler... if slow, check whether dataset is provided with duration",
|
199 |
):
|
200 |
indices.append((idx, data_source.get_frame_len(idx)))
|
201 |
indices.sort(key=lambda elem: elem[1])
|
|
|
203 |
batch = []
|
204 |
batch_frames = 0
|
205 |
for idx, frame_len in tqdm(
|
206 |
+
indices,
|
207 |
+
desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu",
|
208 |
):
|
209 |
+
if batch_frames + frame_len <= self.frames_threshold and (
|
210 |
+
max_samples == 0 or len(batch) < max_samples
|
211 |
+
):
|
212 |
batch.append(idx)
|
213 |
batch_frames += frame_len
|
214 |
else:
|
|
|
269 |
print("Loading dataset ...")
|
270 |
|
271 |
if dataset_type == "CustomDataset":
|
272 |
+
rel_data_path = str(
|
273 |
+
files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}")
|
274 |
+
)
|
275 |
if audio_type == "raw":
|
276 |
try:
|
277 |
train_dataset = load_from_disk(f"{rel_data_path}/raw")
|
|
|
302 |
data_dict = json.load(f)
|
303 |
durations = data_dict["duration"]
|
304 |
train_dataset = CustomDataset(
|
305 |
+
train_dataset,
|
306 |
+
durations=durations,
|
307 |
+
preprocessed_mel=preprocessed_mel,
|
308 |
+
**mel_spec_kwargs,
|
309 |
)
|
310 |
|
311 |
elif dataset_type == "HFDataset":
|
|
|
315 |
)
|
316 |
pre, post = dataset_name.split("_")
|
317 |
train_dataset = HFDataset(
|
318 |
+
load_dataset(
|
319 |
+
f"{pre}/{pre}",
|
320 |
+
split=f"train.{post}",
|
321 |
+
cache_dir=str(files("f5_tts").joinpath("../../data")),
|
322 |
+
),
|
323 |
)
|
324 |
|
325 |
return train_dataset
|
f5_tts/model_new/modules.py
CHANGED
@@ -6,6 +6,7 @@ nt - text sequence
|
|
6 |
nw - raw wave length
|
7 |
d - dimension
|
8 |
"""
|
|
|
9 |
# flake8: noqa
|
10 |
|
11 |
from __future__ import annotations
|
@@ -22,7 +23,6 @@ from x_transformers.x_transformers import apply_rotary_pos_emb
|
|
22 |
|
23 |
from f5_tts.model_new.utils import is_package_available
|
24 |
|
25 |
-
|
26 |
# raw wav to mel spec
|
27 |
|
28 |
|
@@ -45,15 +45,25 @@ def get_bigvgan_mel_spectrogram(
|
|
45 |
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
46 |
|
47 |
if key not in mel_basis_cache:
|
48 |
-
mel = librosa_mel_fn(
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
51 |
|
52 |
mel_basis = mel_basis_cache[key]
|
53 |
hann_window = hann_window_cache[key]
|
54 |
|
55 |
padding = (n_fft - hop_length) // 2
|
56 |
-
waveform = torch.nn.functional.pad(
|
|
|
|
|
57 |
|
58 |
spec = torch.stft(
|
59 |
waveform,
|
@@ -115,7 +125,9 @@ class MelSpec(nn.Module):
|
|
115 |
mel_spec_type="vocos",
|
116 |
):
|
117 |
super().__init__()
|
118 |
-
assert mel_spec_type in ["vocos", "bigvgan"], print(
|
|
|
|
|
119 |
|
120 |
self.n_fft = n_fft
|
121 |
self.hop_length = hop_length
|
@@ -196,7 +208,9 @@ class ConvPositionEmbedding(nn.Module):
|
|
196 |
# rotary positional embedding related
|
197 |
|
198 |
|
199 |
-
def precompute_freqs_cis(
|
|
|
|
|
200 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
201 |
# has some connection to NTK literature
|
202 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
@@ -212,10 +226,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
|
212 |
|
213 |
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
214 |
# length = length if isinstance(length, int) else length.max()
|
215 |
-
scale = scale * torch.ones_like(
|
|
|
|
|
216 |
pos = (
|
217 |
start.unsqueeze(1)
|
218 |
-
+ (
|
|
|
|
|
|
|
219 |
)
|
220 |
# avoid extra long error.
|
221 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
@@ -254,7 +273,9 @@ class ConvNeXtV2Block(nn.Module):
|
|
254 |
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
255 |
) # depthwise conv
|
256 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
257 |
-
self.pwconv1 = nn.Linear(
|
|
|
|
|
258 |
self.act = nn.GELU()
|
259 |
self.grn = GRN(intermediate_dim)
|
260 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
@@ -286,7 +307,9 @@ class RMSNorm(nn.Module):
|
|
286 |
if self.native_rms_norm:
|
287 |
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
288 |
x = x.to(self.weight.dtype)
|
289 |
-
x = F.rms_norm(
|
|
|
|
|
290 |
else:
|
291 |
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
292 |
x = x * torch.rsqrt(variance + self.eps)
|
@@ -312,7 +335,9 @@ class AdaLayerNorm(nn.Module):
|
|
312 |
|
313 |
def forward(self, x, emb=None):
|
314 |
emb = self.linear(self.silu(emb))
|
315 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
|
|
|
|
|
316 |
|
317 |
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
318 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
@@ -343,14 +368,18 @@ class AdaLayerNorm_Final(nn.Module):
|
|
343 |
|
344 |
|
345 |
class FeedForward(nn.Module):
|
346 |
-
def __init__(
|
|
|
|
|
347 |
super().__init__()
|
348 |
inner_dim = int(dim * mult)
|
349 |
dim_out = dim_out if dim_out is not None else dim
|
350 |
|
351 |
activation = nn.GELU(approximate=approximate)
|
352 |
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
353 |
-
self.ff = nn.Sequential(
|
|
|
|
|
354 |
|
355 |
def forward(self, x):
|
356 |
return self.ff(x)
|
@@ -375,7 +404,9 @@ class Attention(nn.Module):
|
|
375 |
super().__init__()
|
376 |
|
377 |
if not hasattr(F, "scaled_dot_product_attention"):
|
378 |
-
raise ImportError(
|
|
|
|
|
379 |
|
380 |
self.processor = processor
|
381 |
|
@@ -435,19 +466,23 @@ class Attention(nn.Module):
|
|
435 |
# Attention processor
|
436 |
|
437 |
if is_package_available("flash_attn"):
|
|
|
438 |
from flash_attn.bert_padding import pad_input, unpad_input
|
439 |
-
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
440 |
|
441 |
|
442 |
class AttnProcessor:
|
443 |
def __init__(
|
444 |
self,
|
445 |
-
pe_attn_head:
|
|
|
|
|
446 |
attn_backend: str = "torch", # "torch" or "flash_attn"
|
447 |
attn_mask_enabled: bool = True,
|
448 |
):
|
449 |
if attn_backend == "flash_attn":
|
450 |
-
assert is_package_available(
|
|
|
|
|
451 |
|
452 |
self.pe_attn_head = pe_attn_head
|
453 |
self.attn_backend = attn_backend
|
@@ -483,12 +518,18 @@ class AttnProcessor:
|
|
483 |
# apply rotary position embedding
|
484 |
if rope is not None:
|
485 |
freqs, xpos_scale = rope
|
486 |
-
q_xpos_scale, k_xpos_scale = (
|
|
|
|
|
487 |
|
488 |
if self.pe_attn_head is not None:
|
489 |
pn = self.pe_attn_head
|
490 |
-
query[:, :pn, :, :] = apply_rotary_pos_emb(
|
491 |
-
|
|
|
|
|
|
|
|
|
492 |
else:
|
493 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
494 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
@@ -498,10 +539,14 @@ class AttnProcessor:
|
|
498 |
if self.attn_mask_enabled and mask is not None:
|
499 |
attn_mask = mask
|
500 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
501 |
-
attn_mask = attn_mask.expand(
|
|
|
|
|
502 |
else:
|
503 |
attn_mask = None
|
504 |
-
x = F.scaled_dot_product_attention(
|
|
|
|
|
505 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
506 |
|
507 |
elif self.attn_backend == "flash_attn":
|
@@ -509,7 +554,9 @@ class AttnProcessor:
|
|
509 |
key = key.transpose(1, 2)
|
510 |
value = value.transpose(1, 2)
|
511 |
if self.attn_mask_enabled and mask is not None:
|
512 |
-
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(
|
|
|
|
|
513 |
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
514 |
value, _, _, _, _ = unpad_input(value, mask)
|
515 |
x = flash_attn_varlen_func(
|
@@ -595,12 +642,16 @@ class JointAttnProcessor:
|
|
595 |
# apply rope for context and noised input independently
|
596 |
if rope is not None:
|
597 |
freqs, xpos_scale = rope
|
598 |
-
q_xpos_scale, k_xpos_scale = (
|
|
|
|
|
599 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
600 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
601 |
if c_rope is not None:
|
602 |
freqs, xpos_scale = c_rope
|
603 |
-
q_xpos_scale, k_xpos_scale = (
|
|
|
|
|
604 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
605 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
606 |
|
@@ -613,11 +664,15 @@ class JointAttnProcessor:
|
|
613 |
if mask is not None:
|
614 |
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
615 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
616 |
-
attn_mask = attn_mask.expand(
|
|
|
|
|
617 |
else:
|
618 |
attn_mask = None
|
619 |
|
620 |
-
x = F.scaled_dot_product_attention(
|
|
|
|
|
621 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
622 |
x = x.to(query.dtype)
|
623 |
|
@@ -675,7 +730,9 @@ class DiTBlock(nn.Module):
|
|
675 |
)
|
676 |
|
677 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
678 |
-
self.ff = FeedForward(
|
|
|
|
|
679 |
|
680 |
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
681 |
# pre-norm & modulation for attention input
|
@@ -708,14 +765,26 @@ class MMDiTBlock(nn.Module):
|
|
708 |
"""
|
709 |
|
710 |
def __init__(
|
711 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
712 |
):
|
713 |
super().__init__()
|
714 |
if context_dim is None:
|
715 |
context_dim = dim
|
716 |
self.context_pre_only = context_pre_only
|
717 |
|
718 |
-
self.attn_norm_c =
|
|
|
|
|
|
|
|
|
719 |
self.attn_norm_x = AdaLayerNorm(dim)
|
720 |
self.attn = Attention(
|
721 |
processor=JointAttnProcessor(),
|
@@ -729,24 +798,38 @@ class MMDiTBlock(nn.Module):
|
|
729 |
)
|
730 |
|
731 |
if not context_pre_only:
|
732 |
-
self.ff_norm_c = nn.LayerNorm(
|
733 |
-
|
|
|
|
|
|
|
|
|
734 |
else:
|
735 |
self.ff_norm_c = None
|
736 |
self.ff_c = None
|
737 |
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
738 |
-
self.ff_x = FeedForward(
|
|
|
|
|
739 |
|
740 |
-
def forward(
|
|
|
|
|
741 |
# pre-norm & modulation for attention input
|
742 |
if self.context_pre_only:
|
743 |
norm_c = self.attn_norm_c(c, t)
|
744 |
else:
|
745 |
-
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
|
746 |
-
|
|
|
|
|
|
|
|
|
747 |
|
748 |
# attention
|
749 |
-
x_attn_output, c_attn_output = self.attn(
|
|
|
|
|
750 |
|
751 |
# process attention output for context c
|
752 |
if self.context_pre_only:
|
@@ -754,7 +837,9 @@ class MMDiTBlock(nn.Module):
|
|
754 |
else: # if not last layer
|
755 |
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
756 |
|
757 |
-
norm_c =
|
|
|
|
|
758 |
c_ff_output = self.ff_c(norm_c)
|
759 |
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
760 |
|
@@ -775,7 +860,9 @@ class TimestepEmbedding(nn.Module):
|
|
775 |
def __init__(self, dim, freq_embed_dim=256):
|
776 |
super().__init__()
|
777 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
778 |
-
self.time_mlp = nn.Sequential(
|
|
|
|
|
779 |
|
780 |
def forward(self, timestep: float["b"]):
|
781 |
time_hidden = self.time_embed(timestep)
|
|
|
6 |
nw - raw wave length
|
7 |
d - dimension
|
8 |
"""
|
9 |
+
|
10 |
# flake8: noqa
|
11 |
|
12 |
from __future__ import annotations
|
|
|
23 |
|
24 |
from f5_tts.model_new.utils import is_package_available
|
25 |
|
|
|
26 |
# raw wav to mel spec
|
27 |
|
28 |
|
|
|
45 |
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
46 |
|
47 |
if key not in mel_basis_cache:
|
48 |
+
mel = librosa_mel_fn(
|
49 |
+
sr=target_sample_rate,
|
50 |
+
n_fft=n_fft,
|
51 |
+
n_mels=n_mel_channels,
|
52 |
+
fmin=fmin,
|
53 |
+
fmax=fmax,
|
54 |
+
)
|
55 |
+
mel_basis_cache[key] = (
|
56 |
+
torch.from_numpy(mel).float().to(device)
|
57 |
+
) # TODO: why they need .float()?
|
58 |
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
59 |
|
60 |
mel_basis = mel_basis_cache[key]
|
61 |
hann_window = hann_window_cache[key]
|
62 |
|
63 |
padding = (n_fft - hop_length) // 2
|
64 |
+
waveform = torch.nn.functional.pad(
|
65 |
+
waveform.unsqueeze(1), (padding, padding), mode="reflect"
|
66 |
+
).squeeze(1)
|
67 |
|
68 |
spec = torch.stft(
|
69 |
waveform,
|
|
|
125 |
mel_spec_type="vocos",
|
126 |
):
|
127 |
super().__init__()
|
128 |
+
assert mel_spec_type in ["vocos", "bigvgan"], print(
|
129 |
+
"We only support two extract mel backend: vocos or bigvgan"
|
130 |
+
)
|
131 |
|
132 |
self.n_fft = n_fft
|
133 |
self.hop_length = hop_length
|
|
|
208 |
# rotary positional embedding related
|
209 |
|
210 |
|
211 |
+
def precompute_freqs_cis(
|
212 |
+
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
|
213 |
+
):
|
214 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
215 |
# has some connection to NTK literature
|
216 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
|
226 |
|
227 |
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
228 |
# length = length if isinstance(length, int) else length.max()
|
229 |
+
scale = scale * torch.ones_like(
|
230 |
+
start, dtype=torch.float32
|
231 |
+
) # in case scale is a scalar
|
232 |
pos = (
|
233 |
start.unsqueeze(1)
|
234 |
+
+ (
|
235 |
+
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
|
236 |
+
* scale.unsqueeze(1)
|
237 |
+
).long()
|
238 |
)
|
239 |
# avoid extra long error.
|
240 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
|
|
273 |
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
274 |
) # depthwise conv
|
275 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
276 |
+
self.pwconv1 = nn.Linear(
|
277 |
+
dim, intermediate_dim
|
278 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
279 |
self.act = nn.GELU()
|
280 |
self.grn = GRN(intermediate_dim)
|
281 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
|
307 |
if self.native_rms_norm:
|
308 |
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
309 |
x = x.to(self.weight.dtype)
|
310 |
+
x = F.rms_norm(
|
311 |
+
x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps
|
312 |
+
)
|
313 |
else:
|
314 |
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
315 |
x = x * torch.rsqrt(variance + self.eps)
|
|
|
335 |
|
336 |
def forward(self, x, emb=None):
|
337 |
emb = self.linear(self.silu(emb))
|
338 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
|
339 |
+
emb, 6, dim=1
|
340 |
+
)
|
341 |
|
342 |
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
343 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
|
368 |
|
369 |
|
370 |
class FeedForward(nn.Module):
|
371 |
+
def __init__(
|
372 |
+
self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
|
373 |
+
):
|
374 |
super().__init__()
|
375 |
inner_dim = int(dim * mult)
|
376 |
dim_out = dim_out if dim_out is not None else dim
|
377 |
|
378 |
activation = nn.GELU(approximate=approximate)
|
379 |
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
380 |
+
self.ff = nn.Sequential(
|
381 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
382 |
+
)
|
383 |
|
384 |
def forward(self, x):
|
385 |
return self.ff(x)
|
|
|
404 |
super().__init__()
|
405 |
|
406 |
if not hasattr(F, "scaled_dot_product_attention"):
|
407 |
+
raise ImportError(
|
408 |
+
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
409 |
+
)
|
410 |
|
411 |
self.processor = processor
|
412 |
|
|
|
466 |
# Attention processor
|
467 |
|
468 |
if is_package_available("flash_attn"):
|
469 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
470 |
from flash_attn.bert_padding import pad_input, unpad_input
|
|
|
471 |
|
472 |
|
473 |
class AttnProcessor:
|
474 |
def __init__(
|
475 |
self,
|
476 |
+
pe_attn_head: (
|
477 |
+
int | None
|
478 |
+
) = None, # number of attention head to apply rope, None for all
|
479 |
attn_backend: str = "torch", # "torch" or "flash_attn"
|
480 |
attn_mask_enabled: bool = True,
|
481 |
):
|
482 |
if attn_backend == "flash_attn":
|
483 |
+
assert is_package_available(
|
484 |
+
"flash_attn"
|
485 |
+
), "Please install flash-attn first."
|
486 |
|
487 |
self.pe_attn_head = pe_attn_head
|
488 |
self.attn_backend = attn_backend
|
|
|
518 |
# apply rotary position embedding
|
519 |
if rope is not None:
|
520 |
freqs, xpos_scale = rope
|
521 |
+
q_xpos_scale, k_xpos_scale = (
|
522 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
523 |
+
)
|
524 |
|
525 |
if self.pe_attn_head is not None:
|
526 |
pn = self.pe_attn_head
|
527 |
+
query[:, :pn, :, :] = apply_rotary_pos_emb(
|
528 |
+
query[:, :pn, :, :], freqs, q_xpos_scale
|
529 |
+
)
|
530 |
+
key[:, :pn, :, :] = apply_rotary_pos_emb(
|
531 |
+
key[:, :pn, :, :], freqs, k_xpos_scale
|
532 |
+
)
|
533 |
else:
|
534 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
535 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
|
|
539 |
if self.attn_mask_enabled and mask is not None:
|
540 |
attn_mask = mask
|
541 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
542 |
+
attn_mask = attn_mask.expand(
|
543 |
+
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
544 |
+
)
|
545 |
else:
|
546 |
attn_mask = None
|
547 |
+
x = F.scaled_dot_product_attention(
|
548 |
+
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
549 |
+
)
|
550 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
551 |
|
552 |
elif self.attn_backend == "flash_attn":
|
|
|
554 |
key = key.transpose(1, 2)
|
555 |
value = value.transpose(1, 2)
|
556 |
if self.attn_mask_enabled and mask is not None:
|
557 |
+
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(
|
558 |
+
query, mask
|
559 |
+
)
|
560 |
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
561 |
value, _, _, _, _ = unpad_input(value, mask)
|
562 |
x = flash_attn_varlen_func(
|
|
|
642 |
# apply rope for context and noised input independently
|
643 |
if rope is not None:
|
644 |
freqs, xpos_scale = rope
|
645 |
+
q_xpos_scale, k_xpos_scale = (
|
646 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
647 |
+
)
|
648 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
649 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
650 |
if c_rope is not None:
|
651 |
freqs, xpos_scale = c_rope
|
652 |
+
q_xpos_scale, k_xpos_scale = (
|
653 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
654 |
+
)
|
655 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
656 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
657 |
|
|
|
664 |
if mask is not None:
|
665 |
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
666 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
667 |
+
attn_mask = attn_mask.expand(
|
668 |
+
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
669 |
+
)
|
670 |
else:
|
671 |
attn_mask = None
|
672 |
|
673 |
+
x = F.scaled_dot_product_attention(
|
674 |
+
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
675 |
+
)
|
676 |
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
677 |
x = x.to(query.dtype)
|
678 |
|
|
|
730 |
)
|
731 |
|
732 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
733 |
+
self.ff = FeedForward(
|
734 |
+
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
735 |
+
)
|
736 |
|
737 |
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
738 |
# pre-norm & modulation for attention input
|
|
|
765 |
"""
|
766 |
|
767 |
def __init__(
|
768 |
+
self,
|
769 |
+
dim,
|
770 |
+
heads,
|
771 |
+
dim_head,
|
772 |
+
ff_mult=4,
|
773 |
+
dropout=0.1,
|
774 |
+
context_dim=None,
|
775 |
+
context_pre_only=False,
|
776 |
+
qk_norm=None,
|
777 |
):
|
778 |
super().__init__()
|
779 |
if context_dim is None:
|
780 |
context_dim = dim
|
781 |
self.context_pre_only = context_pre_only
|
782 |
|
783 |
+
self.attn_norm_c = (
|
784 |
+
AdaLayerNorm_Final(context_dim)
|
785 |
+
if context_pre_only
|
786 |
+
else AdaLayerNorm(context_dim)
|
787 |
+
)
|
788 |
self.attn_norm_x = AdaLayerNorm(dim)
|
789 |
self.attn = Attention(
|
790 |
processor=JointAttnProcessor(),
|
|
|
798 |
)
|
799 |
|
800 |
if not context_pre_only:
|
801 |
+
self.ff_norm_c = nn.LayerNorm(
|
802 |
+
context_dim, elementwise_affine=False, eps=1e-6
|
803 |
+
)
|
804 |
+
self.ff_c = FeedForward(
|
805 |
+
dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
806 |
+
)
|
807 |
else:
|
808 |
self.ff_norm_c = None
|
809 |
self.ff_c = None
|
810 |
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
811 |
+
self.ff_x = FeedForward(
|
812 |
+
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
813 |
+
)
|
814 |
|
815 |
+
def forward(
|
816 |
+
self, x, c, t, mask=None, rope=None, c_rope=None
|
817 |
+
): # x: noised input, c: context, t: time embedding
|
818 |
# pre-norm & modulation for attention input
|
819 |
if self.context_pre_only:
|
820 |
norm_c = self.attn_norm_c(c, t)
|
821 |
else:
|
822 |
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
|
823 |
+
c, emb=t
|
824 |
+
)
|
825 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
|
826 |
+
x, emb=t
|
827 |
+
)
|
828 |
|
829 |
# attention
|
830 |
+
x_attn_output, c_attn_output = self.attn(
|
831 |
+
x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
|
832 |
+
)
|
833 |
|
834 |
# process attention output for context c
|
835 |
if self.context_pre_only:
|
|
|
837 |
else: # if not last layer
|
838 |
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
839 |
|
840 |
+
norm_c = (
|
841 |
+
self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
842 |
+
)
|
843 |
c_ff_output = self.ff_c(norm_c)
|
844 |
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
845 |
|
|
|
860 |
def __init__(self, dim, freq_embed_dim=256):
|
861 |
super().__init__()
|
862 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
863 |
+
self.time_mlp = nn.Sequential(
|
864 |
+
nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
865 |
+
)
|
866 |
|
867 |
def forward(self, timestep: float["b"]):
|
868 |
time_hidden = self.time_embed(timestep)
|
f5_tts/model_new/trainer.py
CHANGED
@@ -19,7 +19,6 @@ from f5_tts.model import CFM
|
|
19 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
20 |
from f5_tts.model.utils import default, exists
|
21 |
|
22 |
-
|
23 |
# trainer
|
24 |
|
25 |
|
@@ -70,7 +69,13 @@ class Trainer:
|
|
70 |
self.logger = logger
|
71 |
if self.logger == "wandb":
|
72 |
if exists(wandb_resume_id):
|
73 |
-
init_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
76 |
|
@@ -138,7 +143,9 @@ class Trainer:
|
|
138 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
139 |
else:
|
140 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
141 |
-
self.model, self.optimizer = self.accelerator.prepare(
|
|
|
|
|
142 |
|
143 |
@property
|
144 |
def is_main(self):
|
@@ -157,12 +164,16 @@ class Trainer:
|
|
157 |
if not os.path.exists(self.checkpoint_path):
|
158 |
os.makedirs(self.checkpoint_path)
|
159 |
if last:
|
160 |
-
self.accelerator.save(
|
|
|
|
|
161 |
print(f"Saved last checkpoint at update {update}")
|
162 |
else:
|
163 |
if self.keep_last_n_checkpoints == 0:
|
164 |
return
|
165 |
-
self.accelerator.save(
|
|
|
|
|
166 |
if self.keep_last_n_checkpoints > 0:
|
167 |
# Updated logic to exclude pretrained model from rotation
|
168 |
checkpoints = [
|
@@ -183,7 +194,10 @@ class Trainer:
|
|
183 |
if (
|
184 |
not exists(self.checkpoint_path)
|
185 |
or not os.path.exists(self.checkpoint_path)
|
186 |
-
or not any(
|
|
|
|
|
|
|
187 |
):
|
188 |
return 0
|
189 |
|
@@ -195,11 +209,16 @@ class Trainer:
|
|
195 |
all_checkpoints = [
|
196 |
f
|
197 |
for f in os.listdir(self.checkpoint_path)
|
198 |
-
if (f.startswith("model_") or f.startswith("pretrained_"))
|
|
|
199 |
]
|
200 |
|
201 |
# First try to find regular training checkpoints
|
202 |
-
training_checkpoints = [
|
|
|
|
|
|
|
|
|
203 |
if training_checkpoints:
|
204 |
latest_checkpoint = sorted(
|
205 |
training_checkpoints,
|
@@ -207,21 +226,30 @@ class Trainer:
|
|
207 |
)[-1]
|
208 |
else:
|
209 |
# If no training checkpoints, use pretrained model
|
210 |
-
latest_checkpoint = next(
|
|
|
|
|
211 |
|
212 |
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
|
213 |
from safetensors.torch import load_file
|
214 |
|
215 |
-
checkpoint = load_file(
|
|
|
|
|
216 |
checkpoint = {"ema_model_state_dict": checkpoint}
|
217 |
elif latest_checkpoint.endswith(".pt"):
|
218 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
219 |
checkpoint = torch.load(
|
220 |
-
f"{self.checkpoint_path}/{latest_checkpoint}",
|
|
|
|
|
221 |
)
|
222 |
|
223 |
# patch for backward compatibility, 305e3ea
|
224 |
-
for key in [
|
|
|
|
|
|
|
225 |
if key in checkpoint["ema_model_state_dict"]:
|
226 |
del checkpoint["ema_model_state_dict"][key]
|
227 |
|
@@ -231,17 +259,24 @@ class Trainer:
|
|
231 |
if "update" in checkpoint or "step" in checkpoint:
|
232 |
# patch for backward compatibility, with before f992c4e
|
233 |
if "step" in checkpoint:
|
234 |
-
checkpoint["update"] =
|
|
|
|
|
235 |
if self.grad_accumulation_steps > 1 and self.is_main:
|
236 |
print(
|
237 |
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
|
238 |
)
|
239 |
# patch for backward compatibility, 305e3ea
|
240 |
-
for key in [
|
|
|
|
|
|
|
241 |
if key in checkpoint["model_state_dict"]:
|
242 |
del checkpoint["model_state_dict"][key]
|
243 |
|
244 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
|
|
|
|
245 |
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
246 |
if self.scheduler:
|
247 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
@@ -252,21 +287,30 @@ class Trainer:
|
|
252 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
253 |
if k not in ["initted", "update", "step"]
|
254 |
}
|
255 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(
|
|
|
|
|
256 |
update = 0
|
257 |
|
258 |
del checkpoint
|
259 |
gc.collect()
|
260 |
return update
|
261 |
|
262 |
-
def train(
|
|
|
|
|
263 |
if self.log_samples:
|
264 |
-
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder,
|
|
|
265 |
|
266 |
vocoder = load_vocoder(
|
267 |
-
vocoder_name=self.vocoder_name,
|
|
|
|
|
268 |
)
|
269 |
-
target_sample_rate = self.accelerator.unwrap_model(
|
|
|
|
|
270 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
271 |
os.makedirs(log_samples_path, exist_ok=True)
|
272 |
|
@@ -306,7 +350,9 @@ class Trainer:
|
|
306 |
batch_sampler=batch_sampler,
|
307 |
)
|
308 |
else:
|
309 |
-
raise ValueError(
|
|
|
|
|
310 |
|
311 |
# accelerator.prepare() dispatches batches to devices;
|
312 |
# which means the length of dataloader calculated before, should consider the number of devices
|
@@ -314,12 +360,24 @@ class Trainer:
|
|
314 |
self.num_warmup_updates * self.accelerator.num_processes
|
315 |
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
316 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
317 |
-
total_updates =
|
|
|
|
|
|
|
318 |
decay_updates = total_updates - warmup_updates
|
319 |
-
warmup_scheduler = LinearLR(
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
self.scheduler = SequentialLR(
|
322 |
-
self.optimizer,
|
|
|
|
|
323 |
)
|
324 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
325 |
train_dataloader, self.scheduler
|
@@ -332,21 +390,27 @@ class Trainer:
|
|
332 |
start_step = start_update * self.grad_accumulation_steps
|
333 |
skipped_epoch = int(start_step // orig_epoch_step)
|
334 |
skipped_batch = start_step % orig_epoch_step
|
335 |
-
skipped_dataloader = self.accelerator.skip_first_batches(
|
|
|
|
|
336 |
else:
|
337 |
skipped_epoch = 0
|
338 |
|
339 |
for epoch in range(skipped_epoch, self.epochs):
|
340 |
self.model.train()
|
341 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
342 |
-
progress_bar_initial = math.ceil(
|
|
|
|
|
343 |
current_dataloader = skipped_dataloader
|
344 |
else:
|
345 |
progress_bar_initial = 0
|
346 |
current_dataloader = train_dataloader
|
347 |
|
348 |
# Set epoch for the batch sampler if it exists
|
349 |
-
if hasattr(train_dataloader, "batch_sampler") and hasattr(
|
|
|
|
|
350 |
train_dataloader.batch_sampler.set_epoch(epoch)
|
351 |
|
352 |
progress_bar = tqdm(
|
@@ -364,17 +428,29 @@ class Trainer:
|
|
364 |
mel_lengths = batch["mel_lengths"]
|
365 |
|
366 |
# TODO. add duration predictor training
|
367 |
-
if
|
368 |
-
|
369 |
-
self.accelerator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
loss, cond, pred = self.model(
|
372 |
-
mel_spec,
|
|
|
|
|
|
|
373 |
)
|
374 |
self.accelerator.backward(loss)
|
375 |
|
376 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
377 |
-
self.accelerator.clip_grad_norm_(
|
|
|
|
|
378 |
|
379 |
self.optimizer.step()
|
380 |
self.scheduler.step()
|
@@ -386,29 +462,44 @@ class Trainer:
|
|
386 |
|
387 |
global_update += 1
|
388 |
progress_bar.update(1)
|
389 |
-
progress_bar.set_postfix(
|
|
|
|
|
390 |
|
391 |
if self.accelerator.is_local_main_process:
|
392 |
self.accelerator.log(
|
393 |
-
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]},
|
|
|
394 |
)
|
395 |
if self.logger == "tensorboard":
|
396 |
self.writer.add_scalar("loss", loss.item(), global_update)
|
397 |
-
self.writer.add_scalar(
|
|
|
|
|
398 |
|
399 |
-
if
|
|
|
|
|
|
|
400 |
self.save_checkpoint(global_update, last=True)
|
401 |
|
402 |
-
if
|
|
|
|
|
|
|
403 |
self.save_checkpoint(global_update)
|
404 |
|
405 |
if self.log_samples and self.accelerator.is_local_main_process:
|
406 |
ref_audio_len = mel_lengths[0]
|
407 |
infer_text = [
|
408 |
-
text_inputs[0]
|
|
|
|
|
409 |
]
|
410 |
with torch.inference_mode():
|
411 |
-
generated, _ = self.accelerator.unwrap_model(
|
|
|
|
|
412 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
413 |
text=infer_text,
|
414 |
duration=ref_audio_len * 2,
|
@@ -417,7 +508,11 @@ class Trainer:
|
|
417 |
sway_sampling_coef=sway_sampling_coef,
|
418 |
)
|
419 |
generated = generated.to(torch.float32)
|
420 |
-
gen_mel_spec =
|
|
|
|
|
|
|
|
|
421 |
ref_mel_spec = batch["mel"][0].unsqueeze(0)
|
422 |
if self.vocoder_name == "vocos":
|
423 |
gen_audio = vocoder.decode(gen_mel_spec).cpu()
|
@@ -427,10 +522,14 @@ class Trainer:
|
|
427 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
428 |
|
429 |
torchaudio.save(
|
430 |
-
f"{log_samples_path}/update_{global_update}_gen.wav",
|
|
|
|
|
431 |
)
|
432 |
torchaudio.save(
|
433 |
-
f"{log_samples_path}/update_{global_update}_ref.wav",
|
|
|
|
|
434 |
)
|
435 |
self.model.train()
|
436 |
|
|
|
19 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
20 |
from f5_tts.model.utils import default, exists
|
21 |
|
|
|
22 |
# trainer
|
23 |
|
24 |
|
|
|
69 |
self.logger = logger
|
70 |
if self.logger == "wandb":
|
71 |
if exists(wandb_resume_id):
|
72 |
+
init_kwargs = {
|
73 |
+
"wandb": {
|
74 |
+
"resume": "allow",
|
75 |
+
"name": wandb_run_name,
|
76 |
+
"id": wandb_resume_id,
|
77 |
+
}
|
78 |
+
}
|
79 |
else:
|
80 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
81 |
|
|
|
143 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
144 |
else:
|
145 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
146 |
+
self.model, self.optimizer = self.accelerator.prepare(
|
147 |
+
self.model, self.optimizer
|
148 |
+
)
|
149 |
|
150 |
@property
|
151 |
def is_main(self):
|
|
|
164 |
if not os.path.exists(self.checkpoint_path):
|
165 |
os.makedirs(self.checkpoint_path)
|
166 |
if last:
|
167 |
+
self.accelerator.save(
|
168 |
+
checkpoint, f"{self.checkpoint_path}/model_last.pt"
|
169 |
+
)
|
170 |
print(f"Saved last checkpoint at update {update}")
|
171 |
else:
|
172 |
if self.keep_last_n_checkpoints == 0:
|
173 |
return
|
174 |
+
self.accelerator.save(
|
175 |
+
checkpoint, f"{self.checkpoint_path}/model_{update}.pt"
|
176 |
+
)
|
177 |
if self.keep_last_n_checkpoints > 0:
|
178 |
# Updated logic to exclude pretrained model from rotation
|
179 |
checkpoints = [
|
|
|
194 |
if (
|
195 |
not exists(self.checkpoint_path)
|
196 |
or not os.path.exists(self.checkpoint_path)
|
197 |
+
or not any(
|
198 |
+
filename.endswith((".pt", ".safetensors"))
|
199 |
+
for filename in os.listdir(self.checkpoint_path)
|
200 |
+
)
|
201 |
):
|
202 |
return 0
|
203 |
|
|
|
209 |
all_checkpoints = [
|
210 |
f
|
211 |
for f in os.listdir(self.checkpoint_path)
|
212 |
+
if (f.startswith("model_") or f.startswith("pretrained_"))
|
213 |
+
and f.endswith((".pt", ".safetensors"))
|
214 |
]
|
215 |
|
216 |
# First try to find regular training checkpoints
|
217 |
+
training_checkpoints = [
|
218 |
+
f
|
219 |
+
for f in all_checkpoints
|
220 |
+
if f.startswith("model_") and f != "model_last.pt"
|
221 |
+
]
|
222 |
if training_checkpoints:
|
223 |
latest_checkpoint = sorted(
|
224 |
training_checkpoints,
|
|
|
226 |
)[-1]
|
227 |
else:
|
228 |
# If no training checkpoints, use pretrained model
|
229 |
+
latest_checkpoint = next(
|
230 |
+
f for f in all_checkpoints if f.startswith("pretrained_")
|
231 |
+
)
|
232 |
|
233 |
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
|
234 |
from safetensors.torch import load_file
|
235 |
|
236 |
+
checkpoint = load_file(
|
237 |
+
f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu"
|
238 |
+
)
|
239 |
checkpoint = {"ema_model_state_dict": checkpoint}
|
240 |
elif latest_checkpoint.endswith(".pt"):
|
241 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
242 |
checkpoint = torch.load(
|
243 |
+
f"{self.checkpoint_path}/{latest_checkpoint}",
|
244 |
+
weights_only=True,
|
245 |
+
map_location="cpu",
|
246 |
)
|
247 |
|
248 |
# patch for backward compatibility, 305e3ea
|
249 |
+
for key in [
|
250 |
+
"ema_model.mel_spec.mel_stft.mel_scale.fb",
|
251 |
+
"ema_model.mel_spec.mel_stft.spectrogram.window",
|
252 |
+
]:
|
253 |
if key in checkpoint["ema_model_state_dict"]:
|
254 |
del checkpoint["ema_model_state_dict"][key]
|
255 |
|
|
|
259 |
if "update" in checkpoint or "step" in checkpoint:
|
260 |
# patch for backward compatibility, with before f992c4e
|
261 |
if "step" in checkpoint:
|
262 |
+
checkpoint["update"] = (
|
263 |
+
checkpoint["step"] // self.grad_accumulation_steps
|
264 |
+
)
|
265 |
if self.grad_accumulation_steps > 1 and self.is_main:
|
266 |
print(
|
267 |
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
|
268 |
)
|
269 |
# patch for backward compatibility, 305e3ea
|
270 |
+
for key in [
|
271 |
+
"mel_spec.mel_stft.mel_scale.fb",
|
272 |
+
"mel_spec.mel_stft.spectrogram.window",
|
273 |
+
]:
|
274 |
if key in checkpoint["model_state_dict"]:
|
275 |
del checkpoint["model_state_dict"][key]
|
276 |
|
277 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
278 |
+
checkpoint["model_state_dict"]
|
279 |
+
)
|
280 |
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
281 |
if self.scheduler:
|
282 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
|
287 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
288 |
if k not in ["initted", "update", "step"]
|
289 |
}
|
290 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(
|
291 |
+
checkpoint["model_state_dict"]
|
292 |
+
)
|
293 |
update = 0
|
294 |
|
295 |
del checkpoint
|
296 |
gc.collect()
|
297 |
return update
|
298 |
|
299 |
+
def train(
|
300 |
+
self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None
|
301 |
+
):
|
302 |
if self.log_samples:
|
303 |
+
from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
|
304 |
+
nfe_step, sway_sampling_coef)
|
305 |
|
306 |
vocoder = load_vocoder(
|
307 |
+
vocoder_name=self.vocoder_name,
|
308 |
+
is_local=self.is_local_vocoder,
|
309 |
+
local_path=self.local_vocoder_path,
|
310 |
)
|
311 |
+
target_sample_rate = self.accelerator.unwrap_model(
|
312 |
+
self.model
|
313 |
+
).mel_spec.target_sample_rate
|
314 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
315 |
os.makedirs(log_samples_path, exist_ok=True)
|
316 |
|
|
|
350 |
batch_sampler=batch_sampler,
|
351 |
)
|
352 |
else:
|
353 |
+
raise ValueError(
|
354 |
+
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
|
355 |
+
)
|
356 |
|
357 |
# accelerator.prepare() dispatches batches to devices;
|
358 |
# which means the length of dataloader calculated before, should consider the number of devices
|
|
|
360 |
self.num_warmup_updates * self.accelerator.num_processes
|
361 |
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
362 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
363 |
+
total_updates = (
|
364 |
+
math.ceil(len(train_dataloader) / self.grad_accumulation_steps)
|
365 |
+
* self.epochs
|
366 |
+
)
|
367 |
decay_updates = total_updates - warmup_updates
|
368 |
+
warmup_scheduler = LinearLR(
|
369 |
+
self.optimizer,
|
370 |
+
start_factor=1e-8,
|
371 |
+
end_factor=1.0,
|
372 |
+
total_iters=warmup_updates,
|
373 |
+
)
|
374 |
+
decay_scheduler = LinearLR(
|
375 |
+
self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates
|
376 |
+
)
|
377 |
self.scheduler = SequentialLR(
|
378 |
+
self.optimizer,
|
379 |
+
schedulers=[warmup_scheduler, decay_scheduler],
|
380 |
+
milestones=[warmup_updates],
|
381 |
)
|
382 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
383 |
train_dataloader, self.scheduler
|
|
|
390 |
start_step = start_update * self.grad_accumulation_steps
|
391 |
skipped_epoch = int(start_step // orig_epoch_step)
|
392 |
skipped_batch = start_step % orig_epoch_step
|
393 |
+
skipped_dataloader = self.accelerator.skip_first_batches(
|
394 |
+
train_dataloader, num_batches=skipped_batch
|
395 |
+
)
|
396 |
else:
|
397 |
skipped_epoch = 0
|
398 |
|
399 |
for epoch in range(skipped_epoch, self.epochs):
|
400 |
self.model.train()
|
401 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
402 |
+
progress_bar_initial = math.ceil(
|
403 |
+
skipped_batch / self.grad_accumulation_steps
|
404 |
+
)
|
405 |
current_dataloader = skipped_dataloader
|
406 |
else:
|
407 |
progress_bar_initial = 0
|
408 |
current_dataloader = train_dataloader
|
409 |
|
410 |
# Set epoch for the batch sampler if it exists
|
411 |
+
if hasattr(train_dataloader, "batch_sampler") and hasattr(
|
412 |
+
train_dataloader.batch_sampler, "set_epoch"
|
413 |
+
):
|
414 |
train_dataloader.batch_sampler.set_epoch(epoch)
|
415 |
|
416 |
progress_bar = tqdm(
|
|
|
428 |
mel_lengths = batch["mel_lengths"]
|
429 |
|
430 |
# TODO. add duration predictor training
|
431 |
+
if (
|
432 |
+
self.duration_predictor is not None
|
433 |
+
and self.accelerator.is_local_main_process
|
434 |
+
):
|
435 |
+
dur_loss = self.duration_predictor(
|
436 |
+
mel_spec, lens=batch.get("durations")
|
437 |
+
)
|
438 |
+
self.accelerator.log(
|
439 |
+
{"duration loss": dur_loss.item()}, step=global_update
|
440 |
+
)
|
441 |
|
442 |
loss, cond, pred = self.model(
|
443 |
+
mel_spec,
|
444 |
+
text=text_inputs,
|
445 |
+
lens=mel_lengths,
|
446 |
+
noise_scheduler=self.noise_scheduler,
|
447 |
)
|
448 |
self.accelerator.backward(loss)
|
449 |
|
450 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
451 |
+
self.accelerator.clip_grad_norm_(
|
452 |
+
self.model.parameters(), self.max_grad_norm
|
453 |
+
)
|
454 |
|
455 |
self.optimizer.step()
|
456 |
self.scheduler.step()
|
|
|
462 |
|
463 |
global_update += 1
|
464 |
progress_bar.update(1)
|
465 |
+
progress_bar.set_postfix(
|
466 |
+
update=str(global_update), loss=loss.item()
|
467 |
+
)
|
468 |
|
469 |
if self.accelerator.is_local_main_process:
|
470 |
self.accelerator.log(
|
471 |
+
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]},
|
472 |
+
step=global_update,
|
473 |
)
|
474 |
if self.logger == "tensorboard":
|
475 |
self.writer.add_scalar("loss", loss.item(), global_update)
|
476 |
+
self.writer.add_scalar(
|
477 |
+
"lr", self.scheduler.get_last_lr()[0], global_update
|
478 |
+
)
|
479 |
|
480 |
+
if (
|
481 |
+
global_update % self.last_per_updates == 0
|
482 |
+
and self.accelerator.sync_gradients
|
483 |
+
):
|
484 |
self.save_checkpoint(global_update, last=True)
|
485 |
|
486 |
+
if (
|
487 |
+
global_update % self.save_per_updates == 0
|
488 |
+
and self.accelerator.sync_gradients
|
489 |
+
):
|
490 |
self.save_checkpoint(global_update)
|
491 |
|
492 |
if self.log_samples and self.accelerator.is_local_main_process:
|
493 |
ref_audio_len = mel_lengths[0]
|
494 |
infer_text = [
|
495 |
+
text_inputs[0]
|
496 |
+
+ ([" "] if isinstance(text_inputs[0], list) else " ")
|
497 |
+
+ text_inputs[0]
|
498 |
]
|
499 |
with torch.inference_mode():
|
500 |
+
generated, _ = self.accelerator.unwrap_model(
|
501 |
+
self.model
|
502 |
+
).sample(
|
503 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
504 |
text=infer_text,
|
505 |
duration=ref_audio_len * 2,
|
|
|
508 |
sway_sampling_coef=sway_sampling_coef,
|
509 |
)
|
510 |
generated = generated.to(torch.float32)
|
511 |
+
gen_mel_spec = (
|
512 |
+
generated[:, ref_audio_len:, :]
|
513 |
+
.permute(0, 2, 1)
|
514 |
+
.to(self.accelerator.device)
|
515 |
+
)
|
516 |
ref_mel_spec = batch["mel"][0].unsqueeze(0)
|
517 |
if self.vocoder_name == "vocos":
|
518 |
gen_audio = vocoder.decode(gen_mel_spec).cpu()
|
|
|
522 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
523 |
|
524 |
torchaudio.save(
|
525 |
+
f"{log_samples_path}/update_{global_update}_gen.wav",
|
526 |
+
gen_audio,
|
527 |
+
target_sample_rate,
|
528 |
)
|
529 |
torchaudio.save(
|
530 |
+
f"{log_samples_path}/update_{global_update}_ref.wav",
|
531 |
+
ref_audio,
|
532 |
+
target_sample_rate,
|
533 |
)
|
534 |
self.model.train()
|
535 |
|
f5_tts/model_new/utils.py
CHANGED
@@ -10,7 +10,6 @@ import torch
|
|
10 |
from pypinyin import Style, lazy_pinyin
|
11 |
from torch.nn.utils.rnn import pad_sequence
|
12 |
|
13 |
-
|
14 |
# seed everything
|
15 |
|
16 |
|
@@ -48,7 +47,9 @@ def is_package_available(package_name: str) -> bool:
|
|
48 |
# tensor helpers
|
49 |
|
50 |
|
51 |
-
def lens_to_mask(
|
|
|
|
|
52 |
if not exists(length):
|
53 |
length = t.amax()
|
54 |
|
@@ -56,7 +57,9 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa
|
|
56 |
return seq[None, :] < t[:, None]
|
57 |
|
58 |
|
59 |
-
def mask_from_start_end_indices(
|
|
|
|
|
60 |
max_seq_len = seq_len.max().item()
|
61 |
seq = torch.arange(max_seq_len, device=start.device).long()
|
62 |
start_mask = seq[None, :] >= start[:, None]
|
@@ -64,7 +67,9 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"
|
|
64 |
return start_mask & end_mask
|
65 |
|
66 |
|
67 |
-
def mask_from_frac_lengths(
|
|
|
|
|
68 |
lengths = (frac_lengths * seq_len).long()
|
69 |
max_start = seq_len - lengths
|
70 |
|
@@ -75,7 +80,9 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa
|
|
75 |
return mask_from_start_end_indices(seq_len, start, end)
|
76 |
|
77 |
|
78 |
-
def maybe_masked_mean(
|
|
|
|
|
79 |
if not exists(mask):
|
80 |
return t.mean(dim=1)
|
81 |
|
@@ -99,7 +106,9 @@ def list_str_to_idx(
|
|
99 |
vocab_char_map: dict[str, int], # {char: idx}
|
100 |
padding_value=-1,
|
101 |
) -> int["b nt"]: # noqa: F722
|
102 |
-
list_idx_tensors = [
|
|
|
|
|
103 |
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
104 |
return text
|
105 |
|
@@ -118,13 +127,18 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
118 |
- if use "byte", set to 256 (unicode byte range)
|
119 |
"""
|
120 |
if tokenizer in ["pinyin", "char"]:
|
121 |
-
tokenizer_path = os.path.join(
|
|
|
|
|
|
|
122 |
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
123 |
vocab_char_map = {}
|
124 |
for i, char in enumerate(f):
|
125 |
vocab_char_map[char[:-1]] = i
|
126 |
vocab_size = len(vocab_char_map)
|
127 |
-
assert
|
|
|
|
|
128 |
|
129 |
elif tokenizer == "byte":
|
130 |
vocab_char_map = None
|
@@ -154,9 +168,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
|
154 |
) # add custom trans here, to address oov
|
155 |
|
156 |
def is_chinese(c):
|
157 |
-
return
|
158 |
-
"\u3100" <= c <= "\u9fff" # common chinese characters
|
159 |
-
)
|
160 |
|
161 |
for text in text_list:
|
162 |
char_list = []
|
@@ -167,7 +179,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
|
167 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
168 |
char_list.append(" ")
|
169 |
char_list.extend(seg)
|
170 |
-
elif polyphone and seg_byte_len == 3 * len(
|
|
|
|
|
171 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
172 |
for i, c in enumerate(seg):
|
173 |
if is_chinese(c):
|
@@ -179,7 +193,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
|
179 |
char_list.extend(c)
|
180 |
elif is_chinese(c):
|
181 |
char_list.append(" ")
|
182 |
-
char_list.extend(
|
|
|
|
|
183 |
else:
|
184 |
char_list.append(c)
|
185 |
final_text_list.append(char_list)
|
|
|
10 |
from pypinyin import Style, lazy_pinyin
|
11 |
from torch.nn.utils.rnn import pad_sequence
|
12 |
|
|
|
13 |
# seed everything
|
14 |
|
15 |
|
|
|
47 |
# tensor helpers
|
48 |
|
49 |
|
50 |
+
def lens_to_mask(
|
51 |
+
t: int["b"], length: int | None = None
|
52 |
+
) -> bool["b n"]: # noqa: F722 F821
|
53 |
if not exists(length):
|
54 |
length = t.amax()
|
55 |
|
|
|
57 |
return seq[None, :] < t[:, None]
|
58 |
|
59 |
|
60 |
+
def mask_from_start_end_indices(
|
61 |
+
seq_len: int["b"], start: int["b"], end: int["b"]
|
62 |
+
): # noqa: F722 F821
|
63 |
max_seq_len = seq_len.max().item()
|
64 |
seq = torch.arange(max_seq_len, device=start.device).long()
|
65 |
start_mask = seq[None, :] >= start[:, None]
|
|
|
67 |
return start_mask & end_mask
|
68 |
|
69 |
|
70 |
+
def mask_from_frac_lengths(
|
71 |
+
seq_len: int["b"], frac_lengths: float["b"]
|
72 |
+
): # noqa: F722 F821
|
73 |
lengths = (frac_lengths * seq_len).long()
|
74 |
max_start = seq_len - lengths
|
75 |
|
|
|
80 |
return mask_from_start_end_indices(seq_len, start, end)
|
81 |
|
82 |
|
83 |
+
def maybe_masked_mean(
|
84 |
+
t: float["b n d"], mask: bool["b n"] = None
|
85 |
+
) -> float["b d"]: # noqa: F722
|
86 |
if not exists(mask):
|
87 |
return t.mean(dim=1)
|
88 |
|
|
|
106 |
vocab_char_map: dict[str, int], # {char: idx}
|
107 |
padding_value=-1,
|
108 |
) -> int["b nt"]: # noqa: F722
|
109 |
+
list_idx_tensors = [
|
110 |
+
torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
|
111 |
+
] # pinyin or char style
|
112 |
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
113 |
return text
|
114 |
|
|
|
127 |
- if use "byte", set to 256 (unicode byte range)
|
128 |
"""
|
129 |
if tokenizer in ["pinyin", "char"]:
|
130 |
+
tokenizer_path = os.path.join(
|
131 |
+
files("f5_tts").joinpath("../../data"),
|
132 |
+
f"{dataset_name}_{tokenizer}/vocab.txt",
|
133 |
+
)
|
134 |
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
135 |
vocab_char_map = {}
|
136 |
for i, char in enumerate(f):
|
137 |
vocab_char_map[char[:-1]] = i
|
138 |
vocab_size = len(vocab_char_map)
|
139 |
+
assert (
|
140 |
+
vocab_char_map[" "] == 0
|
141 |
+
), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
142 |
|
143 |
elif tokenizer == "byte":
|
144 |
vocab_char_map = None
|
|
|
168 |
) # add custom trans here, to address oov
|
169 |
|
170 |
def is_chinese(c):
|
171 |
+
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
|
|
|
|
172 |
|
173 |
for text in text_list:
|
174 |
char_list = []
|
|
|
179 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
180 |
char_list.append(" ")
|
181 |
char_list.extend(seg)
|
182 |
+
elif polyphone and seg_byte_len == 3 * len(
|
183 |
+
seg
|
184 |
+
): # if pure east asian characters
|
185 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
186 |
for i, c in enumerate(seg):
|
187 |
if is_chinese(c):
|
|
|
193 |
char_list.extend(c)
|
194 |
elif is_chinese(c):
|
195 |
char_list.append(" ")
|
196 |
+
char_list.extend(
|
197 |
+
lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
|
198 |
+
)
|
199 |
else:
|
200 |
char_list.append(c)
|
201 |
final_text_list.append(char_list)
|
f5_tts/runtime/triton_trtllm/benchmark.py
CHANGED
@@ -51,7 +51,6 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|
51 |
from tqdm import tqdm
|
52 |
from vocos import Vocos
|
53 |
|
54 |
-
|
55 |
torch.manual_seed(0)
|
56 |
|
57 |
|
@@ -64,7 +63,9 @@ def get_args():
|
|
64 |
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
65 |
help="huggingface dataset split name",
|
66 |
)
|
67 |
-
parser.add_argument(
|
|
|
|
|
68 |
parser.add_argument(
|
69 |
"--vocab-file",
|
70 |
required=True,
|
@@ -89,8 +90,12 @@ def get_args():
|
|
89 |
type=int,
|
90 |
help="batch size (per-device) for inference",
|
91 |
)
|
92 |
-
parser.add_argument(
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
parser.add_argument(
|
95 |
"--vocoder",
|
96 |
default="vocos",
|
@@ -105,8 +110,16 @@ def get_args():
|
|
105 |
)
|
106 |
parser.add_argument("--enable-warmup", action="store_true")
|
107 |
parser.add_argument("--remove-input-padding", action="store_true")
|
108 |
-
parser.add_argument(
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
args = parser.parse_args()
|
111 |
return args
|
112 |
|
@@ -126,7 +139,13 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
|
126 |
torch.cuda.nvtx.range_push("data_collator")
|
127 |
target_sample_rate = 24000
|
128 |
target_rms = 0.1
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
[],
|
131 |
[],
|
132 |
[],
|
@@ -170,7 +189,14 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
|
170 |
ref_mel_len_list.append(ref_mel_len)
|
171 |
|
172 |
estimated_reference_target_mel_len.append(
|
173 |
-
int(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
)
|
175 |
|
176 |
max_seq_len = max(estimated_reference_target_mel_len)
|
@@ -182,12 +208,22 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
|
182 |
|
183 |
for i, item in enumerate(text_pad_sequence):
|
184 |
text_pad_sequence[i] = F.pad(
|
185 |
-
item,
|
|
|
|
|
|
|
186 |
)
|
187 |
-
text_pad_sequence[
|
188 |
-
|
|
|
|
|
|
|
|
|
189 |
text_pad_sequence = F.pad(
|
190 |
-
text_pad_sequence,
|
|
|
|
|
|
|
191 |
)
|
192 |
if use_perf:
|
193 |
torch.cuda.nvtx.range_pop()
|
@@ -252,7 +288,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
|
252 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
253 |
char_list.append(" ")
|
254 |
char_list.extend(seg)
|
255 |
-
elif polyphone and seg_byte_len == 3 * len(
|
|
|
|
|
256 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
257 |
for i, c in enumerate(seg):
|
258 |
if is_chinese(c):
|
@@ -264,7 +302,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
|
264 |
char_list.extend(c)
|
265 |
elif is_chinese(c):
|
266 |
char_list.append(" ")
|
267 |
-
char_list.extend(
|
|
|
|
|
268 |
else:
|
269 |
char_list.append(c)
|
270 |
final_reference_target_texts_list.append(char_list)
|
@@ -277,13 +317,20 @@ def list_str_to_idx(
|
|
277 |
vocab_char_map: Dict[str, int], # {char: idx}
|
278 |
padding_value=-1,
|
279 |
):
|
280 |
-
list_idx_tensors = [
|
|
|
|
|
281 |
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
282 |
return list_idx_tensors
|
283 |
|
284 |
|
285 |
def load_vocoder(
|
286 |
-
vocoder_name="vocos",
|
|
|
|
|
|
|
|
|
|
|
287 |
):
|
288 |
if vocoder_name == "vocos":
|
289 |
if vocoder_trt_engine_path is not None:
|
@@ -297,8 +344,14 @@ def load_vocoder(
|
|
297 |
else:
|
298 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
299 |
repo_id = "charactr/vocos-mel-24khz"
|
300 |
-
config_path = hf_hub_download(
|
301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
vocoder = Vocos.from_hparams(config_path)
|
303 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
304 |
from vocos.feature_extractors import EncodecFeatures
|
@@ -343,14 +396,21 @@ class VocosTensorRT:
|
|
343 |
with open(engine_path, "rb") as f:
|
344 |
engine_buffer = f.read()
|
345 |
self.session = Session.from_serialized_engine(engine_buffer)
|
346 |
-
self.stream =
|
|
|
|
|
347 |
|
348 |
def decode(self, mels):
|
349 |
mels = mels.contiguous()
|
350 |
inputs = {"mel": mels}
|
351 |
-
output_info = self.session.infer_shapes(
|
|
|
|
|
352 |
outputs = {
|
353 |
-
t.name: torch.empty(
|
|
|
|
|
|
|
354 |
}
|
355 |
ok = self.session.run(inputs, outputs, self.stream)
|
356 |
|
@@ -376,12 +436,18 @@ def main():
|
|
376 |
config = json.load(f)
|
377 |
if args.backend_type == "trt":
|
378 |
model = F5TTS(
|
379 |
-
config,
|
|
|
|
|
|
|
|
|
380 |
)
|
381 |
elif args.backend_type == "pytorch":
|
382 |
import sys
|
383 |
|
384 |
-
sys.path.append(
|
|
|
|
|
385 |
from f5_tts.infer.utils_infer import load_model
|
386 |
from f5_tts.model import DiT
|
387 |
|
@@ -398,7 +464,9 @@ def main():
|
|
398 |
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
|
399 |
|
400 |
vocoder = load_vocoder(
|
401 |
-
vocoder_name=args.vocoder,
|
|
|
|
|
402 |
)
|
403 |
|
404 |
dataset = load_dataset(
|
@@ -411,7 +479,9 @@ def main():
|
|
411 |
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
|
412 |
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
|
413 |
estimated_duration = prompt_audio_len * scale_factor
|
414 |
-
example["estimated_duration"] =
|
|
|
|
|
415 |
return example
|
416 |
|
417 |
dataset = dataset.map(add_estimated_duration)
|
@@ -442,12 +512,18 @@ def main():
|
|
442 |
|
443 |
if args.enable_warmup:
|
444 |
for batch in dataloader:
|
445 |
-
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
|
|
|
|
446 |
text_pad_seq = batch["text_pad_sequence"].to(device)
|
447 |
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
448 |
if args.backend_type == "trt":
|
449 |
_ = model.sample(
|
450 |
-
text_pad_seq,
|
|
|
|
|
|
|
|
|
451 |
)
|
452 |
elif args.backend_type == "pytorch":
|
453 |
with torch.inference_mode():
|
@@ -475,7 +551,9 @@ def main():
|
|
475 |
for batch in dataloader:
|
476 |
if args.use_perf:
|
477 |
torch.cuda.nvtx.range_push("data sample")
|
478 |
-
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
|
|
|
|
479 |
text_pad_seq = batch["text_pad_sequence"].to(device)
|
480 |
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
481 |
|
|
|
51 |
from tqdm import tqdm
|
52 |
from vocos import Vocos
|
53 |
|
|
|
54 |
torch.manual_seed(0)
|
55 |
|
56 |
|
|
|
63 |
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
64 |
help="huggingface dataset split name",
|
65 |
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--output-dir", required=True, type=str, help="dir to save result"
|
68 |
+
)
|
69 |
parser.add_argument(
|
70 |
"--vocab-file",
|
71 |
required=True,
|
|
|
90 |
type=int,
|
91 |
help="batch size (per-device) for inference",
|
92 |
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--num-workers", type=int, default=0, help="workers for dataloader"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--prefetch", type=int, default=None, help="prefetch for dataloader"
|
98 |
+
)
|
99 |
parser.add_argument(
|
100 |
"--vocoder",
|
101 |
default="vocos",
|
|
|
110 |
)
|
111 |
parser.add_argument("--enable-warmup", action="store_true")
|
112 |
parser.add_argument("--remove-input-padding", action="store_true")
|
113 |
+
parser.add_argument(
|
114 |
+
"--use-perf", action="store_true", help="use nvtx to record performance"
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--backend-type",
|
118 |
+
type=str,
|
119 |
+
default="triton",
|
120 |
+
choices=["trt", "pytorch"],
|
121 |
+
help="backend type",
|
122 |
+
)
|
123 |
args = parser.parse_args()
|
124 |
return args
|
125 |
|
|
|
139 |
torch.cuda.nvtx.range_push("data_collator")
|
140 |
target_sample_rate = 24000
|
141 |
target_rms = 0.1
|
142 |
+
(
|
143 |
+
ids,
|
144 |
+
ref_mel_list,
|
145 |
+
ref_mel_len_list,
|
146 |
+
estimated_reference_target_mel_len,
|
147 |
+
reference_target_texts_list,
|
148 |
+
) = (
|
149 |
[],
|
150 |
[],
|
151 |
[],
|
|
|
189 |
ref_mel_len_list.append(ref_mel_len)
|
190 |
|
191 |
estimated_reference_target_mel_len.append(
|
192 |
+
int(
|
193 |
+
ref_mel.shape[0]
|
194 |
+
* (
|
195 |
+
1
|
196 |
+
+ len(target_text.encode("utf-8"))
|
197 |
+
/ len(prompt_text.encode("utf-8"))
|
198 |
+
)
|
199 |
+
)
|
200 |
)
|
201 |
|
202 |
max_seq_len = max(estimated_reference_target_mel_len)
|
|
|
208 |
|
209 |
for i, item in enumerate(text_pad_sequence):
|
210 |
text_pad_sequence[i] = F.pad(
|
211 |
+
item,
|
212 |
+
(0, estimated_reference_target_mel_len[i] - len(item)),
|
213 |
+
mode="constant",
|
214 |
+
value=-1,
|
215 |
)
|
216 |
+
text_pad_sequence[
|
217 |
+
i
|
218 |
+
] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
219 |
+
text_pad_sequence = pad_sequence(
|
220 |
+
text_pad_sequence, padding_value=-1, batch_first=True
|
221 |
+
).to(device)
|
222 |
text_pad_sequence = F.pad(
|
223 |
+
text_pad_sequence,
|
224 |
+
(0, max_seq_len - text_pad_sequence.shape[1]),
|
225 |
+
mode="constant",
|
226 |
+
value=-1,
|
227 |
)
|
228 |
if use_perf:
|
229 |
torch.cuda.nvtx.range_pop()
|
|
|
288 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
289 |
char_list.append(" ")
|
290 |
char_list.extend(seg)
|
291 |
+
elif polyphone and seg_byte_len == 3 * len(
|
292 |
+
seg
|
293 |
+
): # if pure east asian characters
|
294 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
295 |
for i, c in enumerate(seg):
|
296 |
if is_chinese(c):
|
|
|
302 |
char_list.extend(c)
|
303 |
elif is_chinese(c):
|
304 |
char_list.append(" ")
|
305 |
+
char_list.extend(
|
306 |
+
lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
|
307 |
+
)
|
308 |
else:
|
309 |
char_list.append(c)
|
310 |
final_reference_target_texts_list.append(char_list)
|
|
|
317 |
vocab_char_map: Dict[str, int], # {char: idx}
|
318 |
padding_value=-1,
|
319 |
):
|
320 |
+
list_idx_tensors = [
|
321 |
+
torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
|
322 |
+
] # pinyin or char style
|
323 |
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
324 |
return list_idx_tensors
|
325 |
|
326 |
|
327 |
def load_vocoder(
|
328 |
+
vocoder_name="vocos",
|
329 |
+
is_local=False,
|
330 |
+
local_path="",
|
331 |
+
device="cuda",
|
332 |
+
hf_cache_dir=None,
|
333 |
+
vocoder_trt_engine_path=None,
|
334 |
):
|
335 |
if vocoder_name == "vocos":
|
336 |
if vocoder_trt_engine_path is not None:
|
|
|
344 |
else:
|
345 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
346 |
repo_id = "charactr/vocos-mel-24khz"
|
347 |
+
config_path = hf_hub_download(
|
348 |
+
repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml"
|
349 |
+
)
|
350 |
+
model_path = hf_hub_download(
|
351 |
+
repo_id=repo_id,
|
352 |
+
cache_dir=hf_cache_dir,
|
353 |
+
filename="pytorch_model.bin",
|
354 |
+
)
|
355 |
vocoder = Vocos.from_hparams(config_path)
|
356 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
357 |
from vocos.feature_extractors import EncodecFeatures
|
|
|
396 |
with open(engine_path, "rb") as f:
|
397 |
engine_buffer = f.read()
|
398 |
self.session = Session.from_serialized_engine(engine_buffer)
|
399 |
+
self.stream = (
|
400 |
+
stream if stream is not None else torch.cuda.current_stream().cuda_stream
|
401 |
+
)
|
402 |
|
403 |
def decode(self, mels):
|
404 |
mels = mels.contiguous()
|
405 |
inputs = {"mel": mels}
|
406 |
+
output_info = self.session.infer_shapes(
|
407 |
+
[TensorInfo("mel", trt.DataType.FLOAT, mels.shape)]
|
408 |
+
)
|
409 |
outputs = {
|
410 |
+
t.name: torch.empty(
|
411 |
+
tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda"
|
412 |
+
)
|
413 |
+
for t in output_info
|
414 |
}
|
415 |
ok = self.session.run(inputs, outputs, self.stream)
|
416 |
|
|
|
436 |
config = json.load(f)
|
437 |
if args.backend_type == "trt":
|
438 |
model = F5TTS(
|
439 |
+
config,
|
440 |
+
debug_mode=False,
|
441 |
+
tllm_model_dir=tllm_model_dir,
|
442 |
+
model_path=args.model_path,
|
443 |
+
vocab_size=vocab_size,
|
444 |
)
|
445 |
elif args.backend_type == "pytorch":
|
446 |
import sys
|
447 |
|
448 |
+
sys.path.append(
|
449 |
+
f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/"
|
450 |
+
)
|
451 |
from f5_tts.infer.utils_infer import load_model
|
452 |
from f5_tts.model import DiT
|
453 |
|
|
|
464 |
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
|
465 |
|
466 |
vocoder = load_vocoder(
|
467 |
+
vocoder_name=args.vocoder,
|
468 |
+
device=device,
|
469 |
+
vocoder_trt_engine_path=args.vocoder_trt_engine_path,
|
470 |
)
|
471 |
|
472 |
dataset = load_dataset(
|
|
|
479 |
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
|
480 |
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
|
481 |
estimated_duration = prompt_audio_len * scale_factor
|
482 |
+
example["estimated_duration"] = (
|
483 |
+
estimated_duration / example["prompt_audio"]["sampling_rate"]
|
484 |
+
)
|
485 |
return example
|
486 |
|
487 |
dataset = dataset.map(add_estimated_duration)
|
|
|
512 |
|
513 |
if args.enable_warmup:
|
514 |
for batch in dataloader:
|
515 |
+
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
516 |
+
"ref_mel_len_batch"
|
517 |
+
].to(device)
|
518 |
text_pad_seq = batch["text_pad_sequence"].to(device)
|
519 |
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
520 |
if args.backend_type == "trt":
|
521 |
_ = model.sample(
|
522 |
+
text_pad_seq,
|
523 |
+
ref_mels,
|
524 |
+
ref_mel_lens,
|
525 |
+
total_mel_lens,
|
526 |
+
remove_input_padding=args.remove_input_padding,
|
527 |
)
|
528 |
elif args.backend_type == "pytorch":
|
529 |
with torch.inference_mode():
|
|
|
551 |
for batch in dataloader:
|
552 |
if args.use_perf:
|
553 |
torch.cuda.nvtx.range_push("data sample")
|
554 |
+
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
555 |
+
"ref_mel_len_batch"
|
556 |
+
].to(device)
|
557 |
text_pad_seq = batch["text_pad_sequence"].to(device)
|
558 |
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
559 |
|
f5_tts/runtime/triton_trtllm/client_grpc.py
CHANGED
@@ -64,8 +64,12 @@ def write_triton_stats(stats, summary_file):
|
|
64 |
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
65 |
)
|
66 |
summary_f.write("To learn more about the log, please refer to: \n")
|
67 |
-
summary_f.write(
|
68 |
-
|
|
|
|
|
|
|
|
|
69 |
summary_f.write(
|
70 |
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
71 |
)
|
@@ -86,7 +90,9 @@ def write_triton_stats(stats, summary_file):
|
|
86 |
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
87 |
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
88 |
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
89 |
-
total_output_time_s =
|
|
|
|
|
90 |
summary_f.write(
|
91 |
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
92 |
)
|
@@ -97,7 +103,11 @@ def write_triton_stats(stats, summary_file):
|
|
97 |
compute_output = batch["compute_output"]
|
98 |
compute_infer = batch["compute_infer"]
|
99 |
batch_count = int(compute_infer["count"])
|
100 |
-
assert
|
|
|
|
|
|
|
|
|
101 |
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
102 |
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
103 |
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
@@ -113,7 +123,9 @@ def write_triton_stats(stats, summary_file):
|
|
113 |
|
114 |
|
115 |
def get_args():
|
116 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
117 |
|
118 |
parser.add_argument(
|
119 |
"--server-addr",
|
@@ -254,7 +266,9 @@ async def send(
|
|
254 |
for i, item in enumerate(manifest_item_list):
|
255 |
if i % log_interval == 0:
|
256 |
print(f"{name}: {i}/{len(manifest_item_list)}")
|
257 |
-
waveform, sample_rate = load_audio(
|
|
|
|
|
258 |
duration = len(waveform) / sample_rate
|
259 |
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
260 |
|
@@ -269,7 +283,10 @@ async def send(
|
|
269 |
1,
|
270 |
padding_duration
|
271 |
* sample_rate
|
272 |
-
* (
|
|
|
|
|
|
|
273 |
),
|
274 |
dtype=np.float32,
|
275 |
)
|
@@ -281,8 +298,12 @@ async def send(
|
|
281 |
samples = samples.reshape(1, -1).astype(np.float32)
|
282 |
|
283 |
inputs = [
|
284 |
-
protocol_client.InferInput(
|
285 |
-
|
|
|
|
|
|
|
|
|
286 |
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
287 |
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
288 |
]
|
@@ -301,13 +322,17 @@ async def send(
|
|
301 |
|
302 |
sequence_id = 100000000 + i + task_id * 10
|
303 |
start = time.time()
|
304 |
-
response = await triton_client.infer(
|
|
|
|
|
305 |
|
306 |
audio = response.as_numpy("waveform").reshape(-1)
|
307 |
|
308 |
end = time.time() - start
|
309 |
|
310 |
-
audio_save_path = os.path.join(
|
|
|
|
|
311 |
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
312 |
|
313 |
actual_duration = len(audio) / save_sample_rate
|
@@ -341,7 +366,9 @@ def load_manifests(manifest_path):
|
|
341 |
def split_data(data, k):
|
342 |
n = len(data)
|
343 |
if n < k:
|
344 |
-
print(
|
|
|
|
|
345 |
k = n
|
346 |
|
347 |
quotient = n // k
|
@@ -461,7 +488,9 @@ async def main():
|
|
461 |
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
|
462 |
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
463 |
|
464 |
-
metadata = await triton_client.get_model_config(
|
|
|
|
|
465 |
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
466 |
json.dump(metadata, f, indent=4)
|
467 |
|
|
|
64 |
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
65 |
)
|
66 |
summary_f.write("To learn more about the log, please refer to: \n")
|
67 |
+
summary_f.write(
|
68 |
+
"1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n"
|
69 |
+
)
|
70 |
+
summary_f.write(
|
71 |
+
"2. https://github.com/triton-inference-server/server/issues/5374 \n\n"
|
72 |
+
)
|
73 |
summary_f.write(
|
74 |
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
75 |
)
|
|
|
90 |
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
91 |
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
92 |
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
93 |
+
total_output_time_s = (
|
94 |
+
int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
95 |
+
)
|
96 |
summary_f.write(
|
97 |
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
98 |
)
|
|
|
103 |
compute_output = batch["compute_output"]
|
104 |
compute_infer = batch["compute_infer"]
|
105 |
batch_count = int(compute_infer["count"])
|
106 |
+
assert (
|
107 |
+
compute_infer["count"]
|
108 |
+
== compute_output["count"]
|
109 |
+
== compute_input["count"]
|
110 |
+
)
|
111 |
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
112 |
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
113 |
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
|
|
123 |
|
124 |
|
125 |
def get_args():
|
126 |
+
parser = argparse.ArgumentParser(
|
127 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
128 |
+
)
|
129 |
|
130 |
parser.add_argument(
|
131 |
"--server-addr",
|
|
|
266 |
for i, item in enumerate(manifest_item_list):
|
267 |
if i % log_interval == 0:
|
268 |
print(f"{name}: {i}/{len(manifest_item_list)}")
|
269 |
+
waveform, sample_rate = load_audio(
|
270 |
+
item["audio_filepath"], target_sample_rate=24000
|
271 |
+
)
|
272 |
duration = len(waveform) / sample_rate
|
273 |
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
274 |
|
|
|
283 |
1,
|
284 |
padding_duration
|
285 |
* sample_rate
|
286 |
+
* (
|
287 |
+
(int(estimated_target_duration + duration) // padding_duration)
|
288 |
+
+ 1
|
289 |
+
),
|
290 |
),
|
291 |
dtype=np.float32,
|
292 |
)
|
|
|
298 |
samples = samples.reshape(1, -1).astype(np.float32)
|
299 |
|
300 |
inputs = [
|
301 |
+
protocol_client.InferInput(
|
302 |
+
"reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)
|
303 |
+
),
|
304 |
+
protocol_client.InferInput(
|
305 |
+
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
306 |
+
),
|
307 |
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
308 |
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
309 |
]
|
|
|
322 |
|
323 |
sequence_id = 100000000 + i + task_id * 10
|
324 |
start = time.time()
|
325 |
+
response = await triton_client.infer(
|
326 |
+
model_name, inputs, request_id=str(sequence_id), outputs=outputs
|
327 |
+
)
|
328 |
|
329 |
audio = response.as_numpy("waveform").reshape(-1)
|
330 |
|
331 |
end = time.time() - start
|
332 |
|
333 |
+
audio_save_path = os.path.join(
|
334 |
+
audio_save_dir, f"{item['target_audio_path']}.wav"
|
335 |
+
)
|
336 |
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
337 |
|
338 |
actual_duration = len(audio) / save_sample_rate
|
|
|
366 |
def split_data(data, k):
|
367 |
n = len(data)
|
368 |
if n < k:
|
369 |
+
print(
|
370 |
+
f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}."
|
371 |
+
)
|
372 |
k = n
|
373 |
|
374 |
quotient = n // k
|
|
|
488 |
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
|
489 |
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
490 |
|
491 |
+
metadata = await triton_client.get_model_config(
|
492 |
+
model_name=args.model_name, as_json=True
|
493 |
+
)
|
494 |
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
495 |
json.dump(metadata, f, indent=4)
|
496 |
|
f5_tts/runtime/triton_trtllm/client_http.py
CHANGED
@@ -31,7 +31,9 @@ import soundfile as sf
|
|
31 |
|
32 |
|
33 |
def get_args():
|
34 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
35 |
|
36 |
parser.add_argument(
|
37 |
"--server-url",
|
@@ -91,15 +93,30 @@ def prepare_request(
|
|
91 |
|
92 |
data = {
|
93 |
"inputs": [
|
94 |
-
{
|
|
|
|
|
|
|
|
|
|
|
95 |
{
|
96 |
"name": "reference_wav_len",
|
97 |
"shape": lengths.shape,
|
98 |
"datatype": "INT32",
|
99 |
"data": lengths.tolist(),
|
100 |
},
|
101 |
-
{
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
]
|
104 |
}
|
105 |
|
@@ -135,7 +152,11 @@ if __name__ == "__main__":
|
|
135 |
data = prepare_request(samples, args.reference_text, args.target_text)
|
136 |
|
137 |
rsp = requests.post(
|
138 |
-
url,
|
|
|
|
|
|
|
|
|
139 |
)
|
140 |
result = rsp.json()
|
141 |
audio = result["outputs"][0]["data"]
|
|
|
31 |
|
32 |
|
33 |
def get_args():
|
34 |
+
parser = argparse.ArgumentParser(
|
35 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
36 |
+
)
|
37 |
|
38 |
parser.add_argument(
|
39 |
"--server-url",
|
|
|
93 |
|
94 |
data = {
|
95 |
"inputs": [
|
96 |
+
{
|
97 |
+
"name": "reference_wav",
|
98 |
+
"shape": samples.shape,
|
99 |
+
"datatype": "FP32",
|
100 |
+
"data": samples.tolist(),
|
101 |
+
},
|
102 |
{
|
103 |
"name": "reference_wav_len",
|
104 |
"shape": lengths.shape,
|
105 |
"datatype": "INT32",
|
106 |
"data": lengths.tolist(),
|
107 |
},
|
108 |
+
{
|
109 |
+
"name": "reference_text",
|
110 |
+
"shape": [1, 1],
|
111 |
+
"datatype": "BYTES",
|
112 |
+
"data": [reference_text],
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"name": "target_text",
|
116 |
+
"shape": [1, 1],
|
117 |
+
"datatype": "BYTES",
|
118 |
+
"data": [target_text],
|
119 |
+
},
|
120 |
]
|
121 |
}
|
122 |
|
|
|
152 |
data = prepare_request(samples, args.reference_text, args.target_text)
|
153 |
|
154 |
rsp = requests.post(
|
155 |
+
url,
|
156 |
+
headers={"Content-Type": "application/json"},
|
157 |
+
json=data,
|
158 |
+
verify=False,
|
159 |
+
params={"request_id": "0"},
|
160 |
)
|
161 |
result = rsp.json()
|
162 |
audio = result["outputs"][0]["data"]
|
f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py
CHANGED
@@ -17,7 +17,9 @@ from tensorrt_llm.runtime.session import Session
|
|
17 |
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
18 |
# Audio tensor case: batch, seq_len, feature_len
|
19 |
# position_ids case: batch, seq_len
|
20 |
-
assert
|
|
|
|
|
21 |
|
22 |
# Initialize a list to collect valid sequences
|
23 |
valid_sequences = []
|
@@ -32,11 +34,29 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
|
32 |
|
33 |
|
34 |
class TextEmbedding(nn.Module):
|
35 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
super().__init__()
|
37 |
-
self.text_embed = nn.Embedding(
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def forward(self, text):
|
42 |
# only keep tensors with value not -1
|
@@ -80,7 +100,9 @@ class ConvNeXtV2Block(nn.Module):
|
|
80 |
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
81 |
) # depthwise conv
|
82 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
83 |
-
self.pwconv1 = nn.Linear(
|
|
|
|
|
84 |
self.act = nn.GELU()
|
85 |
self.grn = GRN(intermediate_dim)
|
86 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
@@ -98,7 +120,9 @@ class ConvNeXtV2Block(nn.Module):
|
|
98 |
return residual + x
|
99 |
|
100 |
|
101 |
-
def precompute_freqs_cis(
|
|
|
|
|
102 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
103 |
# has some connection to NTK literature
|
104 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
@@ -125,7 +149,9 @@ def load_checkpoint(ckpt_path, use_ema=True):
|
|
125 |
for key in dict_state.keys():
|
126 |
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
127 |
if "text_embed" in key:
|
128 |
-
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[
|
|
|
|
|
129 |
return text_embed_dict
|
130 |
|
131 |
|
@@ -148,7 +174,12 @@ class F5TTS(object):
|
|
148 |
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
|
149 |
assert pp_size == 1
|
150 |
self.mapping = tensorrt_llm.Mapping(
|
151 |
-
world_size=world_size,
|
|
|
|
|
|
|
|
|
|
|
152 |
)
|
153 |
|
154 |
local_rank = rank % self.mapping.gpus_per_node
|
@@ -176,10 +207,23 @@ class F5TTS(object):
|
|
176 |
self.outputs = {}
|
177 |
self.buffer_allocated = False
|
178 |
|
179 |
-
expected_tensor_names = [
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
logger.error(
|
184 |
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
|
185 |
)
|
@@ -190,11 +234,16 @@ class F5TTS(object):
|
|
190 |
logger.error(f"Found tensor names: {found_tensor_names}")
|
191 |
raise RuntimeError("Tensor names in engine are not the same as expected.")
|
192 |
if self.debug_mode:
|
193 |
-
self.debug_tensors = list(
|
|
|
|
|
194 |
|
195 |
self.max_mel_len = 4096
|
196 |
self.text_embedding = TextEmbedding(
|
197 |
-
text_num_embeds=vocab_size,
|
|
|
|
|
|
|
198 |
).to(self.device)
|
199 |
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
200 |
|
@@ -208,9 +257,16 @@ class F5TTS(object):
|
|
208 |
self.head_dim = 64
|
209 |
self.base_rescale_factor = 1.0
|
210 |
self.interpolation_factor = 1.0
|
211 |
-
base = 10000.0 * self.base_rescale_factor ** (
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
215 |
self.rope_cos = self.freqs.cos().half()
|
216 |
self.rope_sin = self.freqs.sin().half()
|
@@ -223,7 +279,9 @@ class F5TTS(object):
|
|
223 |
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
224 |
half_dim = tmp_dim // 2
|
225 |
emb_factor = math.log(10000) / (half_dim - 1)
|
226 |
-
emb_factor = 1000.0 * torch.exp(
|
|
|
|
|
227 |
for i in range(self.nfe_steps):
|
228 |
emb = time_step[i] * emb_factor
|
229 |
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
@@ -242,7 +300,9 @@ class F5TTS(object):
|
|
242 |
shape = list(self.session.engine.get_tensor_shape(name))
|
243 |
shape[0] = batch_size
|
244 |
shape[1] = seq_len
|
245 |
-
self.outputs[name] = torch.empty(
|
|
|
|
|
246 |
|
247 |
self.buffer_allocated = True
|
248 |
|
@@ -356,17 +416,29 @@ class F5TTS(object):
|
|
356 |
max_seq_len = ref_mel_batch.shape[1]
|
357 |
|
358 |
text_pad_sequence_drop = torch.cat(
|
359 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
)
|
361 |
|
362 |
text_embedding_drop_list = []
|
363 |
for i in range(batch + 1):
|
364 |
-
text_embedding_drop_list.append(
|
|
|
|
|
|
|
|
|
365 |
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
366 |
|
367 |
text_embedding = text_embedding_drop_condition[:-1]
|
368 |
# text_embedding_drop B,T,C batch should be the same
|
369 |
-
text_embedding_drop =
|
|
|
|
|
370 |
|
371 |
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
372 |
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
@@ -375,7 +447,9 @@ class F5TTS(object):
|
|
375 |
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
376 |
cat_mel_text_drop = torch.cat(
|
377 |
(
|
378 |
-
torch.zeros(
|
|
|
|
|
379 |
text_embedding_drop,
|
380 |
),
|
381 |
dim=-1,
|
@@ -384,7 +458,9 @@ class F5TTS(object):
|
|
384 |
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
|
385 |
|
386 |
# Convert estimated_reference_target_mel_len to tensor
|
387 |
-
input_lengths = torch.tensor(
|
|
|
|
|
388 |
|
389 |
# combine above along the batch dimension
|
390 |
inputs = {
|
@@ -393,20 +469,34 @@ class F5TTS(object):
|
|
393 |
"time_expand": time_expand,
|
394 |
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
|
395 |
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
|
396 |
-
"input_lengths": torch.cat(
|
|
|
|
|
397 |
"delta_t": self.delta_t,
|
398 |
}
|
399 |
if use_perf and remove_input_padding:
|
400 |
torch.cuda.nvtx.range_push("remove input padding")
|
401 |
if remove_input_padding:
|
402 |
max_seq_len = inputs["cond"].shape[1]
|
403 |
-
inputs["noise"] = remove_tensor_padding(
|
404 |
-
|
|
|
|
|
|
|
|
|
405 |
# for time_expand, convert from B,D to B,T,D by repeat
|
406 |
-
inputs["time_expand"] =
|
407 |
-
|
408 |
-
|
409 |
-
inputs["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
if use_perf and remove_input_padding:
|
411 |
torch.cuda.nvtx.range_pop()
|
412 |
for key in inputs:
|
@@ -422,7 +512,9 @@ class F5TTS(object):
|
|
422 |
denoised_list = []
|
423 |
start_idx = 0
|
424 |
for i in range(batch):
|
425 |
-
denoised_list.append(
|
|
|
|
|
426 |
start_idx += inputs["input_lengths"][i]
|
427 |
if use_perf and remove_input_padding:
|
428 |
torch.cuda.nvtx.range_pop()
|
|
|
17 |
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
18 |
# Audio tensor case: batch, seq_len, feature_len
|
19 |
# position_ids case: batch, seq_len
|
20 |
+
assert (
|
21 |
+
input_tensor_lengths is not None
|
22 |
+
), "input_tensor_lengths must be provided for 3D input_tensor"
|
23 |
|
24 |
# Initialize a list to collect valid sequences
|
25 |
valid_sequences = []
|
|
|
34 |
|
35 |
|
36 |
class TextEmbedding(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
text_num_embeds,
|
40 |
+
text_dim,
|
41 |
+
conv_layers=0,
|
42 |
+
conv_mult=2,
|
43 |
+
precompute_max_pos=4096,
|
44 |
+
):
|
45 |
super().__init__()
|
46 |
+
self.text_embed = nn.Embedding(
|
47 |
+
text_num_embeds + 1, text_dim
|
48 |
+
) # use 0 as filler token
|
49 |
+
self.register_buffer(
|
50 |
+
"freqs_cis",
|
51 |
+
precompute_freqs_cis(text_dim, precompute_max_pos),
|
52 |
+
persistent=False,
|
53 |
+
)
|
54 |
+
self.text_blocks = nn.Sequential(
|
55 |
+
*[
|
56 |
+
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
|
57 |
+
for _ in range(conv_layers)
|
58 |
+
]
|
59 |
+
)
|
60 |
|
61 |
def forward(self, text):
|
62 |
# only keep tensors with value not -1
|
|
|
100 |
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
101 |
) # depthwise conv
|
102 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
103 |
+
self.pwconv1 = nn.Linear(
|
104 |
+
dim, intermediate_dim
|
105 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
106 |
self.act = nn.GELU()
|
107 |
self.grn = GRN(intermediate_dim)
|
108 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
|
120 |
return residual + x
|
121 |
|
122 |
|
123 |
+
def precompute_freqs_cis(
|
124 |
+
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
|
125 |
+
):
|
126 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
127 |
# has some connection to NTK literature
|
128 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
|
149 |
for key in dict_state.keys():
|
150 |
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
151 |
if "text_embed" in key:
|
152 |
+
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[
|
153 |
+
key
|
154 |
+
]
|
155 |
return text_embed_dict
|
156 |
|
157 |
|
|
|
174 |
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
|
175 |
assert pp_size == 1
|
176 |
self.mapping = tensorrt_llm.Mapping(
|
177 |
+
world_size=world_size,
|
178 |
+
rank=rank,
|
179 |
+
cp_size=cp_size,
|
180 |
+
tp_size=tp_size,
|
181 |
+
pp_size=1,
|
182 |
+
gpus_per_node=1,
|
183 |
)
|
184 |
|
185 |
local_rank = rank % self.mapping.gpus_per_node
|
|
|
207 |
self.outputs = {}
|
208 |
self.buffer_allocated = False
|
209 |
|
210 |
+
expected_tensor_names = [
|
211 |
+
"noise",
|
212 |
+
"cond",
|
213 |
+
"time",
|
214 |
+
"rope_cos",
|
215 |
+
"rope_sin",
|
216 |
+
"input_lengths",
|
217 |
+
"denoised",
|
218 |
+
]
|
219 |
+
|
220 |
+
found_tensor_names = [
|
221 |
+
self.session.engine.get_tensor_name(i)
|
222 |
+
for i in range(self.session.engine.num_io_tensors)
|
223 |
+
]
|
224 |
+
if not self.debug_mode and set(expected_tensor_names) != set(
|
225 |
+
found_tensor_names
|
226 |
+
):
|
227 |
logger.error(
|
228 |
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
|
229 |
)
|
|
|
234 |
logger.error(f"Found tensor names: {found_tensor_names}")
|
235 |
raise RuntimeError("Tensor names in engine are not the same as expected.")
|
236 |
if self.debug_mode:
|
237 |
+
self.debug_tensors = list(
|
238 |
+
set(found_tensor_names) - set(expected_tensor_names)
|
239 |
+
)
|
240 |
|
241 |
self.max_mel_len = 4096
|
242 |
self.text_embedding = TextEmbedding(
|
243 |
+
text_num_embeds=vocab_size,
|
244 |
+
text_dim=512,
|
245 |
+
conv_layers=4,
|
246 |
+
precompute_max_pos=self.max_mel_len,
|
247 |
).to(self.device)
|
248 |
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
249 |
|
|
|
257 |
self.head_dim = 64
|
258 |
self.base_rescale_factor = 1.0
|
259 |
self.interpolation_factor = 1.0
|
260 |
+
base = 10000.0 * self.base_rescale_factor ** (
|
261 |
+
self.head_dim / (self.head_dim - 2)
|
262 |
+
)
|
263 |
+
inv_freq = 1.0 / (
|
264 |
+
base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)
|
265 |
+
)
|
266 |
+
freqs = (
|
267 |
+
torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq)
|
268 |
+
/ self.interpolation_factor
|
269 |
+
)
|
270 |
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
271 |
self.rope_cos = self.freqs.cos().half()
|
272 |
self.rope_sin = self.freqs.sin().half()
|
|
|
279 |
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
280 |
half_dim = tmp_dim // 2
|
281 |
emb_factor = math.log(10000) / (half_dim - 1)
|
282 |
+
emb_factor = 1000.0 * torch.exp(
|
283 |
+
torch.arange(half_dim, dtype=torch.float32) * -emb_factor
|
284 |
+
)
|
285 |
for i in range(self.nfe_steps):
|
286 |
emb = time_step[i] * emb_factor
|
287 |
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
|
300 |
shape = list(self.session.engine.get_tensor_shape(name))
|
301 |
shape[0] = batch_size
|
302 |
shape[1] = seq_len
|
303 |
+
self.outputs[name] = torch.empty(
|
304 |
+
shape, dtype=self._tensor_dtype(name), device=self.device
|
305 |
+
)
|
306 |
|
307 |
self.buffer_allocated = True
|
308 |
|
|
|
416 |
max_seq_len = ref_mel_batch.shape[1]
|
417 |
|
418 |
text_pad_sequence_drop = torch.cat(
|
419 |
+
(
|
420 |
+
text_pad_sequence,
|
421 |
+
torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(
|
422 |
+
self.device
|
423 |
+
),
|
424 |
+
),
|
425 |
+
dim=0,
|
426 |
)
|
427 |
|
428 |
text_embedding_drop_list = []
|
429 |
for i in range(batch + 1):
|
430 |
+
text_embedding_drop_list.append(
|
431 |
+
self.text_embedding(
|
432 |
+
text_pad_sequence_drop[i].unsqueeze(0).to(self.device)
|
433 |
+
)
|
434 |
+
)
|
435 |
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
436 |
|
437 |
text_embedding = text_embedding_drop_condition[:-1]
|
438 |
# text_embedding_drop B,T,C batch should be the same
|
439 |
+
text_embedding_drop = (
|
440 |
+
text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
|
441 |
+
)
|
442 |
|
443 |
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
444 |
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
|
|
447 |
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
448 |
cat_mel_text_drop = torch.cat(
|
449 |
(
|
450 |
+
torch.zeros(
|
451 |
+
(batch, max_seq_len, self.n_mel_channels), dtype=torch.float32
|
452 |
+
).to(self.device),
|
453 |
text_embedding_drop,
|
454 |
),
|
455 |
dim=-1,
|
|
|
458 |
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
|
459 |
|
460 |
# Convert estimated_reference_target_mel_len to tensor
|
461 |
+
input_lengths = torch.tensor(
|
462 |
+
estimated_reference_target_mel_len, dtype=torch.int32
|
463 |
+
)
|
464 |
|
465 |
# combine above along the batch dimension
|
466 |
inputs = {
|
|
|
469 |
"time_expand": time_expand,
|
470 |
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
|
471 |
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
|
472 |
+
"input_lengths": torch.cat(
|
473 |
+
(input_lengths, input_lengths), dim=0
|
474 |
+
).contiguous(),
|
475 |
"delta_t": self.delta_t,
|
476 |
}
|
477 |
if use_perf and remove_input_padding:
|
478 |
torch.cuda.nvtx.range_push("remove input padding")
|
479 |
if remove_input_padding:
|
480 |
max_seq_len = inputs["cond"].shape[1]
|
481 |
+
inputs["noise"] = remove_tensor_padding(
|
482 |
+
inputs["noise"], inputs["input_lengths"]
|
483 |
+
)
|
484 |
+
inputs["cond"] = remove_tensor_padding(
|
485 |
+
inputs["cond"], inputs["input_lengths"]
|
486 |
+
)
|
487 |
# for time_expand, convert from B,D to B,T,D by repeat
|
488 |
+
inputs["time_expand"] = (
|
489 |
+
inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
|
490 |
+
)
|
491 |
+
inputs["time_expand"] = remove_tensor_padding(
|
492 |
+
inputs["time_expand"], inputs["input_lengths"]
|
493 |
+
)
|
494 |
+
inputs["rope_cos"] = remove_tensor_padding(
|
495 |
+
inputs["rope_cos"], inputs["input_lengths"]
|
496 |
+
)
|
497 |
+
inputs["rope_sin"] = remove_tensor_padding(
|
498 |
+
inputs["rope_sin"], inputs["input_lengths"]
|
499 |
+
)
|
500 |
if use_perf and remove_input_padding:
|
501 |
torch.cuda.nvtx.range_pop()
|
502 |
for key in inputs:
|
|
|
512 |
denoised_list = []
|
513 |
start_idx = 0
|
514 |
for i in range(batch):
|
515 |
+
denoised_list.append(
|
516 |
+
denoised[start_idx : start_idx + inputs["input_lengths"][i]]
|
517 |
+
)
|
518 |
start_idx += inputs["input_lengths"][i]
|
519 |
if use_perf and remove_input_padding:
|
520 |
torch.cuda.nvtx.range_pop()
|
f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py
CHANGED
@@ -73,7 +73,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
|
73 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
74 |
char_list.append(" ")
|
75 |
char_list.extend(seg)
|
76 |
-
elif polyphone and seg_byte_len == 3 * len(
|
|
|
|
|
77 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
78 |
for i, c in enumerate(seg):
|
79 |
if is_chinese(c):
|
@@ -85,7 +87,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
|
85 |
char_list.extend(c)
|
86 |
elif is_chinese(c):
|
87 |
char_list.append(" ")
|
88 |
-
char_list.extend(
|
|
|
|
|
89 |
else:
|
90 |
char_list.append(c)
|
91 |
final_reference_target_texts_list.append(char_list)
|
@@ -98,7 +102,9 @@ def list_str_to_idx(
|
|
98 |
vocab_char_map: dict[str, int], # {char: idx}
|
99 |
padding_value=-1,
|
100 |
): # noqa: F722
|
101 |
-
list_idx_tensors = [
|
|
|
|
|
102 |
return list_idx_tensors
|
103 |
|
104 |
|
@@ -121,7 +127,9 @@ class TritonPythonModel:
|
|
121 |
|
122 |
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
|
123 |
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
|
124 |
-
self.resampler = torchaudio.transforms.Resample(
|
|
|
|
|
125 |
|
126 |
self.tllm_model_dir = parameters["tllm_model_dir"]
|
127 |
config_file = os.path.join(self.tllm_model_dir, "config.json")
|
@@ -163,13 +171,17 @@ class TritonPythonModel:
|
|
163 |
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
|
164 |
|
165 |
inference_request = pb_utils.InferenceRequest(
|
166 |
-
model_name="vocoder",
|
|
|
|
|
167 |
)
|
168 |
inference_response = inference_request.exec()
|
169 |
if inference_response.has_error():
|
170 |
raise pb_utils.TritonModelException(inference_response.error().message())
|
171 |
else:
|
172 |
-
waveform = pb_utils.get_output_tensor_by_name(
|
|
|
|
|
173 |
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
174 |
|
175 |
return waveform
|
@@ -181,7 +193,13 @@ class TritonPythonModel:
|
|
181 |
reference_target_texts_list,
|
182 |
estimated_reference_target_mel_len,
|
183 |
reference_mel_len,
|
184 |
-
) =
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
mel_features_list = []
|
186 |
if self.use_perf:
|
187 |
torch.cuda.nvtx.range_push("preprocess")
|
@@ -189,10 +207,14 @@ class TritonPythonModel:
|
|
189 |
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
190 |
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
191 |
|
192 |
-
reference_text = pb_utils.get_input_tensor_by_name(
|
|
|
|
|
193 |
reference_text = reference_text[0][0].decode("utf-8")
|
194 |
reference_text_list.append(reference_text)
|
195 |
-
target_text = pb_utils.get_input_tensor_by_name(
|
|
|
|
|
196 |
target_text = target_text[0][0].decode("utf-8")
|
197 |
target_text_list.append(target_text)
|
198 |
|
@@ -221,30 +243,49 @@ class TritonPythonModel:
|
|
221 |
reference_mel_len.append(mel_features.shape[1])
|
222 |
estimated_reference_target_mel_len.append(
|
223 |
int(
|
224 |
-
mel_features.shape[1]
|
|
|
|
|
|
|
|
|
|
|
225 |
)
|
226 |
)
|
227 |
|
228 |
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
229 |
|
230 |
batch = len(requests)
|
231 |
-
mel_features = torch.zeros(
|
|
|
|
|
232 |
for i, mel in enumerate(mel_features_list):
|
233 |
mel_features[i, : mel.shape[1], :] = mel
|
234 |
|
235 |
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
|
236 |
|
237 |
-
pinyin_list = convert_char_to_pinyin(
|
|
|
|
|
238 |
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
239 |
|
240 |
for i, item in enumerate(text_pad_sequence):
|
241 |
text_pad_sequence[i] = F.pad(
|
242 |
-
item,
|
|
|
|
|
|
|
243 |
)
|
244 |
-
text_pad_sequence[
|
245 |
-
|
|
|
|
|
|
|
|
|
246 |
text_pad_sequence = F.pad(
|
247 |
-
text_pad_sequence,
|
|
|
|
|
|
|
248 |
)
|
249 |
if self.use_perf:
|
250 |
torch.cuda.nvtx.range_pop()
|
@@ -264,7 +305,11 @@ class TritonPythonModel:
|
|
264 |
for i in range(batch):
|
265 |
ref_me_len = reference_mel_len[i]
|
266 |
estimated_mel_len = estimated_reference_target_mel_len[i]
|
267 |
-
denoised_one_item =
|
|
|
|
|
|
|
|
|
268 |
audio = self.forward_vocoder(denoised_one_item)
|
269 |
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
270 |
if rms < self.target_rms:
|
|
|
73 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
74 |
char_list.append(" ")
|
75 |
char_list.extend(seg)
|
76 |
+
elif polyphone and seg_byte_len == 3 * len(
|
77 |
+
seg
|
78 |
+
): # if pure east asian characters
|
79 |
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
80 |
for i, c in enumerate(seg):
|
81 |
if is_chinese(c):
|
|
|
87 |
char_list.extend(c)
|
88 |
elif is_chinese(c):
|
89 |
char_list.append(" ")
|
90 |
+
char_list.extend(
|
91 |
+
lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
|
92 |
+
)
|
93 |
else:
|
94 |
char_list.append(c)
|
95 |
final_reference_target_texts_list.append(char_list)
|
|
|
102 |
vocab_char_map: dict[str, int], # {char: idx}
|
103 |
padding_value=-1,
|
104 |
): # noqa: F722
|
105 |
+
list_idx_tensors = [
|
106 |
+
torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
|
107 |
+
] # pinyin or char style
|
108 |
return list_idx_tensors
|
109 |
|
110 |
|
|
|
127 |
|
128 |
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
|
129 |
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
|
130 |
+
self.resampler = torchaudio.transforms.Resample(
|
131 |
+
self.reference_sample_rate, self.target_audio_sample_rate
|
132 |
+
)
|
133 |
|
134 |
self.tllm_model_dir = parameters["tllm_model_dir"]
|
135 |
config_file = os.path.join(self.tllm_model_dir, "config.json")
|
|
|
171 |
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
|
172 |
|
173 |
inference_request = pb_utils.InferenceRequest(
|
174 |
+
model_name="vocoder",
|
175 |
+
requested_output_names=["waveform"],
|
176 |
+
inputs=[input_tensor_0],
|
177 |
)
|
178 |
inference_response = inference_request.exec()
|
179 |
if inference_response.has_error():
|
180 |
raise pb_utils.TritonModelException(inference_response.error().message())
|
181 |
else:
|
182 |
+
waveform = pb_utils.get_output_tensor_by_name(
|
183 |
+
inference_response, "waveform"
|
184 |
+
)
|
185 |
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
186 |
|
187 |
return waveform
|
|
|
193 |
reference_target_texts_list,
|
194 |
estimated_reference_target_mel_len,
|
195 |
reference_mel_len,
|
196 |
+
) = (
|
197 |
+
[],
|
198 |
+
[],
|
199 |
+
[],
|
200 |
+
[],
|
201 |
+
[],
|
202 |
+
)
|
203 |
mel_features_list = []
|
204 |
if self.use_perf:
|
205 |
torch.cuda.nvtx.range_push("preprocess")
|
|
|
207 |
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
208 |
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
209 |
|
210 |
+
reference_text = pb_utils.get_input_tensor_by_name(
|
211 |
+
request, "reference_text"
|
212 |
+
).as_numpy()
|
213 |
reference_text = reference_text[0][0].decode("utf-8")
|
214 |
reference_text_list.append(reference_text)
|
215 |
+
target_text = pb_utils.get_input_tensor_by_name(
|
216 |
+
request, "target_text"
|
217 |
+
).as_numpy()
|
218 |
target_text = target_text[0][0].decode("utf-8")
|
219 |
target_text_list.append(target_text)
|
220 |
|
|
|
243 |
reference_mel_len.append(mel_features.shape[1])
|
244 |
estimated_reference_target_mel_len.append(
|
245 |
int(
|
246 |
+
mel_features.shape[1]
|
247 |
+
* (
|
248 |
+
1
|
249 |
+
+ len(target_text.encode("utf-8"))
|
250 |
+
/ len(reference_text.encode("utf-8"))
|
251 |
+
)
|
252 |
)
|
253 |
)
|
254 |
|
255 |
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
256 |
|
257 |
batch = len(requests)
|
258 |
+
mel_features = torch.zeros(
|
259 |
+
(batch, max_seq_len, self.n_mel_channels), dtype=torch.float16
|
260 |
+
).to(self.device)
|
261 |
for i, mel in enumerate(mel_features_list):
|
262 |
mel_features[i, : mel.shape[1], :] = mel
|
263 |
|
264 |
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
|
265 |
|
266 |
+
pinyin_list = convert_char_to_pinyin(
|
267 |
+
reference_target_texts_list, polyphone=True
|
268 |
+
)
|
269 |
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
270 |
|
271 |
for i, item in enumerate(text_pad_sequence):
|
272 |
text_pad_sequence[i] = F.pad(
|
273 |
+
item,
|
274 |
+
(0, estimated_reference_target_mel_len[i] - len(item)),
|
275 |
+
mode="constant",
|
276 |
+
value=-1,
|
277 |
)
|
278 |
+
text_pad_sequence[
|
279 |
+
i
|
280 |
+
] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
281 |
+
text_pad_sequence = pad_sequence(
|
282 |
+
text_pad_sequence, padding_value=-1, batch_first=True
|
283 |
+
).to(self.device)
|
284 |
text_pad_sequence = F.pad(
|
285 |
+
text_pad_sequence,
|
286 |
+
(0, max_seq_len - text_pad_sequence.shape[1]),
|
287 |
+
mode="constant",
|
288 |
+
value=-1,
|
289 |
)
|
290 |
if self.use_perf:
|
291 |
torch.cuda.nvtx.range_pop()
|
|
|
305 |
for i in range(batch):
|
306 |
ref_me_len = reference_mel_len[i]
|
307 |
estimated_mel_len = estimated_reference_target_mel_len[i]
|
308 |
+
denoised_one_item = (
|
309 |
+
denoised[i, ref_me_len:estimated_mel_len, :]
|
310 |
+
.unsqueeze(0)
|
311 |
+
.transpose(1, 2)
|
312 |
+
)
|
313 |
audio = self.forward_vocoder(denoised_one_item)
|
314 |
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
315 |
if rms < self.target_rms:
|
f5_tts/runtime/triton_trtllm/patch/__init__.py
CHANGED
@@ -13,14 +13,10 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
from .baichuan.model import BaichuanForCausalLM
|
16 |
-
from .bert.model import (
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
RobertaForQuestionAnswering,
|
21 |
-
RobertaForSequenceClassification,
|
22 |
-
RobertaModel,
|
23 |
-
)
|
24 |
from .bloom.model import BloomForCausalLM, BloomModel
|
25 |
from .chatglm.config import ChatGLMConfig
|
26 |
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
|
@@ -51,17 +47,17 @@ from .mamba.model import MambaForCausalLM
|
|
51 |
from .medusa.config import MedusaConfig
|
52 |
from .medusa.model import MedusaForCausalLm
|
53 |
from .mllama.model import MLLaMAModel
|
54 |
-
from .modeling_utils import PretrainedConfig, PretrainedModel,
|
|
|
55 |
from .mpt.model import MPTForCausalLM, MPTModel
|
56 |
from .nemotron_nas.model import DeciLMForCausalLM
|
57 |
from .opt.model import OPTForCausalLM, OPTModel
|
58 |
-
from .phi.model import PhiForCausalLM, PhiModel
|
59 |
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
|
|
60 |
from .qwen.model import QWenForCausalLM
|
61 |
from .recurrentgemma.model import RecurrentGemmaForCausalLM
|
62 |
from .redrafter.model import ReDrafterForCausalLM
|
63 |
|
64 |
-
|
65 |
__all__ = [
|
66 |
"BertModel",
|
67 |
"BertForQuestionAnswering",
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
from .baichuan.model import BaichuanForCausalLM
|
16 |
+
from .bert.model import (BertForQuestionAnswering,
|
17 |
+
BertForSequenceClassification, BertModel,
|
18 |
+
RobertaForQuestionAnswering,
|
19 |
+
RobertaForSequenceClassification, RobertaModel)
|
|
|
|
|
|
|
|
|
20 |
from .bloom.model import BloomForCausalLM, BloomModel
|
21 |
from .chatglm.config import ChatGLMConfig
|
22 |
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
|
|
|
47 |
from .medusa.config import MedusaConfig
|
48 |
from .medusa.model import MedusaForCausalLm
|
49 |
from .mllama.model import MLLaMAModel
|
50 |
+
from .modeling_utils import (PretrainedConfig, PretrainedModel,
|
51 |
+
SpeculativeDecodingMode)
|
52 |
from .mpt.model import MPTForCausalLM, MPTModel
|
53 |
from .nemotron_nas.model import DeciLMForCausalLM
|
54 |
from .opt.model import OPTForCausalLM, OPTModel
|
|
|
55 |
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
56 |
+
from .phi.model import PhiForCausalLM, PhiModel
|
57 |
from .qwen.model import QWenForCausalLM
|
58 |
from .recurrentgemma.model import RecurrentGemmaForCausalLM
|
59 |
from .redrafter.model import ReDrafterForCausalLM
|
60 |
|
|
|
61 |
__all__ = [
|
62 |
"BertModel",
|
63 |
"BertForQuestionAnswering",
|
f5_tts/runtime/triton_trtllm/patch/f5tts/model.py
CHANGED
@@ -13,8 +13,8 @@ from ...layers import Linear
|
|
13 |
from ...module import Module, ModuleList
|
14 |
from ...plugin import current_all_reduce_helper
|
15 |
from ..modeling_utils import PretrainedConfig, PretrainedModel
|
16 |
-
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock,
|
17 |
-
|
18 |
|
19 |
current_file_path = os.path.abspath(__file__)
|
20 |
parent_dir = os.path.dirname(current_file_path)
|
@@ -38,7 +38,9 @@ class F5TTS(PretrainedModel):
|
|
38 |
self.dtype = str_dtype_to_trt(config.dtype)
|
39 |
|
40 |
self.time_embed = TimestepEmbedding(config.hidden_size)
|
41 |
-
self.input_embed = InputEmbedding(
|
|
|
|
|
42 |
|
43 |
self.dim = config.hidden_size
|
44 |
self.depth = config.num_hidden_layers
|
@@ -71,7 +73,14 @@ class F5TTS(PretrainedModel):
|
|
71 |
t = self.time_embed(time)
|
72 |
x = self.input_embed(noise, cond)
|
73 |
for block in self.transformer_blocks:
|
74 |
-
x = block(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
denoise = self.proj_out(self.norm_out(x, t))
|
76 |
denoise.mark_output("denoised", self.dtype)
|
77 |
return denoise
|
|
|
13 |
from ...module import Module, ModuleList
|
14 |
from ...plugin import current_all_reduce_helper
|
15 |
from ..modeling_utils import PretrainedConfig, PretrainedModel
|
16 |
+
from .modules import (AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock,
|
17 |
+
TimestepEmbedding)
|
18 |
|
19 |
current_file_path = os.path.abspath(__file__)
|
20 |
parent_dir = os.path.dirname(current_file_path)
|
|
|
38 |
self.dtype = str_dtype_to_trt(config.dtype)
|
39 |
|
40 |
self.time_embed = TimestepEmbedding(config.hidden_size)
|
41 |
+
self.input_embed = InputEmbedding(
|
42 |
+
config.mel_dim, config.text_dim, config.hidden_size
|
43 |
+
)
|
44 |
|
45 |
self.dim = config.hidden_size
|
46 |
self.depth = config.num_hidden_layers
|
|
|
73 |
t = self.time_embed(time)
|
74 |
x = self.input_embed(noise, cond)
|
75 |
for block in self.transformer_blocks:
|
76 |
+
x = block(
|
77 |
+
x,
|
78 |
+
t,
|
79 |
+
rope_cos=rope_cos,
|
80 |
+
rope_sin=rope_sin,
|
81 |
+
input_lengths=input_lengths,
|
82 |
+
scale=scale,
|
83 |
+
)
|
84 |
denoise = self.proj_out(self.norm_out(x, t))
|
85 |
denoise.mark_output("denoised", self.dtype)
|
86 |
return denoise
|
f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py
CHANGED
@@ -9,28 +9,10 @@ import torch.nn.functional as F
|
|
9 |
from tensorrt_llm._common import default_net
|
10 |
|
11 |
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
|
12 |
-
from ...functional import (
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
chunk,
|
17 |
-
concat,
|
18 |
-
constant,
|
19 |
-
expand,
|
20 |
-
expand_dims,
|
21 |
-
expand_dims_like,
|
22 |
-
expand_mask,
|
23 |
-
gelu,
|
24 |
-
matmul,
|
25 |
-
permute,
|
26 |
-
shape,
|
27 |
-
silu,
|
28 |
-
slice,
|
29 |
-
softmax,
|
30 |
-
squeeze,
|
31 |
-
unsqueeze,
|
32 |
-
view,
|
33 |
-
)
|
34 |
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
|
35 |
from ...module import Module
|
36 |
|
@@ -57,7 +39,9 @@ class AdaLayerNormZero(Module):
|
|
57 |
|
58 |
def forward(self, x, emb=None):
|
59 |
emb = self.linear(silu(emb))
|
60 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(
|
|
|
|
|
61 |
x = self.norm(x)
|
62 |
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
63 |
if default_net().plugin_config.remove_input_padding:
|
@@ -91,8 +75,12 @@ class ConvPositionEmbedding(Module):
|
|
91 |
def __init__(self, dim, kernel_size=31, groups=16):
|
92 |
super().__init__()
|
93 |
assert kernel_size % 2 != 0
|
94 |
-
self.conv1d1 = Conv1d(
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
self.mish = Mish()
|
97 |
|
98 |
def forward(self, x, mask=None): # noqa: F722
|
@@ -120,7 +108,9 @@ class Attention(Module):
|
|
120 |
super().__init__()
|
121 |
|
122 |
if not hasattr(F, "scaled_dot_product_attention"):
|
123 |
-
raise ImportError(
|
|
|
|
|
124 |
|
125 |
self.processor = processor
|
126 |
|
@@ -191,16 +181,32 @@ class Attention(Module):
|
|
191 |
c_rope=None, # rotary position embedding for c
|
192 |
) -> torch.Tensor:
|
193 |
if c is not None:
|
194 |
-
return self.processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
else:
|
196 |
return self.processor(
|
197 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
198 |
)
|
199 |
|
200 |
|
201 |
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
202 |
shape_tensor = concat(
|
203 |
-
[
|
|
|
|
|
|
|
204 |
)
|
205 |
if default_net().plugin_config.remove_input_padding:
|
206 |
assert tensor.ndim() == 2
|
@@ -208,7 +214,9 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
|
208 |
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
|
209 |
x1 = expand_dims(x1, 2)
|
210 |
x2 = expand_dims(x2, 2)
|
211 |
-
zero = constant(
|
|
|
|
|
212 |
x2 = zero - x2
|
213 |
x = concat([x2, x1], 2)
|
214 |
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
|
@@ -219,7 +227,9 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
|
219 |
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
|
220 |
x1 = expand_dims(x1, 3)
|
221 |
x2 = expand_dims(x2, 3)
|
222 |
-
zero = constant(
|
|
|
|
|
223 |
x2 = zero - x2
|
224 |
x = concat([x2, x1], 3)
|
225 |
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
|
@@ -235,15 +245,23 @@ def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
|
|
235 |
end_dim = shape(x, -1) - shape(rope_cos, -1)
|
236 |
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
|
237 |
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
|
238 |
-
out = concat(
|
|
|
|
|
239 |
else:
|
240 |
rot_dim = shape(rope_cos, 2) # 64
|
241 |
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
|
242 |
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
|
243 |
end_dim = shape(x, 2) - shape(rope_cos, 2)
|
244 |
-
new_t_unrotated_shape = concat(
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
return out
|
248 |
|
249 |
|
@@ -279,8 +297,12 @@ class AttnProcessor:
|
|
279 |
seq_len_2d = concat([1, N])
|
280 |
max_position_embeddings = 4096
|
281 |
# create position ids
|
282 |
-
position_ids_buffer = constant(
|
283 |
-
|
|
|
|
|
|
|
|
|
284 |
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
|
285 |
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
|
286 |
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
|
@@ -315,14 +337,28 @@ class AttnProcessor:
|
|
315 |
assert not default_net().plugin_config.remove_input_padding
|
316 |
|
317 |
def transpose_for_scores(x):
|
318 |
-
new_x_shape = concat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
y = x.view(new_x_shape)
|
321 |
y = y.transpose(1, 2)
|
322 |
return y
|
323 |
|
324 |
def transpose_for_scores_k(x):
|
325 |
-
new_x_shape = concat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
y = x.view(new_x_shape)
|
328 |
y = y.permute([0, 2, 3, 1])
|
@@ -342,7 +378,11 @@ class AttnProcessor:
|
|
342 |
attention_probs = softmax(attention_scores, dim=-1)
|
343 |
|
344 |
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
|
345 |
-
context = context.view(
|
|
|
|
|
|
|
|
|
346 |
context = attn.to_out(context)
|
347 |
if mask is not None:
|
348 |
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
|
@@ -370,13 +410,26 @@ class DiTBlock(Module):
|
|
370 |
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
|
371 |
|
372 |
def forward(
|
373 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
): # x: noised input, t: time embedding
|
375 |
# pre-norm & modulation for attention input
|
376 |
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
377 |
# attention
|
378 |
# norm ----> (2,1226,1024)
|
379 |
-
attn_output = self.attn(
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
# process attention output for input x
|
382 |
if default_net().plugin_config.remove_input_padding:
|
@@ -387,7 +440,9 @@ class DiTBlock(Module):
|
|
387 |
if default_net().plugin_config.remove_input_padding:
|
388 |
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
389 |
else:
|
390 |
-
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(
|
|
|
|
|
391 |
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
392 |
ff_output = self.ff(norm)
|
393 |
if default_net().plugin_config.remove_input_padding:
|
|
|
9 |
from tensorrt_llm._common import default_net
|
10 |
|
11 |
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
|
12 |
+
from ...functional import (Tensor, bert_attention, cast, chunk, concat,
|
13 |
+
constant, expand, expand_dims, expand_dims_like,
|
14 |
+
expand_mask, gelu, matmul, permute, shape, silu,
|
15 |
+
slice, softmax, squeeze, unsqueeze, view)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
|
17 |
from ...module import Module
|
18 |
|
|
|
39 |
|
40 |
def forward(self, x, emb=None):
|
41 |
emb = self.linear(silu(emb))
|
42 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(
|
43 |
+
emb, 6, dim=1
|
44 |
+
)
|
45 |
x = self.norm(x)
|
46 |
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
47 |
if default_net().plugin_config.remove_input_padding:
|
|
|
75 |
def __init__(self, dim, kernel_size=31, groups=16):
|
76 |
super().__init__()
|
77 |
assert kernel_size % 2 != 0
|
78 |
+
self.conv1d1 = Conv1d(
|
79 |
+
dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
|
80 |
+
)
|
81 |
+
self.conv1d2 = Conv1d(
|
82 |
+
dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
|
83 |
+
)
|
84 |
self.mish = Mish()
|
85 |
|
86 |
def forward(self, x, mask=None): # noqa: F722
|
|
|
108 |
super().__init__()
|
109 |
|
110 |
if not hasattr(F, "scaled_dot_product_attention"):
|
111 |
+
raise ImportError(
|
112 |
+
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
113 |
+
)
|
114 |
|
115 |
self.processor = processor
|
116 |
|
|
|
181 |
c_rope=None, # rotary position embedding for c
|
182 |
) -> torch.Tensor:
|
183 |
if c is not None:
|
184 |
+
return self.processor(
|
185 |
+
self,
|
186 |
+
x,
|
187 |
+
c=c,
|
188 |
+
input_lengths=input_lengths,
|
189 |
+
scale=scale,
|
190 |
+
rope=rope,
|
191 |
+
c_rope=c_rope,
|
192 |
+
)
|
193 |
else:
|
194 |
return self.processor(
|
195 |
+
self,
|
196 |
+
x,
|
197 |
+
rope_cos=rope_cos,
|
198 |
+
rope_sin=rope_sin,
|
199 |
+
input_lengths=input_lengths,
|
200 |
+
scale=scale,
|
201 |
)
|
202 |
|
203 |
|
204 |
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
205 |
shape_tensor = concat(
|
206 |
+
[
|
207 |
+
shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i)
|
208 |
+
for i in range(tensor.ndim())
|
209 |
+
]
|
210 |
)
|
211 |
if default_net().plugin_config.remove_input_padding:
|
212 |
assert tensor.ndim() == 2
|
|
|
214 |
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
|
215 |
x1 = expand_dims(x1, 2)
|
216 |
x2 = expand_dims(x2, 2)
|
217 |
+
zero = constant(
|
218 |
+
np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))
|
219 |
+
)
|
220 |
x2 = zero - x2
|
221 |
x = concat([x2, x1], 2)
|
222 |
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
|
|
|
227 |
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
|
228 |
x1 = expand_dims(x1, 3)
|
229 |
x2 = expand_dims(x2, 3)
|
230 |
+
zero = constant(
|
231 |
+
np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))
|
232 |
+
)
|
233 |
x2 = zero - x2
|
234 |
x = concat([x2, x1], 3)
|
235 |
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
|
|
|
245 |
end_dim = shape(x, -1) - shape(rope_cos, -1)
|
246 |
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
|
247 |
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
|
248 |
+
out = concat(
|
249 |
+
[x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1
|
250 |
+
)
|
251 |
else:
|
252 |
rot_dim = shape(rope_cos, 2) # 64
|
253 |
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
|
254 |
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
|
255 |
end_dim = shape(x, 2) - shape(rope_cos, 2)
|
256 |
+
new_t_unrotated_shape = concat(
|
257 |
+
[shape(x, 0), shape(x, 1), end_dim]
|
258 |
+
) # (2, -1, 960)
|
259 |
+
x_unrotated = slice(
|
260 |
+
x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1]
|
261 |
+
)
|
262 |
+
out = concat(
|
263 |
+
[x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1
|
264 |
+
)
|
265 |
return out
|
266 |
|
267 |
|
|
|
297 |
seq_len_2d = concat([1, N])
|
298 |
max_position_embeddings = 4096
|
299 |
# create position ids
|
300 |
+
position_ids_buffer = constant(
|
301 |
+
np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0)
|
302 |
+
)
|
303 |
+
tmp_position_ids = slice(
|
304 |
+
position_ids_buffer, starts=[0, 0], sizes=seq_len_2d
|
305 |
+
)
|
306 |
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
|
307 |
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
|
308 |
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
|
|
|
337 |
assert not default_net().plugin_config.remove_input_padding
|
338 |
|
339 |
def transpose_for_scores(x):
|
340 |
+
new_x_shape = concat(
|
341 |
+
[
|
342 |
+
shape(x, 0),
|
343 |
+
shape(x, 1),
|
344 |
+
attn.num_attention_heads,
|
345 |
+
attn.attention_head_size,
|
346 |
+
]
|
347 |
+
)
|
348 |
|
349 |
y = x.view(new_x_shape)
|
350 |
y = y.transpose(1, 2)
|
351 |
return y
|
352 |
|
353 |
def transpose_for_scores_k(x):
|
354 |
+
new_x_shape = concat(
|
355 |
+
[
|
356 |
+
shape(x, 0),
|
357 |
+
shape(x, 1),
|
358 |
+
attn.num_attention_heads,
|
359 |
+
attn.attention_head_size,
|
360 |
+
]
|
361 |
+
)
|
362 |
|
363 |
y = x.view(new_x_shape)
|
364 |
y = y.permute([0, 2, 3, 1])
|
|
|
378 |
attention_probs = softmax(attention_scores, dim=-1)
|
379 |
|
380 |
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
|
381 |
+
context = context.view(
|
382 |
+
concat(
|
383 |
+
[shape(context, 0), shape(context, 1), attn.attention_hidden_size]
|
384 |
+
)
|
385 |
+
)
|
386 |
context = attn.to_out(context)
|
387 |
if mask is not None:
|
388 |
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
|
|
|
410 |
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
|
411 |
|
412 |
def forward(
|
413 |
+
self,
|
414 |
+
x,
|
415 |
+
t,
|
416 |
+
rope_cos,
|
417 |
+
rope_sin,
|
418 |
+
input_lengths,
|
419 |
+
scale=1.0,
|
420 |
+
rope=ModuleNotFoundError,
|
421 |
): # x: noised input, t: time embedding
|
422 |
# pre-norm & modulation for attention input
|
423 |
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
424 |
# attention
|
425 |
# norm ----> (2,1226,1024)
|
426 |
+
attn_output = self.attn(
|
427 |
+
x=norm,
|
428 |
+
rope_cos=rope_cos,
|
429 |
+
rope_sin=rope_sin,
|
430 |
+
input_lengths=input_lengths,
|
431 |
+
scale=scale,
|
432 |
+
)
|
433 |
|
434 |
# process attention output for input x
|
435 |
if default_net().plugin_config.remove_input_padding:
|
|
|
440 |
if default_net().plugin_config.remove_input_padding:
|
441 |
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
442 |
else:
|
443 |
+
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(
|
444 |
+
shift_mlp, 1
|
445 |
+
)
|
446 |
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
447 |
ff_output = self.ff(norm)
|
448 |
if default_net().plugin_config.remove_input_padding:
|
f5_tts/runtime/triton_trtllm/scripts/conv_stft.py
CHANGED
@@ -40,7 +40,6 @@ import torch as th
|
|
40 |
import torch.nn.functional as F
|
41 |
from scipy.signal import check_COLA, get_window
|
42 |
|
43 |
-
|
44 |
support_clp_op = None
|
45 |
if th.__version__ >= "1.7.0":
|
46 |
from torch.fft import rfft as fft
|
@@ -124,7 +123,9 @@ class STFT(th.nn.Module):
|
|
124 |
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
|
125 |
window = get_window(self.win_type, self.win_len)
|
126 |
|
127 |
-
self.perfect_reconstruct = check_COLA(
|
|
|
|
|
128 |
window = th.FloatTensor(window)
|
129 |
if self.mode == "continue":
|
130 |
left_pad = (self.fft_len - self.win_len) // 2
|
|
|
40 |
import torch.nn.functional as F
|
41 |
from scipy.signal import check_COLA, get_window
|
42 |
|
|
|
43 |
support_clp_op = None
|
44 |
if th.__version__ >= "1.7.0":
|
45 |
from torch.fft import rfft as fft
|
|
|
123 |
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
|
124 |
window = get_window(self.win_type, self.win_len)
|
125 |
|
126 |
+
self.perfect_reconstruct = check_COLA(
|
127 |
+
window, self.win_len, self.win_len - self.win_hop
|
128 |
+
)
|
129 |
window = th.FloatTensor(window)
|
130 |
if self.mode == "continue":
|
131 |
left_pad = (self.fft_len - self.win_len) // 2
|
f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py
CHANGED
@@ -179,19 +179,47 @@ def parse_arguments():
|
|
179 |
) # TODO: support F5TTS_v1_Base
|
180 |
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
181 |
parser.add_argument(
|
182 |
-
"--output_dir",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
)
|
184 |
-
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
185 |
-
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
186 |
-
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
187 |
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
188 |
-
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
189 |
-
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
190 |
-
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
191 |
-
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
|
192 |
-
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
|
193 |
parser.add_argument(
|
194 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
)
|
196 |
args = parser.parse_args()
|
197 |
return args
|
@@ -205,10 +233,15 @@ def convert_timm_dit(args, mapping, dtype="float32"):
|
|
205 |
|
206 |
model_params = dict(torch.load(args.timm_ckpt))
|
207 |
model_params = {
|
208 |
-
k: v
|
|
|
|
|
209 |
}
|
210 |
prefix = "ema_model.transformer."
|
211 |
-
model_params = {
|
|
|
|
|
|
|
212 |
|
213 |
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
214 |
|
@@ -223,8 +256,13 @@ def convert_timm_dit(args, mapping, dtype="float32"):
|
|
223 |
|
224 |
weights = dict()
|
225 |
for name, param in model_params.items():
|
226 |
-
if
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
228 |
else:
|
229 |
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
|
230 |
|
@@ -239,25 +277,37 @@ def convert_timm_dit(args, mapping, dtype="float32"):
|
|
239 |
for k, v in weights.items():
|
240 |
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
|
241 |
weights[k] *= scale_factor
|
242 |
-
weights[k] = split_q_tp(
|
|
|
|
|
243 |
|
244 |
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
|
245 |
weights[k] *= scale_factor
|
246 |
-
weights[k] = split_q_bias_tp(
|
|
|
|
|
247 |
|
248 |
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
|
249 |
-
weights[k] = split_q_tp(
|
|
|
|
|
250 |
weights[k] *= scale_factor
|
251 |
|
252 |
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
|
253 |
-
weights[k] = split_q_bias_tp(
|
|
|
|
|
254 |
weights[k] *= scale_factor
|
255 |
|
256 |
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
|
257 |
-
weights[k] = split_q_tp(
|
|
|
|
|
258 |
|
259 |
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
|
260 |
-
weights[k] = split_q_bias_tp(
|
|
|
|
|
261 |
|
262 |
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
|
263 |
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
|
@@ -317,7 +367,9 @@ def covert_and_save(args, rank):
|
|
317 |
|
318 |
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
319 |
|
320 |
-
safetensors.torch.save_file(
|
|
|
|
|
321 |
|
322 |
|
323 |
def execute(workers, func, args):
|
@@ -334,7 +386,9 @@ def execute(workers, func, args):
|
|
334 |
except Exception as e:
|
335 |
traceback.print_exc()
|
336 |
exceptions.append(e)
|
337 |
-
assert
|
|
|
|
|
338 |
|
339 |
|
340 |
def main():
|
|
|
179 |
) # TODO: support F5TTS_v1_Base
|
180 |
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
181 |
parser.add_argument(
|
182 |
+
"--output_dir",
|
183 |
+
type=str,
|
184 |
+
default="./tllm_checkpoint",
|
185 |
+
help="The path to save the TensorRT-LLM checkpoint",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--hidden_size", type=int, default=1024, help="The hidden size of DiT"
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--depth", type=int, default=22, help="The number of DiTBlock layers"
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--num_heads",
|
195 |
+
type=int,
|
196 |
+
default=16,
|
197 |
+
help="The number of heads of attention module",
|
198 |
)
|
|
|
|
|
|
|
199 |
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
|
|
|
|
|
|
|
|
|
|
200 |
parser.add_argument(
|
201 |
+
"--tp_size", type=int, default=1, help="N-way tensor parallelism size"
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--cp_size", type=int, default=1, help="Context parallelism size"
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--pp_size", type=int, default=1, help="N-way pipeline parallelism size"
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--dtype",
|
211 |
+
type=str,
|
212 |
+
default="float16",
|
213 |
+
choices=["float32", "bfloat16", "float16"],
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--fp8_linear", action="store_true", help="Whether use FP8 for linear layers"
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--workers",
|
220 |
+
type=int,
|
221 |
+
default=1,
|
222 |
+
help="The number of workers for converting checkpoint in parallel",
|
223 |
)
|
224 |
args = parser.parse_args()
|
225 |
return args
|
|
|
233 |
|
234 |
model_params = dict(torch.load(args.timm_ckpt))
|
235 |
model_params = {
|
236 |
+
k: v
|
237 |
+
for k, v in model_params["ema_model_state_dict"].items()
|
238 |
+
if k.startswith("ema_model.transformer")
|
239 |
}
|
240 |
prefix = "ema_model.transformer."
|
241 |
+
model_params = {
|
242 |
+
key[len(prefix) :] if key.startswith(prefix) else key: value
|
243 |
+
for key, value in model_params.items()
|
244 |
+
}
|
245 |
|
246 |
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
247 |
|
|
|
256 |
|
257 |
weights = dict()
|
258 |
for name, param in model_params.items():
|
259 |
+
if (
|
260 |
+
name == "input_embed.conv_pos_embed.conv1d.0.weight"
|
261 |
+
or name == "input_embed.conv_pos_embed.conv1d.2.weight"
|
262 |
+
):
|
263 |
+
weights[get_trtllm_name(name)] = (
|
264 |
+
param.contiguous().to(torch_dtype).unsqueeze(-1)
|
265 |
+
)
|
266 |
else:
|
267 |
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
|
268 |
|
|
|
277 |
for k, v in weights.items():
|
278 |
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
|
279 |
weights[k] *= scale_factor
|
280 |
+
weights[k] = split_q_tp(
|
281 |
+
v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
|
282 |
+
)
|
283 |
|
284 |
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
|
285 |
weights[k] *= scale_factor
|
286 |
+
weights[k] = split_q_bias_tp(
|
287 |
+
v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
|
288 |
+
)
|
289 |
|
290 |
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
|
291 |
+
weights[k] = split_q_tp(
|
292 |
+
v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
|
293 |
+
)
|
294 |
weights[k] *= scale_factor
|
295 |
|
296 |
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
|
297 |
+
weights[k] = split_q_bias_tp(
|
298 |
+
v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
|
299 |
+
)
|
300 |
weights[k] *= scale_factor
|
301 |
|
302 |
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
|
303 |
+
weights[k] = split_q_tp(
|
304 |
+
v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
|
305 |
+
)
|
306 |
|
307 |
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
|
308 |
+
weights[k] = split_q_bias_tp(
|
309 |
+
v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
|
310 |
+
)
|
311 |
|
312 |
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
|
313 |
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
|
|
|
367 |
|
368 |
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
369 |
|
370 |
+
safetensors.torch.save_file(
|
371 |
+
weights, os.path.join(args.output_dir, f"rank{rank}.safetensors")
|
372 |
+
)
|
373 |
|
374 |
|
375 |
def execute(workers, func, args):
|
|
|
386 |
except Exception as e:
|
387 |
traceback.print_exc()
|
388 |
exceptions.append(e)
|
389 |
+
assert (
|
390 |
+
len(exceptions) == 0
|
391 |
+
), "Checkpoint conversion failed, please check error log."
|
392 |
|
393 |
|
394 |
def main():
|
f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py
CHANGED
@@ -20,12 +20,13 @@ from conv_stft import STFT
|
|
20 |
from huggingface_hub import hf_hub_download
|
21 |
from vocos import Vocos
|
22 |
|
23 |
-
|
24 |
opset_version = 17
|
25 |
|
26 |
|
27 |
def get_args():
|
28 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
29 |
parser.add_argument(
|
30 |
"--vocoder",
|
31 |
type=str,
|
@@ -108,7 +109,9 @@ def export_VocosVocoder(vocos_vocoder, output_path, verbose):
|
|
108 |
print("Exported to {}".format(output_path))
|
109 |
|
110 |
|
111 |
-
def load_vocoder(
|
|
|
|
|
112 |
if vocoder_name == "vocos":
|
113 |
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
114 |
if is_local:
|
@@ -118,8 +121,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cp
|
|
118 |
else:
|
119 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
120 |
repo_id = "charactr/vocos-mel-24khz"
|
121 |
-
config_path = hf_hub_download(
|
122 |
-
|
|
|
|
|
|
|
|
|
123 |
vocoder = Vocos.from_hparams(config_path)
|
124 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
125 |
vocoder.load_state_dict(state_dict)
|
|
|
20 |
from huggingface_hub import hf_hub_download
|
21 |
from vocos import Vocos
|
22 |
|
|
|
23 |
opset_version = 17
|
24 |
|
25 |
|
26 |
def get_args():
|
27 |
+
parser = argparse.ArgumentParser(
|
28 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
29 |
+
)
|
30 |
parser.add_argument(
|
31 |
"--vocoder",
|
32 |
type=str,
|
|
|
109 |
print("Exported to {}".format(output_path))
|
110 |
|
111 |
|
112 |
+
def load_vocoder(
|
113 |
+
vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None
|
114 |
+
):
|
115 |
if vocoder_name == "vocos":
|
116 |
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
117 |
if is_local:
|
|
|
121 |
else:
|
122 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
123 |
repo_id = "charactr/vocos-mel-24khz"
|
124 |
+
config_path = hf_hub_download(
|
125 |
+
repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml"
|
126 |
+
)
|
127 |
+
model_path = hf_hub_download(
|
128 |
+
repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin"
|
129 |
+
)
|
130 |
vocoder = Vocos.from_hparams(config_path)
|
131 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
132 |
vocoder.load_state_dict(state_dict)
|
f5_tts/runtime/triton_trtllm/scripts/fill_template.py
CHANGED
@@ -29,8 +29,12 @@ if __name__ == "__main__":
|
|
29 |
"substitutions",
|
30 |
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
|
31 |
)
|
32 |
-
parser.add_argument(
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
args = parser.parse_args()
|
35 |
|
36 |
main(**vars(args))
|
|
|
29 |
"substitutions",
|
30 |
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
|
31 |
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--in_place", "-i", action="store_true", help="do the operation in-place"
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--participant_ids", help="Participant IDs for the model", default=""
|
37 |
+
)
|
38 |
args = parser.parse_args()
|
39 |
|
40 |
main(**vars(args))
|
f5_tts/scripts/count_max_epoch.py
CHANGED
@@ -24,10 +24,14 @@ updates_per_epoch = total_hours / mini_batch_hours
|
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
-
print(
|
|
|
|
|
28 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
29 |
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
30 |
|
31 |
# others
|
32 |
print(f"total {total_hours:.0f} hours")
|
33 |
-
print(
|
|
|
|
|
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
+
print(
|
28 |
+
f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})"
|
29 |
+
)
|
30 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
31 |
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
32 |
|
33 |
# others
|
34 |
print(f"total {total_hours:.0f} hours")
|
35 |
+
print(
|
36 |
+
f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch"
|
37 |
+
)
|