mrfakename commited on
Commit
597cecf
·
1 Parent(s): 39d2f14
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +143 -248
  2. ctcmodel.py +73 -75
  3. discriminator_conformer.py +100 -63
  4. dmd_trainer.py +338 -151
  5. duration_predictor.py +38 -17
  6. duration_trainer.py +181 -116
  7. duration_trainer_with_prompt.py +164 -94
  8. ecapa_tdnn.py +155 -50
  9. f5_tts/api.py +29 -18
  10. f5_tts/eval/ecapa_tdnn.py +91 -19
  11. f5_tts/eval/eval_infer_batch.py +34 -15
  12. f5_tts/eval/eval_librispeech_test_clean.py +20 -8
  13. f5_tts/eval/eval_seedtts_testset.py +17 -7
  14. f5_tts/eval/eval_utmos.py +9 -3
  15. f5_tts/eval/utils_eval.py +43 -13
  16. f5_tts/infer/infer_cli.py +53 -32
  17. f5_tts/infer/infer_gradio.py +178 -57
  18. f5_tts/infer/speech_edit.py +30 -14
  19. f5_tts/infer/utils_infer.py +106 -33
  20. f5_tts/model/__init__.py +2 -5
  21. f5_tts/model/backbones/dit.py +63 -30
  22. f5_tts/model/backbones/mmdit.py +82 -65
  23. f5_tts/model/backbones/unett.py +45 -23
  24. f5_tts/model/cfm.py +61 -25
  25. f5_tts/model/dataset.py +100 -82
  26. f5_tts/model/modules.py +105 -38
  27. f5_tts/model/trainer.py +163 -78
  28. f5_tts/model/utils.py +31 -18
  29. f5_tts/model_new/__init__.py +0 -1
  30. f5_tts/model_new/backbones/dit.py +65 -26
  31. f5_tts/model_new/backbones/mmdit.py +42 -20
  32. f5_tts/model_new/backbones/unett.py +68 -27
  33. f5_tts/model_new/cfm.py +41 -18
  34. f5_tts/model_new/dataset.py +31 -9
  35. f5_tts/model_new/modules.py +126 -39
  36. f5_tts/model_new/trainer.py +142 -43
  37. f5_tts/model_new/utils.py +29 -13
  38. f5_tts/runtime/triton_trtllm/benchmark.py +106 -28
  39. f5_tts/runtime/triton_trtllm/client_grpc.py +42 -13
  40. f5_tts/runtime/triton_trtllm/client_http.py +26 -5
  41. f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py +125 -33
  42. f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +62 -17
  43. f5_tts/runtime/triton_trtllm/patch/__init__.py +7 -11
  44. f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +13 -4
  45. f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +98 -43
  46. f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +3 -2
  47. f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py +76 -22
  48. f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +12 -5
  49. f5_tts/runtime/triton_trtllm/scripts/fill_template.py +6 -2
  50. f5_tts/scripts/count_max_epoch.py +6 -2
app.py CHANGED
@@ -1,130 +1,51 @@
1
- import gradio as gr
2
- import torch
3
- import torchaudio
4
- import numpy as np
5
  import tempfile
6
  import time
7
  from pathlib import Path
8
- from huggingface_hub import hf_hub_download
9
- import os
 
10
  import spaces
 
 
 
 
11
  from transformers import pipeline
12
 
13
- # Import the inference module
14
  from infer import DMOInference
15
 
16
- # Global variables
17
- model = None
18
- asr_pipe = None
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # Initialize ASR pipeline
22
- def initialize_asr_pipeline(device=device, dtype=None):
23
- """Initialize the ASR pipeline on startup."""
24
- global asr_pipe
25
-
26
- if dtype is None:
27
- dtype = (
28
- torch.float16
29
- if "cuda" in device
30
- and torch.cuda.is_available()
31
- and torch.cuda.get_device_properties(device).major >= 7
32
- and not torch.cuda.get_device_name().endswith("[ZLUDA]")
33
- else torch.float32
34
- )
35
-
36
- print("Initializing ASR pipeline...")
37
- try:
38
- asr_pipe = pipeline(
39
- "automatic-speech-recognition",
40
- model="openai/whisper-large-v3-turbo",
41
- torch_dtype=dtype,
42
- device="cpu" # Keep ASR on CPU to save GPU memory
43
- )
44
- print("ASR pipeline initialized successfully")
45
- except Exception as e:
46
- print(f"Error initializing ASR pipeline: {e}")
47
- asr_pipe = None
48
-
49
- # Transcribe function
50
  def transcribe(ref_audio, language=None):
51
  """Transcribe audio using the pre-loaded ASR pipeline."""
52
- global asr_pipe
53
-
54
- if asr_pipe is None:
55
- return "" # Return empty string if ASR is not available
56
-
57
- try:
58
- result = asr_pipe(
59
- ref_audio,
60
- chunk_length_s=30,
61
- batch_size=128,
62
- generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
63
- return_timestamps=False,
64
- )
65
- return result["text"].strip()
66
- except Exception as e:
67
- print(f"Transcription error: {e}")
68
- return ""
69
-
70
- def download_models():
71
- """Download models from HuggingFace Hub."""
72
- try:
73
- print("Downloading models from HuggingFace...")
74
-
75
- # Download student model
76
- student_path = hf_hub_download(
77
- repo_id="yl4579/DMOSpeech2",
78
- filename="model_85000.pt",
79
- cache_dir="./models"
80
- )
81
-
82
- # Download duration predictor
83
- duration_path = hf_hub_download(
84
- repo_id="yl4579/DMOSpeech2",
85
- filename="model_1500.pt",
86
- cache_dir="./models"
87
- )
88
-
89
- print(f"Student model: {student_path}")
90
- print(f"Duration model: {duration_path}")
91
-
92
- return student_path, duration_path
93
-
94
- except Exception as e:
95
- print(f"Error downloading models: {e}")
96
- return None, None
97
-
98
- def initialize_model():
99
- """Initialize the model on startup."""
100
- global model
101
-
102
- try:
103
- # Download models
104
- student_path, duration_path = download_models()
105
-
106
- if not student_path or not duration_path:
107
- return False, "Failed to download models from HuggingFace"
108
-
109
- # Initialize model
110
- model = DMOInference(
111
- student_checkpoint_path=student_path,
112
- duration_predictor_path=duration_path,
113
- device=device,
114
- model_type="F5TTS_Base"
115
- )
116
-
117
- return True, f"Model loaded successfully on {device.upper()}"
118
-
119
- except Exception as e:
120
- return False, f"Error initializing model: {str(e)}"
121
-
122
- # Initialize models on startup
123
- print("Initializing models...")
124
- model_loaded, status_message = initialize_model()
125
- initialize_asr_pipeline() # Initialize ASR pipeline
126
-
127
- @spaces.GPU(duration=120) # Request GPU for up to 120 seconds
128
  def generate_speech(
129
  prompt_audio,
130
  prompt_text,
@@ -134,128 +55,115 @@ def generate_speech(
134
  custom_teacher_steps,
135
  custom_teacher_stopping_time,
136
  custom_student_start_step,
137
- verbose
138
  ):
139
- """Generate speech with different configurations."""
140
-
141
- if not model_loaded or model is None:
142
- return None, "Model not loaded! Please refresh the page.", "", ""
143
-
144
  if prompt_audio is None:
145
- return None, "Please upload a reference audio!", "", ""
146
-
147
  if not target_text:
148
- return None, "Please enter text to generate!", "", ""
149
-
150
- try:
151
- # Auto-transcribe if prompt_text is empty
152
- if not prompt_text and prompt_text != "":
153
- print("Auto-transcribing reference audio...")
154
- prompt_text = transcribe(prompt_audio)
155
- print(f"Transcribed: {prompt_text}")
156
-
157
- start_time = time.time()
158
-
159
- # Configure parameters based on mode
160
- if mode == "Student Only (4 steps)":
161
- teacher_steps = 0
162
- student_start_step = 0
163
- teacher_stopping_time = 1.0
164
- elif mode == "Teacher-Guided (8 steps)":
165
- # Default configuration from the notebook
166
- teacher_steps = 16
167
- teacher_stopping_time = 0.07
168
- student_start_step = 1
169
- elif mode == "High Diversity (16 steps)":
170
- teacher_steps = 24
171
- teacher_stopping_time = 0.3
172
- student_start_step = 2
173
- else: # Custom
174
- teacher_steps = custom_teacher_steps
175
- teacher_stopping_time = custom_teacher_stopping_time
176
- student_start_step = custom_student_start_step
177
-
178
- # Generate speech
179
- generated_audio = model.generate(
180
- gen_text=target_text,
181
- audio_path=prompt_audio,
182
- prompt_text=prompt_text if prompt_text else None,
183
- teacher_steps=teacher_steps,
184
- teacher_stopping_time=teacher_stopping_time,
185
- student_start_step=student_start_step,
186
- temperature=temperature,
187
- verbose=verbose
188
- )
189
-
190
- end_time = time.time()
191
-
192
- # Calculate metrics
193
- processing_time = end_time - start_time
194
- audio_duration = generated_audio.shape[-1] / 24000
195
- rtf = processing_time / audio_duration
196
-
197
- # Save audio
198
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
199
- output_path = tmp_file.name
200
-
201
- if isinstance(generated_audio, np.ndarray):
202
- generated_audio = torch.from_numpy(generated_audio)
203
-
204
- if generated_audio.dim() == 1:
205
- generated_audio = generated_audio.unsqueeze(0)
206
-
207
- torchaudio.save(output_path, generated_audio, 24000)
208
-
209
- # Format metrics
210
- metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio"
211
-
212
- return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}"
213
-
214
- except Exception as e:
215
- return None, f"Error: {str(e)}", "", ""
216
 
217
  # Create Gradio interface
218
  with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
219
- gr.Markdown(f"""
 
220
  # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
221
 
222
  Generate natural speech in any voice with just a short reference audio!
223
- """)
224
-
 
225
  with gr.Row():
226
  with gr.Column(scale=1):
227
  # Reference audio input
228
  prompt_audio = gr.Audio(
229
  label="📎 Reference Audio",
230
  type="filepath",
231
- sources=["upload", "microphone"]
232
  )
233
-
234
  prompt_text = gr.Textbox(
235
  label="📝 Reference Text (leave empty for auto-transcription)",
236
  placeholder="The text spoken in the reference audio...",
237
- lines=2
238
  )
239
-
240
  target_text = gr.Textbox(
241
  label="✍️ Text to Generate",
242
  placeholder="Enter the text you want to synthesize...",
243
- lines=4
244
  )
245
-
246
  # Generation mode
247
  mode = gr.Radio(
248
  choices=[
249
  "Student Only (4 steps)",
250
  "Teacher-Guided (8 steps)",
251
  "High Diversity (16 steps)",
252
- "Custom"
253
  ],
254
  value="Teacher-Guided (8 steps)",
255
  label="🚀 Generation Mode",
256
- info="Choose speed vs quality/diversity tradeoff"
257
  )
258
-
259
  # Advanced settings (collapsible)
260
  with gr.Accordion("⚙️ Advanced Settings", open=False):
261
  temperature = gr.Slider(
@@ -264,9 +172,9 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
264
  value=0.0,
265
  step=0.1,
266
  label="Duration Temperature",
267
- info="0 = deterministic, >0 = more variation in speech rhythm"
268
  )
269
-
270
  with gr.Group(visible=False) as custom_settings:
271
  gr.Markdown("### Custom Mode Settings")
272
  custom_teacher_steps = gr.Slider(
@@ -275,60 +183,50 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
275
  value=16,
276
  step=1,
277
  label="Teacher Steps",
278
- info="More steps = higher quality"
279
  )
280
-
281
  custom_teacher_stopping_time = gr.Slider(
282
  minimum=0.0,
283
  maximum=1.0,
284
  value=0.07,
285
  step=0.01,
286
  label="Teacher Stopping Time",
287
- info="When to switch to student"
288
  )
289
-
290
  custom_student_start_step = gr.Slider(
291
  minimum=0,
292
  maximum=4,
293
  value=1,
294
  step=1,
295
  label="Student Start Step",
296
- info="Which student step to start from"
297
  )
298
-
299
  verbose = gr.Checkbox(
300
  value=False,
301
  label="Verbose Output",
302
- info="Show detailed generation steps"
303
  )
304
-
305
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
306
-
307
  with gr.Column(scale=1):
308
  # Output
309
  output_audio = gr.Audio(
310
- label="🔊 Generated Speech",
311
- type="filepath",
312
- autoplay=True
313
  )
314
-
315
- status = gr.Textbox(
316
- label="Status",
317
- interactive=False
318
- )
319
-
320
- metrics = gr.Textbox(
321
- label="Performance Metrics",
322
- interactive=False
323
- )
324
-
325
- info = gr.Textbox(
326
- label="Generation Info",
327
- interactive=False
328
- )
329
-
330
  # Tips
331
- gr.Markdown("""
 
332
  ### 💡 Quick Tips:
333
 
334
  - **Auto-transcription**: Leave reference text empty to auto-transcribe
@@ -341,8 +239,9 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
341
  - Student Only: ~0.05x (20x faster than real-time)
342
  - Teacher-Guided: ~0.10x (10x faster)
343
  - High Diversity: ~0.20x (5x faster)
344
- """)
345
-
 
346
  # Event handler
347
  generate_btn.click(
348
  generate_speech,
@@ -355,21 +254,17 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
355
  custom_teacher_steps,
356
  custom_teacher_stopping_time,
357
  custom_student_start_step,
358
- verbose
359
  ],
360
- outputs=[output_audio, status, metrics, info]
361
  )
362
-
363
  # Update visibility of custom settings based on mode
364
  def update_custom_visibility(mode):
365
- is_custom = (mode == "Custom")
366
  return gr.update(visible=is_custom)
367
-
368
- mode.change(
369
- update_custom_visibility,
370
- inputs=[mode],
371
- outputs=[custom_settings]
372
- )
373
 
374
  # Launch the app
375
  if __name__ == "__main__":
@@ -377,5 +272,5 @@ if __name__ == "__main__":
377
  print(f"Warning: Model failed to load - {status_message}")
378
  if not asr_pipe:
379
  print("Warning: ASR pipeline not available - auto-transcription disabled")
380
-
381
- demo.launch()
 
1
+ ## IMPORTS ##
2
+ import os
 
 
3
  import tempfile
4
  import time
5
  from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
  import spaces
10
+ import torch
11
+ import torchaudio
12
+ from cached_path import cached_path
13
+ from huggingface_hub import hf_hub_download
14
  from transformers import pipeline
15
 
 
16
  from infer import DMOInference
17
 
18
+ ## CUDA DEVICE ##
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ ## LOAD MODELS ##
22
+ asr_pipe = pipeline(
23
+ "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", device=device
24
+ )
25
+ model = DMOInference(
26
+ student_checkpoint_path=str(cached_path("hf://yl4579/DMOSpeech2/model_85000.pt")),
27
+ duration_predictor_path=str(cached_path("hf://yl4579/DMOSpeech2/model_1500.pt")),
28
+ device=device,
29
+ model_type="F5TTS_Base",
30
+ )
31
+
32
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def transcribe(ref_audio, language=None):
34
  """Transcribe audio using the pre-loaded ASR pipeline."""
35
+ return asr_pipe(
36
+ ref_audio,
37
+ chunk_length_s=30,
38
+ batch_size=128,
39
+ generate_kwargs=(
40
+ {"task": "transcribe", "language": language}
41
+ if language
42
+ else {"task": "transcribe"}
43
+ ),
44
+ return_timestamps=False,
45
+ )["text"].strip()
46
+
47
+
48
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def generate_speech(
50
  prompt_audio,
51
  prompt_text,
 
55
  custom_teacher_steps,
56
  custom_teacher_stopping_time,
57
  custom_student_start_step,
58
+ verbose,
59
  ):
 
 
 
 
 
60
  if prompt_audio is None:
61
+ raise gr.Error("Please upload a reference audio!")
62
+
63
  if not target_text:
64
+ raise gr.Error("Please enter text to generate!")
65
+
66
+ if not prompt_text and prompt_text != "":
67
+ prompt_text = transcribe(prompt_audio)
68
+
69
+
70
+ if mode == "Student Only (4 steps)":
71
+ teacher_steps = 0
72
+ student_start_step = 0
73
+ teacher_stopping_time = 1.0
74
+ elif mode == "Teacher-Guided (8 steps)":
75
+ teacher_steps = 16
76
+ teacher_stopping_time = 0.07
77
+ student_start_step = 1
78
+ elif mode == "High Diversity (16 steps)":
79
+ teacher_steps = 24
80
+ teacher_stopping_time = 0.3
81
+ student_start_step = 2
82
+ else: # Custom
83
+ teacher_steps = custom_teacher_steps
84
+ teacher_stopping_time = custom_teacher_stopping_time
85
+ student_start_step = custom_student_start_step
86
+
87
+ # Generate speech
88
+ generated_audio = model.generate(
89
+ gen_text=target_text,
90
+ audio_path=prompt_audio,
91
+ prompt_text=prompt_text if prompt_text else None,
92
+ teacher_steps=teacher_steps,
93
+ teacher_stopping_time=teacher_stopping_time,
94
+ student_start_step=student_start_step,
95
+ temperature=temperature,
96
+ verbose=verbose,
97
+ )
98
+
99
+
100
+ # Save audio
101
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
102
+ output_path = tmp_file.name
103
+
104
+ if isinstance(generated_audio, np.ndarray):
105
+ generated_audio = torch.from_numpy(generated_audio)
106
+
107
+ if generated_audio.dim() == 1:
108
+ generated_audio = generated_audio.unsqueeze(0)
109
+
110
+ torchaudio.save(output_path, generated_audio, 24000)
111
+
112
+ return (
113
+ output_path,
114
+ "Success!",
115
+ (
116
+ f"Mode: {mode} | Transcribed: {prompt_text[:50]}..."
117
+ if not prompt_text
118
+ else f"Mode: {mode}"
119
+ ),
120
+ )
121
+
 
 
 
 
 
 
 
 
 
 
122
 
123
  # Create Gradio interface
124
  with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
125
+ gr.Markdown(
126
+ f"""
127
  # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
128
 
129
  Generate natural speech in any voice with just a short reference audio!
130
+ """
131
+ )
132
+
133
  with gr.Row():
134
  with gr.Column(scale=1):
135
  # Reference audio input
136
  prompt_audio = gr.Audio(
137
  label="📎 Reference Audio",
138
  type="filepath",
139
+ sources=["upload", "microphone"],
140
  )
141
+
142
  prompt_text = gr.Textbox(
143
  label="📝 Reference Text (leave empty for auto-transcription)",
144
  placeholder="The text spoken in the reference audio...",
145
+ lines=2,
146
  )
147
+
148
  target_text = gr.Textbox(
149
  label="✍️ Text to Generate",
150
  placeholder="Enter the text you want to synthesize...",
151
+ lines=4,
152
  )
153
+
154
  # Generation mode
155
  mode = gr.Radio(
156
  choices=[
157
  "Student Only (4 steps)",
158
  "Teacher-Guided (8 steps)",
159
  "High Diversity (16 steps)",
160
+ "Custom",
161
  ],
162
  value="Teacher-Guided (8 steps)",
163
  label="🚀 Generation Mode",
164
+ info="Choose speed vs quality/diversity tradeoff",
165
  )
166
+
167
  # Advanced settings (collapsible)
168
  with gr.Accordion("⚙️ Advanced Settings", open=False):
169
  temperature = gr.Slider(
 
172
  value=0.0,
173
  step=0.1,
174
  label="Duration Temperature",
175
+ info="0 = deterministic, >0 = more variation in speech rhythm",
176
  )
177
+
178
  with gr.Group(visible=False) as custom_settings:
179
  gr.Markdown("### Custom Mode Settings")
180
  custom_teacher_steps = gr.Slider(
 
183
  value=16,
184
  step=1,
185
  label="Teacher Steps",
186
+ info="More steps = higher quality",
187
  )
188
+
189
  custom_teacher_stopping_time = gr.Slider(
190
  minimum=0.0,
191
  maximum=1.0,
192
  value=0.07,
193
  step=0.01,
194
  label="Teacher Stopping Time",
195
+ info="When to switch to student",
196
  )
197
+
198
  custom_student_start_step = gr.Slider(
199
  minimum=0,
200
  maximum=4,
201
  value=1,
202
  step=1,
203
  label="Student Start Step",
204
+ info="Which student step to start from",
205
  )
206
+
207
  verbose = gr.Checkbox(
208
  value=False,
209
  label="Verbose Output",
210
+ info="Show detailed generation steps",
211
  )
212
+
213
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
214
+
215
  with gr.Column(scale=1):
216
  # Output
217
  output_audio = gr.Audio(
218
+ label="🔊 Generated Speech", type="filepath", autoplay=True
 
 
219
  )
220
+
221
+ status = gr.Textbox(label="Status", interactive=False)
222
+
223
+ metrics = gr.Textbox(label="Performance Metrics", interactive=False)
224
+
225
+ info = gr.Textbox(label="Generation Info", interactive=False)
226
+
 
 
 
 
 
 
 
 
 
227
  # Tips
228
+ gr.Markdown(
229
+ """
230
  ### 💡 Quick Tips:
231
 
232
  - **Auto-transcription**: Leave reference text empty to auto-transcribe
 
239
  - Student Only: ~0.05x (20x faster than real-time)
240
  - Teacher-Guided: ~0.10x (10x faster)
241
  - High Diversity: ~0.20x (5x faster)
242
+ """
243
+ )
244
+
245
  # Event handler
246
  generate_btn.click(
247
  generate_speech,
 
254
  custom_teacher_steps,
255
  custom_teacher_stopping_time,
256
  custom_student_start_step,
257
+ verbose,
258
  ],
259
+ outputs=[output_audio, status, metrics, info],
260
  )
261
+
262
  # Update visibility of custom settings based on mode
263
  def update_custom_visibility(mode):
264
+ is_custom = mode == "Custom"
265
  return gr.update(visible=is_custom)
266
+
267
+ mode.change(update_custom_visibility, inputs=[mode], outputs=[custom_settings])
 
 
 
 
268
 
269
  # Launch the app
270
  if __name__ == "__main__":
 
272
  print(f"Warning: Model failed to load - {status_message}")
273
  if not asr_pipe:
274
  print("Warning: ASR pipeline not available - auto-transcription disabled")
275
+
276
+ demo.launch()
ctcmodel.py CHANGED
@@ -1,36 +1,24 @@
1
- from torch import nn
2
- import torch
3
  import copy
4
-
5
  from pathlib import Path
6
- from torchaudio.models import Conformer
7
-
8
 
9
- from f5_tts.model.utils import default
10
- from f5_tts.model.utils import exists
11
- from f5_tts.model.utils import list_str_to_idx
12
- from f5_tts.model.utils import list_str_to_tensor
13
- from f5_tts.model.utils import lens_to_mask
14
- from f5_tts.model.utils import mask_from_frac_lengths
15
 
 
 
16
 
17
- from f5_tts.model.utils import (
18
- default,
19
- exists,
20
- list_str_to_idx,
21
- list_str_to_tensor,
22
- lens_to_mask,
23
- mask_from_frac_lengths,
24
- )
25
 
26
  class ResBlock(nn.Module):
27
  def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
28
  super().__init__()
29
  self._n_groups = 8
30
- self.blocks = nn.ModuleList([
31
- self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p)
32
- for i in range(n_conv)])
33
-
 
 
34
 
35
  def forward(self, x):
36
  for block in self.blocks:
@@ -41,70 +29,71 @@ class ResBlock(nn.Module):
41
 
42
  def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
43
  layers = [
44
- nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
 
 
 
 
 
 
45
  nn.ReLU(),
46
  nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
47
  nn.Dropout(p=dropout_p),
48
  nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
49
  nn.ReLU(),
50
- nn.Dropout(p=dropout_p)
51
  ]
52
  return nn.Sequential(*layers)
53
 
54
 
55
  class ConformerCTC(nn.Module):
56
- def __init__(self,
57
- vocab_size,
58
- mel_dim=100,
59
- num_heads=8,
60
- d_hid=512,
61
- nlayers=6):
62
  super().__init__()
63
-
64
  self.mel_proj = nn.Conv1d(mel_dim, d_hid, kernel_size=3, padding=1)
65
-
66
  self.d_hid = d_hid
67
-
68
  self.resblock1 = nn.Sequential(
69
- ResBlock(d_hid),
70
- nn.GroupNorm(num_groups=1, num_channels=d_hid)
71
- )
72
-
73
  self.resblock2 = nn.Sequential(
74
- ResBlock(d_hid),
75
- nn.GroupNorm(num_groups=1, num_channels=d_hid)
76
- )
77
-
78
 
79
  self.conf_pre = torch.nn.ModuleList(
80
- [Conformer(
81
- input_dim=d_hid,
82
- num_heads=num_heads,
83
- ffn_dim=d_hid * 2,
84
- num_layers=1,
85
- depthwise_conv_kernel_size=15,
86
- use_group_norm=True,)
 
 
87
  for _ in range(nlayers // 2)
88
  ]
89
  )
90
-
91
  self.conf_after = torch.nn.ModuleList(
92
- [Conformer(
93
- input_dim=d_hid,
94
- num_heads=num_heads,
95
- ffn_dim=d_hid * 2,
96
- num_layers=1,
97
- depthwise_conv_kernel_size=7,
98
- use_group_norm=True,)
 
 
99
  for _ in range(nlayers // 2)
100
  ]
101
  )
102
 
103
- self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank
104
 
105
  self.ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True).cuda()
106
 
107
-
108
  def forward(self, latent, text=None, text_lens=None):
109
  layers = []
110
 
@@ -125,20 +114,24 @@ class ConformerCTC(nn.Module):
125
 
126
  batch_size, time_steps, _ = x.shape
127
  # Create a dummy lengths tensor (all sequences are assumed to be full length).
128
- input_lengths = torch.full((batch_size,), time_steps, device=x.device, dtype=torch.int64)
 
 
129
 
130
- for layer in (self.conf_pre):
131
  x, _ = layer(x, input_lengths)
132
  layers.append(x.transpose(1, 2))
133
 
134
- for layer in (self.conf_after):
135
  x, _ = layer(x, input_lengths)
136
  layers.append(x.transpose(1, 2))
137
 
138
  x = self.out(x)
139
 
140
  if text_lens is not None and text is not None:
141
- loss = self.ctc_loss(x.log_softmax(dim=2).transpose(0, 1), text, input_lengths, text_lens)
 
 
142
  return x, layers, loss
143
  else:
144
  return x, layers
@@ -147,9 +140,8 @@ class ConformerCTC(nn.Module):
147
  if __name__ == "__main__":
148
  from f5_tts.model.utils import get_tokenizer
149
 
150
-
151
  bsz = 16
152
-
153
  tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
154
  tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
155
  dataset_name = "Emilia_ZH_EN"
@@ -158,15 +150,17 @@ if __name__ == "__main__":
158
  else:
159
  tokenizer_path = dataset_name
160
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
161
-
162
- model = ConformerCTC(vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6).cuda()
163
-
 
 
164
  text = ["hello world"] * bsz
165
  lens = torch.randint(1, 1000, (bsz,)).cuda()
166
  inp = torch.randn(bsz, lens.max(), 80).cuda()
167
-
168
  batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
169
-
170
  # handle text as string
171
  text_lens = torch.tensor([len(t) for t in text], device=device)
172
  if isinstance(text, list):
@@ -198,7 +192,6 @@ if __name__ == "__main__":
198
 
199
  char_vocab_map = list(vocab_char_map.keys())
200
 
201
-
202
  for batch in best_path:
203
  decoded_sequence = []
204
  previous_token = None
@@ -212,10 +205,15 @@ if __name__ == "__main__":
212
  decoded_sequences.append(decoded_sequence)
213
 
214
  # Convert token indices to characters
215
- decoded_texts = [''.join([char_vocab_map[token] for token in sequence]) for sequence in decoded_sequences]
 
 
 
216
  gt_texts = []
217
  for i in range(text_lens.size(0)):
218
- gt_texts.append(''.join([char_vocab_map[token] for token in text[i, :text_lens[i]]]))
219
-
 
 
220
  print(decoded_texts)
221
- print(gt_texts)
 
 
 
1
  import copy
 
2
  from pathlib import Path
 
 
3
 
4
+ import torch
5
+ from torch import nn
6
+ from torchaudio.models import Conformer
 
 
 
7
 
8
+ from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
9
+ list_str_to_tensor, mask_from_frac_lengths)
10
 
 
 
 
 
 
 
 
 
11
 
12
  class ResBlock(nn.Module):
13
  def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
14
  super().__init__()
15
  self._n_groups = 8
16
+ self.blocks = nn.ModuleList(
17
+ [
18
+ self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p)
19
+ for i in range(n_conv)
20
+ ]
21
+ )
22
 
23
  def forward(self, x):
24
  for block in self.blocks:
 
29
 
30
  def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
31
  layers = [
32
+ nn.Conv1d(
33
+ hidden_dim,
34
+ hidden_dim,
35
+ kernel_size=3,
36
+ padding=dilation,
37
+ dilation=dilation,
38
+ ),
39
  nn.ReLU(),
40
  nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
41
  nn.Dropout(p=dropout_p),
42
  nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
43
  nn.ReLU(),
44
+ nn.Dropout(p=dropout_p),
45
  ]
46
  return nn.Sequential(*layers)
47
 
48
 
49
  class ConformerCTC(nn.Module):
50
+ def __init__(self, vocab_size, mel_dim=100, num_heads=8, d_hid=512, nlayers=6):
 
 
 
 
 
51
  super().__init__()
52
+
53
  self.mel_proj = nn.Conv1d(mel_dim, d_hid, kernel_size=3, padding=1)
54
+
55
  self.d_hid = d_hid
56
+
57
  self.resblock1 = nn.Sequential(
58
+ ResBlock(d_hid), nn.GroupNorm(num_groups=1, num_channels=d_hid)
59
+ )
60
+
 
61
  self.resblock2 = nn.Sequential(
62
+ ResBlock(d_hid), nn.GroupNorm(num_groups=1, num_channels=d_hid)
63
+ )
 
 
64
 
65
  self.conf_pre = torch.nn.ModuleList(
66
+ [
67
+ Conformer(
68
+ input_dim=d_hid,
69
+ num_heads=num_heads,
70
+ ffn_dim=d_hid * 2,
71
+ num_layers=1,
72
+ depthwise_conv_kernel_size=15,
73
+ use_group_norm=True,
74
+ )
75
  for _ in range(nlayers // 2)
76
  ]
77
  )
78
+
79
  self.conf_after = torch.nn.ModuleList(
80
+ [
81
+ Conformer(
82
+ input_dim=d_hid,
83
+ num_heads=num_heads,
84
+ ffn_dim=d_hid * 2,
85
+ num_layers=1,
86
+ depthwise_conv_kernel_size=7,
87
+ use_group_norm=True,
88
+ )
89
  for _ in range(nlayers // 2)
90
  ]
91
  )
92
 
93
+ self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank
94
 
95
  self.ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True).cuda()
96
 
 
97
  def forward(self, latent, text=None, text_lens=None):
98
  layers = []
99
 
 
114
 
115
  batch_size, time_steps, _ = x.shape
116
  # Create a dummy lengths tensor (all sequences are assumed to be full length).
117
+ input_lengths = torch.full(
118
+ (batch_size,), time_steps, device=x.device, dtype=torch.int64
119
+ )
120
 
121
+ for layer in self.conf_pre:
122
  x, _ = layer(x, input_lengths)
123
  layers.append(x.transpose(1, 2))
124
 
125
+ for layer in self.conf_after:
126
  x, _ = layer(x, input_lengths)
127
  layers.append(x.transpose(1, 2))
128
 
129
  x = self.out(x)
130
 
131
  if text_lens is not None and text is not None:
132
+ loss = self.ctc_loss(
133
+ x.log_softmax(dim=2).transpose(0, 1), text, input_lengths, text_lens
134
+ )
135
  return x, layers, loss
136
  else:
137
  return x, layers
 
140
  if __name__ == "__main__":
141
  from f5_tts.model.utils import get_tokenizer
142
 
 
143
  bsz = 16
144
+
145
  tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
146
  tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
147
  dataset_name = "Emilia_ZH_EN"
 
150
  else:
151
  tokenizer_path = dataset_name
152
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
153
+
154
+ model = ConformerCTC(
155
+ vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6
156
+ ).cuda()
157
+
158
  text = ["hello world"] * bsz
159
  lens = torch.randint(1, 1000, (bsz,)).cuda()
160
  inp = torch.randn(bsz, lens.max(), 80).cuda()
161
+
162
  batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
163
+
164
  # handle text as string
165
  text_lens = torch.tensor([len(t) for t in text], device=device)
166
  if isinstance(text, list):
 
192
 
193
  char_vocab_map = list(vocab_char_map.keys())
194
 
 
195
  for batch in best_path:
196
  decoded_sequence = []
197
  previous_token = None
 
205
  decoded_sequences.append(decoded_sequence)
206
 
207
  # Convert token indices to characters
208
+ decoded_texts = [
209
+ "".join([char_vocab_map[token] for token in sequence])
210
+ for sequence in decoded_sequences
211
+ ]
212
  gt_texts = []
213
  for i in range(text_lens.size(0)):
214
+ gt_texts.append(
215
+ "".join([char_vocab_map[token] for token in text[i, : text_lens[i]]])
216
+ )
217
+
218
  print(decoded_texts)
219
+ print(gt_texts)
discriminator_conformer.py CHANGED
@@ -2,30 +2,28 @@
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  import torchaudio.transforms as trans
9
- from pathlib import Path
10
  from torchaudio.models import Conformer
11
 
12
- from f5_tts.model.utils import (
13
- default,
14
- exists,
15
- list_str_to_idx,
16
- list_str_to_tensor,
17
- lens_to_mask,
18
- mask_from_frac_lengths,
19
- )
20
 
21
  class ResBlock(nn.Module):
22
  def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
23
  super().__init__()
24
  self._n_groups = 8
25
- self.blocks = nn.ModuleList([
26
- self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p)
27
- for i in range(n_conv)])
28
-
 
 
29
 
30
  def forward(self, x):
31
  for block in self.blocks:
@@ -36,46 +34,67 @@ class ResBlock(nn.Module):
36
 
37
  def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
38
  layers = [
39
- nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
 
 
 
 
 
 
40
  nn.ReLU(),
41
  nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
42
  nn.Dropout(p=dropout_p),
43
  nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
44
  nn.ReLU(),
45
- nn.Dropout(p=dropout_p)
46
  ]
47
  return nn.Sequential(*layers)
48
 
 
49
  class ConformerDiscirminator(nn.Module):
50
- def __init__(self, input_dim, channels=512, num_layers=3, num_heads=8, depthwise_conv_kernel_size=15, use_group_norm=True):
 
 
 
 
 
 
 
 
51
  super().__init__()
52
-
53
  self.input_layer = nn.Conv1d(input_dim, channels, kernel_size=3, padding=1)
54
 
55
  self.resblock1 = nn.Sequential(
56
- ResBlock(channels),
57
- nn.GroupNorm(num_groups=1, num_channels=channels)
58
- )
59
-
60
  self.resblock2 = nn.Sequential(
61
- ResBlock(channels),
62
- nn.GroupNorm(num_groups=1, num_channels=channels)
63
- )
64
-
65
- self.conformer1 = Conformer(**{"input_dim": channels,
66
- "num_heads": num_heads,
67
- "ffn_dim": channels * 2,
68
- "num_layers": 1,
 
69
  "depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2,
70
- "use_group_norm": use_group_norm})
71
-
72
- self.conformer2 = Conformer(**{"input_dim": channels,
73
- "num_heads": num_heads,
74
- "ffn_dim": channels * 2,
75
- "num_layers": num_layers - 1,
 
 
 
 
76
  "depthwise_conv_kernel_size": depthwise_conv_kernel_size,
77
- "use_group_norm": use_group_norm})
78
-
 
 
79
  self.linear = nn.Conv1d(channels, 1, kernel_size=1)
80
 
81
  def forward(self, x):
@@ -89,12 +108,14 @@ class ConformerDiscirminator(nn.Module):
89
  x = nn.functional.avg_pool1d(x, 2)
90
  x = self.resblock2(x)
91
  x = nn.functional.avg_pool1d(x, 2)
92
-
93
  # Transpose to (B, T, C) for the conformer.
94
  x = x.transpose(1, 2)
95
  batch_size, time_steps, _ = x.shape
96
  # Create a dummy lengths tensor (all sequences are assumed to be full length).
97
- lengths = torch.full((batch_size,), time_steps, device=x.device, dtype=torch.int64)
 
 
98
  # The built-in Conformer returns (output, output_lengths); we discard lengths.
99
 
100
  x, _ = self.conformer1(x, lengths)
@@ -107,12 +128,13 @@ class ConformerDiscirminator(nn.Module):
107
 
108
  return out
109
 
 
110
  if __name__ == "__main__":
111
- from f5_tts.model.utils import get_tokenizer
112
  from f5_tts.model import DiT
 
113
 
114
  bsz = 2
115
-
116
  tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
117
  tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
118
  dataset_name = "Emilia_ZH_EN"
@@ -121,20 +143,28 @@ if __name__ == "__main__":
121
  else:
122
  tokenizer_path = dataset_name
123
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
124
-
125
-
126
- fake_unet = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=80)
 
 
 
 
 
 
 
 
127
 
128
  fake_unet = fake_unet.cuda()
129
 
130
  text = ["hello world"] * bsz
131
  lens = torch.randint(1, 1000, (bsz,)).cuda()
132
  inp = torch.randn(bsz, lens.max(), 80).cuda()
133
-
134
  batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
135
 
136
  batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
137
-
138
  # handle text as string
139
  if isinstance(text, list):
140
  if exists(vocab_char_map):
@@ -147,13 +177,17 @@ if __name__ == "__main__":
147
  if not exists(lens):
148
  lens = torch.full((batch,), seq_len, device=device)
149
 
150
- mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
 
 
151
  frac_lengths_mask = (0.7, 1.0)
152
-
153
  # get a random span to mask out for training conditionally
154
- frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask)
 
 
155
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
156
-
157
  if exists(mask):
158
  rand_span_mask &= mask
159
 
@@ -163,38 +197,41 @@ if __name__ == "__main__":
163
  x1 = inp
164
  x0 = torch.randn_like(x1)
165
  t = time.unsqueeze(-1).unsqueeze(-1)
166
-
167
  phi = (1 - t) * x0 + t * x1
168
  flow = x1 - x0
169
  cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
170
 
171
  layers = fake_unet(
172
- x=phi,
173
  cond=cond,
174
- text=text,
175
- time=time,
176
  drop_audio_cond=False,
177
  drop_text=False,
178
- classify_mode=True
179
  )
180
 
181
  # layers = torch.stack(layers, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
182
  # print(layers.shape)
183
 
184
  from ctcmodel import ConformerCTC
185
- ctcmodel = ConformerCTC(vocab_size=vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6).cuda()
 
 
 
186
  real_out, layer = ctcmodel(inp)
187
- layer = layer[-3:] # only use the last 3 layers
188
- layer = [F.interpolate(l, mode='nearest', scale_factor=4).transpose(-1, -2) for l in layer]
 
 
 
189
  if layer[0].size(1) < layers[0].size(1):
190
  layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer]
191
-
192
  layers = layer + layers
193
 
194
- model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512,
195
- channels=512
196
- )
197
-
198
 
199
  model = model.cuda()
200
  print(model)
 
2
 
3
  from __future__ import annotations
4
 
5
+ from pathlib import Path
6
+
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  import torchaudio.transforms as trans
 
11
  from torchaudio.models import Conformer
12
 
13
+ from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
14
+ list_str_to_tensor, mask_from_frac_lengths)
15
+
 
 
 
 
 
16
 
17
  class ResBlock(nn.Module):
18
  def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2):
19
  super().__init__()
20
  self._n_groups = 8
21
+ self.blocks = nn.ModuleList(
22
+ [
23
+ self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p)
24
+ for i in range(n_conv)
25
+ ]
26
+ )
27
 
28
  def forward(self, x):
29
  for block in self.blocks:
 
34
 
35
  def _get_conv(self, hidden_dim, dilation, dropout_p=0.2):
36
  layers = [
37
+ nn.Conv1d(
38
+ hidden_dim,
39
+ hidden_dim,
40
+ kernel_size=3,
41
+ padding=dilation,
42
+ dilation=dilation,
43
+ ),
44
  nn.ReLU(),
45
  nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
46
  nn.Dropout(p=dropout_p),
47
  nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
48
  nn.ReLU(),
49
+ nn.Dropout(p=dropout_p),
50
  ]
51
  return nn.Sequential(*layers)
52
 
53
+
54
  class ConformerDiscirminator(nn.Module):
55
+ def __init__(
56
+ self,
57
+ input_dim,
58
+ channels=512,
59
+ num_layers=3,
60
+ num_heads=8,
61
+ depthwise_conv_kernel_size=15,
62
+ use_group_norm=True,
63
+ ):
64
  super().__init__()
65
+
66
  self.input_layer = nn.Conv1d(input_dim, channels, kernel_size=3, padding=1)
67
 
68
  self.resblock1 = nn.Sequential(
69
+ ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels)
70
+ )
71
+
 
72
  self.resblock2 = nn.Sequential(
73
+ ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels)
74
+ )
75
+
76
+ self.conformer1 = Conformer(
77
+ **{
78
+ "input_dim": channels,
79
+ "num_heads": num_heads,
80
+ "ffn_dim": channels * 2,
81
+ "num_layers": 1,
82
  "depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2,
83
+ "use_group_norm": use_group_norm,
84
+ }
85
+ )
86
+
87
+ self.conformer2 = Conformer(
88
+ **{
89
+ "input_dim": channels,
90
+ "num_heads": num_heads,
91
+ "ffn_dim": channels * 2,
92
+ "num_layers": num_layers - 1,
93
  "depthwise_conv_kernel_size": depthwise_conv_kernel_size,
94
+ "use_group_norm": use_group_norm,
95
+ }
96
+ )
97
+
98
  self.linear = nn.Conv1d(channels, 1, kernel_size=1)
99
 
100
  def forward(self, x):
 
108
  x = nn.functional.avg_pool1d(x, 2)
109
  x = self.resblock2(x)
110
  x = nn.functional.avg_pool1d(x, 2)
111
+
112
  # Transpose to (B, T, C) for the conformer.
113
  x = x.transpose(1, 2)
114
  batch_size, time_steps, _ = x.shape
115
  # Create a dummy lengths tensor (all sequences are assumed to be full length).
116
+ lengths = torch.full(
117
+ (batch_size,), time_steps, device=x.device, dtype=torch.int64
118
+ )
119
  # The built-in Conformer returns (output, output_lengths); we discard lengths.
120
 
121
  x, _ = self.conformer1(x, lengths)
 
128
 
129
  return out
130
 
131
+
132
  if __name__ == "__main__":
 
133
  from f5_tts.model import DiT
134
+ from f5_tts.model.utils import get_tokenizer
135
 
136
  bsz = 2
137
+
138
  tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
139
  tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
140
  dataset_name = "Emilia_ZH_EN"
 
143
  else:
144
  tokenizer_path = dataset_name
145
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
146
+
147
+ fake_unet = DiT(
148
+ dim=1024,
149
+ depth=22,
150
+ heads=16,
151
+ ff_mult=2,
152
+ text_dim=512,
153
+ conv_layers=4,
154
+ text_num_embeds=vocab_size,
155
+ mel_dim=80,
156
+ )
157
 
158
  fake_unet = fake_unet.cuda()
159
 
160
  text = ["hello world"] * bsz
161
  lens = torch.randint(1, 1000, (bsz,)).cuda()
162
  inp = torch.randn(bsz, lens.max(), 80).cuda()
163
+
164
  batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
165
 
166
  batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
167
+
168
  # handle text as string
169
  if isinstance(text, list):
170
  if exists(vocab_char_map):
 
177
  if not exists(lens):
178
  lens = torch.full((batch,), seq_len, device=device)
179
 
180
+ mask = lens_to_mask(
181
+ lens, length=seq_len
182
+ ) # useless here, as collate_fn will pad to max length in batch
183
  frac_lengths_mask = (0.7, 1.0)
184
+
185
  # get a random span to mask out for training conditionally
186
+ frac_lengths = (
187
+ torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask)
188
+ )
189
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
190
+
191
  if exists(mask):
192
  rand_span_mask &= mask
193
 
 
197
  x1 = inp
198
  x0 = torch.randn_like(x1)
199
  t = time.unsqueeze(-1).unsqueeze(-1)
200
+
201
  phi = (1 - t) * x0 + t * x1
202
  flow = x1 - x0
203
  cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
204
 
205
  layers = fake_unet(
206
+ x=phi,
207
  cond=cond,
208
+ text=text,
209
+ time=time,
210
  drop_audio_cond=False,
211
  drop_text=False,
212
+ classify_mode=True,
213
  )
214
 
215
  # layers = torch.stack(layers, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
216
  # print(layers.shape)
217
 
218
  from ctcmodel import ConformerCTC
219
+
220
+ ctcmodel = ConformerCTC(
221
+ vocab_size=vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6
222
+ ).cuda()
223
  real_out, layer = ctcmodel(inp)
224
+ layer = layer[-3:] # only use the last 3 layers
225
+ layer = [
226
+ F.interpolate(l, mode="nearest", scale_factor=4).transpose(-1, -2)
227
+ for l in layer
228
+ ]
229
  if layer[0].size(1) < layers[0].size(1):
230
  layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer]
231
+
232
  layers = layer + layers
233
 
234
+ model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512, channels=512)
 
 
 
235
 
236
  model = model.cuda()
237
  print(model)
dmd_trainer.py CHANGED
@@ -1,28 +1,26 @@
1
  from __future__ import annotations
2
 
3
- import os
4
  import gc
5
- from tqdm import tqdm
6
- import wandb
7
 
8
  import torch
9
  import torch.nn as nn
10
- from torch.optim import AdamW
11
- from torch.utils.data import DataLoader, Dataset, SequentialSampler
12
- from torch.optim.lr_scheduler import LinearLR, SequentialLR
13
-
14
  from accelerate import Accelerator
15
  from accelerate.utils import DistributedDataParallelKwargs
 
 
 
 
16
 
17
- from unimodel import UniModel
18
  from f5_tts.model import CFM
19
- from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
-
 
22
 
23
  # trainer
24
 
25
- import math
26
 
27
  class RunningStats:
28
  def __init__(self):
@@ -41,7 +39,7 @@ class RunningStats:
41
  @property
42
  def variance(self):
43
  """Return the sample variance. Returns NaN if fewer than two samples."""
44
- return self.M2 / (self.count - 1) if self.count > 1 else float('nan')
45
 
46
  @property
47
  def std(self):
@@ -49,7 +47,6 @@ class RunningStats:
49
  return math.sqrt(self.variance)
50
 
51
 
52
-
53
  class Trainer:
54
  def __init__(
55
  self,
@@ -74,7 +71,6 @@ class Trainer:
74
  accelerate_kwargs: dict = dict(),
75
  bnb_optimizer: bool = False,
76
  scale: float = 1.0,
77
-
78
  # training parameters for DMDSpeech
79
  num_student_step: int = 1,
80
  gen_update_ratio: int = 5,
@@ -82,7 +78,6 @@ class Trainer:
82
  lambda_generator_loss: float = 1.0,
83
  lambda_ctc_loss: float = 1.0,
84
  lambda_sim_loss: float = 1.0,
85
-
86
  num_GAN: int = 5000,
87
  num_D: int = 500,
88
  num_ctc: int = 5000,
@@ -103,7 +98,13 @@ class Trainer:
103
 
104
  if logger == "wandb":
105
  if exists(wandb_resume_id):
106
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
 
 
 
 
 
 
107
  else:
108
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
109
  self.accelerator.init_trackers(
@@ -130,7 +131,9 @@ class Trainer:
130
  self.epochs = epochs
131
  self.num_warmup_updates = num_warmup_updates
132
  self.save_per_updates = save_per_updates
133
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
 
 
134
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
135
 
136
  self.batch_size = batch_size
@@ -142,41 +145,56 @@ class Trainer:
142
  self.noise_scheduler = noise_scheduler
143
 
144
  self.duration_predictor = duration_predictor
145
-
146
  self.log_step = log_step
147
 
148
- self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update
149
- self.lambda_discriminator_loss = lambda_discriminator_loss # weight for discriminator loss (L_adv)
150
- self.lambda_generator_loss = lambda_generator_loss # weight for generator loss (L_adv)
151
- self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss
152
- self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss
153
-
 
 
 
 
154
  # create distillation schedule for student model
155
- self.student_steps = (
156
- torch.linspace(0.0, 1.0, num_student_step + 1)[:-1])
157
-
158
- self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training
159
- self.num_GAN = num_GAN # number of steps before adversarial training
160
- self.num_D = num_D # number of steps to train the discriminator before adversarial training
161
- self.num_ctc = num_ctc # number of steps before CTC training
162
- self.num_sim = num_sim # number of steps before similarity training
163
- self.num_simu = num_simu # number of steps before using simulated data
164
 
165
  # Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible
166
  if bnb_optimizer:
167
  import bitsandbytes as bnb
168
- self.optimizer_generator = bnb.optim.AdamW8bit(self.model.feedforward_model.parameters(), lr=learning_rate)
169
- self.optimizer_guidance = bnb.optim.AdamW8bit(self.model.guidance_model.parameters(), lr=learning_rate)
 
 
 
 
 
170
  else:
171
- self.optimizer_generator = AdamW(self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7)
172
- self.optimizer_guidance = AdamW(self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7)
 
 
 
 
173
 
174
- self.model, self.optimizer_generator, self.optimizer_guidance = self.accelerator.prepare(self.model, self.optimizer_generator, self.optimizer_guidance)
 
 
 
 
175
 
176
  self.generator_norm = RunningStats()
177
  self.guidance_norm = RunningStats()
178
 
179
-
180
  @property
181
  def is_main(self):
182
  return self.accelerator.is_main_process
@@ -186,8 +204,12 @@ class Trainer:
186
  if self.is_main:
187
  checkpoint = dict(
188
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
189
- optimizer_generator_state_dict=self.accelerator.unwrap_model(self.optimizer_generator).state_dict(),
190
- optimizer_guidance_state_dict=self.accelerator.unwrap_model(self.optimizer_guidance).state_dict(),
 
 
 
 
191
  scheduler_generator_state_dict=self.scheduler_generator.state_dict(),
192
  scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(),
193
  step=step,
@@ -196,10 +218,14 @@ class Trainer:
196
  if not os.path.exists(self.checkpoint_path):
197
  os.makedirs(self.checkpoint_path)
198
  if last:
199
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
 
 
200
  print(f"Saved last checkpoint at step {step}")
201
  else:
202
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
 
 
203
 
204
  def load_checkpoint(self):
205
  if (
@@ -218,9 +244,15 @@ class Trainer:
218
  key=lambda x: int("".join(filter(str.isdigit, x))),
219
  )[-1]
220
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
221
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
 
 
222
 
223
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"], strict=False)
 
 
224
  # self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"])
225
  # self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"])
226
  # if self.scheduler_guidance:
@@ -232,9 +264,14 @@ class Trainer:
232
  del checkpoint
233
  gc.collect()
234
  return step
235
-
236
 
237
- def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int = None, vocoder: nn.Module = None):
 
 
 
 
 
 
238
  if exists(resumable_with_seed):
239
  generator = torch.Generator()
240
  generator.manual_seed(resumable_with_seed)
@@ -256,7 +293,11 @@ class Trainer:
256
  self.accelerator.even_batches = False
257
  sampler = SequentialSampler(train_dataset)
258
  batch_sampler = DynamicBatchSampler(
259
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
 
 
 
 
260
  )
261
  train_dataloader = DataLoader(
262
  train_dataset,
@@ -267,29 +308,63 @@ class Trainer:
267
  batch_sampler=batch_sampler,
268
  )
269
  else:
270
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
 
 
271
 
272
  # accelerator.prepare() dispatches batches to devices;
273
  # which means the length of dataloader calculated before, should consider the number of devices
274
- warmup_steps = (
275
- self.num_warmup_updates * self.accelerator.num_processes
276
- )
277
-
278
  # consider a fixed warmup steps while using accelerate multi-gpu ddp
279
  # otherwise by default with split_batches=False, warmup steps change with num_processes
280
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
281
  decay_steps = total_steps - warmup_steps
282
-
283
- warmup_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps))
284
- decay_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps // (self.gen_update_ratio * self.grad_accumulation_steps))
285
- self.scheduler_generator = SequentialLR(self.optimizer_generator, schedulers=[warmup_scheduler_generator, decay_scheduler_generator], milestones=[warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)])
286
 
287
- warmup_scheduler_guidance = LinearLR(self.optimizer_guidance, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
288
- decay_scheduler_guidance = LinearLR(self.optimizer_guidance, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
289
- self.scheduler_guidance = SequentialLR(self.optimizer_guidance, schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance], milestones=[warmup_steps])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- train_dataloader, self.scheduler_generator, self.scheduler_guidance = self.accelerator.prepare(
292
- train_dataloader, self.scheduler_generator, self.scheduler_guidance
 
 
293
  ) # actual steps = 1 gpu steps / gpus
294
  start_step = self.load_checkpoint()
295
  global_step = start_step
@@ -298,7 +373,9 @@ class Trainer:
298
  orig_epoch_step = len(train_dataloader)
299
  skipped_epoch = int(start_step // orig_epoch_step)
300
  skipped_batch = start_step % orig_epoch_step
301
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
 
 
302
  else:
303
  skipped_epoch = 0
304
 
@@ -323,48 +400,59 @@ class Trainer:
323
 
324
  for batch in progress_bar:
325
  update_generator = global_step % self.gen_update_ratio == 0
326
-
327
  with self.accelerator.accumulate(self.model):
328
  metrics = {}
329
  text_inputs = batch["text"]
330
  mel_spec = batch["mel"].permute(0, 2, 1)
331
  mel_lengths = batch["mel_lengths"]
332
-
333
  mel_spec = mel_spec / self.scale
334
-
335
- guidance_loss_dict, guidance_log_dict = self.model(inp=mel_spec,
336
- text=text_inputs,
337
- lens=mel_lengths,
338
- student_steps=self.student_steps,
339
- update_generator=False,
340
- use_simulated=global_step >= self.num_simu,
341
- )
 
342
 
343
  # if self.GAN and update_generator:
344
  # # only add discriminator loss if GAN is enabled and generator is being updated
345
  # guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0)
346
  # metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"]
347
  # self.accelerator.backward(guidance_cls_loss, retain_graph=True)
348
-
349
  # if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
350
  # metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
351
 
352
  guidance_loss = 0
353
  guidance_loss += guidance_loss_dict["loss_fake_mean"]
354
- metrics['loss/fake_score'] = guidance_loss_dict["loss_fake_mean"]
355
  metrics["loss/guidance_loss"] = guidance_loss
356
 
357
  if self.GAN and update_generator:
358
  # only add discriminator loss if GAN is enabled and generator is being updated
359
- guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0)
360
- metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"]
 
 
 
 
 
 
361
 
362
  guidance_loss += guidance_cls_loss
363
-
364
  self.accelerator.backward(guidance_loss)
365
 
366
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
367
- metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
 
 
 
 
368
 
369
  # if self.guidance_norm.count < 100:
370
  # self.guidance_norm.update(metrics['grad_norm_guidance'])
@@ -376,20 +464,20 @@ class Trainer:
376
  # elif self.guidance_norm.count >= 100:
377
  # self.guidance_norm.update(metrics['grad_norm_guidance'])
378
 
379
-
380
  self.optimizer_guidance.step()
381
  self.scheduler_guidance.step()
382
  self.optimizer_guidance.zero_grad()
383
  self.optimizer_generator.zero_grad() # zero out the generator's gradient as well
384
-
385
  if update_generator:
386
- generator_loss_dict, generator_log_dict = self.model(inp=mel_spec,
387
- text=text_inputs,
388
- lens=mel_lengths,
389
- student_steps=self.student_steps,
390
- update_generator=True,
391
- use_simulated=global_step >= self.num_ctc,
392
- )
 
393
  # if self.GAN:
394
  # gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0)
395
  # metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"]
@@ -402,32 +490,57 @@ class Trainer:
402
  generator_loss = 0
403
  generator_loss += generator_loss_dict["loss_dm"]
404
  if "loss_mse" in generator_loss_dict:
405
- generator_loss += generator_loss_dict["loss_mse"]
406
- generator_loss += generator_loss_dict["loss_ctc"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0)
407
- generator_loss += generator_loss_dict["loss_sim"] * (self.lambda_sim_loss if global_step >= self.num_sim else 0)
408
- generator_loss += generator_loss_dict["loss_kl"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0)
 
 
 
 
 
 
409
  if self.GAN:
410
- gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0)
411
- metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"]
 
 
 
 
 
 
 
412
  generator_loss += gen_cls_loss
413
 
414
- metrics['loss/dm_loss'] = generator_loss_dict["loss_dm"]
415
- metrics['loss/ctc_loss'] = generator_loss_dict["loss_ctc"]
416
-
417
- metrics['loss/similarity_loss'] = generator_loss_dict["loss_sim"]
418
- metrics['loss/generator_loss'] = generator_loss
419
-
420
- if "loss_mse" in generator_loss_dict and generator_loss_dict["loss_mse"] != 0:
421
- metrics['loss/mse_loss'] = generator_loss_dict["loss_mse"]
422
- if "loss_kl" in generator_loss_dict and generator_loss_dict["loss_kl"] != 0:
423
- metrics['loss/kl_loss'] = generator_loss_dict["loss_kl"]
 
 
 
 
 
 
 
 
424
 
425
  self.accelerator.backward(generator_loss)
426
 
427
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
428
- metrics['grad_norm_generator'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
 
 
 
 
429
  # self.generator_norm.update(metrics['grad_norm_generator'])
430
-
431
  # if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std:
432
  # self.optimizer_generator.zero_grad()
433
  # self.optimizer_guidance.zero_grad()
@@ -440,89 +553,165 @@ class Trainer:
440
  self.optimizer_generator.zero_grad()
441
  self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well
442
 
443
-
444
  global_step += 1
445
 
446
  if self.accelerator.is_local_main_process:
447
- self.accelerator.log({**metrics,
448
- "lr_generator": self.scheduler_generator.get_last_lr()[0],
449
- "lr_guidance": self.scheduler_guidance.get_last_lr()[0],
450
- }
451
- , step=global_step)
452
-
453
- if global_step % self.log_step == 0 and self.accelerator.is_local_main_process and vocoder is not None:
 
 
 
 
 
 
 
454
  # log the first batch of the epoch
455
  with torch.no_grad():
456
- generator_input = generator_log_dict['generator_input'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
 
 
 
 
 
457
  generator_input = vocoder.decode(generator_input.float().cpu())
458
  generator_input = wandb.Audio(
459
  generator_input.float().numpy().squeeze(),
460
  sample_rate=24000,
461
- caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
 
462
  )
463
 
464
- generator_output = generator_log_dict['generator_output'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
465
- generator_output = vocoder.decode(generator_output.float().cpu())
 
 
 
 
 
 
 
466
  generator_output = wandb.Audio(
467
  generator_output.float().numpy().squeeze(),
468
  sample_rate=24000,
469
- caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
 
 
 
 
 
 
 
 
470
  )
471
-
472
- generator_cond = generator_log_dict['generator_cond'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
473
  generator_cond = vocoder.decode(generator_cond.float().cpu())
474
  generator_cond = wandb.Audio(
475
  generator_cond.float().numpy().squeeze(),
476
  sample_rate=24000,
477
- caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
 
 
 
 
 
 
 
 
478
  )
479
-
480
- ground_truth = generator_log_dict['ground_truth'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
481
  ground_truth = vocoder.decode(ground_truth.float().cpu())
482
  ground_truth = wandb.Audio(
483
  ground_truth.float().numpy().squeeze(),
484
  sample_rate=24000,
485
- caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
486
  )
487
-
488
- dmtrain_noisy_inp = generator_log_dict['dmtrain_noisy_inp'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
489
- dmtrain_noisy_inp = vocoder.decode(dmtrain_noisy_inp.float().cpu())
490
  dmtrain_noisy_inp = wandb.Audio(
491
  dmtrain_noisy_inp.float().numpy().squeeze(),
492
  sample_rate=24000,
493
- caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  )
495
-
496
- dmtrain_pred_real_image = generator_log_dict['dmtrain_pred_real_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
497
- dmtrain_pred_real_image = vocoder.decode(dmtrain_pred_real_image.float().cpu())
498
  dmtrain_pred_real_image = wandb.Audio(
499
  dmtrain_pred_real_image.float().numpy().squeeze(),
500
  sample_rate=24000,
501
- caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  )
503
-
504
- dmtrain_pred_fake_image = generator_log_dict['dmtrain_pred_fake_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
505
- dmtrain_pred_fake_image = vocoder.decode(dmtrain_pred_fake_image.float().cpu())
506
  dmtrain_pred_fake_image = wandb.Audio(
507
  dmtrain_pred_fake_image.float().numpy().squeeze(),
508
  sample_rate=24000,
509
- caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  )
511
-
512
-
513
- self.accelerator.log({"noisy_input": generator_input,
514
- "output": generator_output,
515
- "cond": generator_cond,
516
- "ground_truth": ground_truth,
517
- "dmtrain_noisy_inp": dmtrain_noisy_inp,
518
- "dmtrain_pred_real_image": dmtrain_pred_real_image,
519
- "dmtrain_pred_fake_image": dmtrain_pred_fake_image,
520
-
521
- }, step=global_step)
522
 
523
  progress_bar.set_postfix(step=str(global_step), metrics=metrics)
524
 
525
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
 
 
 
526
  self.save_checkpoint(global_step)
527
 
528
  if global_step % self.last_per_steps == 0:
@@ -531,5 +720,3 @@ class Trainer:
531
  self.save_checkpoint(global_step, last=True)
532
 
533
  self.accelerator.end_training()
534
-
535
-
 
1
  from __future__ import annotations
2
 
 
3
  import gc
4
+ import math
5
+ import os
6
 
7
  import torch
8
  import torch.nn as nn
9
+ import wandb
 
 
 
10
  from accelerate import Accelerator
11
  from accelerate.utils import DistributedDataParallelKwargs
12
+ from torch.optim import AdamW
13
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
14
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
15
+ from tqdm import tqdm
16
 
 
17
  from f5_tts.model import CFM
 
18
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
19
+ from f5_tts.model.utils import default, exists
20
+ from unimodel import UniModel
21
 
22
  # trainer
23
 
 
24
 
25
  class RunningStats:
26
  def __init__(self):
 
39
  @property
40
  def variance(self):
41
  """Return the sample variance. Returns NaN if fewer than two samples."""
42
+ return self.M2 / (self.count - 1) if self.count > 1 else float("nan")
43
 
44
  @property
45
  def std(self):
 
47
  return math.sqrt(self.variance)
48
 
49
 
 
50
  class Trainer:
51
  def __init__(
52
  self,
 
71
  accelerate_kwargs: dict = dict(),
72
  bnb_optimizer: bool = False,
73
  scale: float = 1.0,
 
74
  # training parameters for DMDSpeech
75
  num_student_step: int = 1,
76
  gen_update_ratio: int = 5,
 
78
  lambda_generator_loss: float = 1.0,
79
  lambda_ctc_loss: float = 1.0,
80
  lambda_sim_loss: float = 1.0,
 
81
  num_GAN: int = 5000,
82
  num_D: int = 500,
83
  num_ctc: int = 5000,
 
98
 
99
  if logger == "wandb":
100
  if exists(wandb_resume_id):
101
+ init_kwargs = {
102
+ "wandb": {
103
+ "resume": "allow",
104
+ "name": wandb_run_name,
105
+ "id": wandb_resume_id,
106
+ }
107
+ }
108
  else:
109
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
110
  self.accelerator.init_trackers(
 
131
  self.epochs = epochs
132
  self.num_warmup_updates = num_warmup_updates
133
  self.save_per_updates = save_per_updates
134
+ self.last_per_steps = default(
135
+ last_per_steps, save_per_updates * grad_accumulation_steps
136
+ )
137
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
138
 
139
  self.batch_size = batch_size
 
145
  self.noise_scheduler = noise_scheduler
146
 
147
  self.duration_predictor = duration_predictor
148
+
149
  self.log_step = log_step
150
 
151
+ self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update
152
+ self.lambda_discriminator_loss = (
153
+ lambda_discriminator_loss # weight for discriminator loss (L_adv)
154
+ )
155
+ self.lambda_generator_loss = (
156
+ lambda_generator_loss # weight for generator loss (L_adv)
157
+ )
158
+ self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss
159
+ self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss
160
+
161
  # create distillation schedule for student model
162
+ self.student_steps = torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]
163
+
164
+ self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training
165
+ self.num_GAN = num_GAN # number of steps before adversarial training
166
+ self.num_D = num_D # number of steps to train the discriminator before adversarial training
167
+ self.num_ctc = num_ctc # number of steps before CTC training
168
+ self.num_sim = num_sim # number of steps before similarity training
169
+ self.num_simu = num_simu # number of steps before using simulated data
 
170
 
171
  # Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible
172
  if bnb_optimizer:
173
  import bitsandbytes as bnb
174
+
175
+ self.optimizer_generator = bnb.optim.AdamW8bit(
176
+ self.model.feedforward_model.parameters(), lr=learning_rate
177
+ )
178
+ self.optimizer_guidance = bnb.optim.AdamW8bit(
179
+ self.model.guidance_model.parameters(), lr=learning_rate
180
+ )
181
  else:
182
+ self.optimizer_generator = AdamW(
183
+ self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7
184
+ )
185
+ self.optimizer_guidance = AdamW(
186
+ self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7
187
+ )
188
 
189
+ self.model, self.optimizer_generator, self.optimizer_guidance = (
190
+ self.accelerator.prepare(
191
+ self.model, self.optimizer_generator, self.optimizer_guidance
192
+ )
193
+ )
194
 
195
  self.generator_norm = RunningStats()
196
  self.guidance_norm = RunningStats()
197
 
 
198
  @property
199
  def is_main(self):
200
  return self.accelerator.is_main_process
 
204
  if self.is_main:
205
  checkpoint = dict(
206
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
207
+ optimizer_generator_state_dict=self.accelerator.unwrap_model(
208
+ self.optimizer_generator
209
+ ).state_dict(),
210
+ optimizer_guidance_state_dict=self.accelerator.unwrap_model(
211
+ self.optimizer_guidance
212
+ ).state_dict(),
213
  scheduler_generator_state_dict=self.scheduler_generator.state_dict(),
214
  scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(),
215
  step=step,
 
218
  if not os.path.exists(self.checkpoint_path):
219
  os.makedirs(self.checkpoint_path)
220
  if last:
221
+ self.accelerator.save(
222
+ checkpoint, f"{self.checkpoint_path}/model_last.pt"
223
+ )
224
  print(f"Saved last checkpoint at step {step}")
225
  else:
226
+ self.accelerator.save(
227
+ checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
228
+ )
229
 
230
  def load_checkpoint(self):
231
  if (
 
244
  key=lambda x: int("".join(filter(str.isdigit, x))),
245
  )[-1]
246
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
247
+ checkpoint = torch.load(
248
+ f"{self.checkpoint_path}/{latest_checkpoint}",
249
+ weights_only=True,
250
+ map_location="cpu",
251
+ )
252
 
253
+ self.accelerator.unwrap_model(self.model).load_state_dict(
254
+ checkpoint["model_state_dict"], strict=False
255
+ )
256
  # self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"])
257
  # self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"])
258
  # if self.scheduler_guidance:
 
264
  del checkpoint
265
  gc.collect()
266
  return step
 
267
 
268
+ def train(
269
+ self,
270
+ train_dataset: Dataset,
271
+ num_workers=64,
272
+ resumable_with_seed: int = None,
273
+ vocoder: nn.Module = None,
274
+ ):
275
  if exists(resumable_with_seed):
276
  generator = torch.Generator()
277
  generator.manual_seed(resumable_with_seed)
 
293
  self.accelerator.even_batches = False
294
  sampler = SequentialSampler(train_dataset)
295
  batch_sampler = DynamicBatchSampler(
296
+ sampler,
297
+ self.batch_size,
298
+ max_samples=self.max_samples,
299
+ random_seed=resumable_with_seed,
300
+ drop_last=False,
301
  )
302
  train_dataloader = DataLoader(
303
  train_dataset,
 
308
  batch_sampler=batch_sampler,
309
  )
310
  else:
311
+ raise ValueError(
312
+ f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
313
+ )
314
 
315
  # accelerator.prepare() dispatches batches to devices;
316
  # which means the length of dataloader calculated before, should consider the number of devices
317
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes
318
+
 
 
319
  # consider a fixed warmup steps while using accelerate multi-gpu ddp
320
  # otherwise by default with split_batches=False, warmup steps change with num_processes
321
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
322
  decay_steps = total_steps - warmup_steps
 
 
 
 
323
 
324
+ warmup_scheduler_generator = LinearLR(
325
+ self.optimizer_generator,
326
+ start_factor=1e-8,
327
+ end_factor=1.0,
328
+ total_iters=warmup_steps
329
+ // (self.gen_update_ratio * self.grad_accumulation_steps),
330
+ )
331
+ decay_scheduler_generator = LinearLR(
332
+ self.optimizer_generator,
333
+ start_factor=1.0,
334
+ end_factor=1e-8,
335
+ total_iters=decay_steps
336
+ // (self.gen_update_ratio * self.grad_accumulation_steps),
337
+ )
338
+ self.scheduler_generator = SequentialLR(
339
+ self.optimizer_generator,
340
+ schedulers=[warmup_scheduler_generator, decay_scheduler_generator],
341
+ milestones=[
342
+ warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)
343
+ ],
344
+ )
345
+
346
+ warmup_scheduler_guidance = LinearLR(
347
+ self.optimizer_guidance,
348
+ start_factor=1e-8,
349
+ end_factor=1.0,
350
+ total_iters=warmup_steps,
351
+ )
352
+ decay_scheduler_guidance = LinearLR(
353
+ self.optimizer_guidance,
354
+ start_factor=1.0,
355
+ end_factor=1e-8,
356
+ total_iters=decay_steps,
357
+ )
358
+ self.scheduler_guidance = SequentialLR(
359
+ self.optimizer_guidance,
360
+ schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance],
361
+ milestones=[warmup_steps],
362
+ )
363
 
364
+ train_dataloader, self.scheduler_generator, self.scheduler_guidance = (
365
+ self.accelerator.prepare(
366
+ train_dataloader, self.scheduler_generator, self.scheduler_guidance
367
+ )
368
  ) # actual steps = 1 gpu steps / gpus
369
  start_step = self.load_checkpoint()
370
  global_step = start_step
 
373
  orig_epoch_step = len(train_dataloader)
374
  skipped_epoch = int(start_step // orig_epoch_step)
375
  skipped_batch = start_step % orig_epoch_step
376
+ skipped_dataloader = self.accelerator.skip_first_batches(
377
+ train_dataloader, num_batches=skipped_batch
378
+ )
379
  else:
380
  skipped_epoch = 0
381
 
 
400
 
401
  for batch in progress_bar:
402
  update_generator = global_step % self.gen_update_ratio == 0
403
+
404
  with self.accelerator.accumulate(self.model):
405
  metrics = {}
406
  text_inputs = batch["text"]
407
  mel_spec = batch["mel"].permute(0, 2, 1)
408
  mel_lengths = batch["mel_lengths"]
409
+
410
  mel_spec = mel_spec / self.scale
411
+
412
+ guidance_loss_dict, guidance_log_dict = self.model(
413
+ inp=mel_spec,
414
+ text=text_inputs,
415
+ lens=mel_lengths,
416
+ student_steps=self.student_steps,
417
+ update_generator=False,
418
+ use_simulated=global_step >= self.num_simu,
419
+ )
420
 
421
  # if self.GAN and update_generator:
422
  # # only add discriminator loss if GAN is enabled and generator is being updated
423
  # guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0)
424
  # metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"]
425
  # self.accelerator.backward(guidance_cls_loss, retain_graph=True)
426
+
427
  # if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
428
  # metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
429
 
430
  guidance_loss = 0
431
  guidance_loss += guidance_loss_dict["loss_fake_mean"]
432
+ metrics["loss/fake_score"] = guidance_loss_dict["loss_fake_mean"]
433
  metrics["loss/guidance_loss"] = guidance_loss
434
 
435
  if self.GAN and update_generator:
436
  # only add discriminator loss if GAN is enabled and generator is being updated
437
+ guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (
438
+ self.lambda_discriminator_loss
439
+ if global_step >= self.num_GAN and update_generator
440
+ else 0
441
+ )
442
+ metrics["loss/discriminator_loss"] = guidance_loss_dict[
443
+ "guidance_cls_loss"
444
+ ]
445
 
446
  guidance_loss += guidance_cls_loss
447
+
448
  self.accelerator.backward(guidance_loss)
449
 
450
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
451
+ metrics["grad_norm_guidance"] = (
452
+ self.accelerator.clip_grad_norm_(
453
+ self.model.parameters(), self.max_grad_norm
454
+ )
455
+ )
456
 
457
  # if self.guidance_norm.count < 100:
458
  # self.guidance_norm.update(metrics['grad_norm_guidance'])
 
464
  # elif self.guidance_norm.count >= 100:
465
  # self.guidance_norm.update(metrics['grad_norm_guidance'])
466
 
 
467
  self.optimizer_guidance.step()
468
  self.scheduler_guidance.step()
469
  self.optimizer_guidance.zero_grad()
470
  self.optimizer_generator.zero_grad() # zero out the generator's gradient as well
471
+
472
  if update_generator:
473
+ generator_loss_dict, generator_log_dict = self.model(
474
+ inp=mel_spec,
475
+ text=text_inputs,
476
+ lens=mel_lengths,
477
+ student_steps=self.student_steps,
478
+ update_generator=True,
479
+ use_simulated=global_step >= self.num_ctc,
480
+ )
481
  # if self.GAN:
482
  # gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0)
483
  # metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"]
 
490
  generator_loss = 0
491
  generator_loss += generator_loss_dict["loss_dm"]
492
  if "loss_mse" in generator_loss_dict:
493
+ generator_loss += generator_loss_dict["loss_mse"]
494
+ generator_loss += generator_loss_dict["loss_ctc"] * (
495
+ self.lambda_ctc_loss if global_step >= self.num_ctc else 0
496
+ )
497
+ generator_loss += generator_loss_dict["loss_sim"] * (
498
+ self.lambda_sim_loss if global_step >= self.num_sim else 0
499
+ )
500
+ generator_loss += generator_loss_dict["loss_kl"] * (
501
+ self.lambda_ctc_loss if global_step >= self.num_ctc else 0
502
+ )
503
  if self.GAN:
504
+ gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (
505
+ self.lambda_generator_loss
506
+ if global_step >= (self.num_GAN + self.num_D)
507
+ and update_generator
508
+ else 0
509
+ )
510
+ metrics["loss/gen_cls_loss"] = generator_loss_dict[
511
+ "gen_cls_loss"
512
+ ]
513
  generator_loss += gen_cls_loss
514
 
515
+ metrics["loss/dm_loss"] = generator_loss_dict["loss_dm"]
516
+ metrics["loss/ctc_loss"] = generator_loss_dict["loss_ctc"]
517
+
518
+ metrics["loss/similarity_loss"] = generator_loss_dict[
519
+ "loss_sim"
520
+ ]
521
+ metrics["loss/generator_loss"] = generator_loss
522
+
523
+ if (
524
+ "loss_mse" in generator_loss_dict
525
+ and generator_loss_dict["loss_mse"] != 0
526
+ ):
527
+ metrics["loss/mse_loss"] = generator_loss_dict["loss_mse"]
528
+ if (
529
+ "loss_kl" in generator_loss_dict
530
+ and generator_loss_dict["loss_kl"] != 0
531
+ ):
532
+ metrics["loss/kl_loss"] = generator_loss_dict["loss_kl"]
533
 
534
  self.accelerator.backward(generator_loss)
535
 
536
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
537
+ metrics["grad_norm_generator"] = (
538
+ self.accelerator.clip_grad_norm_(
539
+ self.model.parameters(), self.max_grad_norm
540
+ )
541
+ )
542
  # self.generator_norm.update(metrics['grad_norm_generator'])
543
+
544
  # if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std:
545
  # self.optimizer_generator.zero_grad()
546
  # self.optimizer_guidance.zero_grad()
 
553
  self.optimizer_generator.zero_grad()
554
  self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well
555
 
 
556
  global_step += 1
557
 
558
  if self.accelerator.is_local_main_process:
559
+ self.accelerator.log(
560
+ {
561
+ **metrics,
562
+ "lr_generator": self.scheduler_generator.get_last_lr()[0],
563
+ "lr_guidance": self.scheduler_guidance.get_last_lr()[0],
564
+ },
565
+ step=global_step,
566
+ )
567
+
568
+ if (
569
+ global_step % self.log_step == 0
570
+ and self.accelerator.is_local_main_process
571
+ and vocoder is not None
572
+ ):
573
  # log the first batch of the epoch
574
  with torch.no_grad():
575
+ generator_input = (
576
+ generator_log_dict["generator_input"][0]
577
+ .unsqueeze(0)
578
+ .permute(0, 2, 1)
579
+ * self.scale
580
+ )
581
  generator_input = vocoder.decode(generator_input.float().cpu())
582
  generator_input = wandb.Audio(
583
  generator_input.float().numpy().squeeze(),
584
  sample_rate=24000,
585
+ caption="time: "
586
+ + str(generator_log_dict["time"][0].float().cpu().numpy()),
587
  )
588
 
589
+ generator_output = (
590
+ generator_log_dict["generator_output"][0]
591
+ .unsqueeze(0)
592
+ .permute(0, 2, 1)
593
+ * self.scale
594
+ )
595
+ generator_output = vocoder.decode(
596
+ generator_output.float().cpu()
597
+ )
598
  generator_output = wandb.Audio(
599
  generator_output.float().numpy().squeeze(),
600
  sample_rate=24000,
601
+ caption="time: "
602
+ + str(generator_log_dict["time"][0].float().cpu().numpy()),
603
+ )
604
+
605
+ generator_cond = (
606
+ generator_log_dict["generator_cond"][0]
607
+ .unsqueeze(0)
608
+ .permute(0, 2, 1)
609
+ * self.scale
610
  )
 
 
611
  generator_cond = vocoder.decode(generator_cond.float().cpu())
612
  generator_cond = wandb.Audio(
613
  generator_cond.float().numpy().squeeze(),
614
  sample_rate=24000,
615
+ caption="time: "
616
+ + str(generator_log_dict["time"][0].float().cpu().numpy()),
617
+ )
618
+
619
+ ground_truth = (
620
+ generator_log_dict["ground_truth"][0]
621
+ .unsqueeze(0)
622
+ .permute(0, 2, 1)
623
+ * self.scale
624
  )
 
 
625
  ground_truth = vocoder.decode(ground_truth.float().cpu())
626
  ground_truth = wandb.Audio(
627
  ground_truth.float().numpy().squeeze(),
628
  sample_rate=24000,
629
+ caption="time: "
630
+ + str(generator_log_dict["time"][0].float().cpu().numpy()),
631
+ )
632
+
633
+ dmtrain_noisy_inp = (
634
+ generator_log_dict["dmtrain_noisy_inp"][0]
635
+ .unsqueeze(0)
636
+ .permute(0, 2, 1)
637
+ * self.scale
638
+ )
639
+ dmtrain_noisy_inp = vocoder.decode(
640
+ dmtrain_noisy_inp.float().cpu()
641
  )
 
 
 
642
  dmtrain_noisy_inp = wandb.Audio(
643
  dmtrain_noisy_inp.float().numpy().squeeze(),
644
  sample_rate=24000,
645
+ caption="dmtrain_time: "
646
+ + str(
647
+ generator_log_dict["dmtrain_time"][0]
648
+ .float()
649
+ .cpu()
650
+ .numpy()
651
+ ),
652
+ )
653
+
654
+ dmtrain_pred_real_image = (
655
+ generator_log_dict["dmtrain_pred_real_image"][0]
656
+ .unsqueeze(0)
657
+ .permute(0, 2, 1)
658
+ * self.scale
659
+ )
660
+ dmtrain_pred_real_image = vocoder.decode(
661
+ dmtrain_pred_real_image.float().cpu()
662
  )
 
 
 
663
  dmtrain_pred_real_image = wandb.Audio(
664
  dmtrain_pred_real_image.float().numpy().squeeze(),
665
  sample_rate=24000,
666
+ caption="dmtrain_time: "
667
+ + str(
668
+ generator_log_dict["dmtrain_time"][0]
669
+ .float()
670
+ .cpu()
671
+ .numpy()
672
+ ),
673
+ )
674
+
675
+ dmtrain_pred_fake_image = (
676
+ generator_log_dict["dmtrain_pred_fake_image"][0]
677
+ .unsqueeze(0)
678
+ .permute(0, 2, 1)
679
+ * self.scale
680
+ )
681
+ dmtrain_pred_fake_image = vocoder.decode(
682
+ dmtrain_pred_fake_image.float().cpu()
683
  )
 
 
 
684
  dmtrain_pred_fake_image = wandb.Audio(
685
  dmtrain_pred_fake_image.float().numpy().squeeze(),
686
  sample_rate=24000,
687
+ caption="dmtrain_time: "
688
+ + str(
689
+ generator_log_dict["dmtrain_time"][0]
690
+ .float()
691
+ .cpu()
692
+ .numpy()
693
+ ),
694
+ )
695
+
696
+ self.accelerator.log(
697
+ {
698
+ "noisy_input": generator_input,
699
+ "output": generator_output,
700
+ "cond": generator_cond,
701
+ "ground_truth": ground_truth,
702
+ "dmtrain_noisy_inp": dmtrain_noisy_inp,
703
+ "dmtrain_pred_real_image": dmtrain_pred_real_image,
704
+ "dmtrain_pred_fake_image": dmtrain_pred_fake_image,
705
+ },
706
+ step=global_step,
707
  )
 
 
 
 
 
 
 
 
 
 
 
708
 
709
  progress_bar.set_postfix(step=str(global_step), metrics=metrics)
710
 
711
+ if (
712
+ global_step % (self.save_per_updates * self.grad_accumulation_steps)
713
+ == 0
714
+ ):
715
  self.save_checkpoint(global_step)
716
 
717
  if global_step % self.last_per_steps == 0:
 
720
  self.save_checkpoint(global_step, last=True)
721
 
722
  self.accelerator.end_training()
 
 
duration_predictor.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
 
4
  # from tts_encode import tts_encode
5
 
 
6
  def calculate_remaining_lengths(mel_lengths):
7
  B = mel_lengths.shape[0]
8
  max_L = mel_lengths.max().item() # Get the maximum length in the batch
@@ -21,64 +22,84 @@ class PositionalEncoding(nn.Module):
21
  super().__init__()
22
  pe = torch.zeros(max_len, hidden_dim)
23
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
24
- div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_dim))
 
 
 
25
  pe[:, 0::2] = torch.sin(position * div_term)
26
  pe[:, 1::2] = torch.cos(position * div_term)
27
  self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim)
28
 
29
  def forward(self, x):
30
- x = x + self.pe[:, :x.size(1)].to(x.device)
31
  return x
32
 
33
 
34
  class SpeechLengthPredictor(nn.Module):
35
 
36
- def __init__(self,
37
- vocab_size=2545, n_mel=100, hidden_dim=256,
38
- n_text_layer=4, n_cross_layer=4, n_head=8,
 
 
 
 
 
39
  output_dim=1,
40
  ):
41
  super().__init__()
42
-
43
  # Text Encoder: Embedding + Transformer Layers
44
- self.text_embedder = nn.Embedding(vocab_size+1, hidden_dim, padding_idx=vocab_size)
 
 
45
  self.text_pe = PositionalEncoding(hidden_dim)
46
  encoder_layer = nn.TransformerEncoderLayer(
47
- d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True
 
 
 
 
 
 
48
  )
49
- self.text_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_text_layer)
50
-
51
  # Mel Spectrogram Embedder
52
  self.mel_embedder = nn.Linear(n_mel, hidden_dim)
53
  self.mel_pe = PositionalEncoding(hidden_dim)
54
 
55
  # Transformer Decoder Layers with Cross-Attention in Every Layer
56
  decoder_layer = nn.TransformerDecoderLayer(
57
- d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True
 
 
 
58
  )
59
  self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)
60
-
61
  # Final Classification Layer
62
  self.predictor = nn.Linear(hidden_dim, output_dim)
63
 
64
  def forward(self, text_ids, mel):
65
  # Encode text
66
  text_embedded = self.text_pe(self.text_embedder(text_ids))
67
- text_features = self.text_encoder(text_embedded) # (B, L_text, D)
68
-
69
  # Encode Mel spectrogram
70
  mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D)
71
-
72
  # Causal Masking for Decoder
73
  seq_len = mel_features.size(1)
74
- causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device)
 
 
75
  # causal_mask = torch.triu(
76
  # torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
77
  # )
78
 
79
  # Transformer Decoder with Cross-Attention in Each Layer
80
  decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)
81
-
82
  # Length Prediction
83
  length_logits = self.predictor(decoder_out).squeeze(-1)
84
  return length_logits
 
3
 
4
  # from tts_encode import tts_encode
5
 
6
+
7
  def calculate_remaining_lengths(mel_lengths):
8
  B = mel_lengths.shape[0]
9
  max_L = mel_lengths.max().item() # Get the maximum length in the batch
 
22
  super().__init__()
23
  pe = torch.zeros(max_len, hidden_dim)
24
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
25
+ div_term = torch.exp(
26
+ torch.arange(0, hidden_dim, 2).float()
27
+ * (-torch.log(torch.tensor(10000.0)) / hidden_dim)
28
+ )
29
  pe[:, 0::2] = torch.sin(position * div_term)
30
  pe[:, 1::2] = torch.cos(position * div_term)
31
  self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim)
32
 
33
  def forward(self, x):
34
+ x = x + self.pe[:, : x.size(1)].to(x.device)
35
  return x
36
 
37
 
38
  class SpeechLengthPredictor(nn.Module):
39
 
40
+ def __init__(
41
+ self,
42
+ vocab_size=2545,
43
+ n_mel=100,
44
+ hidden_dim=256,
45
+ n_text_layer=4,
46
+ n_cross_layer=4,
47
+ n_head=8,
48
  output_dim=1,
49
  ):
50
  super().__init__()
51
+
52
  # Text Encoder: Embedding + Transformer Layers
53
+ self.text_embedder = nn.Embedding(
54
+ vocab_size + 1, hidden_dim, padding_idx=vocab_size
55
+ )
56
  self.text_pe = PositionalEncoding(hidden_dim)
57
  encoder_layer = nn.TransformerEncoderLayer(
58
+ d_model=hidden_dim,
59
+ nhead=n_head,
60
+ dim_feedforward=hidden_dim * 2,
61
+ batch_first=True,
62
+ )
63
+ self.text_encoder = nn.TransformerEncoder(
64
+ encoder_layer, num_layers=n_text_layer
65
  )
66
+
 
67
  # Mel Spectrogram Embedder
68
  self.mel_embedder = nn.Linear(n_mel, hidden_dim)
69
  self.mel_pe = PositionalEncoding(hidden_dim)
70
 
71
  # Transformer Decoder Layers with Cross-Attention in Every Layer
72
  decoder_layer = nn.TransformerDecoderLayer(
73
+ d_model=hidden_dim,
74
+ nhead=n_head,
75
+ dim_feedforward=hidden_dim * 2,
76
+ batch_first=True,
77
  )
78
  self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)
79
+
80
  # Final Classification Layer
81
  self.predictor = nn.Linear(hidden_dim, output_dim)
82
 
83
  def forward(self, text_ids, mel):
84
  # Encode text
85
  text_embedded = self.text_pe(self.text_embedder(text_ids))
86
+ text_features = self.text_encoder(text_embedded) # (B, L_text, D)
87
+
88
  # Encode Mel spectrogram
89
  mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D)
90
+
91
  # Causal Masking for Decoder
92
  seq_len = mel_features.size(1)
93
+ causal_mask = (
94
+ torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device)
95
+ )
96
  # causal_mask = torch.triu(
97
  # torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
98
  # )
99
 
100
  # Transformer Decoder with Cross-Attention in Each Layer
101
  decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)
102
+
103
  # Length Prediction
104
  length_logits = self.predictor(decoder_out).squeeze(-1)
105
  return length_logits
duration_trainer.py CHANGED
@@ -1,11 +1,11 @@
1
  from __future__ import annotations
2
 
3
  import gc
4
- import os
5
-
6
  import math
 
7
 
8
  import torch
 
9
  import torchaudio
10
  import wandb
11
  from accelerate import Accelerator
@@ -13,37 +13,28 @@ from accelerate.utils import DistributedDataParallelKwargs
13
  from ema_pytorch import EMA
14
  from torch.optim import AdamW
15
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
16
- from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset # <-- Added Subset import
 
17
  from tqdm import tqdm
18
 
19
- import torch.nn.functional as F
20
-
21
- from f5_tts.model import CFM
22
- from f5_tts.model.dataset import collate_fn, DynamicBatchSampler
23
- from f5_tts.model.utils import default, exists
24
-
25
  from duration_predictor import calculate_remaining_lengths
 
 
 
 
26
 
27
  # trainer
28
 
29
- from f5_tts.model.utils import (
30
- default,
31
- exists,
32
- list_str_to_idx,
33
- list_str_to_tensor,
34
- lens_to_mask,
35
- mask_from_frac_lengths,
36
- )
37
 
38
  SAMPLE_RATE = 24_000
39
 
40
 
41
  def masked_l1_loss(est_lengths, tar_lengths):
42
- first_zero_idx = (tar_lengths == 0).int().argmax(dim=1)
43
  B, L = tar_lengths.shape
44
- range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L)
45
  mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
46
- loss = F.l1_loss(est_lengths, tar_lengths, reduction='none') # (B, L)
47
  loss = loss * mask # Zero out ignored positions
48
  loss = loss.sum() / mask.sum() # Normalize by valid elements
49
  return loss
@@ -55,9 +46,9 @@ def masked_cross_entropy_loss(est_length_logits, tar_length_labels):
55
  range_tensor = torch.arange(L, device=tar_length_labels.device).expand(B, L)
56
  mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
57
  loss = F.cross_entropy(
58
- est_length_logits.reshape(-1, est_length_logits.size(-1)),
59
- tar_length_labels.reshape(-1),
60
- reduction='none'
61
  ).reshape(B, L)
62
  loss = loss * mask
63
  loss = loss.sum() / mask.sum()
@@ -71,7 +62,7 @@ class Trainer:
71
  vocab_size,
72
  vocab_char_map,
73
  process_token_to_id=True,
74
- loss_fn='L1',
75
  lambda_L1=1,
76
  gumbel_tau=0.5,
77
  n_class=301,
@@ -110,7 +101,13 @@ class Trainer:
110
  self.logger = logger
111
  if self.logger == "wandb":
112
  if exists(wandb_resume_id):
113
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
 
 
 
 
 
 
114
  else:
115
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
116
 
@@ -139,7 +136,7 @@ class Trainer:
139
  self.vocab_size = vocab_size
140
  self.vocab_char_map = vocab_char_map
141
  self.process_token_to_id = process_token_to_id
142
- assert loss_fn in ['L1', 'CE', 'L1_and_CE']
143
  self.loss_fn = loss_fn
144
  self.lambda_L1 = lambda_L1
145
  self.n_class = n_class
@@ -149,7 +146,9 @@ class Trainer:
149
  self.epochs = epochs
150
  self.num_warmup_updates = num_warmup_updates
151
  self.save_per_updates = save_per_updates
152
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
 
 
153
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
154
 
155
  self.batch_size = batch_size
@@ -164,33 +163,44 @@ class Trainer:
164
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
165
  else:
166
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
167
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
168
-
 
 
169
  @property
170
  def is_main(self):
171
  return self.accelerator.is_main_process
172
 
173
  def save_checkpoint(self, step, last=False):
174
  self.accelerator.wait_for_everyone()
175
- if self.is_main:
176
  checkpoint = dict(
177
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
178
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
 
 
179
  scheduler_state_dict=self.scheduler.state_dict(),
180
  step=step,
181
  )
182
  if not os.path.exists(self.checkpoint_path):
183
  os.makedirs(self.checkpoint_path)
184
  if last:
185
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
 
 
186
  else:
187
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
 
 
188
 
189
  def load_checkpoint(self):
190
  if (
191
  not exists(self.checkpoint_path)
192
  or not os.path.exists(self.checkpoint_path)
193
- or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
 
 
 
194
  ):
195
  return 0
196
 
@@ -203,21 +213,32 @@ class Trainer:
203
  key=lambda x: int("".join(filter(str.isdigit, x))),
204
  )[-1]
205
 
206
- print(f'To load from {latest_checkpoint}.')
207
 
208
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
209
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
 
 
210
 
211
- print(f'Loaded from {latest_checkpoint}.')
212
 
213
  if "step" in checkpoint:
214
  # patch for backward compatibility, 305e3ea
215
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
 
 
 
216
  if key in checkpoint["model_state_dict"]:
217
  del checkpoint["model_state_dict"][key]
218
 
219
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
220
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
 
 
 
 
221
  if self.scheduler:
222
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
223
  step = checkpoint["step"]
@@ -227,17 +248,18 @@ class Trainer:
227
  for k, v in checkpoint["ema_model_state_dict"].items()
228
  if k not in ["initted", "step"]
229
  }
230
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
 
 
231
  step = 0
232
-
233
  del checkpoint
234
  gc.collect()
235
 
236
- print(f'Exit load_checkpoint.')
237
 
238
  return step
239
 
240
-
241
  def validate(self, valid_dataloader, global_step):
242
  """
243
  Runs evaluation on the validation set, computes the average loss,
@@ -251,54 +273,61 @@ class Trainer:
251
  with torch.no_grad():
252
  for batch in valid_dataloader:
253
  # Inputs
254
- mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D)
255
- text = batch['text']
256
 
257
  if self.process_token_to_id:
258
  text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device)
259
- text_ids = text_ids.masked_fill(text_ids==-1, self.vocab_size)
260
  else:
261
  text_ids = text
262
 
263
  # Targets
264
- mel_lengths = batch['mel_lengths']
265
  tar_lengths = calculate_remaining_lengths(mel_lengths)
266
  predictions = self.model(text_ids=text_ids, mel=mel)
267
 
268
- if self.loss_fn == 'L1':
269
  est_lengths = predictions
270
  loss = masked_l1_loss(
271
  est_lengths=est_lengths, tar_lengths=tar_lengths
272
  )
273
  frame_error = loss
274
 
275
- elif self.loss_fn == 'CE':
276
- tar_length_labels = (tar_lengths // self.n_frame_per_class) \
277
- .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1]
 
278
  est_length_logtis = predictions
279
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
280
  loss = masked_cross_entropy_loss(
281
- est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels
 
282
  )
283
  est_lengths = est_length_labels * self.n_frame_per_class
284
  frame_error = masked_l1_loss(
285
  est_lengths=est_lengths, tar_lengths=tar_lengths
286
  )
287
 
288
- elif self.loss_fn == 'L1_and_CE':
289
- tar_length_labels = (tar_lengths // self.n_frame_per_class) \
290
- .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1]
 
291
  est_length_logtis = predictions
292
  est_length_1hots = F.gumbel_softmax(
293
  est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
294
  )
295
- length_values = torch.arange(
296
- self.n_class, device=est_length_1hots.device
297
- ).float() * self.n_frame_per_class
 
 
 
298
  est_lengths = (est_length_1hots * length_values).sum(-1)
299
 
300
  loss_CE = masked_cross_entropy_loss(
301
- est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels
 
302
  )
303
 
304
  loss_L1 = masked_l1_loss(
@@ -321,18 +350,19 @@ class Trainer:
321
  avg_valid_sec_error = total_sec_error / count if count > 0 else 0.0
322
  # Log validation metrics
323
  self.accelerator.log(
324
- {
325
- f"valid_loss": avg_valid_loss,
326
- f"valid_sec_error": avg_valid_sec_error
327
- },
328
- step=global_step
329
  )
330
-
331
- self.model.train()
332
 
 
333
 
334
- def train(self, train_dataset: Dataset, valid_dataset: Dataset,
335
- num_workers=64, resumable_with_seed: int = None):
 
 
 
 
 
336
  if exists(resumable_with_seed):
337
  generator = torch.Generator()
338
  generator.manual_seed(resumable_with_seed)
@@ -366,7 +396,11 @@ class Trainer:
366
 
367
  sampler = SequentialSampler(train_dataset)
368
  batch_sampler = DynamicBatchSampler(
369
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
 
 
 
 
370
  )
371
  train_dataloader = DataLoader(
372
  train_dataset,
@@ -379,20 +413,26 @@ class Trainer:
379
 
380
  sampler = SequentialSampler(valid_dataset)
381
  batch_sampler = DynamicBatchSampler(
382
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
 
 
 
 
383
  )
384
  # Create validation dataloader (always sequential, no shuffling)
385
  valid_dataloader = DataLoader(
386
  valid_dataset,
387
  collate_fn=collate_fn,
388
  num_workers=num_workers,
389
- pin_memory=True,
390
  persistent_workers=True,
391
  batch_sampler=batch_sampler,
392
  )
393
  else:
394
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
395
-
 
 
396
  # accelerator.prepare() dispatches batches to devices;
397
  # which means the length of dataloader calculated before, should consider the number of devices
398
  warmup_steps = (
@@ -401,10 +441,16 @@ class Trainer:
401
  # otherwise by default with split_batches=False, warmup steps change with num_processes
402
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
403
  decay_steps = total_steps - warmup_steps
404
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
405
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
 
 
 
 
406
  self.scheduler = SequentialLR(
407
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
 
 
408
  )
409
  train_dataloader, self.scheduler = self.accelerator.prepare(
410
  train_dataloader, self.scheduler
@@ -418,7 +464,9 @@ class Trainer:
418
  orig_epoch_step = len(train_dataloader)
419
  skipped_epoch = int(start_step // orig_epoch_step)
420
  skipped_batch = start_step % orig_epoch_step
421
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
 
 
422
  else:
423
  skipped_epoch = 0
424
 
@@ -444,21 +492,23 @@ class Trainer:
444
  for batch in progress_bar:
445
  with self.accelerator.accumulate(self.model):
446
  # Inputs
447
- mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D)
448
- text = batch['text']
449
 
450
  if self.process_token_to_id:
451
- text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device)
452
- text_ids = text_ids.masked_fill(text_ids==-1, self.vocab_size)
 
 
453
  else:
454
  text_ids = text
455
 
456
  # Targets
457
- mel_lengths = batch['mel_lengths']
458
  tar_lengths = calculate_remaining_lengths(mel_lengths)
459
  predictions = self.model(text_ids=text_ids, mel=mel)
460
 
461
- if self.loss_fn == 'L1':
462
  est_lengths = predictions
463
  loss = masked_l1_loss(
464
  est_lengths=est_lengths, tar_lengths=tar_lengths
@@ -469,19 +519,23 @@ class Trainer:
469
  sec_error = frame_error * 256 / 24000
470
 
471
  log_dict = {
472
- 'loss': loss.item(),
473
- 'loss_L1': loss.item(),
474
- 'sec_error': sec_error.item(),
475
- 'lr': self.scheduler.get_last_lr()[0]
476
- }
477
-
478
- elif self.loss_fn == 'CE':
479
- tar_length_labels = (tar_lengths // self.n_frame_per_class) \
480
- .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1]
 
 
 
481
  est_length_logtis = predictions
482
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
483
  loss = masked_cross_entropy_loss(
484
- est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels
 
485
  )
486
  with torch.no_grad():
487
  est_lengths = est_length_labels * self.n_frame_per_class
@@ -491,29 +545,36 @@ class Trainer:
491
  sec_error = frame_error * 256 / 24000
492
 
493
  log_dict = {
494
- 'loss': loss.item(),
495
- 'loss_CE': loss.item(),
496
- 'sec_error': sec_error.item(),
497
- 'lr': self.scheduler.get_last_lr()[0]
498
- }
499
-
500
- elif self.loss_fn == 'L1_and_CE':
501
- tar_length_labels = (tar_lengths // self.n_frame_per_class) \
502
- .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1]
 
 
 
503
  est_length_logtis = predictions
504
  est_length_1hots = F.gumbel_softmax(
505
  est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
506
  )
507
- length_values = torch.arange(
508
- self.n_class, device=est_length_1hots.device
509
- ).float() * self.n_frame_per_class
 
 
 
510
  est_lengths = (est_length_1hots * length_values).sum(-1)
511
 
512
  loss_CE = masked_cross_entropy_loss(
513
- est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels
 
514
  )
515
 
516
- loss_L1 = masked_l1_loss(
517
  est_lengths=est_lengths, tar_lengths=tar_lengths
518
  )
519
 
@@ -524,21 +585,22 @@ class Trainer:
524
  sec_error = frame_error * 256 / 24000
525
 
526
  log_dict = {
527
- 'loss': loss.item(),
528
- 'loss_L1': loss_L1.item(),
529
- 'loss_CE': loss_CE.item(),
530
- 'sec_error': sec_error.item(),
531
- 'lr': self.scheduler.get_last_lr()[0]
532
  }
533
 
534
  else:
535
  raise NotImplementedError(self.loss_fn)
536
 
537
-
538
  self.accelerator.backward(loss)
539
 
540
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
541
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
 
 
542
 
543
  self.optimizer.step()
544
  self.scheduler.step()
@@ -550,7 +612,10 @@ class Trainer:
550
  self.accelerator.log(log_dict, step=global_step)
551
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
552
 
553
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
 
 
 
554
  self.save_checkpoint(global_step)
555
  # if self.log_samples and self.accelerator.is_local_main_process:
556
  # Run validation at the end of each epoch (only on the main process)
 
1
  from __future__ import annotations
2
 
3
  import gc
 
 
4
  import math
5
+ import os
6
 
7
  import torch
8
+ import torch.nn.functional as F
9
  import torchaudio
10
  import wandb
11
  from accelerate import Accelerator
 
13
  from ema_pytorch import EMA
14
  from torch.optim import AdamW
15
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
16
+ from torch.utils.data import Dataset # <-- Added Subset import
17
+ from torch.utils.data import DataLoader, SequentialSampler, Subset
18
  from tqdm import tqdm
19
 
 
 
 
 
 
 
20
  from duration_predictor import calculate_remaining_lengths
21
+ from f5_tts.model import CFM
22
+ from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
23
+ from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
24
+ list_str_to_tensor, mask_from_frac_lengths)
25
 
26
  # trainer
27
 
 
 
 
 
 
 
 
 
28
 
29
  SAMPLE_RATE = 24_000
30
 
31
 
32
  def masked_l1_loss(est_lengths, tar_lengths):
33
+ first_zero_idx = (tar_lengths == 0).int().argmax(dim=1)
34
  B, L = tar_lengths.shape
35
+ range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L)
36
  mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
37
+ loss = F.l1_loss(est_lengths, tar_lengths, reduction="none") # (B, L)
38
  loss = loss * mask # Zero out ignored positions
39
  loss = loss.sum() / mask.sum() # Normalize by valid elements
40
  return loss
 
46
  range_tensor = torch.arange(L, device=tar_length_labels.device).expand(B, L)
47
  mask = range_tensor <= first_zero_idx[:, None] # Include the first 0
48
  loss = F.cross_entropy(
49
+ est_length_logits.reshape(-1, est_length_logits.size(-1)),
50
+ tar_length_labels.reshape(-1),
51
+ reduction="none",
52
  ).reshape(B, L)
53
  loss = loss * mask
54
  loss = loss.sum() / mask.sum()
 
62
  vocab_size,
63
  vocab_char_map,
64
  process_token_to_id=True,
65
+ loss_fn="L1",
66
  lambda_L1=1,
67
  gumbel_tau=0.5,
68
  n_class=301,
 
101
  self.logger = logger
102
  if self.logger == "wandb":
103
  if exists(wandb_resume_id):
104
+ init_kwargs = {
105
+ "wandb": {
106
+ "resume": "allow",
107
+ "name": wandb_run_name,
108
+ "id": wandb_resume_id,
109
+ }
110
+ }
111
  else:
112
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
113
 
 
136
  self.vocab_size = vocab_size
137
  self.vocab_char_map = vocab_char_map
138
  self.process_token_to_id = process_token_to_id
139
+ assert loss_fn in ["L1", "CE", "L1_and_CE"]
140
  self.loss_fn = loss_fn
141
  self.lambda_L1 = lambda_L1
142
  self.n_class = n_class
 
146
  self.epochs = epochs
147
  self.num_warmup_updates = num_warmup_updates
148
  self.save_per_updates = save_per_updates
149
+ self.last_per_steps = default(
150
+ last_per_steps, save_per_updates * grad_accumulation_steps
151
+ )
152
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
153
 
154
  self.batch_size = batch_size
 
163
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
164
  else:
165
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
166
+ self.model, self.optimizer = self.accelerator.prepare(
167
+ self.model, self.optimizer
168
+ )
169
+
170
  @property
171
  def is_main(self):
172
  return self.accelerator.is_main_process
173
 
174
  def save_checkpoint(self, step, last=False):
175
  self.accelerator.wait_for_everyone()
176
+ if self.is_main:
177
  checkpoint = dict(
178
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
179
+ optimizer_state_dict=self.accelerator.unwrap_model(
180
+ self.optimizer
181
+ ).state_dict(),
182
  scheduler_state_dict=self.scheduler.state_dict(),
183
  step=step,
184
  )
185
  if not os.path.exists(self.checkpoint_path):
186
  os.makedirs(self.checkpoint_path)
187
  if last:
188
+ self.accelerator.save(
189
+ checkpoint, f"{self.checkpoint_path}/model_last.pt"
190
+ )
191
  else:
192
+ self.accelerator.save(
193
+ checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
194
+ )
195
 
196
  def load_checkpoint(self):
197
  if (
198
  not exists(self.checkpoint_path)
199
  or not os.path.exists(self.checkpoint_path)
200
+ or not any(
201
+ filename.endswith(".pt")
202
+ for filename in os.listdir(self.checkpoint_path)
203
+ )
204
  ):
205
  return 0
206
 
 
213
  key=lambda x: int("".join(filter(str.isdigit, x))),
214
  )[-1]
215
 
216
+ print(f"To load from {latest_checkpoint}.")
217
 
218
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
219
+ checkpoint = torch.load(
220
+ f"{self.checkpoint_path}/{latest_checkpoint}",
221
+ weights_only=True,
222
+ map_location="cpu",
223
+ )
224
 
225
+ print(f"Loaded from {latest_checkpoint}.")
226
 
227
  if "step" in checkpoint:
228
  # patch for backward compatibility, 305e3ea
229
+ for key in [
230
+ "mel_spec.mel_stft.mel_scale.fb",
231
+ "mel_spec.mel_stft.spectrogram.window",
232
+ ]:
233
  if key in checkpoint["model_state_dict"]:
234
  del checkpoint["model_state_dict"][key]
235
 
236
+ self.accelerator.unwrap_model(self.model).load_state_dict(
237
+ checkpoint["model_state_dict"]
238
+ )
239
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(
240
+ checkpoint["optimizer_state_dict"]
241
+ )
242
  if self.scheduler:
243
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
244
  step = checkpoint["step"]
 
248
  for k, v in checkpoint["ema_model_state_dict"].items()
249
  if k not in ["initted", "step"]
250
  }
251
+ self.accelerator.unwrap_model(self.model).load_state_dict(
252
+ checkpoint["model_state_dict"]
253
+ )
254
  step = 0
255
+
256
  del checkpoint
257
  gc.collect()
258
 
259
+ print(f"Exit load_checkpoint.")
260
 
261
  return step
262
 
 
263
  def validate(self, valid_dataloader, global_step):
264
  """
265
  Runs evaluation on the validation set, computes the average loss,
 
273
  with torch.no_grad():
274
  for batch in valid_dataloader:
275
  # Inputs
276
+ mel = batch["mel"].permute(0, 2, 1) # (B, L_mel, D)
277
+ text = batch["text"]
278
 
279
  if self.process_token_to_id:
280
  text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device)
281
+ text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size)
282
  else:
283
  text_ids = text
284
 
285
  # Targets
286
+ mel_lengths = batch["mel_lengths"]
287
  tar_lengths = calculate_remaining_lengths(mel_lengths)
288
  predictions = self.model(text_ids=text_ids, mel=mel)
289
 
290
+ if self.loss_fn == "L1":
291
  est_lengths = predictions
292
  loss = masked_l1_loss(
293
  est_lengths=est_lengths, tar_lengths=tar_lengths
294
  )
295
  frame_error = loss
296
 
297
+ elif self.loss_fn == "CE":
298
+ tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp(
299
+ min=0, max=self.n_class - 1
300
+ ) # [0, 1, ..., n_class-1]
301
  est_length_logtis = predictions
302
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
303
  loss = masked_cross_entropy_loss(
304
+ est_length_logits=est_length_logtis,
305
+ tar_length_labels=tar_length_labels,
306
  )
307
  est_lengths = est_length_labels * self.n_frame_per_class
308
  frame_error = masked_l1_loss(
309
  est_lengths=est_lengths, tar_lengths=tar_lengths
310
  )
311
 
312
+ elif self.loss_fn == "L1_and_CE":
313
+ tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp(
314
+ min=0, max=self.n_class - 1
315
+ ) # [0, 1, ..., n_class-1]
316
  est_length_logtis = predictions
317
  est_length_1hots = F.gumbel_softmax(
318
  est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
319
  )
320
+ length_values = (
321
+ torch.arange(
322
+ self.n_class, device=est_length_1hots.device
323
+ ).float()
324
+ * self.n_frame_per_class
325
+ )
326
  est_lengths = (est_length_1hots * length_values).sum(-1)
327
 
328
  loss_CE = masked_cross_entropy_loss(
329
+ est_length_logits=est_length_logtis,
330
+ tar_length_labels=tar_length_labels,
331
  )
332
 
333
  loss_L1 = masked_l1_loss(
 
350
  avg_valid_sec_error = total_sec_error / count if count > 0 else 0.0
351
  # Log validation metrics
352
  self.accelerator.log(
353
+ {f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error},
354
+ step=global_step,
 
 
 
355
  )
 
 
356
 
357
+ self.model.train()
358
 
359
+ def train(
360
+ self,
361
+ train_dataset: Dataset,
362
+ valid_dataset: Dataset,
363
+ num_workers=64,
364
+ resumable_with_seed: int = None,
365
+ ):
366
  if exists(resumable_with_seed):
367
  generator = torch.Generator()
368
  generator.manual_seed(resumable_with_seed)
 
396
 
397
  sampler = SequentialSampler(train_dataset)
398
  batch_sampler = DynamicBatchSampler(
399
+ sampler,
400
+ self.batch_size,
401
+ max_samples=self.max_samples,
402
+ random_seed=resumable_with_seed,
403
+ drop_last=False,
404
  )
405
  train_dataloader = DataLoader(
406
  train_dataset,
 
413
 
414
  sampler = SequentialSampler(valid_dataset)
415
  batch_sampler = DynamicBatchSampler(
416
+ sampler,
417
+ self.batch_size,
418
+ max_samples=self.max_samples,
419
+ random_seed=resumable_with_seed,
420
+ drop_last=False,
421
  )
422
  # Create validation dataloader (always sequential, no shuffling)
423
  valid_dataloader = DataLoader(
424
  valid_dataset,
425
  collate_fn=collate_fn,
426
  num_workers=num_workers,
427
+ pin_memory=True,
428
  persistent_workers=True,
429
  batch_sampler=batch_sampler,
430
  )
431
  else:
432
+ raise ValueError(
433
+ f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
434
+ )
435
+
436
  # accelerator.prepare() dispatches batches to devices;
437
  # which means the length of dataloader calculated before, should consider the number of devices
438
  warmup_steps = (
 
441
  # otherwise by default with split_batches=False, warmup steps change with num_processes
442
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
443
  decay_steps = total_steps - warmup_steps
444
+ warmup_scheduler = LinearLR(
445
+ self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps
446
+ )
447
+ decay_scheduler = LinearLR(
448
+ self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps
449
+ )
450
  self.scheduler = SequentialLR(
451
+ self.optimizer,
452
+ schedulers=[warmup_scheduler, decay_scheduler],
453
+ milestones=[warmup_steps],
454
  )
455
  train_dataloader, self.scheduler = self.accelerator.prepare(
456
  train_dataloader, self.scheduler
 
464
  orig_epoch_step = len(train_dataloader)
465
  skipped_epoch = int(start_step // orig_epoch_step)
466
  skipped_batch = start_step % orig_epoch_step
467
+ skipped_dataloader = self.accelerator.skip_first_batches(
468
+ train_dataloader, num_batches=skipped_batch
469
+ )
470
  else:
471
  skipped_epoch = 0
472
 
 
492
  for batch in progress_bar:
493
  with self.accelerator.accumulate(self.model):
494
  # Inputs
495
+ mel = batch["mel"].permute(0, 2, 1) # (B, L_mel, D)
496
+ text = batch["text"]
497
 
498
  if self.process_token_to_id:
499
+ text_ids = list_str_to_idx(text, self.vocab_char_map).to(
500
+ mel.device
501
+ )
502
+ text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size)
503
  else:
504
  text_ids = text
505
 
506
  # Targets
507
+ mel_lengths = batch["mel_lengths"]
508
  tar_lengths = calculate_remaining_lengths(mel_lengths)
509
  predictions = self.model(text_ids=text_ids, mel=mel)
510
 
511
+ if self.loss_fn == "L1":
512
  est_lengths = predictions
513
  loss = masked_l1_loss(
514
  est_lengths=est_lengths, tar_lengths=tar_lengths
 
519
  sec_error = frame_error * 256 / 24000
520
 
521
  log_dict = {
522
+ "loss": loss.item(),
523
+ "loss_L1": loss.item(),
524
+ "sec_error": sec_error.item(),
525
+ "lr": self.scheduler.get_last_lr()[0],
526
+ }
527
+
528
+ elif self.loss_fn == "CE":
529
+ tar_length_labels = (
530
+ tar_lengths // self.n_frame_per_class
531
+ ).clamp(
532
+ min=0, max=self.n_class - 1
533
+ ) # [0, 1, ..., n_class-1]
534
  est_length_logtis = predictions
535
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
536
  loss = masked_cross_entropy_loss(
537
+ est_length_logits=est_length_logtis,
538
+ tar_length_labels=tar_length_labels,
539
  )
540
  with torch.no_grad():
541
  est_lengths = est_length_labels * self.n_frame_per_class
 
545
  sec_error = frame_error * 256 / 24000
546
 
547
  log_dict = {
548
+ "loss": loss.item(),
549
+ "loss_CE": loss.item(),
550
+ "sec_error": sec_error.item(),
551
+ "lr": self.scheduler.get_last_lr()[0],
552
+ }
553
+
554
+ elif self.loss_fn == "L1_and_CE":
555
+ tar_length_labels = (
556
+ tar_lengths // self.n_frame_per_class
557
+ ).clamp(
558
+ min=0, max=self.n_class - 1
559
+ ) # [0, 1, ..., n_class-1]
560
  est_length_logtis = predictions
561
  est_length_1hots = F.gumbel_softmax(
562
  est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1
563
  )
564
+ length_values = (
565
+ torch.arange(
566
+ self.n_class, device=est_length_1hots.device
567
+ ).float()
568
+ * self.n_frame_per_class
569
+ )
570
  est_lengths = (est_length_1hots * length_values).sum(-1)
571
 
572
  loss_CE = masked_cross_entropy_loss(
573
+ est_length_logits=est_length_logtis,
574
+ tar_length_labels=tar_length_labels,
575
  )
576
 
577
+ loss_L1 = masked_l1_loss(
578
  est_lengths=est_lengths, tar_lengths=tar_lengths
579
  )
580
 
 
585
  sec_error = frame_error * 256 / 24000
586
 
587
  log_dict = {
588
+ "loss": loss.item(),
589
+ "loss_L1": loss_L1.item(),
590
+ "loss_CE": loss_CE.item(),
591
+ "sec_error": sec_error.item(),
592
+ "lr": self.scheduler.get_last_lr()[0],
593
  }
594
 
595
  else:
596
  raise NotImplementedError(self.loss_fn)
597
 
 
598
  self.accelerator.backward(loss)
599
 
600
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
601
+ self.accelerator.clip_grad_norm_(
602
+ self.model.parameters(), self.max_grad_norm
603
+ )
604
 
605
  self.optimizer.step()
606
  self.scheduler.step()
 
612
  self.accelerator.log(log_dict, step=global_step)
613
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
614
 
615
+ if (
616
+ global_step % (self.save_per_updates * self.grad_accumulation_steps)
617
+ == 0
618
+ ):
619
  self.save_checkpoint(global_step)
620
  # if self.log_samples and self.accelerator.is_local_main_process:
621
  # Run validation at the end of each epoch (only on the main process)
duration_trainer_with_prompt.py CHANGED
@@ -1,11 +1,11 @@
1
  from __future__ import annotations
2
 
3
  import gc
4
- import os
5
-
6
  import math
 
7
 
8
  import torch
 
9
  import torchaudio
10
  import wandb
11
  from accelerate import Accelerator
@@ -13,25 +13,17 @@ from accelerate.utils import DistributedDataParallelKwargs
13
  from ema_pytorch import EMA
14
  from torch.optim import AdamW
15
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
16
- from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset # <-- Added Subset import
 
17
  from tqdm import tqdm
18
 
19
- import torch.nn.functional as F
20
-
21
  from f5_tts.model import CFM
22
- from f5_tts.model.dataset import collate_fn, DynamicBatchSampler
23
- from f5_tts.model.utils import default, exists
 
24
 
25
  # trainer
26
 
27
- from f5_tts.model.utils import (
28
- default,
29
- exists,
30
- list_str_to_idx,
31
- list_str_to_tensor,
32
- lens_to_mask,
33
- mask_from_frac_lengths,
34
- )
35
 
36
  SAMPLE_RATE = 24_000
37
 
@@ -43,7 +35,7 @@ class Trainer:
43
  vocab_size,
44
  vocab_char_map,
45
  process_token_to_id=True,
46
- loss_fn='L1',
47
  lambda_L1=1,
48
  gumbel_tau=0.5,
49
  n_class=301,
@@ -83,7 +75,13 @@ class Trainer:
83
  self.logger = logger
84
  if self.logger == "wandb":
85
  if exists(wandb_resume_id):
86
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
 
 
 
 
 
 
87
  else:
88
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
89
 
@@ -112,7 +110,7 @@ class Trainer:
112
  self.vocab_size = vocab_size
113
  self.vocab_char_map = vocab_char_map
114
  self.process_token_to_id = process_token_to_id
115
- assert loss_fn in ['L1', 'CE', 'L1_and_CE']
116
  self.loss_fn = loss_fn
117
  self.lambda_L1 = lambda_L1
118
  self.n_class = n_class
@@ -122,7 +120,9 @@ class Trainer:
122
  self.epochs = epochs
123
  self.num_warmup_updates = num_warmup_updates
124
  self.save_per_updates = save_per_updates
125
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
 
 
126
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
127
 
128
  self.batch_size = batch_size
@@ -137,33 +137,44 @@ class Trainer:
137
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
138
  else:
139
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
140
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
141
-
 
 
142
  @property
143
  def is_main(self):
144
  return self.accelerator.is_main_process
145
 
146
  def save_checkpoint(self, step, last=False):
147
  self.accelerator.wait_for_everyone()
148
- if self.is_main:
149
  checkpoint = dict(
150
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
151
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
 
 
152
  scheduler_state_dict=self.scheduler.state_dict(),
153
  step=step,
154
  )
155
  if not os.path.exists(self.checkpoint_path):
156
  os.makedirs(self.checkpoint_path)
157
  if last:
158
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
 
 
159
  else:
160
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
 
 
161
 
162
  def load_checkpoint(self):
163
  if (
164
  not exists(self.checkpoint_path)
165
  or not os.path.exists(self.checkpoint_path)
166
- or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
 
 
 
167
  ):
168
  return 0
169
 
@@ -176,21 +187,32 @@ class Trainer:
176
  key=lambda x: int("".join(filter(str.isdigit, x))),
177
  )[-1]
178
 
179
- print(f'To load from {latest_checkpoint}.')
180
 
181
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
182
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
 
 
183
 
184
- print(f'Loaded from {latest_checkpoint}.')
185
 
186
  if "step" in checkpoint:
187
  # patch for backward compatibility, 305e3ea
188
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
 
 
 
189
  if key in checkpoint["model_state_dict"]:
190
  del checkpoint["model_state_dict"][key]
191
 
192
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
193
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
 
 
 
 
194
  if self.scheduler:
195
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
196
  step = checkpoint["step"]
@@ -200,17 +222,18 @@ class Trainer:
200
  for k, v in checkpoint["ema_model_state_dict"].items()
201
  if k not in ["initted", "step"]
202
  }
203
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
 
 
204
  step = 0
205
-
206
  del checkpoint
207
  gc.collect()
208
 
209
- print(f'Exit load_checkpoint.')
210
 
211
  return step
212
 
213
-
214
  def validate(self, valid_dataloader, global_step):
215
  """
216
  Runs evaluation on the validation set, computes the average loss,
@@ -226,31 +249,40 @@ class Trainer:
226
  for batch in valid_dataloader:
227
 
228
  # Inputs
229
- prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D)
230
- prompt_text = batch['pmt_text']
231
- text = batch['text']
232
 
233
- target_ids = list_str_to_idx(text, self.vocab_char_map).to(prompt_mel.device)
234
- target_ids = target_ids.masked_fill(target_ids==-1, vocab_size)
 
 
235
 
236
- prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(prompt_mel.device)
237
- prompt_ids = prompt_ids.masked_fill(prompt_ids==-1, vocab_size)
 
 
238
 
239
  # Targets
240
- tar_lengths = batch['mel_lengths']
241
 
242
  # Forward
243
- predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C)
244
-
245
- if self.loss_fn == 'CE':
246
- tar_length_labels = (tar_lengths // self.n_frame_per_class) \
247
- .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1]
 
 
 
248
  est_length_logtis = predictions
249
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
250
  loss = F.cross_entropy(est_length_logtis, tar_length_labels)
251
-
252
  est_lengths = est_length_labels * self.n_frame_per_class
253
- frame_error = (est_lengths.float() - tar_lengths.float()).abs().mean()
 
 
254
  sec_error = frame_error * 256 / 24000
255
 
256
  total_sec_error += sec_error.item()
@@ -262,18 +294,19 @@ class Trainer:
262
 
263
  # Log validation metrics
264
  self.accelerator.log(
265
- {
266
- f"valid_loss": avg_valid_loss,
267
- f"valid_sec_error": avg_valid_sec_error
268
- },
269
- step=global_step
270
  )
271
-
272
- self.model.train()
273
 
 
274
 
275
- def train(self, train_dataset: Dataset, valid_dataset: Dataset,
276
- num_workers=64, resumable_with_seed: int = None):
 
 
 
 
 
277
  if exists(resumable_with_seed):
278
  generator = torch.Generator()
279
  generator.manual_seed(resumable_with_seed)
@@ -307,7 +340,11 @@ class Trainer:
307
 
308
  sampler = SequentialSampler(train_dataset)
309
  batch_sampler = DynamicBatchSampler(
310
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
 
 
 
 
311
  )
312
  train_dataloader = DataLoader(
313
  train_dataset,
@@ -320,20 +357,26 @@ class Trainer:
320
 
321
  sampler = SequentialSampler(valid_dataset)
322
  batch_sampler = DynamicBatchSampler(
323
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
 
 
 
 
324
  )
325
  # Create validation dataloader (always sequential, no shuffling)
326
  valid_dataloader = DataLoader(
327
  valid_dataset,
328
  collate_fn=collate_fn,
329
  num_workers=num_workers,
330
- pin_memory=True,
331
  persistent_workers=True,
332
  batch_sampler=batch_sampler,
333
  )
334
  else:
335
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
336
-
 
 
337
  # accelerator.prepare() dispatches batches to devices;
338
  # which means the length of dataloader calculated before, should consider the number of devices
339
  warmup_steps = (
@@ -342,10 +385,16 @@ class Trainer:
342
  # otherwise by default with split_batches=False, warmup steps change with num_processes
343
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
344
  decay_steps = total_steps - warmup_steps
345
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
346
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
 
 
 
 
347
  self.scheduler = SequentialLR(
348
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
 
 
349
  )
350
  train_dataloader, self.scheduler = self.accelerator.prepare(
351
  train_dataloader, self.scheduler
@@ -359,7 +408,9 @@ class Trainer:
359
  orig_epoch_step = len(train_dataloader)
360
  skipped_epoch = int(start_step // orig_epoch_step)
361
  skipped_batch = start_step % orig_epoch_step
362
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
 
 
363
  else:
364
  skipped_epoch = 0
365
 
@@ -385,49 +436,65 @@ class Trainer:
385
  for batch in progress_bar:
386
  with self.accelerator.accumulate(self.model):
387
  # Inputs
388
- prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D)
389
- prompt_text = batch['pmt_text']
390
- text = batch['text']
391
-
392
- target_ids = list_str_to_idx(text, self.vocab_char_map).to(prompt_mel.device)
393
- target_ids = target_ids.masked_fill(target_ids==-1, vocab_size)
394
-
395
- prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(prompt_mel.device)
396
- prompt_ids = prompt_ids.masked_fill(prompt_ids==-1, vocab_size)
 
 
 
 
 
 
397
 
398
  # Targets
399
- tar_lengths = batch['mel_lengths']
400
 
401
  # Forward
402
- predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C)
403
-
404
- if self.loss_fn == 'CE':
405
- tar_length_labels = (tar_lengths // self.n_frame_per_class) \
406
- .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1]
 
 
 
 
 
 
 
407
  est_length_logtis = predictions
408
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
409
  loss = F.cross_entropy(est_length_logtis, tar_length_labels)
410
-
411
  with torch.no_grad():
412
  est_lengths = est_length_labels * self.n_frame_per_class
413
- frame_error = (est_lengths.float() - tar_lengths.float()).abs().mean()
 
 
414
  sec_error = frame_error * 256 / 24000
415
 
416
  log_dict = {
417
- 'loss': loss.item(),
418
- 'loss_CE': loss.item(),
419
- 'sec_error': sec_error.item(),
420
- 'lr': self.scheduler.get_last_lr()[0]
421
- }
422
 
423
  else:
424
  raise NotImplementedError(self.loss_fn)
425
 
426
-
427
  self.accelerator.backward(loss)
428
 
429
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
430
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
 
 
431
 
432
  self.optimizer.step()
433
  self.scheduler.step()
@@ -439,7 +506,10 @@ class Trainer:
439
  self.accelerator.log(log_dict, step=global_step)
440
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
441
 
442
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
 
 
 
443
  self.save_checkpoint(global_step)
444
  # if self.log_samples and self.accelerator.is_local_main_process:
445
  # Run validation at the end of each epoch (only on the main process)
 
1
  from __future__ import annotations
2
 
3
  import gc
 
 
4
  import math
5
+ import os
6
 
7
  import torch
8
+ import torch.nn.functional as F
9
  import torchaudio
10
  import wandb
11
  from accelerate import Accelerator
 
13
  from ema_pytorch import EMA
14
  from torch.optim import AdamW
15
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
16
+ from torch.utils.data import Dataset # <-- Added Subset import
17
+ from torch.utils.data import DataLoader, SequentialSampler, Subset
18
  from tqdm import tqdm
19
 
 
 
20
  from f5_tts.model import CFM
21
+ from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
22
+ from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
23
+ list_str_to_tensor, mask_from_frac_lengths)
24
 
25
  # trainer
26
 
 
 
 
 
 
 
 
 
27
 
28
  SAMPLE_RATE = 24_000
29
 
 
35
  vocab_size,
36
  vocab_char_map,
37
  process_token_to_id=True,
38
+ loss_fn="L1",
39
  lambda_L1=1,
40
  gumbel_tau=0.5,
41
  n_class=301,
 
75
  self.logger = logger
76
  if self.logger == "wandb":
77
  if exists(wandb_resume_id):
78
+ init_kwargs = {
79
+ "wandb": {
80
+ "resume": "allow",
81
+ "name": wandb_run_name,
82
+ "id": wandb_resume_id,
83
+ }
84
+ }
85
  else:
86
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
87
 
 
110
  self.vocab_size = vocab_size
111
  self.vocab_char_map = vocab_char_map
112
  self.process_token_to_id = process_token_to_id
113
+ assert loss_fn in ["L1", "CE", "L1_and_CE"]
114
  self.loss_fn = loss_fn
115
  self.lambda_L1 = lambda_L1
116
  self.n_class = n_class
 
120
  self.epochs = epochs
121
  self.num_warmup_updates = num_warmup_updates
122
  self.save_per_updates = save_per_updates
123
+ self.last_per_steps = default(
124
+ last_per_steps, save_per_updates * grad_accumulation_steps
125
+ )
126
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
127
 
128
  self.batch_size = batch_size
 
137
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
138
  else:
139
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
140
+ self.model, self.optimizer = self.accelerator.prepare(
141
+ self.model, self.optimizer
142
+ )
143
+
144
  @property
145
  def is_main(self):
146
  return self.accelerator.is_main_process
147
 
148
  def save_checkpoint(self, step, last=False):
149
  self.accelerator.wait_for_everyone()
150
+ if self.is_main:
151
  checkpoint = dict(
152
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
153
+ optimizer_state_dict=self.accelerator.unwrap_model(
154
+ self.optimizer
155
+ ).state_dict(),
156
  scheduler_state_dict=self.scheduler.state_dict(),
157
  step=step,
158
  )
159
  if not os.path.exists(self.checkpoint_path):
160
  os.makedirs(self.checkpoint_path)
161
  if last:
162
+ self.accelerator.save(
163
+ checkpoint, f"{self.checkpoint_path}/model_last.pt"
164
+ )
165
  else:
166
+ self.accelerator.save(
167
+ checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
168
+ )
169
 
170
  def load_checkpoint(self):
171
  if (
172
  not exists(self.checkpoint_path)
173
  or not os.path.exists(self.checkpoint_path)
174
+ or not any(
175
+ filename.endswith(".pt")
176
+ for filename in os.listdir(self.checkpoint_path)
177
+ )
178
  ):
179
  return 0
180
 
 
187
  key=lambda x: int("".join(filter(str.isdigit, x))),
188
  )[-1]
189
 
190
+ print(f"To load from {latest_checkpoint}.")
191
 
192
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
193
+ checkpoint = torch.load(
194
+ f"{self.checkpoint_path}/{latest_checkpoint}",
195
+ weights_only=True,
196
+ map_location="cpu",
197
+ )
198
 
199
+ print(f"Loaded from {latest_checkpoint}.")
200
 
201
  if "step" in checkpoint:
202
  # patch for backward compatibility, 305e3ea
203
+ for key in [
204
+ "mel_spec.mel_stft.mel_scale.fb",
205
+ "mel_spec.mel_stft.spectrogram.window",
206
+ ]:
207
  if key in checkpoint["model_state_dict"]:
208
  del checkpoint["model_state_dict"][key]
209
 
210
+ self.accelerator.unwrap_model(self.model).load_state_dict(
211
+ checkpoint["model_state_dict"]
212
+ )
213
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(
214
+ checkpoint["optimizer_state_dict"]
215
+ )
216
  if self.scheduler:
217
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
218
  step = checkpoint["step"]
 
222
  for k, v in checkpoint["ema_model_state_dict"].items()
223
  if k not in ["initted", "step"]
224
  }
225
+ self.accelerator.unwrap_model(self.model).load_state_dict(
226
+ checkpoint["model_state_dict"]
227
+ )
228
  step = 0
229
+
230
  del checkpoint
231
  gc.collect()
232
 
233
+ print(f"Exit load_checkpoint.")
234
 
235
  return step
236
 
 
237
  def validate(self, valid_dataloader, global_step):
238
  """
239
  Runs evaluation on the validation set, computes the average loss,
 
249
  for batch in valid_dataloader:
250
 
251
  # Inputs
252
+ prompt_mel = batch["pmt_mel_specs"].permute(0, 2, 1) # (B, L_mel, D)
253
+ prompt_text = batch["pmt_text"]
254
+ text = batch["text"]
255
 
256
+ target_ids = list_str_to_idx(text, self.vocab_char_map).to(
257
+ prompt_mel.device
258
+ )
259
+ target_ids = target_ids.masked_fill(target_ids == -1, vocab_size)
260
 
261
+ prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(
262
+ prompt_mel.device
263
+ )
264
+ prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size)
265
 
266
  # Targets
267
+ tar_lengths = batch["mel_lengths"]
268
 
269
  # Forward
270
+ predictions = SLP(
271
+ target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel
272
+ ) # (B, C)
273
+
274
+ if self.loss_fn == "CE":
275
+ tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp(
276
+ min=0, max=self.n_class - 1
277
+ ) # [0, 1, ..., n_class-1]
278
  est_length_logtis = predictions
279
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
280
  loss = F.cross_entropy(est_length_logtis, tar_length_labels)
281
+
282
  est_lengths = est_length_labels * self.n_frame_per_class
283
+ frame_error = (
284
+ (est_lengths.float() - tar_lengths.float()).abs().mean()
285
+ )
286
  sec_error = frame_error * 256 / 24000
287
 
288
  total_sec_error += sec_error.item()
 
294
 
295
  # Log validation metrics
296
  self.accelerator.log(
297
+ {f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error},
298
+ step=global_step,
 
 
 
299
  )
 
 
300
 
301
+ self.model.train()
302
 
303
+ def train(
304
+ self,
305
+ train_dataset: Dataset,
306
+ valid_dataset: Dataset,
307
+ num_workers=64,
308
+ resumable_with_seed: int = None,
309
+ ):
310
  if exists(resumable_with_seed):
311
  generator = torch.Generator()
312
  generator.manual_seed(resumable_with_seed)
 
340
 
341
  sampler = SequentialSampler(train_dataset)
342
  batch_sampler = DynamicBatchSampler(
343
+ sampler,
344
+ self.batch_size,
345
+ max_samples=self.max_samples,
346
+ random_seed=resumable_with_seed,
347
+ drop_last=False,
348
  )
349
  train_dataloader = DataLoader(
350
  train_dataset,
 
357
 
358
  sampler = SequentialSampler(valid_dataset)
359
  batch_sampler = DynamicBatchSampler(
360
+ sampler,
361
+ self.batch_size,
362
+ max_samples=self.max_samples,
363
+ random_seed=resumable_with_seed,
364
+ drop_last=False,
365
  )
366
  # Create validation dataloader (always sequential, no shuffling)
367
  valid_dataloader = DataLoader(
368
  valid_dataset,
369
  collate_fn=collate_fn,
370
  num_workers=num_workers,
371
+ pin_memory=True,
372
  persistent_workers=True,
373
  batch_sampler=batch_sampler,
374
  )
375
  else:
376
+ raise ValueError(
377
+ f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
378
+ )
379
+
380
  # accelerator.prepare() dispatches batches to devices;
381
  # which means the length of dataloader calculated before, should consider the number of devices
382
  warmup_steps = (
 
385
  # otherwise by default with split_batches=False, warmup steps change with num_processes
386
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
387
  decay_steps = total_steps - warmup_steps
388
+ warmup_scheduler = LinearLR(
389
+ self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps
390
+ )
391
+ decay_scheduler = LinearLR(
392
+ self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps
393
+ )
394
  self.scheduler = SequentialLR(
395
+ self.optimizer,
396
+ schedulers=[warmup_scheduler, decay_scheduler],
397
+ milestones=[warmup_steps],
398
  )
399
  train_dataloader, self.scheduler = self.accelerator.prepare(
400
  train_dataloader, self.scheduler
 
408
  orig_epoch_step = len(train_dataloader)
409
  skipped_epoch = int(start_step // orig_epoch_step)
410
  skipped_batch = start_step % orig_epoch_step
411
+ skipped_dataloader = self.accelerator.skip_first_batches(
412
+ train_dataloader, num_batches=skipped_batch
413
+ )
414
  else:
415
  skipped_epoch = 0
416
 
 
436
  for batch in progress_bar:
437
  with self.accelerator.accumulate(self.model):
438
  # Inputs
439
+ prompt_mel = batch["pmt_mel_specs"].permute(
440
+ 0, 2, 1
441
+ ) # (B, L_mel, D)
442
+ prompt_text = batch["pmt_text"]
443
+ text = batch["text"]
444
+
445
+ target_ids = list_str_to_idx(text, self.vocab_char_map).to(
446
+ prompt_mel.device
447
+ )
448
+ target_ids = target_ids.masked_fill(target_ids == -1, vocab_size)
449
+
450
+ prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(
451
+ prompt_mel.device
452
+ )
453
+ prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size)
454
 
455
  # Targets
456
+ tar_lengths = batch["mel_lengths"]
457
 
458
  # Forward
459
+ predictions = SLP(
460
+ target_ids=target_ids,
461
+ prompt_ids=prompt_ids,
462
+ prompt_mel=prompt_mel,
463
+ ) # (B, C)
464
+
465
+ if self.loss_fn == "CE":
466
+ tar_length_labels = (
467
+ tar_lengths // self.n_frame_per_class
468
+ ).clamp(
469
+ min=0, max=self.n_class - 1
470
+ ) # [0, 1, ..., n_class-1]
471
  est_length_logtis = predictions
472
  est_length_labels = torch.argmax(est_length_logtis, dim=-1)
473
  loss = F.cross_entropy(est_length_logtis, tar_length_labels)
474
+
475
  with torch.no_grad():
476
  est_lengths = est_length_labels * self.n_frame_per_class
477
+ frame_error = (
478
+ (est_lengths.float() - tar_lengths.float()).abs().mean()
479
+ )
480
  sec_error = frame_error * 256 / 24000
481
 
482
  log_dict = {
483
+ "loss": loss.item(),
484
+ "loss_CE": loss.item(),
485
+ "sec_error": sec_error.item(),
486
+ "lr": self.scheduler.get_last_lr()[0],
487
+ }
488
 
489
  else:
490
  raise NotImplementedError(self.loss_fn)
491
 
 
492
  self.accelerator.backward(loss)
493
 
494
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
495
+ self.accelerator.clip_grad_norm_(
496
+ self.model.parameters(), self.max_grad_norm
497
+ )
498
 
499
  self.optimizer.step()
500
  self.scheduler.step()
 
506
  self.accelerator.log(log_dict, step=global_step)
507
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
508
 
509
+ if (
510
+ global_step % (self.save_per_updates * self.grad_accumulation_steps)
511
+ == 0
512
+ ):
513
  self.save_checkpoint(global_step)
514
  # if self.log_samples and self.accelerator.is_local_main_process:
515
  # Run validation at the end of each epoch (only on the main process)
ecapa_tdnn.py CHANGED
@@ -1,23 +1,34 @@
1
  # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
2
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import torchaudio.transforms as trans
 
7
  from ctcmodel import ConformerCTC
8
- # from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool
9
- from pathlib import Path
10
 
11
- ''' Res2Conv1d + BatchNorm1d + ReLU
12
- '''
13
 
14
 
15
  class Res2Conv1dReluBn(nn.Module):
16
- '''
17
  in_channels == out_channels == channels
18
- '''
19
-
20
- def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
 
 
 
 
 
 
 
 
 
21
  super().__init__()
22
  assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
23
  self.scale = scale
@@ -27,7 +38,17 @@ class Res2Conv1dReluBn(nn.Module):
27
  self.convs = []
28
  self.bns = []
29
  for i in range(self.nums):
30
- self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
 
 
 
 
 
 
 
 
 
 
31
  self.bns.append(nn.BatchNorm1d(self.width))
32
  self.convs = nn.ModuleList(self.convs)
33
  self.bns = nn.ModuleList(self.bns)
@@ -51,22 +72,33 @@ class Res2Conv1dReluBn(nn.Module):
51
  return out
52
 
53
 
54
- ''' Conv1d + BatchNorm1d + ReLU
55
- '''
56
 
57
 
58
  class Conv1dReluBn(nn.Module):
59
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
 
 
 
 
 
 
 
 
 
60
  super().__init__()
61
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
 
 
62
  self.bn = nn.BatchNorm1d(out_channels)
63
 
64
  def forward(self, x):
65
  return self.bn(F.relu(self.conv(x)))
66
 
67
 
68
- ''' The SE connection of 1D case.
69
- '''
70
 
71
 
72
  class SE_Connect(nn.Module):
@@ -84,15 +116,32 @@ class SE_Connect(nn.Module):
84
  return out
85
 
86
 
87
- ''' SE-Res2Block of the ECAPA-TDNN architecture.
88
- '''
 
89
 
90
  class SE_Res2Block(nn.Module):
91
- def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
 
 
 
 
 
 
 
 
 
 
92
  super().__init__()
93
- self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
94
- self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
95
- self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
 
 
 
 
 
 
96
  self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
97
 
98
  self.shortcut = None
@@ -116,8 +165,9 @@ class SE_Res2Block(nn.Module):
116
  return x + residual
117
 
118
 
119
- ''' Attentive weighted mean and standard deviation pooling.
120
- '''
 
121
 
122
  class AttentiveStatsPool(nn.Module):
123
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
@@ -126,16 +176,24 @@ class AttentiveStatsPool(nn.Module):
126
 
127
  # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
128
  if global_context_att:
129
- self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
 
 
130
  else:
131
- self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
132
- self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
 
 
 
 
133
 
134
  def forward(self, x):
135
 
136
  if self.global_context_att:
137
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
138
- context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
 
 
139
  x_in = torch.cat((x, context_mean, context_std), dim=1)
140
  else:
141
  x_in = x
@@ -145,42 +203,52 @@ class AttentiveStatsPool(nn.Module):
145
  # alpha = F.relu(self.linear1(x_in))
146
  alpha = torch.softmax(self.linear2(alpha), dim=2)
147
  mean = torch.sum(alpha * x, dim=2)
148
- residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
149
  std = torch.sqrt(residuals.clamp(min=1e-9))
150
  return torch.cat([mean, std], dim=1)
151
 
152
 
153
  class ECAPA_TDNN(nn.Module):
154
- def __init__(self, channels=512, emb_dim=512,
155
- global_context_att=False, use_fp16=True,
 
 
 
 
156
  ctc_cls=ConformerCTC,
157
- ctc_path='/data4/F5TTS/ckpts/F5TTS_norm_ASR_vocos_pinyin_Emilia_ZH_EN/model_last.pt',
158
- ctc_args={'vocab_size': 2545, 'mel_dim': 100, 'num_heads': 8, 'd_hid': 512, 'nlayers': 6},
159
- ctc_no_grad=False
 
 
 
 
 
 
160
  ):
161
  super().__init__()
162
  if ctc_path != None:
163
  ctc_path = Path(ctc_path)
164
  model = ctc_cls(**ctc_args)
165
- state_dict = torch.load(ctc_path, map_location='cpu')
166
- model.load_state_dict(state_dict['model_state_dict'])
167
  print(f"Initialized pretrained ConformerCTC backbone from {ctc_path}.")
168
  else:
169
  raise ValueError(ctc_path)
170
 
171
  self.ctc_model = model
172
  self.ctc_model.out.requires_grad_(False)
173
-
174
  if ctc_cls == ConformerCTC:
175
- self.feat_num = ctc_args['nlayers'] + 2 + 1
176
  # elif ctc_cls == ConformerCTCNoPool:
177
  # self.feat_num = ctc_args['nlayers'] + 1
178
  else:
179
  raise ValueError(ctc_cls)
180
- feat_dim = ctc_args['d_hid']
181
 
182
  self.emb_dim = emb_dim
183
-
184
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
185
  self.instance_norm = nn.InstanceNorm1d(feat_dim)
186
 
@@ -188,14 +256,45 @@ class ECAPA_TDNN(nn.Module):
188
  self.channels = [channels] * 4 + [1536]
189
 
190
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
191
- self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
192
- self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
193
- self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
196
  cat_channels = channels * 3
197
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
198
- self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
 
 
 
 
199
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
200
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
201
 
@@ -206,21 +305,26 @@ class ECAPA_TDNN(nn.Module):
206
  else:
207
  self.ctc_model = self.ctc_model.train()
208
  self.ctc_no_grad = ctc_no_grad
209
- print('ctc_no_grad: ', self.ctc_no_grad)
210
 
211
- def forward(self, latent, input_lengths, return_asr=False):
212
  if self.ctc_no_grad:
213
  with torch.no_grad():
214
  asr, h = self.ctc_model(latent, input_lengths)
215
  else:
216
  asr, h = self.ctc_model(latent, input_lengths)
217
-
218
  x = torch.stack(h, dim=0)
219
- norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
 
 
 
 
 
220
  x = (norm_weights * x).sum(dim=0)
221
  x = x + 1e-6
222
  # x = torch.transpose(x, 1, 2) + 1e-6
223
-
224
  x = self.instance_norm(x)
225
  # x = torch.transpose(x, 1, 2)
226
 
@@ -238,9 +342,10 @@ class ECAPA_TDNN(nn.Module):
238
  return out, asr
239
  return out
240
 
 
241
  if __name__ == "__main__":
242
- from diffspeech.ldm.model import DiT
243
  from diffspeech.data.collate import get_mask_from_lengths
 
244
  from diffspeech.tools.text.vocab import IPA
245
 
246
  bsz = 3
@@ -265,4 +370,4 @@ if __name__ == "__main__":
265
 
266
  emb = model(latent, latent_mask.sum(axis=-1))
267
 
268
- print(emb.shape)
 
1
  # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
2
 
3
+ # from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool
4
+ from pathlib import Path
5
+
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
  import torchaudio.transforms as trans
10
+
11
  from ctcmodel import ConformerCTC
 
 
12
 
13
+ """ Res2Conv1d + BatchNorm1d + ReLU
14
+ """
15
 
16
 
17
  class Res2Conv1dReluBn(nn.Module):
18
+ """
19
  in_channels == out_channels == channels
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ channels,
25
+ kernel_size=1,
26
+ stride=1,
27
+ padding=0,
28
+ dilation=1,
29
+ bias=True,
30
+ scale=4,
31
+ ):
32
  super().__init__()
33
  assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
34
  self.scale = scale
 
38
  self.convs = []
39
  self.bns = []
40
  for i in range(self.nums):
41
+ self.convs.append(
42
+ nn.Conv1d(
43
+ self.width,
44
+ self.width,
45
+ kernel_size,
46
+ stride,
47
+ padding,
48
+ dilation,
49
+ bias=bias,
50
+ )
51
+ )
52
  self.bns.append(nn.BatchNorm1d(self.width))
53
  self.convs = nn.ModuleList(self.convs)
54
  self.bns = nn.ModuleList(self.bns)
 
72
  return out
73
 
74
 
75
+ """ Conv1d + BatchNorm1d + ReLU
76
+ """
77
 
78
 
79
  class Conv1dReluBn(nn.Module):
80
+ def __init__(
81
+ self,
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size=1,
85
+ stride=1,
86
+ padding=0,
87
+ dilation=1,
88
+ bias=True,
89
+ ):
90
  super().__init__()
91
+ self.conv = nn.Conv1d(
92
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
93
+ )
94
  self.bn = nn.BatchNorm1d(out_channels)
95
 
96
  def forward(self, x):
97
  return self.bn(F.relu(self.conv(x)))
98
 
99
 
100
+ """ The SE connection of 1D case.
101
+ """
102
 
103
 
104
  class SE_Connect(nn.Module):
 
116
  return out
117
 
118
 
119
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
120
+ """
121
+
122
 
123
  class SE_Res2Block(nn.Module):
124
+ def __init__(
125
+ self,
126
+ in_channels,
127
+ out_channels,
128
+ kernel_size,
129
+ stride,
130
+ padding,
131
+ dilation,
132
+ scale,
133
+ se_bottleneck_dim,
134
+ ):
135
  super().__init__()
136
+ self.Conv1dReluBn1 = Conv1dReluBn(
137
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
138
+ )
139
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
140
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
141
+ )
142
+ self.Conv1dReluBn2 = Conv1dReluBn(
143
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
144
+ )
145
  self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
146
 
147
  self.shortcut = None
 
165
  return x + residual
166
 
167
 
168
+ """ Attentive weighted mean and standard deviation pooling.
169
+ """
170
+
171
 
172
  class AttentiveStatsPool(nn.Module):
173
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
 
176
 
177
  # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
178
  if global_context_att:
179
+ self.linear1 = nn.Conv1d(
180
+ in_dim * 3, attention_channels, kernel_size=1
181
+ ) # equals W and b in the paper
182
  else:
183
+ self.linear1 = nn.Conv1d(
184
+ in_dim, attention_channels, kernel_size=1
185
+ ) # equals W and b in the paper
186
+ self.linear2 = nn.Conv1d(
187
+ attention_channels, in_dim, kernel_size=1
188
+ ) # equals V and k in the paper
189
 
190
  def forward(self, x):
191
 
192
  if self.global_context_att:
193
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
194
+ context_std = torch.sqrt(
195
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
196
+ ).expand_as(x)
197
  x_in = torch.cat((x, context_mean, context_std), dim=1)
198
  else:
199
  x_in = x
 
203
  # alpha = F.relu(self.linear1(x_in))
204
  alpha = torch.softmax(self.linear2(alpha), dim=2)
205
  mean = torch.sum(alpha * x, dim=2)
206
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
207
  std = torch.sqrt(residuals.clamp(min=1e-9))
208
  return torch.cat([mean, std], dim=1)
209
 
210
 
211
  class ECAPA_TDNN(nn.Module):
212
+ def __init__(
213
+ self,
214
+ channels=512,
215
+ emb_dim=512,
216
+ global_context_att=False,
217
+ use_fp16=True,
218
  ctc_cls=ConformerCTC,
219
+ ctc_path="/data4/F5TTS/ckpts/F5TTS_norm_ASR_vocos_pinyin_Emilia_ZH_EN/model_last.pt",
220
+ ctc_args={
221
+ "vocab_size": 2545,
222
+ "mel_dim": 100,
223
+ "num_heads": 8,
224
+ "d_hid": 512,
225
+ "nlayers": 6,
226
+ },
227
+ ctc_no_grad=False,
228
  ):
229
  super().__init__()
230
  if ctc_path != None:
231
  ctc_path = Path(ctc_path)
232
  model = ctc_cls(**ctc_args)
233
+ state_dict = torch.load(ctc_path, map_location="cpu")
234
+ model.load_state_dict(state_dict["model_state_dict"])
235
  print(f"Initialized pretrained ConformerCTC backbone from {ctc_path}.")
236
  else:
237
  raise ValueError(ctc_path)
238
 
239
  self.ctc_model = model
240
  self.ctc_model.out.requires_grad_(False)
241
+
242
  if ctc_cls == ConformerCTC:
243
+ self.feat_num = ctc_args["nlayers"] + 2 + 1
244
  # elif ctc_cls == ConformerCTCNoPool:
245
  # self.feat_num = ctc_args['nlayers'] + 1
246
  else:
247
  raise ValueError(ctc_cls)
248
+ feat_dim = ctc_args["d_hid"]
249
 
250
  self.emb_dim = emb_dim
251
+
252
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
253
  self.instance_norm = nn.InstanceNorm1d(feat_dim)
254
 
 
256
  self.channels = [channels] * 4 + [1536]
257
 
258
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
259
+ self.layer2 = SE_Res2Block(
260
+ self.channels[0],
261
+ self.channels[1],
262
+ kernel_size=3,
263
+ stride=1,
264
+ padding=2,
265
+ dilation=2,
266
+ scale=8,
267
+ se_bottleneck_dim=128,
268
+ )
269
+ self.layer3 = SE_Res2Block(
270
+ self.channels[1],
271
+ self.channels[2],
272
+ kernel_size=3,
273
+ stride=1,
274
+ padding=3,
275
+ dilation=3,
276
+ scale=8,
277
+ se_bottleneck_dim=128,
278
+ )
279
+ self.layer4 = SE_Res2Block(
280
+ self.channels[2],
281
+ self.channels[3],
282
+ kernel_size=3,
283
+ stride=1,
284
+ padding=4,
285
+ dilation=4,
286
+ scale=8,
287
+ se_bottleneck_dim=128,
288
+ )
289
 
290
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
291
  cat_channels = channels * 3
292
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
293
+ self.pooling = AttentiveStatsPool(
294
+ self.channels[-1],
295
+ attention_channels=128,
296
+ global_context_att=global_context_att,
297
+ )
298
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
299
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
300
 
 
305
  else:
306
  self.ctc_model = self.ctc_model.train()
307
  self.ctc_no_grad = ctc_no_grad
308
+ print("ctc_no_grad: ", self.ctc_no_grad)
309
 
310
+ def forward(self, latent, input_lengths, return_asr=False):
311
  if self.ctc_no_grad:
312
  with torch.no_grad():
313
  asr, h = self.ctc_model(latent, input_lengths)
314
  else:
315
  asr, h = self.ctc_model(latent, input_lengths)
316
+
317
  x = torch.stack(h, dim=0)
318
+ norm_weights = (
319
+ F.softmax(self.feature_weight, dim=-1)
320
+ .unsqueeze(-1)
321
+ .unsqueeze(-1)
322
+ .unsqueeze(-1)
323
+ )
324
  x = (norm_weights * x).sum(dim=0)
325
  x = x + 1e-6
326
  # x = torch.transpose(x, 1, 2) + 1e-6
327
+
328
  x = self.instance_norm(x)
329
  # x = torch.transpose(x, 1, 2)
330
 
 
342
  return out, asr
343
  return out
344
 
345
+
346
  if __name__ == "__main__":
 
347
  from diffspeech.data.collate import get_mask_from_lengths
348
+ from diffspeech.ldm.model import DiT
349
  from diffspeech.tools.text.vocab import IPA
350
 
351
  bsz = 3
 
370
 
371
  emb = model(latent, latent_mask.sum(axis=-1))
372
 
373
+ print(emb.shape)
f5_tts/api.py CHANGED
@@ -8,15 +8,10 @@ from cached_path import cached_path
8
  from hydra.utils import get_class
9
  from omegaconf import OmegaConf
10
 
11
- from f5_tts.infer.utils_infer import (
12
- infer_process,
13
- load_model,
14
- load_vocoder,
15
- preprocess_ref_audio_text,
16
- remove_silence_for_generated_wav,
17
- save_spectrogram,
18
- transcribe,
19
- )
20
  from f5_tts.model.utils import seed_everything
21
 
22
 
@@ -32,7 +27,9 @@ class F5TTS:
32
  device=None,
33
  hf_cache_dir=None,
34
  ):
35
- model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
 
 
36
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
37
  model_arc = model_cfg.model.arch
38
 
@@ -50,16 +47,20 @@ class F5TTS:
50
  self.device = (
51
  "cuda"
52
  if torch.cuda.is_available()
53
- else "xpu"
54
- if torch.xpu.is_available()
55
- else "mps"
56
- if torch.backends.mps.is_available()
57
- else "cpu"
58
  )
59
 
60
  # Load models
61
  self.vocoder = load_vocoder(
62
- self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
 
 
 
 
63
  )
64
 
65
  repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
@@ -77,10 +78,20 @@ class F5TTS:
77
 
78
  if not ckpt_file:
79
  ckpt_file = str(
80
- cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
 
 
 
81
  )
82
  self.ema_model = load_model(
83
- model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
 
 
 
 
 
 
 
84
  )
85
 
86
  def transcribe(self, ref_audio, language=None):
 
8
  from hydra.utils import get_class
9
  from omegaconf import OmegaConf
10
 
11
+ from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
12
+ preprocess_ref_audio_text,
13
+ remove_silence_for_generated_wav,
14
+ save_spectrogram, transcribe)
 
 
 
 
 
15
  from f5_tts.model.utils import seed_everything
16
 
17
 
 
27
  device=None,
28
  hf_cache_dir=None,
29
  ):
30
+ model_cfg = OmegaConf.load(
31
+ str(files("f5_tts").joinpath(f"configs/{model}.yaml"))
32
+ )
33
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
34
  model_arc = model_cfg.model.arch
35
 
 
47
  self.device = (
48
  "cuda"
49
  if torch.cuda.is_available()
50
+ else (
51
+ "xpu"
52
+ if torch.xpu.is_available()
53
+ else "mps" if torch.backends.mps.is_available() else "cpu"
54
+ )
55
  )
56
 
57
  # Load models
58
  self.vocoder = load_vocoder(
59
+ self.mel_spec_type,
60
+ vocoder_local_path is not None,
61
+ vocoder_local_path,
62
+ self.device,
63
+ hf_cache_dir,
64
  )
65
 
66
  repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
 
78
 
79
  if not ckpt_file:
80
  ckpt_file = str(
81
+ cached_path(
82
+ f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}",
83
+ cache_dir=hf_cache_dir,
84
+ )
85
  )
86
  self.ema_model = load_model(
87
+ model_cls,
88
+ model_arc,
89
+ ckpt_file,
90
+ self.mel_spec_type,
91
+ vocab_file,
92
+ self.ode_method,
93
+ self.use_ema,
94
+ self.device,
95
  )
96
 
97
  def transcribe(self, ref_audio, language=None):
f5_tts/eval/ecapa_tdnn.py CHANGED
@@ -9,7 +9,6 @@ import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
 
12
-
13
  """ Res2Conv1d + BatchNorm1d + ReLU
14
  """
15
 
@@ -19,7 +18,16 @@ class Res2Conv1dReluBn(nn.Module):
19
  in_channels == out_channels == channels
20
  """
21
 
22
- def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
 
 
 
 
 
 
 
 
 
23
  super().__init__()
24
  assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
25
  self.scale = scale
@@ -29,7 +37,17 @@ class Res2Conv1dReluBn(nn.Module):
29
  self.convs = []
30
  self.bns = []
31
  for i in range(self.nums):
32
- self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
 
 
 
 
 
 
 
 
 
 
33
  self.bns.append(nn.BatchNorm1d(self.width))
34
  self.convs = nn.ModuleList(self.convs)
35
  self.bns = nn.ModuleList(self.bns)
@@ -58,9 +76,20 @@ class Res2Conv1dReluBn(nn.Module):
58
 
59
 
60
  class Conv1dReluBn(nn.Module):
61
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
 
 
 
 
 
 
 
 
 
62
  super().__init__()
63
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
 
 
64
  self.bn = nn.BatchNorm1d(out_channels)
65
 
66
  def forward(self, x):
@@ -99,11 +128,27 @@ class SE_Connect(nn.Module):
99
 
100
 
101
  class SE_Res2Block(nn.Module):
102
- def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
 
 
 
 
 
 
 
 
 
 
103
  super().__init__()
104
- self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
105
- self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
106
- self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
 
 
 
 
 
 
107
  self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
108
 
109
  self.shortcut = None
@@ -138,15 +183,23 @@ class AttentiveStatsPool(nn.Module):
138
 
139
  # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
140
  if global_context_att:
141
- self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
 
 
142
  else:
143
- self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
144
- self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
 
 
 
 
145
 
146
  def forward(self, x):
147
  if self.global_context_att:
148
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
149
- context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
 
 
150
  x_in = torch.cat((x, context_mean, context_std), dim=1)
151
  else:
152
  x_in = x
@@ -184,24 +237,36 @@ class ECAPA_TDNN(nn.Module):
184
  torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
185
  try:
186
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
187
- self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
 
 
188
  except: # noqa: E722
189
  self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
190
 
191
  if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
192
  self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
193
  ):
194
- self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
 
 
195
  if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
196
  self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
197
  ):
198
- self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
 
 
199
 
200
  self.feat_num = self.get_feat_num()
201
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
202
 
203
  if feat_type != "fbank" and feat_type != "mfcc":
204
- freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
 
 
 
 
 
 
205
  for name, param in self.feature_extract.named_parameters():
206
  for freeze_val in freeze_list:
207
  if freeze_val in name:
@@ -252,7 +317,9 @@ class ECAPA_TDNN(nn.Module):
252
  cat_channels = channels * 3
253
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
254
  self.pooling = AttentiveStatsPool(
255
- self.channels[-1], attention_channels=128, global_context_att=global_context_att
 
 
256
  )
257
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
258
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
@@ -287,7 +354,12 @@ class ECAPA_TDNN(nn.Module):
287
  x = torch.stack(x, dim=0)
288
  else:
289
  x = x.unsqueeze(0)
290
- norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
 
 
 
 
 
291
  x = (norm_weights * x).sum(dim=0)
292
  x = torch.transpose(x, 1, 2) + 1e-6
293
 
 
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
 
 
12
  """ Res2Conv1d + BatchNorm1d + ReLU
13
  """
14
 
 
18
  in_channels == out_channels == channels
19
  """
20
 
21
+ def __init__(
22
+ self,
23
+ channels,
24
+ kernel_size=1,
25
+ stride=1,
26
+ padding=0,
27
+ dilation=1,
28
+ bias=True,
29
+ scale=4,
30
+ ):
31
  super().__init__()
32
  assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
33
  self.scale = scale
 
37
  self.convs = []
38
  self.bns = []
39
  for i in range(self.nums):
40
+ self.convs.append(
41
+ nn.Conv1d(
42
+ self.width,
43
+ self.width,
44
+ kernel_size,
45
+ stride,
46
+ padding,
47
+ dilation,
48
+ bias=bias,
49
+ )
50
+ )
51
  self.bns.append(nn.BatchNorm1d(self.width))
52
  self.convs = nn.ModuleList(self.convs)
53
  self.bns = nn.ModuleList(self.bns)
 
76
 
77
 
78
  class Conv1dReluBn(nn.Module):
79
+ def __init__(
80
+ self,
81
+ in_channels,
82
+ out_channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ dilation=1,
87
+ bias=True,
88
+ ):
89
  super().__init__()
90
+ self.conv = nn.Conv1d(
91
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
92
+ )
93
  self.bn = nn.BatchNorm1d(out_channels)
94
 
95
  def forward(self, x):
 
128
 
129
 
130
  class SE_Res2Block(nn.Module):
131
+ def __init__(
132
+ self,
133
+ in_channels,
134
+ out_channels,
135
+ kernel_size,
136
+ stride,
137
+ padding,
138
+ dilation,
139
+ scale,
140
+ se_bottleneck_dim,
141
+ ):
142
  super().__init__()
143
+ self.Conv1dReluBn1 = Conv1dReluBn(
144
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
145
+ )
146
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
147
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
148
+ )
149
+ self.Conv1dReluBn2 = Conv1dReluBn(
150
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
151
+ )
152
  self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
153
 
154
  self.shortcut = None
 
183
 
184
  # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
185
  if global_context_att:
186
+ self.linear1 = nn.Conv1d(
187
+ in_dim * 3, attention_channels, kernel_size=1
188
+ ) # equals W and b in the paper
189
  else:
190
+ self.linear1 = nn.Conv1d(
191
+ in_dim, attention_channels, kernel_size=1
192
+ ) # equals W and b in the paper
193
+ self.linear2 = nn.Conv1d(
194
+ attention_channels, in_dim, kernel_size=1
195
+ ) # equals V and k in the paper
196
 
197
  def forward(self, x):
198
  if self.global_context_att:
199
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
200
+ context_std = torch.sqrt(
201
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
202
+ ).expand_as(x)
203
  x_in = torch.cat((x, context_mean, context_std), dim=1)
204
  else:
205
  x_in = x
 
237
  torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
238
  try:
239
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
240
+ self.feature_extract = torch.hub.load(
241
+ local_s3prl_path, feat_type, source="local", config_path=config_path
242
+ )
243
  except: # noqa: E722
244
  self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
245
 
246
  if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
247
  self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
248
  ):
249
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = (
250
+ False
251
+ )
252
  if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
253
  self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
254
  ):
255
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = (
256
+ False
257
+ )
258
 
259
  self.feat_num = self.get_feat_num()
260
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
261
 
262
  if feat_type != "fbank" and feat_type != "mfcc":
263
+ freeze_list = [
264
+ "final_proj",
265
+ "label_embs_concat",
266
+ "mask_emb",
267
+ "project_q",
268
+ "quantizer",
269
+ ]
270
  for name, param in self.feature_extract.named_parameters():
271
  for freeze_val in freeze_list:
272
  if freeze_val in name:
 
317
  cat_channels = channels * 3
318
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
319
  self.pooling = AttentiveStatsPool(
320
+ self.channels[-1],
321
+ attention_channels=128,
322
+ global_context_att=global_context_att,
323
  )
324
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
325
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
 
354
  x = torch.stack(x, dim=0)
355
  else:
356
  x = x.unsqueeze(0)
357
+ norm_weights = (
358
+ F.softmax(self.feature_weight, dim=-1)
359
+ .unsqueeze(-1)
360
+ .unsqueeze(-1)
361
+ .unsqueeze(-1)
362
+ )
363
  x = (norm_weights * x).sum(dim=0)
364
  x = torch.transpose(x, 1, 2) + 1e-6
365
 
f5_tts/eval/eval_infer_batch.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import sys
3
 
4
-
5
  sys.path.append(os.getcwd())
6
 
7
  import argparse
@@ -15,16 +14,13 @@ from hydra.utils import get_class
15
  from omegaconf import OmegaConf
16
  from tqdm import tqdm
17
 
18
- from f5_tts.eval.utils_eval import (
19
- get_inference_prompt,
20
- get_librispeech_test_clean_metainfo,
21
- get_seedtts_testset_metainfo,
22
- )
23
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
24
  from f5_tts.model import CFM
25
  from f5_tts.model.utils import get_tokenizer
26
 
27
-
28
  accelerator = Accelerator()
29
  device = f"cuda:{accelerator.process_index}"
30
 
@@ -67,7 +63,9 @@ def main():
67
  use_truth_duration = False
68
  no_ref_audio = False
69
 
70
- model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
 
 
71
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
72
  model_arc = model_cfg.model.arch
73
 
@@ -83,8 +81,12 @@ def main():
83
 
84
  if testset == "ls_pc_test_clean":
85
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
86
- librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
87
- metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
 
 
 
 
88
 
89
  elif testset == "seedtts_test_zh":
90
  metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
@@ -126,14 +128,18 @@ def main():
126
  vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
127
  elif mel_spec_type == "bigvgan":
128
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
129
- vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
 
 
130
 
131
  # Tokenizer
132
  vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
133
 
134
  # Model
135
  model = CFM(
136
- transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
 
 
137
  mel_spec_kwargs=dict(
138
  n_fft=n_fft,
139
  hop_length=hop_length,
@@ -154,7 +160,9 @@ def main():
154
  elif os.path.exists(ckpt_prefix + ".safetensors"):
155
  ckpt_path = ckpt_prefix + ".safetensors"
156
  else:
157
- print("Loading from self-organized training checkpoints rather than released pretrained.")
 
 
158
  ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
159
 
160
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
@@ -169,7 +177,14 @@ def main():
169
 
170
  with accelerator.split_between_processes(prompts_all) as prompts:
171
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
172
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
 
 
 
 
 
 
 
173
  ref_mels = ref_mels.to(device)
174
  ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
175
  total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
@@ -198,7 +213,11 @@ def main():
198
 
199
  if ref_rms_list[i] < target_rms:
200
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
201
- torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
 
 
 
 
202
 
203
  accelerator.wait_for_everyone()
204
  if accelerator.is_main_process:
 
1
  import os
2
  import sys
3
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import argparse
 
14
  from omegaconf import OmegaConf
15
  from tqdm import tqdm
16
 
17
+ from f5_tts.eval.utils_eval import (get_inference_prompt,
18
+ get_librispeech_test_clean_metainfo,
19
+ get_seedtts_testset_metainfo)
 
 
20
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21
  from f5_tts.model import CFM
22
  from f5_tts.model.utils import get_tokenizer
23
 
 
24
  accelerator = Accelerator()
25
  device = f"cuda:{accelerator.process_index}"
26
 
 
63
  use_truth_duration = False
64
  no_ref_audio = False
65
 
66
+ model_cfg = OmegaConf.load(
67
+ str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))
68
+ )
69
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
70
  model_arc = model_cfg.model.arch
71
 
 
81
 
82
  if testset == "ls_pc_test_clean":
83
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
84
+ librispeech_test_clean_path = (
85
+ "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
86
+ )
87
+ metainfo = get_librispeech_test_clean_metainfo(
88
+ metalst, librispeech_test_clean_path
89
+ )
90
 
91
  elif testset == "seedtts_test_zh":
92
  metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
 
128
  vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
  elif mel_spec_type == "bigvgan":
130
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
131
+ vocoder = load_vocoder(
132
+ vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path
133
+ )
134
 
135
  # Tokenizer
136
  vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
137
 
138
  # Model
139
  model = CFM(
140
+ transformer=model_cls(
141
+ **model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels
142
+ ),
143
  mel_spec_kwargs=dict(
144
  n_fft=n_fft,
145
  hop_length=hop_length,
 
160
  elif os.path.exists(ckpt_prefix + ".safetensors"):
161
  ckpt_path = ckpt_prefix + ".safetensors"
162
  else:
163
+ print(
164
+ "Loading from self-organized training checkpoints rather than released pretrained."
165
+ )
166
  ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
167
 
168
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
 
177
 
178
  with accelerator.split_between_processes(prompts_all) as prompts:
179
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
180
+ (
181
+ utts,
182
+ ref_rms_list,
183
+ ref_mels,
184
+ ref_mel_lens,
185
+ total_mel_lens,
186
+ final_text_list,
187
+ ) = prompt
188
  ref_mels = ref_mels.to(device)
189
  ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
190
  total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
 
213
 
214
  if ref_rms_list[i] < target_rms:
215
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
216
+ torchaudio.save(
217
+ f"{output_dir}/{utts[i]}.wav",
218
+ generated_wave,
219
+ target_sample_rate,
220
+ )
221
 
222
  accelerator.wait_for_everyone()
223
  if accelerator.is_main_process:
f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -5,7 +5,6 @@ import json
5
  import os
6
  import sys
7
 
8
-
9
  sys.path.append(os.getcwd())
10
 
11
  import multiprocessing as mp
@@ -15,18 +14,23 @@ import numpy as np
15
 
16
  from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
17
 
18
-
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
21
 
22
  def get_args():
23
  parser = argparse.ArgumentParser()
24
- parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
 
 
25
  parser.add_argument("-l", "--lang", type=str, default="en")
26
  parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
27
  parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
28
- parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
29
- parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
 
 
 
 
30
  return parser.parse_args()
31
 
32
 
@@ -39,7 +43,9 @@ def main():
39
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
40
 
41
  gpus = list(range(args.gpu_nums))
42
- test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
 
 
43
 
44
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
45
  ## leading to a low similarity for the ground truth in some cases.
@@ -59,13 +65,19 @@ def main():
59
 
60
  if eval_task == "wer":
61
  with mp.Pool(processes=len(gpus)) as pool:
62
- args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
 
 
 
63
  results = pool.map(run_asr_wer, args)
64
  for r in results:
65
  full_results.extend(r)
66
  elif eval_task == "sim":
67
  with mp.Pool(processes=len(gpus)) as pool:
68
- args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
 
 
 
69
  results = pool.map(run_sim, args)
70
  for r in results:
71
  full_results.extend(r)
 
5
  import os
6
  import sys
7
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  import multiprocessing as mp
 
14
 
15
  from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
16
 
 
17
  rel_path = str(files("f5_tts").joinpath("../../"))
18
 
19
 
20
  def get_args():
21
  parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]
24
+ )
25
  parser.add_argument("-l", "--lang", type=str, default="en")
26
  parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
27
  parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
28
+ parser.add_argument(
29
+ "-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use"
30
+ )
31
+ parser.add_argument(
32
+ "--local", action="store_true", help="Use local custom checkpoint directory"
33
+ )
34
  return parser.parse_args()
35
 
36
 
 
43
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
44
 
45
  gpus = list(range(args.gpu_nums))
46
+ test_set = get_librispeech_test(
47
+ metalst, gen_wav_dir, gpus, librispeech_test_clean_path
48
+ )
49
 
50
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
51
  ## leading to a low similarity for the ground truth in some cases.
 
65
 
66
  if eval_task == "wer":
67
  with mp.Pool(processes=len(gpus)) as pool:
68
+ args = [
69
+ (rank, lang, sub_test_set, asr_ckpt_dir)
70
+ for (rank, sub_test_set) in test_set
71
+ ]
72
  results = pool.map(run_asr_wer, args)
73
  for r in results:
74
  full_results.extend(r)
75
  elif eval_task == "sim":
76
  with mp.Pool(processes=len(gpus)) as pool:
77
+ args = [
78
+ (rank, sub_test_set, wavlm_ckpt_dir)
79
+ for (rank, sub_test_set) in test_set
80
+ ]
81
  results = pool.map(run_sim, args)
82
  for r in results:
83
  full_results.extend(r)
f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -5,7 +5,6 @@ import json
5
  import os
6
  import sys
7
 
8
-
9
  sys.path.append(os.getcwd())
10
 
11
  import multiprocessing as mp
@@ -15,17 +14,22 @@ import numpy as np
15
 
16
  from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
17
 
18
-
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
21
 
22
  def get_args():
23
  parser = argparse.ArgumentParser()
24
- parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
 
 
25
  parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
26
  parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
27
- parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
28
- parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
 
 
 
 
29
  return parser.parse_args()
30
 
31
 
@@ -58,13 +62,19 @@ def main():
58
 
59
  if eval_task == "wer":
60
  with mp.Pool(processes=len(gpus)) as pool:
61
- args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
 
 
 
62
  results = pool.map(run_asr_wer, args)
63
  for r in results:
64
  full_results.extend(r)
65
  elif eval_task == "sim":
66
  with mp.Pool(processes=len(gpus)) as pool:
67
- args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
 
 
 
68
  results = pool.map(run_sim, args)
69
  for r in results:
70
  full_results.extend(r)
 
5
  import os
6
  import sys
7
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  import multiprocessing as mp
 
14
 
15
  from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
16
 
 
17
  rel_path = str(files("f5_tts").joinpath("../../"))
18
 
19
 
20
  def get_args():
21
  parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]
24
+ )
25
  parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
26
  parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
27
+ parser.add_argument(
28
+ "-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use"
29
+ )
30
+ parser.add_argument(
31
+ "--local", action="store_true", help="Use local custom checkpoint directory"
32
+ )
33
  return parser.parse_args()
34
 
35
 
 
62
 
63
  if eval_task == "wer":
64
  with mp.Pool(processes=len(gpus)) as pool:
65
+ args = [
66
+ (rank, lang, sub_test_set, asr_ckpt_dir)
67
+ for (rank, sub_test_set) in test_set
68
+ ]
69
  results = pool.map(run_asr_wer, args)
70
  for r in results:
71
  full_results.extend(r)
72
  elif eval_task == "sim":
73
  with mp.Pool(processes=len(gpus)) as pool:
74
+ args = [
75
+ (rank, sub_test_set, wavlm_ckpt_dir)
76
+ for (rank, sub_test_set) in test_set
77
+ ]
78
  results = pool.map(run_sim, args)
79
  for r in results:
80
  full_results.extend(r)
f5_tts/eval/eval_utmos.py CHANGED
@@ -13,9 +13,15 @@ def main():
13
  parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
14
  args = parser.parse_args()
15
 
16
- device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
17
-
18
- predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
 
 
 
 
 
 
19
  predictor = predictor.to(device)
20
 
21
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
 
13
  parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
14
  args = parser.parse_args()
15
 
16
+ device = (
17
+ "cuda"
18
+ if torch.cuda.is_available()
19
+ else "xpu" if torch.xpu.is_available() else "cpu"
20
+ )
21
+
22
+ predictor = torch.hub.load(
23
+ "tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True
24
+ )
25
  predictor = predictor.to(device)
26
 
27
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
f5_tts/eval/utils_eval.py CHANGED
@@ -43,11 +43,15 @@ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
43
 
44
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
45
  ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
46
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
 
 
47
 
48
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
49
  gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
50
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
 
 
51
 
52
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
53
 
@@ -106,13 +110,17 @@ def get_inference_prompt(
106
  mel_spec_type=mel_spec_type,
107
  )
108
 
109
- for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
 
 
110
  # Audio
111
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
112
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
113
  if ref_rms < target_rms:
114
  ref_audio = ref_audio * target_rms / ref_rms
115
- assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
 
 
116
  if ref_sr != target_sample_rate:
117
  resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
118
  ref_audio = resampler(ref_audio)
@@ -145,14 +153,18 @@ def get_inference_prompt(
145
  else:
146
  ref_text_len = len(prompt_text.encode("utf-8"))
147
  gen_text_len = len(gt_text.encode("utf-8"))
148
- total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
 
 
149
 
150
  # deal with batch
151
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
152
- assert min_tokens <= total_mel_len <= max_tokens, (
153
- f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
 
 
 
154
  )
155
- bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
156
 
157
  utts[bucket_i].append(utt)
158
  ref_rms_list[bucket_i].append(ref_rms)
@@ -183,7 +195,14 @@ def get_inference_prompt(
183
  ref_mel_lens[bucket_i],
184
  total_mel_lens[bucket_i],
185
  final_text_list[bucket_i],
186
- ) = [], [], [], [], [], []
 
 
 
 
 
 
 
187
 
188
  # add residual
189
  for bucket_i, bucket_frames in enumerate(batch_accum):
@@ -244,7 +263,9 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
244
  # get librispeech test-clean cross sentence test
245
 
246
 
247
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
 
 
248
  f = open(metalst)
249
  lines = f.readlines()
250
  f.close()
@@ -255,14 +276,21 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path
255
 
256
  if eval_ground_truth:
257
  gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
258
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
 
 
 
 
 
259
  else:
260
  if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
261
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
262
  gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
263
 
264
  ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
265
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
 
 
266
 
267
  test_set_.append((gen_wav, ref_wav, gen_txt))
268
 
@@ -382,7 +410,9 @@ def run_sim(args):
382
  device = f"cuda:{rank}"
383
 
384
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
385
- state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
 
 
386
  model.load_state_dict(state_dict["model"], strict=False)
387
 
388
  use_gpu = True if torch.cuda.is_available() else False
 
43
 
44
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
45
  ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
46
+ ref_wav = os.path.join(
47
+ librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac"
48
+ )
49
 
50
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
51
  gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
52
+ gen_wav = os.path.join(
53
+ librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac"
54
+ )
55
 
56
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
57
 
 
110
  mel_spec_type=mel_spec_type,
111
  )
112
 
113
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
114
+ metainfo, desc="Processing prompts..."
115
+ ):
116
  # Audio
117
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
118
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
119
  if ref_rms < target_rms:
120
  ref_audio = ref_audio * target_rms / ref_rms
121
+ assert (
122
+ ref_audio.shape[-1] > 5000
123
+ ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
124
  if ref_sr != target_sample_rate:
125
  resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
126
  ref_audio = resampler(ref_audio)
 
153
  else:
154
  ref_text_len = len(prompt_text.encode("utf-8"))
155
  gen_text_len = len(gt_text.encode("utf-8"))
156
+ total_mel_len = ref_mel_len + int(
157
+ ref_mel_len / ref_text_len * gen_text_len / speed
158
+ )
159
 
160
  # deal with batch
161
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
162
+ assert (
163
+ min_tokens <= total_mel_len <= max_tokens
164
+ ), f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
165
+ bucket_i = math.floor(
166
+ (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
167
  )
 
168
 
169
  utts[bucket_i].append(utt)
170
  ref_rms_list[bucket_i].append(ref_rms)
 
195
  ref_mel_lens[bucket_i],
196
  total_mel_lens[bucket_i],
197
  final_text_list[bucket_i],
198
+ ) = (
199
+ [],
200
+ [],
201
+ [],
202
+ [],
203
+ [],
204
+ [],
205
+ )
206
 
207
  # add residual
208
  for bucket_i, bucket_frames in enumerate(batch_accum):
 
263
  # get librispeech test-clean cross sentence test
264
 
265
 
266
+ def get_librispeech_test(
267
+ metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False
268
+ ):
269
  f = open(metalst)
270
  lines = f.readlines()
271
  f.close()
 
276
 
277
  if eval_ground_truth:
278
  gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
279
+ gen_wav = os.path.join(
280
+ librispeech_test_clean_path,
281
+ gen_spk_id,
282
+ gen_chaptr_id,
283
+ gen_utt + ".flac",
284
+ )
285
  else:
286
  if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
287
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
288
  gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
289
 
290
  ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
291
+ ref_wav = os.path.join(
292
+ librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac"
293
+ )
294
 
295
  test_set_.append((gen_wav, ref_wav, gen_txt))
296
 
 
410
  device = f"cuda:{rank}"
411
 
412
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
413
+ state_dict = torch.load(
414
+ ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage
415
+ )
416
  model.load_state_dict(state_dict["model"], strict=False)
417
 
418
  use_gpu = True if torch.cuda.is_available() else False
f5_tts/infer/infer_cli.py CHANGED
@@ -14,23 +14,12 @@ from hydra.utils import get_class
14
  from omegaconf import OmegaConf
15
  from unidecode import unidecode
16
 
17
- from f5_tts.infer.utils_infer import (
18
- cfg_strength,
19
- cross_fade_duration,
20
- device,
21
- fix_duration,
22
- infer_process,
23
- load_model,
24
- load_vocoder,
25
- mel_spec_type,
26
- nfe_step,
27
- preprocess_ref_audio_text,
28
- remove_silence_for_generated_wav,
29
- speed,
30
- sway_sampling_coef,
31
- target_rms,
32
- )
33
-
34
 
35
  parser = argparse.ArgumentParser(
36
  prog="python3 infer-cli.py",
@@ -41,7 +30,9 @@ parser.add_argument(
41
  "-c",
42
  "--config",
43
  type=str,
44
- default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
 
 
45
  help="The configuration file, default see infer/examples/basic/basic.toml",
46
  )
47
 
@@ -188,13 +179,17 @@ model = args.model or config.get("model", "F5TTS_v1_Base")
188
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
189
  vocab_file = args.vocab_file or config.get("vocab_file", "")
190
 
191
- ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
 
 
192
  ref_text = (
193
  args.ref_text
194
  if args.ref_text is not None
195
  else config.get("ref_text", "Some call me nature, others call me mother nature.")
196
  )
197
- gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
 
 
198
  gen_file = args.gen_file or config.get("gen_file", "")
199
 
200
  output_dir = args.output_dir or config.get("output_dir", "tests")
@@ -203,21 +198,29 @@ output_file = args.output_file or config.get(
203
  )
204
 
205
  save_chunk = args.save_chunk or config.get("save_chunk", False)
206
- use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
 
 
207
  if save_chunk and use_legacy_text:
208
  print(
209
  "\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
210
  )
211
 
212
  remove_silence = args.remove_silence or config.get("remove_silence", False)
213
- load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
 
 
214
 
215
  vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
216
  target_rms = args.target_rms or config.get("target_rms", target_rms)
217
- cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
 
 
218
  nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
219
  cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
220
- sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
 
 
221
  speed = args.speed or config.get("speed", speed)
222
  fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
223
  device = args.device or config.get("device", device)
@@ -232,7 +235,9 @@ if "voices" in config:
232
  for voice in config["voices"]:
233
  voice_ref_audio = config["voices"][voice]["ref_audio"]
234
  if "infer/examples/" in voice_ref_audio:
235
- config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
 
 
236
 
237
 
238
  # ignore gen_text if gen_file provided
@@ -259,14 +264,18 @@ elif vocoder_name == "bigvgan":
259
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
260
 
261
  vocoder = load_vocoder(
262
- vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
 
 
 
263
  )
264
 
265
 
266
  # load TTS model
267
 
268
  model_cfg = OmegaConf.load(
269
- args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
 
270
  )
271
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
272
  model_arc = model_cfg.model.arch
@@ -288,11 +297,18 @@ elif model == "E2TTS_Base":
288
  ckpt_step = 1200000
289
 
290
  if not ckpt_file:
291
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
 
 
292
 
293
  print(f"Using {model}...")
294
  ema_model = load_model(
295
- model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
 
 
 
 
 
296
  )
297
 
298
 
@@ -309,8 +325,10 @@ def main():
309
  for voice in voices:
310
  print("Voice:", voice)
311
  print("ref_audio ", voices[voice]["ref_audio"])
312
- voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
313
- voices[voice]["ref_audio"], voices[voice]["ref_text"]
 
 
314
  )
315
  print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
316
 
@@ -360,7 +378,10 @@ def main():
360
  if use_legacy_text:
361
  gen_text_ = unidecode(gen_text_)
362
  sf.write(
363
- os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
 
 
 
364
  audio_segment,
365
  final_sample_rate,
366
  )
 
14
  from omegaconf import OmegaConf
15
  from unidecode import unidecode
16
 
17
+ from f5_tts.infer.utils_infer import (cfg_strength, cross_fade_duration,
18
+ device, fix_duration, infer_process,
19
+ load_model, load_vocoder, mel_spec_type,
20
+ nfe_step, preprocess_ref_audio_text,
21
+ remove_silence_for_generated_wav, speed,
22
+ sway_sampling_coef, target_rms)
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  parser = argparse.ArgumentParser(
25
  prog="python3 infer-cli.py",
 
30
  "-c",
31
  "--config",
32
  type=str,
33
+ default=os.path.join(
34
+ files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"
35
+ ),
36
  help="The configuration file, default see infer/examples/basic/basic.toml",
37
  )
38
 
 
179
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
180
  vocab_file = args.vocab_file or config.get("vocab_file", "")
181
 
182
+ ref_audio = args.ref_audio or config.get(
183
+ "ref_audio", "infer/examples/basic/basic_ref_en.wav"
184
+ )
185
  ref_text = (
186
  args.ref_text
187
  if args.ref_text is not None
188
  else config.get("ref_text", "Some call me nature, others call me mother nature.")
189
  )
190
+ gen_text = args.gen_text or config.get(
191
+ "gen_text", "Here we generate something just for test."
192
+ )
193
  gen_file = args.gen_file or config.get("gen_file", "")
194
 
195
  output_dir = args.output_dir or config.get("output_dir", "tests")
 
198
  )
199
 
200
  save_chunk = args.save_chunk or config.get("save_chunk", False)
201
+ use_legacy_text = args.no_legacy_text or config.get(
202
+ "no_legacy_text", False
203
+ ) # no_legacy_text is a store_false arg
204
  if save_chunk and use_legacy_text:
205
  print(
206
  "\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
207
  )
208
 
209
  remove_silence = args.remove_silence or config.get("remove_silence", False)
210
+ load_vocoder_from_local = args.load_vocoder_from_local or config.get(
211
+ "load_vocoder_from_local", False
212
+ )
213
 
214
  vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
215
  target_rms = args.target_rms or config.get("target_rms", target_rms)
216
+ cross_fade_duration = args.cross_fade_duration or config.get(
217
+ "cross_fade_duration", cross_fade_duration
218
+ )
219
  nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
220
  cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
221
+ sway_sampling_coef = args.sway_sampling_coef or config.get(
222
+ "sway_sampling_coef", sway_sampling_coef
223
+ )
224
  speed = args.speed or config.get("speed", speed)
225
  fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
226
  device = args.device or config.get("device", device)
 
235
  for voice in config["voices"]:
236
  voice_ref_audio = config["voices"][voice]["ref_audio"]
237
  if "infer/examples/" in voice_ref_audio:
238
+ config["voices"][voice]["ref_audio"] = str(
239
+ files("f5_tts").joinpath(f"{voice_ref_audio}")
240
+ )
241
 
242
 
243
  # ignore gen_text if gen_file provided
 
264
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
265
 
266
  vocoder = load_vocoder(
267
+ vocoder_name=vocoder_name,
268
+ is_local=load_vocoder_from_local,
269
+ local_path=vocoder_local_path,
270
+ device=device,
271
  )
272
 
273
 
274
  # load TTS model
275
 
276
  model_cfg = OmegaConf.load(
277
+ args.model_cfg
278
+ or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
279
  )
280
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
281
  model_arc = model_cfg.model.arch
 
297
  ckpt_step = 1200000
298
 
299
  if not ckpt_file:
300
+ ckpt_file = str(
301
+ cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")
302
+ )
303
 
304
  print(f"Using {model}...")
305
  ema_model = load_model(
306
+ model_cls,
307
+ model_arc,
308
+ ckpt_file,
309
+ mel_spec_type=vocoder_name,
310
+ vocab_file=vocab_file,
311
+ device=device,
312
  )
313
 
314
 
 
325
  for voice in voices:
326
  print("Voice:", voice)
327
  print("ref_audio ", voices[voice]["ref_audio"])
328
+ voices[voice]["ref_audio"], voices[voice]["ref_text"] = (
329
+ preprocess_ref_audio_text(
330
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
331
+ )
332
  )
333
  print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
334
 
 
378
  if use_legacy_text:
379
  gen_text_ = unidecode(gen_text_)
380
  sf.write(
381
+ os.path.join(
382
+ output_chunk_dir,
383
+ f"{len(generated_audio_segments) - 1}_{gen_text_}.wav",
384
+ ),
385
  audio_segment,
386
  final_sample_rate,
387
  )
f5_tts/infer/infer_gradio.py CHANGED
@@ -19,7 +19,6 @@ import torchaudio
19
  from cached_path import cached_path
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
-
23
  try:
24
  import spaces
25
 
@@ -35,25 +34,21 @@ def gpu_decorator(func):
35
  return func
36
 
37
 
38
- from f5_tts.infer.utils_infer import (
39
- infer_process,
40
- load_model,
41
- load_vocoder,
42
- preprocess_ref_audio_text,
43
- remove_silence_for_generated_wav,
44
- save_spectrogram,
45
- tempfile_kwargs,
46
- )
47
  from f5_tts.model import DiT, UNetT
48
 
49
-
50
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
51
  tts_model_choice = DEFAULT_TTS_MODEL
52
 
53
  DEFAULT_TTS_MODEL_CFG = [
54
  "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
55
  "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
56
- json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
 
 
57
  ]
58
 
59
 
@@ -69,8 +64,12 @@ def load_f5tts():
69
 
70
 
71
  def load_e2tts():
72
- ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
73
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
 
 
 
 
74
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
75
 
76
 
@@ -113,7 +112,8 @@ def chat_model_inference(messages, model, tokenizer):
113
  )
114
 
115
  generated_ids = [
116
- output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
 
117
  ]
118
  return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
119
 
@@ -157,7 +157,9 @@ def infer(
157
  gr.Warning("Please enter text to generate or upload a text file.")
158
  return gr.update(), gr.update(), ref_text
159
 
160
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
 
 
161
 
162
  if model == DEFAULT_TTS_MODEL:
163
  ema_model = F5TTS_ema_model
@@ -172,7 +174,9 @@ def infer(
172
  global custom_ema_model, pre_custom_path
173
  if pre_custom_path != model[1]:
174
  show_info("Loading Custom TTS model...")
175
- custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
 
 
176
  pre_custom_path = model[1]
177
  ema_model = custom_ema_model
178
 
@@ -202,7 +206,9 @@ def infer(
202
  final_wave = final_wave.squeeze().cpu().numpy()
203
 
204
  # Save the spectrogram
205
- with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
 
 
206
  spectrogram_path = tmp_spectrogram.name
207
  save_spectrogram(combined_spectrogram, spectrogram_path)
208
 
@@ -219,7 +225,9 @@ with gr.Blocks() as app_tts:
219
  max_lines=40,
220
  scale=4,
221
  )
222
- gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
 
 
223
  generate_btn = gr.Button("Synthesize", variant="primary")
224
  with gr.Accordion("Advanced Settings", open=False):
225
  with gr.Row():
@@ -229,7 +237,11 @@ with gr.Blocks() as app_tts:
229
  lines=2,
230
  scale=4,
231
  )
232
- ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
 
 
 
 
233
  with gr.Row():
234
  randomize_seed = gr.Checkbox(
235
  label="Randomize Seed",
@@ -417,13 +429,25 @@ with gr.Blocks() as app_multistyle:
417
  regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
418
  with gr.Row():
419
  regular_seed_slider = gr.Slider(
420
- show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
 
 
 
 
 
421
  )
422
  regular_speed_slider = gr.Slider(
423
- show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
 
 
 
 
 
424
  )
425
  with gr.Column(scale=1, min_width=160):
426
- regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
 
 
427
 
428
  # Regular speech type (max 100)
429
  max_speech_types = 100
@@ -450,13 +474,25 @@ with gr.Blocks() as app_multistyle:
450
  ref_text_input = gr.Textbox(label="Reference Text", lines=4)
451
  with gr.Row():
452
  seed_input = gr.Slider(
453
- show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
 
 
 
 
 
454
  )
455
  speed_input = gr.Slider(
456
- show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
 
 
 
 
 
457
  )
458
  with gr.Column(scale=1, min_width=160):
459
- ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
 
 
460
  speech_type_rows.append(row)
461
  speech_type_names.append(name_input)
462
  speech_type_audios.append(audio_input)
@@ -494,7 +530,9 @@ with gr.Blocks() as app_multistyle:
494
  row_updates[speech_type_count] = gr.update(visible=True)
495
  speech_type_count += 1
496
  else:
497
- gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
 
 
498
  return row_updates
499
 
500
  add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
@@ -525,10 +563,14 @@ with gr.Blocks() as app_multistyle:
525
  scale=4,
526
  placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
527
  )
528
- gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
 
 
529
 
530
  def make_insert_speech_type_fn(index):
531
- def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
 
 
532
  current_text = current_text or ""
533
  if not speech_type_name:
534
  gr.Warning("Please enter speech type name before insert.")
@@ -547,7 +589,12 @@ with gr.Blocks() as app_multistyle:
547
  insert_fn = make_insert_speech_type_fn(i)
548
  insert_btn.click(
549
  insert_fn,
550
- inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
 
 
 
 
 
551
  outputs=gen_text_input_multistyle,
552
  )
553
 
@@ -567,7 +614,9 @@ with gr.Blocks() as app_multistyle:
567
  )
568
 
569
  # Generate button
570
- generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
 
 
571
 
572
  # Output audio
573
  audio_output_multistyle = gr.Audio(label="Synthesized Audio")
@@ -613,7 +662,10 @@ with gr.Blocks() as app_multistyle:
613
  speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
614
  ):
615
  if name_input and audio_input:
616
- speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
 
 
 
617
  else:
618
  speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
619
  ref_text_idx += 1
@@ -635,14 +687,22 @@ with gr.Blocks() as app_multistyle:
635
  if name in speech_types:
636
  current_type_name = name
637
  else:
638
- gr.Warning(f"Type {name} is not available, will use Regular as default.")
 
 
639
  current_type_name = "Regular"
640
 
641
  try:
642
  ref_audio = speech_types[current_type_name]["audio"]
643
  except KeyError:
644
- gr.Warning(f"Please provide reference audio for type {current_type_name}.")
645
- return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
 
 
 
 
 
 
646
  ref_text = speech_types[current_type_name].get("ref_text", "")
647
 
648
  if seed_input == -1:
@@ -664,7 +724,9 @@ with gr.Blocks() as app_multistyle:
664
 
665
  generated_audio_segments.append(audio_data)
666
  speech_types[current_type_name]["ref_text"] = ref_text_out
667
- inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
 
 
668
 
669
  # Concatenate all audio segments
670
  if generated_audio_segments:
@@ -676,7 +738,11 @@ with gr.Blocks() as app_multistyle:
676
  )
677
  else:
678
  gr.Warning("No audio generated.")
679
- return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
 
 
 
 
680
 
681
  generate_multistyle_btn.click(
682
  generate_multistyle_speech,
@@ -689,7 +755,9 @@ with gr.Blocks() as app_multistyle:
689
  + [
690
  remove_silence_multistyle,
691
  ],
692
- outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
 
 
693
  )
694
 
695
  # Validation function to disable Generate button if speech types are missing
@@ -753,7 +821,9 @@ Have a conversation with an AI using your reference voice!
753
  torch.cuda.empty_cache()
754
 
755
  show_info(f"Loading chat model: {chat_model_name}")
756
- chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
 
 
757
  chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
758
  show_info(f"Chat model {chat_model_name} loaded successfully!")
759
 
@@ -769,7 +839,9 @@ Have a conversation with an AI using your reference voice!
769
  info="Enter the name of a HuggingFace chat model",
770
  allow_custom_value=not USING_SPACES,
771
  )
772
- load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
 
 
773
  chat_interface_container = gr.Column(visible=USING_SPACES)
774
 
775
  chat_model_name_input.change(
@@ -779,7 +851,9 @@ Have a conversation with an AI using your reference voice!
779
  show_progress="hidden",
780
  )
781
  load_chat_model_btn.click(
782
- load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
 
 
783
  )
784
 
785
  with chat_interface_container:
@@ -796,7 +870,9 @@ Have a conversation with an AI using your reference voice!
796
  scale=3,
797
  )
798
  ref_text_file_chat = gr.File(
799
- label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
 
 
800
  )
801
  with gr.Row():
802
  randomize_seed_chat = gr.Checkbox(
@@ -805,7 +881,9 @@ Have a conversation with an AI using your reference voice!
805
  info="Uncheck to use the seed specified.",
806
  scale=3,
807
  )
808
- seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
 
 
809
  remove_silence_chat = gr.Checkbox(
810
  label="Remove Silences",
811
  value=True,
@@ -855,13 +933,17 @@ Have a conversation with an AI using your reference voice!
855
  """Generate text response from AI"""
856
 
857
  system_prompt_state = [{"role": "system", "content": system_prompt}]
858
- response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
 
 
859
 
860
  conv_state.append({"role": "assistant", "content": response})
861
  return conv_state
862
 
863
  @gpu_decorator
864
- def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
 
 
865
  """Generate TTS audio for AI response"""
866
  if not conv_state or not ref_audio:
867
  return None, ref_text, seed_input
@@ -896,7 +978,11 @@ Have a conversation with an AI using your reference voice!
896
  outputs=[ref_text_chat],
897
  )
898
 
899
- for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
 
 
 
 
900
  user_operation(
901
  process_audio_input,
902
  inputs=[chatbot_interface, audio_input_chat, text_input_chat],
@@ -923,7 +1009,11 @@ Have a conversation with an AI using your reference voice!
923
  )
924
 
925
  # Handle clear button or system prompt change and reset conversation
926
- for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
 
 
 
 
927
  user_operation(
928
  clear_conversation,
929
  outputs=[chatbot_interface, audio_output_chat],
@@ -931,13 +1021,15 @@ Have a conversation with an AI using your reference voice!
931
 
932
 
933
  with gr.Blocks() as app_credits:
934
- gr.Markdown("""
 
935
  # Credits
936
 
937
  * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
938
  * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
939
  * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
940
- """)
 
941
 
942
 
943
  with gr.Blocks() as app:
@@ -958,7 +1050,9 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
958
  """
959
  )
960
 
961
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
 
 
962
 
963
  def load_last_used_custom():
964
  try:
@@ -974,8 +1068,15 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
974
  def switch_tts_model(new_choice):
975
  global tts_model_choice
976
  if new_choice == "Custom": # override in case webpage is refreshed
977
- custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
978
- tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
 
 
 
 
 
 
 
979
  return (
980
  gr.update(visible=True, value=custom_ckpt_path),
981
  gr.update(visible=True, value=custom_vocab_path),
@@ -983,22 +1084,42 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
983
  )
984
  else:
985
  tts_model_choice = new_choice
986
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
 
 
 
 
987
 
988
  def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
989
  global tts_model_choice
990
- tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
 
 
 
 
 
991
  with open(last_used_custom, "w", encoding="utf-8") as f:
992
- f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
 
 
 
 
 
 
 
993
 
994
  with gr.Row():
995
  if not USING_SPACES:
996
  choose_tts_model = gr.Radio(
997
- choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
 
 
998
  )
999
  else:
1000
  choose_tts_model = gr.Radio(
1001
- choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
 
 
1002
  )
1003
  custom_ckpt_path = gr.Dropdown(
1004
  choices=[DEFAULT_TTS_MODEL_CFG[0]],
 
19
  from cached_path import cached_path
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
 
22
  try:
23
  import spaces
24
 
 
34
  return func
35
 
36
 
37
+ from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
38
+ preprocess_ref_audio_text,
39
+ remove_silence_for_generated_wav,
40
+ save_spectrogram, tempfile_kwargs)
 
 
 
 
 
41
  from f5_tts.model import DiT, UNetT
42
 
 
43
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
44
  tts_model_choice = DEFAULT_TTS_MODEL
45
 
46
  DEFAULT_TTS_MODEL_CFG = [
47
  "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
48
  "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
49
+ json.dumps(
50
+ dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
51
+ ),
52
  ]
53
 
54
 
 
64
 
65
 
66
  def load_e2tts():
67
+ ckpt_path = str(
68
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")
69
+ )
70
+ E2TTS_model_cfg = dict(
71
+ dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1
72
+ )
73
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
74
 
75
 
 
112
  )
113
 
114
  generated_ids = [
115
+ output_ids[len(input_ids) :]
116
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
117
  ]
118
  return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
119
 
 
157
  gr.Warning("Please enter text to generate or upload a text file.")
158
  return gr.update(), gr.update(), ref_text
159
 
160
+ ref_audio, ref_text = preprocess_ref_audio_text(
161
+ ref_audio_orig, ref_text, show_info=show_info
162
+ )
163
 
164
  if model == DEFAULT_TTS_MODEL:
165
  ema_model = F5TTS_ema_model
 
174
  global custom_ema_model, pre_custom_path
175
  if pre_custom_path != model[1]:
176
  show_info("Loading Custom TTS model...")
177
+ custom_ema_model = load_custom(
178
+ model[1], vocab_path=model[2], model_cfg=model[3]
179
+ )
180
  pre_custom_path = model[1]
181
  ema_model = custom_ema_model
182
 
 
206
  final_wave = final_wave.squeeze().cpu().numpy()
207
 
208
  # Save the spectrogram
209
+ with tempfile.NamedTemporaryFile(
210
+ suffix=".png", **tempfile_kwargs
211
+ ) as tmp_spectrogram:
212
  spectrogram_path = tmp_spectrogram.name
213
  save_spectrogram(combined_spectrogram, spectrogram_path)
214
 
 
225
  max_lines=40,
226
  scale=4,
227
  )
228
+ gen_text_file = gr.File(
229
+ label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1
230
+ )
231
  generate_btn = gr.Button("Synthesize", variant="primary")
232
  with gr.Accordion("Advanced Settings", open=False):
233
  with gr.Row():
 
237
  lines=2,
238
  scale=4,
239
  )
240
+ ref_text_file = gr.File(
241
+ label="Load Reference Text from File (.txt)",
242
+ file_types=[".txt"],
243
+ scale=1,
244
+ )
245
  with gr.Row():
246
  randomize_seed = gr.Checkbox(
247
  label="Randomize Seed",
 
429
  regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
430
  with gr.Row():
431
  regular_seed_slider = gr.Slider(
432
+ show_label=False,
433
+ minimum=-1,
434
+ maximum=999,
435
+ value=-1,
436
+ step=1,
437
+ info="Seed, -1 for random",
438
  )
439
  regular_speed_slider = gr.Slider(
440
+ show_label=False,
441
+ minimum=0.3,
442
+ maximum=2.0,
443
+ value=1.0,
444
+ step=0.1,
445
+ info="Adjust the speed",
446
  )
447
  with gr.Column(scale=1, min_width=160):
448
+ regular_ref_text_file = gr.File(
449
+ label="Load Reference Text from File (.txt)", file_types=[".txt"]
450
+ )
451
 
452
  # Regular speech type (max 100)
453
  max_speech_types = 100
 
474
  ref_text_input = gr.Textbox(label="Reference Text", lines=4)
475
  with gr.Row():
476
  seed_input = gr.Slider(
477
+ show_label=False,
478
+ minimum=-1,
479
+ maximum=999,
480
+ value=-1,
481
+ step=1,
482
+ info="Seed. -1 for random",
483
  )
484
  speed_input = gr.Slider(
485
+ show_label=False,
486
+ minimum=0.3,
487
+ maximum=2.0,
488
+ value=1.0,
489
+ step=0.1,
490
+ info="Adjust the speed",
491
  )
492
  with gr.Column(scale=1, min_width=160):
493
+ ref_text_file_input = gr.File(
494
+ label="Load Reference Text from File (.txt)", file_types=[".txt"]
495
+ )
496
  speech_type_rows.append(row)
497
  speech_type_names.append(name_input)
498
  speech_type_audios.append(audio_input)
 
530
  row_updates[speech_type_count] = gr.update(visible=True)
531
  speech_type_count += 1
532
  else:
533
+ gr.Warning(
534
+ "Exhausted maximum number of speech types. Consider restart the app."
535
+ )
536
  return row_updates
537
 
538
  add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
 
563
  scale=4,
564
  placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
565
  )
566
+ gen_text_file_multistyle = gr.File(
567
+ label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1
568
+ )
569
 
570
  def make_insert_speech_type_fn(index):
571
+ def insert_speech_type_fn(
572
+ current_text, speech_type_name, speech_type_seed, speech_type_speed
573
+ ):
574
  current_text = current_text or ""
575
  if not speech_type_name:
576
  gr.Warning("Please enter speech type name before insert.")
 
589
  insert_fn = make_insert_speech_type_fn(i)
590
  insert_btn.click(
591
  insert_fn,
592
+ inputs=[
593
+ gen_text_input_multistyle,
594
+ speech_type_names[i],
595
+ speech_type_seeds[i],
596
+ speech_type_speeds[i],
597
+ ],
598
  outputs=gen_text_input_multistyle,
599
  )
600
 
 
614
  )
615
 
616
  # Generate button
617
+ generate_multistyle_btn = gr.Button(
618
+ "Generate Multi-Style Speech", variant="primary"
619
+ )
620
 
621
  # Output audio
622
  audio_output_multistyle = gr.Audio(label="Synthesized Audio")
 
662
  speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
663
  ):
664
  if name_input and audio_input:
665
+ speech_types[name_input] = {
666
+ "audio": audio_input,
667
+ "ref_text": ref_text_input,
668
+ }
669
  else:
670
  speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
671
  ref_text_idx += 1
 
687
  if name in speech_types:
688
  current_type_name = name
689
  else:
690
+ gr.Warning(
691
+ f"Type {name} is not available, will use Regular as default."
692
+ )
693
  current_type_name = "Regular"
694
 
695
  try:
696
  ref_audio = speech_types[current_type_name]["audio"]
697
  except KeyError:
698
+ gr.Warning(
699
+ f"Please provide reference audio for type {current_type_name}."
700
+ )
701
+ return (
702
+ [None]
703
+ + [speech_types[name]["ref_text"] for name in speech_types]
704
+ + [None]
705
+ )
706
  ref_text = speech_types[current_type_name].get("ref_text", "")
707
 
708
  if seed_input == -1:
 
724
 
725
  generated_audio_segments.append(audio_data)
726
  speech_types[current_type_name]["ref_text"] = ref_text_out
727
+ inference_meta_data += (
728
+ json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
729
+ )
730
 
731
  # Concatenate all audio segments
732
  if generated_audio_segments:
 
738
  )
739
  else:
740
  gr.Warning("No audio generated.")
741
+ return (
742
+ [None]
743
+ + [speech_types[name]["ref_text"] for name in speech_types]
744
+ + [None]
745
+ )
746
 
747
  generate_multistyle_btn.click(
748
  generate_multistyle_speech,
 
755
  + [
756
  remove_silence_multistyle,
757
  ],
758
+ outputs=[audio_output_multistyle]
759
+ + speech_type_ref_texts
760
+ + [cherrypick_interface_multistyle],
761
  )
762
 
763
  # Validation function to disable Generate button if speech types are missing
 
821
  torch.cuda.empty_cache()
822
 
823
  show_info(f"Loading chat model: {chat_model_name}")
824
+ chat_model_state = AutoModelForCausalLM.from_pretrained(
825
+ chat_model_name, torch_dtype="auto", device_map="auto"
826
+ )
827
  chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
828
  show_info(f"Chat model {chat_model_name} loaded successfully!")
829
 
 
839
  info="Enter the name of a HuggingFace chat model",
840
  allow_custom_value=not USING_SPACES,
841
  )
842
+ load_chat_model_btn = gr.Button(
843
+ "Load Chat Model", variant="primary", visible=not USING_SPACES
844
+ )
845
  chat_interface_container = gr.Column(visible=USING_SPACES)
846
 
847
  chat_model_name_input.change(
 
851
  show_progress="hidden",
852
  )
853
  load_chat_model_btn.click(
854
+ load_chat_model,
855
+ inputs=[chat_model_name_input],
856
+ outputs=[load_chat_model_btn, chat_interface_container],
857
  )
858
 
859
  with chat_interface_container:
 
870
  scale=3,
871
  )
872
  ref_text_file_chat = gr.File(
873
+ label="Load Reference Text from File (.txt)",
874
+ file_types=[".txt"],
875
+ scale=1,
876
  )
877
  with gr.Row():
878
  randomize_seed_chat = gr.Checkbox(
 
881
  info="Uncheck to use the seed specified.",
882
  scale=3,
883
  )
884
+ seed_input_chat = gr.Number(
885
+ show_label=False, value=0, precision=0, scale=1
886
+ )
887
  remove_silence_chat = gr.Checkbox(
888
  label="Remove Silences",
889
  value=True,
 
933
  """Generate text response from AI"""
934
 
935
  system_prompt_state = [{"role": "system", "content": system_prompt}]
936
+ response = chat_model_inference(
937
+ system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state
938
+ )
939
 
940
  conv_state.append({"role": "assistant", "content": response})
941
  return conv_state
942
 
943
  @gpu_decorator
944
+ def generate_audio_response(
945
+ conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input
946
+ ):
947
  """Generate TTS audio for AI response"""
948
  if not conv_state or not ref_audio:
949
  return None, ref_text, seed_input
 
978
  outputs=[ref_text_chat],
979
  )
980
 
981
+ for user_operation in [
982
+ audio_input_chat.stop_recording,
983
+ text_input_chat.submit,
984
+ send_btn_chat.click,
985
+ ]:
986
  user_operation(
987
  process_audio_input,
988
  inputs=[chatbot_interface, audio_input_chat, text_input_chat],
 
1009
  )
1010
 
1011
  # Handle clear button or system prompt change and reset conversation
1012
+ for user_operation in [
1013
+ clear_btn_chat.click,
1014
+ system_prompt_chat.change,
1015
+ chatbot_interface.clear,
1016
+ ]:
1017
  user_operation(
1018
  clear_conversation,
1019
  outputs=[chatbot_interface, audio_output_chat],
 
1021
 
1022
 
1023
  with gr.Blocks() as app_credits:
1024
+ gr.Markdown(
1025
+ """
1026
  # Credits
1027
 
1028
  * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
1029
  * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
1030
  * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
1031
+ """
1032
+ )
1033
 
1034
 
1035
  with gr.Blocks() as app:
 
1050
  """
1051
  )
1052
 
1053
+ last_used_custom = files("f5_tts").joinpath(
1054
+ "infer/.cache/last_used_custom_model_info_v1.txt"
1055
+ )
1056
 
1057
  def load_last_used_custom():
1058
  try:
 
1068
  def switch_tts_model(new_choice):
1069
  global tts_model_choice
1070
  if new_choice == "Custom": # override in case webpage is refreshed
1071
+ custom_ckpt_path, custom_vocab_path, custom_model_cfg = (
1072
+ load_last_used_custom()
1073
+ )
1074
+ tts_model_choice = (
1075
+ "Custom",
1076
+ custom_ckpt_path,
1077
+ custom_vocab_path,
1078
+ custom_model_cfg,
1079
+ )
1080
  return (
1081
  gr.update(visible=True, value=custom_ckpt_path),
1082
  gr.update(visible=True, value=custom_vocab_path),
 
1084
  )
1085
  else:
1086
  tts_model_choice = new_choice
1087
+ return (
1088
+ gr.update(visible=False),
1089
+ gr.update(visible=False),
1090
+ gr.update(visible=False),
1091
+ )
1092
 
1093
  def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
1094
  global tts_model_choice
1095
+ tts_model_choice = (
1096
+ "Custom",
1097
+ custom_ckpt_path,
1098
+ custom_vocab_path,
1099
+ custom_model_cfg,
1100
+ )
1101
  with open(last_used_custom, "w", encoding="utf-8") as f:
1102
+ f.write(
1103
+ custom_ckpt_path
1104
+ + "\n"
1105
+ + custom_vocab_path
1106
+ + "\n"
1107
+ + custom_model_cfg
1108
+ + "\n"
1109
+ )
1110
 
1111
  with gr.Row():
1112
  if not USING_SPACES:
1113
  choose_tts_model = gr.Radio(
1114
+ choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"],
1115
+ label="Choose TTS Model",
1116
+ value=DEFAULT_TTS_MODEL,
1117
  )
1118
  else:
1119
  choose_tts_model = gr.Radio(
1120
+ choices=[DEFAULT_TTS_MODEL, "E2-TTS"],
1121
+ label="Choose TTS Model",
1122
+ value=DEFAULT_TTS_MODEL,
1123
  )
1124
  custom_ckpt_path = gr.Dropdown(
1125
  choices=[DEFAULT_TTS_MODEL_CFG[0]],
f5_tts/infer/speech_edit.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
 
3
-
4
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
5
 
6
  from importlib.resources import files
@@ -12,19 +11,19 @@ from cached_path import cached_path
12
  from hydra.utils import get_class
13
  from omegaconf import OmegaConf
14
 
15
- from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
 
16
  from f5_tts.model import CFM
17
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
18
 
19
-
20
  device = (
21
  "cuda"
22
  if torch.cuda.is_available()
23
- else "xpu"
24
- if torch.xpu.is_available()
25
- else "mps"
26
- if torch.backends.mps.is_available()
27
- else "cpu"
28
  )
29
 
30
 
@@ -59,7 +58,9 @@ n_fft = model_cfg.model.mel_spec.n_fft
59
 
60
 
61
  # ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
62
- ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
 
 
63
  output_dir = "tests"
64
 
65
 
@@ -103,14 +104,18 @@ if mel_spec_type == "vocos":
103
  vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
104
  elif mel_spec_type == "bigvgan":
105
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
106
- vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
 
 
107
 
108
  # Tokenizer
109
  vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
110
 
111
  # Model
112
  model = CFM(
113
- transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
 
 
114
  mel_spec_kwargs=dict(
115
  n_fft=n_fft,
116
  hop_length=hop_length,
@@ -146,7 +151,14 @@ for part in parts_to_edit:
146
  part_dur = end - start if fix_duration is None else fix_duration.pop(0)
147
  part_dur = part_dur * target_sample_rate
148
  start = start * target_sample_rate
149
- audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
 
 
 
 
 
 
 
150
  edit_mask = torch.cat(
151
  (
152
  edit_mask,
@@ -157,7 +169,9 @@ for part in parts_to_edit:
157
  )
158
  offset = end * target_sample_rate
159
  audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
160
- edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
 
 
161
  audio = audio.to(device)
162
  edit_mask = edit_mask.to(device)
163
 
@@ -201,5 +215,7 @@ with torch.inference_mode():
201
  generated_wave = generated_wave * rms / target_rms
202
 
203
  save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
204
- torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
 
 
205
  print(f"Generated wav: {generated_wave.shape}")
 
1
  import os
2
 
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
 
5
  from importlib.resources import files
 
11
  from hydra.utils import get_class
12
  from omegaconf import OmegaConf
13
 
14
+ from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder,
15
+ save_spectrogram)
16
  from f5_tts.model import CFM
17
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
18
 
 
19
  device = (
20
  "cuda"
21
  if torch.cuda.is_available()
22
+ else (
23
+ "xpu"
24
+ if torch.xpu.is_available()
25
+ else "mps" if torch.backends.mps.is_available() else "cpu"
26
+ )
27
  )
28
 
29
 
 
58
 
59
 
60
  # ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
61
+ ckpt_path = str(
62
+ cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors")
63
+ )
64
  output_dir = "tests"
65
 
66
 
 
104
  vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
105
  elif mel_spec_type == "bigvgan":
106
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
107
+ vocoder = load_vocoder(
108
+ vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path
109
+ )
110
 
111
  # Tokenizer
112
  vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
113
 
114
  # Model
115
  model = CFM(
116
+ transformer=model_cls(
117
+ **model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels
118
+ ),
119
  mel_spec_kwargs=dict(
120
  n_fft=n_fft,
121
  hop_length=hop_length,
 
151
  part_dur = end - start if fix_duration is None else fix_duration.pop(0)
152
  part_dur = part_dur * target_sample_rate
153
  start = start * target_sample_rate
154
+ audio_ = torch.cat(
155
+ (
156
+ audio_,
157
+ audio[:, round(offset) : round(start)],
158
+ torch.zeros(1, round(part_dur)),
159
+ ),
160
+ dim=-1,
161
+ )
162
  edit_mask = torch.cat(
163
  (
164
  edit_mask,
 
169
  )
170
  offset = end * target_sample_rate
171
  audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
172
+ edit_mask = F.pad(
173
+ edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True
174
+ )
175
  audio = audio.to(device)
176
  edit_mask = edit_mask.to(device)
177
 
 
215
  generated_wave = generated_wave * rms / target_rms
216
 
217
  save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
218
+ torchaudio.save(
219
+ f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate
220
+ )
221
  print(f"Generated wav: {generated_wave.shape}")
f5_tts/infer/utils_infer.py CHANGED
@@ -4,9 +4,10 @@ import os
4
  import sys
5
  from concurrent.futures import ThreadPoolExecutor
6
 
7
-
8
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
9
- sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
 
 
10
 
11
  import hashlib
12
  import re
@@ -15,7 +16,6 @@ from importlib.resources import files
15
 
16
  import matplotlib
17
 
18
-
19
  matplotlib.use("Agg")
20
 
21
  import matplotlib.pylab as plt
@@ -31,21 +31,22 @@ from vocos import Vocos
31
  from f5_tts.model import CFM
32
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
33
 
34
-
35
  _ref_audio_cache = {}
36
  _ref_text_cache = {}
37
 
38
  device = (
39
  "cuda"
40
  if torch.cuda.is_available()
41
- else "xpu"
42
- if torch.xpu.is_available()
43
- else "mps"
44
- if torch.backends.mps.is_available()
45
- else "cpu"
46
  )
47
 
48
- tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
 
 
49
 
50
  # -----------------------------------------
51
 
@@ -87,12 +88,23 @@ def chunk_text(text, max_chars=135):
87
  sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
88
 
89
  for sentence in sentences:
90
- if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
91
- current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
 
 
 
 
 
 
 
92
  else:
93
  if current_chunk:
94
  chunks.append(current_chunk.strip())
95
- current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
 
 
 
 
96
 
97
  if current_chunk:
98
  chunks.append(current_chunk.strip())
@@ -101,7 +113,13 @@ def chunk_text(text, max_chars=135):
101
 
102
 
103
  # load vocoder
104
- def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
 
 
 
 
 
 
105
  if vocoder_name == "vocos":
106
  # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
107
  if is_local:
@@ -111,8 +129,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
111
  else:
112
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
113
  repo_id = "charactr/vocos-mel-24khz"
114
- config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
115
- model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
 
 
 
 
116
  vocoder = Vocos.from_hparams(config_path)
117
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
118
  from vocos.feature_extractors import EncodecFeatures
@@ -129,13 +151,17 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
129
  try:
130
  from third_party.BigVGAN import bigvgan
131
  except ImportError:
132
- print("You need to follow the README to init submodule and change the BigVGAN source code.")
 
 
133
  if is_local:
134
  # download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
135
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
136
  else:
137
  vocoder = bigvgan.BigVGAN.from_pretrained(
138
- "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
 
 
139
  )
140
 
141
  vocoder.remove_weight_norm()
@@ -177,7 +203,11 @@ def transcribe(ref_audio, language=None):
177
  ref_audio,
178
  chunk_length_s=30,
179
  batch_size=128,
180
- generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
 
 
 
 
181
  return_timestamps=False,
182
  )["text"].strip()
183
 
@@ -214,7 +244,10 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
214
  }
215
 
216
  # patch for backward compatibility, 305e3ea
217
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
 
 
 
218
  if key in checkpoint["model_state_dict"]:
219
  del checkpoint["model_state_dict"][key]
220
 
@@ -253,7 +286,9 @@ def load_model(
253
 
254
  vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
255
  model = CFM(
256
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
 
 
257
  mel_spec_kwargs=dict(
258
  n_fft=n_fft,
259
  hop_length=hop_length,
@@ -276,7 +311,9 @@ def load_model(
276
 
277
  def remove_silence_edges(audio, silence_threshold=-42):
278
  # Remove silence from the start
279
- non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
 
 
280
  audio = audio[non_silent_start_idx:]
281
 
282
  # Remove silence from the end
@@ -315,11 +352,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
315
 
316
  # 1. try to find long silence for clipping
317
  non_silent_segs = silence.split_on_silence(
318
- aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
 
 
 
 
319
  )
320
  non_silent_wave = AudioSegment.silent(duration=0)
321
  for non_silent_seg in non_silent_segs:
322
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
 
 
 
323
  show_info("Audio is over 12s, clipping short. (1)")
324
  break
325
  non_silent_wave += non_silent_seg
@@ -327,11 +371,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
327
  # 2. try to find short silence for clipping if 1. failed
328
  if len(non_silent_wave) > 12000:
329
  non_silent_segs = silence.split_on_silence(
330
- aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
 
 
 
 
331
  )
332
  non_silent_wave = AudioSegment.silent(duration=0)
333
  for non_silent_seg in non_silent_segs:
334
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
 
 
 
335
  show_info("Audio is over 12s, clipping short. (2)")
336
  break
337
  non_silent_wave += non_silent_seg
@@ -399,7 +450,12 @@ def infer_process(
399
  ):
400
  # Split the input text into batches
401
  audio, sr = torchaudio.load(ref_audio)
402
- max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
 
 
 
 
 
403
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
404
  for i, gen_text in enumerate(gen_text_batches):
405
  print(f"gen_text {i}", gen_text)
@@ -483,7 +539,9 @@ def infer_batch_process(
483
  # Calculate duration
484
  ref_text_len = len(ref_text.encode("utf-8"))
485
  gen_text_len = len(gen_text.encode("utf-8"))
486
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
 
 
487
 
488
  # inference
489
  with torch.inference_mode():
@@ -519,12 +577,19 @@ def infer_batch_process(
519
  yield generated_wave, generated_cpu
520
 
521
  if streaming:
522
- for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
 
 
 
 
523
  for chunk in process_batch(gen_text):
524
  yield chunk
525
  else:
526
  with ThreadPoolExecutor() as executor:
527
- futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
 
 
 
528
  for future in progress.tqdm(futures) if progress is not None else futures:
529
  result = future.result()
530
  if result:
@@ -545,7 +610,9 @@ def infer_batch_process(
545
 
546
  # Calculate cross-fade samples, ensuring it does not exceed wave lengths
547
  cross_fade_samples = int(cross_fade_duration * target_sample_rate)
548
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
 
 
549
 
550
  if cross_fade_samples <= 0:
551
  # No overlap possible, concatenate
@@ -561,11 +628,17 @@ def infer_batch_process(
561
  fade_in = np.linspace(0, 1, cross_fade_samples)
562
 
563
  # Cross-faded overlap
564
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
 
 
565
 
566
  # Combine
567
  new_wave = np.concatenate(
568
- [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
 
 
 
 
569
  )
570
 
571
  final_wave = new_wave
 
4
  import sys
5
  from concurrent.futures import ThreadPoolExecutor
6
 
 
7
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
8
+ sys.path.append(
9
+ f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/"
10
+ )
11
 
12
  import hashlib
13
  import re
 
16
 
17
  import matplotlib
18
 
 
19
  matplotlib.use("Agg")
20
 
21
  import matplotlib.pylab as plt
 
31
  from f5_tts.model import CFM
32
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
33
 
 
34
  _ref_audio_cache = {}
35
  _ref_text_cache = {}
36
 
37
  device = (
38
  "cuda"
39
  if torch.cuda.is_available()
40
+ else (
41
+ "xpu"
42
+ if torch.xpu.is_available()
43
+ else "mps" if torch.backends.mps.is_available() else "cpu"
44
+ )
45
  )
46
 
47
+ tempfile_kwargs = (
48
+ {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
49
+ )
50
 
51
  # -----------------------------------------
52
 
 
88
  sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
89
 
90
  for sentence in sentences:
91
+ if (
92
+ len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8"))
93
+ <= max_chars
94
+ ):
95
+ current_chunk += (
96
+ sentence + " "
97
+ if sentence and len(sentence[-1].encode("utf-8")) == 1
98
+ else sentence
99
+ )
100
  else:
101
  if current_chunk:
102
  chunks.append(current_chunk.strip())
103
+ current_chunk = (
104
+ sentence + " "
105
+ if sentence and len(sentence[-1].encode("utf-8")) == 1
106
+ else sentence
107
+ )
108
 
109
  if current_chunk:
110
  chunks.append(current_chunk.strip())
 
113
 
114
 
115
  # load vocoder
116
+ def load_vocoder(
117
+ vocoder_name="vocos",
118
+ is_local=False,
119
+ local_path="",
120
+ device=device,
121
+ hf_cache_dir=None,
122
+ ):
123
  if vocoder_name == "vocos":
124
  # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
125
  if is_local:
 
129
  else:
130
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
131
  repo_id = "charactr/vocos-mel-24khz"
132
+ config_path = hf_hub_download(
133
+ repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml"
134
+ )
135
+ model_path = hf_hub_download(
136
+ repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin"
137
+ )
138
  vocoder = Vocos.from_hparams(config_path)
139
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
140
  from vocos.feature_extractors import EncodecFeatures
 
151
  try:
152
  from third_party.BigVGAN import bigvgan
153
  except ImportError:
154
+ print(
155
+ "You need to follow the README to init submodule and change the BigVGAN source code."
156
+ )
157
  if is_local:
158
  # download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
159
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
160
  else:
161
  vocoder = bigvgan.BigVGAN.from_pretrained(
162
+ "nvidia/bigvgan_v2_24khz_100band_256x",
163
+ use_cuda_kernel=False,
164
+ cache_dir=hf_cache_dir,
165
  )
166
 
167
  vocoder.remove_weight_norm()
 
203
  ref_audio,
204
  chunk_length_s=30,
205
  batch_size=128,
206
+ generate_kwargs=(
207
+ {"task": "transcribe", "language": language}
208
+ if language
209
+ else {"task": "transcribe"}
210
+ ),
211
  return_timestamps=False,
212
  )["text"].strip()
213
 
 
244
  }
245
 
246
  # patch for backward compatibility, 305e3ea
247
+ for key in [
248
+ "mel_spec.mel_stft.mel_scale.fb",
249
+ "mel_spec.mel_stft.spectrogram.window",
250
+ ]:
251
  if key in checkpoint["model_state_dict"]:
252
  del checkpoint["model_state_dict"][key]
253
 
 
286
 
287
  vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
288
  model = CFM(
289
+ transformer=model_cls(
290
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
291
+ ),
292
  mel_spec_kwargs=dict(
293
  n_fft=n_fft,
294
  hop_length=hop_length,
 
311
 
312
  def remove_silence_edges(audio, silence_threshold=-42):
313
  # Remove silence from the start
314
+ non_silent_start_idx = silence.detect_leading_silence(
315
+ audio, silence_threshold=silence_threshold
316
+ )
317
  audio = audio[non_silent_start_idx:]
318
 
319
  # Remove silence from the end
 
352
 
353
  # 1. try to find long silence for clipping
354
  non_silent_segs = silence.split_on_silence(
355
+ aseg,
356
+ min_silence_len=1000,
357
+ silence_thresh=-50,
358
+ keep_silence=1000,
359
+ seek_step=10,
360
  )
361
  non_silent_wave = AudioSegment.silent(duration=0)
362
  for non_silent_seg in non_silent_segs:
363
+ if (
364
+ len(non_silent_wave) > 6000
365
+ and len(non_silent_wave + non_silent_seg) > 12000
366
+ ):
367
  show_info("Audio is over 12s, clipping short. (1)")
368
  break
369
  non_silent_wave += non_silent_seg
 
371
  # 2. try to find short silence for clipping if 1. failed
372
  if len(non_silent_wave) > 12000:
373
  non_silent_segs = silence.split_on_silence(
374
+ aseg,
375
+ min_silence_len=100,
376
+ silence_thresh=-40,
377
+ keep_silence=1000,
378
+ seek_step=10,
379
  )
380
  non_silent_wave = AudioSegment.silent(duration=0)
381
  for non_silent_seg in non_silent_segs:
382
+ if (
383
+ len(non_silent_wave) > 6000
384
+ and len(non_silent_wave + non_silent_seg) > 12000
385
+ ):
386
  show_info("Audio is over 12s, clipping short. (2)")
387
  break
388
  non_silent_wave += non_silent_seg
 
450
  ):
451
  # Split the input text into batches
452
  audio, sr = torchaudio.load(ref_audio)
453
+ max_chars = int(
454
+ len(ref_text.encode("utf-8"))
455
+ / (audio.shape[-1] / sr)
456
+ * (22 - audio.shape[-1] / sr)
457
+ * speed
458
+ )
459
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
460
  for i, gen_text in enumerate(gen_text_batches):
461
  print(f"gen_text {i}", gen_text)
 
539
  # Calculate duration
540
  ref_text_len = len(ref_text.encode("utf-8"))
541
  gen_text_len = len(gen_text.encode("utf-8"))
542
+ duration = ref_audio_len + int(
543
+ ref_audio_len / ref_text_len * gen_text_len / local_speed
544
+ )
545
 
546
  # inference
547
  with torch.inference_mode():
 
577
  yield generated_wave, generated_cpu
578
 
579
  if streaming:
580
+ for gen_text in (
581
+ progress.tqdm(gen_text_batches)
582
+ if progress is not None
583
+ else gen_text_batches
584
+ ):
585
  for chunk in process_batch(gen_text):
586
  yield chunk
587
  else:
588
  with ThreadPoolExecutor() as executor:
589
+ futures = [
590
+ executor.submit(process_batch, gen_text)
591
+ for gen_text in gen_text_batches
592
+ ]
593
  for future in progress.tqdm(futures) if progress is not None else futures:
594
  result = future.result()
595
  if result:
 
610
 
611
  # Calculate cross-fade samples, ensuring it does not exceed wave lengths
612
  cross_fade_samples = int(cross_fade_duration * target_sample_rate)
613
+ cross_fade_samples = min(
614
+ cross_fade_samples, len(prev_wave), len(next_wave)
615
+ )
616
 
617
  if cross_fade_samples <= 0:
618
  # No overlap possible, concatenate
 
628
  fade_in = np.linspace(0, 1, cross_fade_samples)
629
 
630
  # Cross-faded overlap
631
+ cross_faded_overlap = (
632
+ prev_overlap * fade_out + next_overlap * fade_in
633
+ )
634
 
635
  # Combine
636
  new_wave = np.concatenate(
637
+ [
638
+ prev_wave[:-cross_fade_samples],
639
+ cross_faded_overlap,
640
+ next_wave[cross_fade_samples:],
641
+ ]
642
  )
643
 
644
  final_wave = new_wave
f5_tts/model/__init__.py CHANGED
@@ -1,10 +1,7 @@
1
- from f5_tts.model.cfm import CFM
2
-
3
- from f5_tts.model.backbones.unett import UNetT
4
  from f5_tts.model.backbones.dit import DiT
5
  from f5_tts.model.backbones.mmdit import MMDiT
6
-
 
7
  from f5_tts.model.trainer import Trainer
8
 
9
-
10
  __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
 
 
 
 
1
  from f5_tts.model.backbones.dit import DiT
2
  from f5_tts.model.backbones.mmdit import MMDiT
3
+ from f5_tts.model.backbones.unett import UNetT
4
+ from f5_tts.model.cfm import CFM
5
  from f5_tts.model.trainer import Trainer
6
 
 
7
  __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
f5_tts/model/backbones/dit.py CHANGED
@@ -10,21 +10,14 @@ d - dimension
10
  from __future__ import annotations
11
 
12
  import torch
13
- from torch import nn
14
  import torch.nn.functional as F
15
-
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
- from f5_tts.model.modules import (
19
- TimestepEmbedding,
20
- ConvNeXtV2Block,
21
- ConvPositionEmbedding,
22
- DiTBlock,
23
- AdaLayerNormZero_Final,
24
- precompute_freqs_cis,
25
- get_pos_embed_indices,
26
- )
27
-
28
 
29
  # Text embedding
30
 
@@ -32,34 +25,49 @@ from f5_tts.model.modules import (
32
  class TextEmbedding(nn.Module):
33
  def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
 
 
36
 
37
  if conv_layers > 0:
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
 
 
 
 
41
  self.text_blocks = nn.Sequential(
42
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
 
 
 
43
  )
44
  else:
45
  self.extra_modeling = False
46
 
47
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
48
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
 
 
 
50
  batch, text_len = text.shape[0], text.shape[1]
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
55
-
56
  text = self.text_embed(text) # b n -> b n d
57
 
58
  # possible extra modeling
59
  if self.extra_modeling:
60
  # sinus pos emb
61
  batch_start = torch.zeros((batch,), dtype=torch.long)
62
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
 
 
63
  text_pos_embed = self.freqs_cis[pos_idx]
64
  text = text + text_pos_embed
65
 
@@ -78,7 +86,13 @@ class InputEmbedding(nn.Module):
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80
 
81
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
 
 
 
 
 
 
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
@@ -114,17 +128,23 @@ class DiT(nn.Module):
114
  if second_time:
115
  self.time_embed2 = TimestepEmbedding(dim)
116
  # Zero-init the weights and biases of the first and last Linear layers in time_mlp
117
- nn.init.zeros_(self.time_embed2.time_mlp[0].weight) # First Linear layer weights
118
- nn.init.zeros_(self.time_embed2.time_mlp[0].bias) # First Linear layer bias
119
- nn.init.zeros_(self.time_embed2.time_mlp[-1].weight) # Last Linear layer weights
120
- nn.init.zeros_(self.time_embed2.time_mlp[-1].bias) # Last Linear layer bias
 
 
 
 
121
  else:
122
  self.time_embed2 = None
123
-
124
  if text_dim is None:
125
  text_dim = mel_dim
126
  self.vocab_size = text_num_embeds
127
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
128
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
129
 
130
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -133,9 +153,20 @@ class DiT(nn.Module):
133
  self.depth = depth
134
 
135
  self.transformer_blocks = nn.ModuleList(
136
- [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
 
 
 
137
  )
138
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
139
 
140
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
141
  self.proj_out = nn.Linear(dim, mel_dim)
@@ -171,7 +202,7 @@ class DiT(nn.Module):
171
  if second_time is not None and self.time_embed2 is not None:
172
  t2 = self.time_embed2(second_time)
173
  t = t + t2
174
-
175
  text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
176
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
177
 
@@ -185,7 +216,9 @@ class DiT(nn.Module):
185
 
186
  for block in self.transformer_blocks:
187
  if self.checkpoint_activations:
188
- x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
 
 
189
  else:
190
  x = block(x, t, mask=mask, rope=rope)
191
 
 
10
  from __future__ import annotations
11
 
12
  import torch
 
13
  import torch.nn.functional as F
14
+ from torch import nn
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
+ from f5_tts.model.modules import (AdaLayerNormZero_Final, ConvNeXtV2Block,
18
+ ConvPositionEmbedding, DiTBlock,
19
+ TimestepEmbedding, get_pos_embed_indices,
20
+ precompute_freqs_cis)
 
 
 
 
 
 
21
 
22
  # Text embedding
23
 
 
25
  class TextEmbedding(nn.Module):
26
  def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
27
  super().__init__()
28
+ self.text_embed = nn.Embedding(
29
+ text_num_embeds + 1, text_dim
30
+ ) # use 0 as filler token
31
 
32
  if conv_layers > 0:
33
  self.extra_modeling = True
34
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
35
+ self.register_buffer(
36
+ "freqs_cis",
37
+ precompute_freqs_cis(text_dim, self.precompute_max_pos),
38
+ persistent=False,
39
+ )
40
  self.text_blocks = nn.Sequential(
41
+ *[
42
+ ConvNeXtV2Block(text_dim, text_dim * conv_mult)
43
+ for _ in range(conv_layers)
44
+ ]
45
  )
46
  else:
47
  self.extra_modeling = False
48
 
49
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
50
+ text = (
51
+ text + 1
52
+ ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
53
+ text = text[
54
+ :, :seq_len
55
+ ] # curtail if character tokens are more than the mel spec tokens
56
  batch, text_len = text.shape[0], text.shape[1]
57
  text = F.pad(text, (0, seq_len - text_len), value=0)
58
 
59
  if drop_text: # cfg for text
60
  text = torch.zeros_like(text)
61
+
62
  text = self.text_embed(text) # b n -> b n d
63
 
64
  # possible extra modeling
65
  if self.extra_modeling:
66
  # sinus pos emb
67
  batch_start = torch.zeros((batch,), dtype=torch.long)
68
+ pos_idx = get_pos_embed_indices(
69
+ batch_start, seq_len, max_pos=self.precompute_max_pos
70
+ )
71
  text_pos_embed = self.freqs_cis[pos_idx]
72
  text = text + text_pos_embed
73
 
 
86
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
87
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
88
 
89
+ def forward(
90
+ self,
91
+ x: float["b n d"],
92
+ cond: float["b n d"],
93
+ text_embed: float["b n d"],
94
+ drop_audio_cond=False,
95
+ ): # noqa: F722
96
  if drop_audio_cond: # cfg for cond audio
97
  cond = torch.zeros_like(cond)
98
 
 
128
  if second_time:
129
  self.time_embed2 = TimestepEmbedding(dim)
130
  # Zero-init the weights and biases of the first and last Linear layers in time_mlp
131
+ nn.init.zeros_(
132
+ self.time_embed2.time_mlp[0].weight
133
+ ) # First Linear layer weights
134
+ nn.init.zeros_(self.time_embed2.time_mlp[0].bias) # First Linear layer bias
135
+ nn.init.zeros_(
136
+ self.time_embed2.time_mlp[-1].weight
137
+ ) # Last Linear layer weights
138
+ nn.init.zeros_(self.time_embed2.time_mlp[-1].bias) # Last Linear layer bias
139
  else:
140
  self.time_embed2 = None
141
+
142
  if text_dim is None:
143
  text_dim = mel_dim
144
  self.vocab_size = text_num_embeds
145
+ self.text_embed = TextEmbedding(
146
+ text_num_embeds, text_dim, conv_layers=conv_layers
147
+ )
148
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
149
 
150
  self.rotary_embed = RotaryEmbedding(dim_head)
 
153
  self.depth = depth
154
 
155
  self.transformer_blocks = nn.ModuleList(
156
+ [
157
+ DiTBlock(
158
+ dim=dim,
159
+ heads=heads,
160
+ dim_head=dim_head,
161
+ ff_mult=ff_mult,
162
+ dropout=dropout,
163
+ )
164
+ for _ in range(depth)
165
+ ]
166
+ )
167
+ self.long_skip_connection = (
168
+ nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
169
  )
 
170
 
171
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
172
  self.proj_out = nn.Linear(dim, mel_dim)
 
202
  if second_time is not None and self.time_embed2 is not None:
203
  t2 = self.time_embed2(second_time)
204
  t = t + t2
205
+
206
  text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
207
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
208
 
 
216
 
217
  for block in self.transformer_blocks:
218
  if self.checkpoint_activations:
219
+ x = torch.utils.checkpoint.checkpoint(
220
+ self.ckpt_wrapper(block), x, t, mask, rope
221
+ )
222
  else:
223
  x = block(x, t, mask=mask, rope=rope)
224
 
f5_tts/model/backbones/mmdit.py CHANGED
@@ -10,41 +10,37 @@ d - dimension
10
  from __future__ import annotations
11
 
12
  import torch
13
- from torch import nn
14
  import torch.nn.functional as F
15
-
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
- from f5_tts.model.modules import (
19
- TimestepEmbedding,
20
- ConvPositionEmbedding,
21
- MMDiTBlock,
22
- DiTBlock,
23
- AdaLayerNormZero_Final,
24
- precompute_freqs_cis,
25
- get_pos_embed_indices,
26
- )
27
-
28
- from f5_tts.model.utils import (
29
- default,
30
- exists,
31
- lens_to_mask,
32
- list_str_to_idx,
33
- list_str_to_tensor,
34
- mask_from_frac_lengths,
35
- )
36
  # text embedding
37
 
38
 
39
  class TextEmbedding(nn.Module):
40
  def __init__(self, out_dim, text_num_embeds):
41
  super().__init__()
42
- self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
 
 
43
 
44
  self.precompute_max_pos = 1024
45
- self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
 
 
 
 
46
 
47
- def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
 
 
48
  text = text + 1
49
  if drop_text:
50
  text = torch.zeros_like(text)
@@ -53,7 +49,9 @@ class TextEmbedding(nn.Module):
53
  # sinus pos emb
54
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
55
  batch_text_len = text.shape[1]
56
- pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
 
 
57
  text_pos_embed = self.freqs_cis[pos_idx]
58
 
59
  text = text + text_pos_embed
@@ -70,7 +68,9 @@ class AudioEmbedding(nn.Module):
70
  self.linear = nn.Linear(2 * in_dim, out_dim)
71
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
72
 
73
- def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
 
 
74
  if drop_audio_cond:
75
  cond = torch.zeros_like(cond)
76
  x = torch.cat((x, cond), dim=-1)
@@ -97,23 +97,24 @@ class MMDiT(nn.Module):
97
  mel_dim=100,
98
  checkpoint_activations=False,
99
  text_encoder=True,
100
-
101
  ):
102
  super().__init__()
103
 
104
  self.time_embed = TimestepEmbedding(dim)
105
  if text_encoder:
106
- self.text_encoder = TextEncoder(text_num_embeds=text_num_embeds,
107
- text_dim=dim,
108
- depth=text_depth,
109
- heads=heads,
110
- dim_head=dim_head,
111
- ff_mult=ff_mult,
112
- dropout=dropout)
 
 
113
  else:
114
  self.text_encoder = None
115
  self.text_embed = TextEmbedding(dim, text_num_embeds)
116
-
117
  self.audio_embed = AudioEmbedding(mel_dim, dim)
118
 
119
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -136,9 +137,8 @@ class MMDiT(nn.Module):
136
  )
137
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
138
  self.proj_out = nn.Linear(dim, mel_dim)
139
-
140
- self.checkpoint_activations = checkpoint_activations
141
 
 
142
 
143
  def forward(
144
  self,
@@ -161,45 +161,53 @@ class MMDiT(nn.Module):
161
  c = self.text_encoder(text, t, mask=text_mask, drop_text=drop_text)
162
  else:
163
  c = self.text_embed(text, drop_text=drop_text)
164
-
165
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
166
 
167
  seq_len = x.shape[1]
168
  text_len = text.shape[1]
169
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
170
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
171
-
172
  # if mask is not None:
173
  # rope_audio = self.rotary_embed.forward_from_seq_len(seq_len + 1)
174
-
175
  # dummy_token = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype)
176
  # x = torch.cat([x, dummy_token], dim=1) # shape is now [b, nw+1, d]
177
-
178
  # # pad the mask so that new dummy token is always masked out
179
  # # mask: [b, nw] -> [b, nw+1]
180
  # false_col = torch.zeros((x.shape[0], 1), dtype=torch.bool, device=x.device)
181
  # mask = torch.cat([mask, false_col], dim=1)
182
-
183
  # if text_mask is not None:
184
  # rope_text = self.rotary_embed.forward_from_seq_len(text_len + 1)
185
 
186
  # dummy_token = torch.zeros((c.shape[0], 1, c.shape[-1]), device=c.device, dtype=c.dtype)
187
  # c = torch.cat([c, dummy_token], dim=1) # shape is now [b, nt+1, d]
188
-
189
  # # pad the text mask so that new dummy token is always masked out
190
  # # text_mask: [b, nt] -> [b, nt+1]
191
  # false_col = torch.zeros((c.shape[0], 1), dtype=torch.bool, device=c.device)
192
  # text_mask = torch.cat([text_mask, false_col], dim=1)
193
-
194
  for block in self.transformer_blocks:
195
- c, x = block(x, c, t, mask=mask, src_mask=text_mask, rope=rope_audio, c_rope=rope_text)
 
 
 
 
 
 
 
 
196
 
197
  x = self.norm_out(x, t)
198
  output = self.proj_out(x)
199
-
200
 
201
  return output
202
 
 
203
  class TextEncoder(nn.Module):
204
  def __init__(
205
  self,
@@ -219,7 +227,7 @@ class TextEncoder(nn.Module):
219
  # Embeddings
220
  self.text_embed = TextEmbedding(text_dim, text_num_embeds)
221
  self.rotary_embed = RotaryEmbedding(dim_head)
222
-
223
  # Example stack of DiTBlocks or any custom blocks
224
  self.transformer_blocks = nn.ModuleList(
225
  [
@@ -239,7 +247,7 @@ class TextEncoder(nn.Module):
239
  text: int["b nt"], # noqa: F821
240
  time: float["b"] | float[""], # time step # noqa: F821 F722
241
  mask: bool["b nt"] | None = None, # noqa: F821 F722
242
- drop_text: bool = False
243
  ):
244
  """
245
  Encode text into hidden states of shape [b, nt, d].
@@ -251,7 +259,7 @@ class TextEncoder(nn.Module):
251
 
252
  # Basic embedding
253
  hidden_states = self.text_embed(text, seq_len) # [b, nt, d]
254
-
255
  # lens and mask
256
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
257
 
@@ -260,17 +268,18 @@ class TextEncoder(nn.Module):
260
  # Here, you likely want standard self-attn, so no cross-attn
261
  hidden_states = block(
262
  x=hidden_states,
263
- t=time, # no time embedding for the text encoder by default
264
- mask=mask, # or pass a text mask if needed
265
- rope=rope # pass a rope if you want rotary embeddings for text
266
  )
267
  return hidden_states
268
 
 
269
  if __name__ == "__main__":
270
  from f5_tts.model.utils import get_tokenizer
271
 
272
  bsz = 16
273
-
274
  tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
275
  tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
276
  dataset_name = "Emilia_ZH_EN"
@@ -279,23 +288,22 @@ if __name__ == "__main__":
279
  else:
280
  tokenizer_path = dataset_name
281
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
282
-
283
  text = ["hello world"] * bsz
284
- text_lens = torch.ones((bsz, ), dtype=torch.long) * len("hello world")
285
  text_lens[-1] = 5
286
  device = "cuda"
287
  batch = bsz
288
  time_embed = TimestepEmbedding(512).to(device)
289
-
290
-
291
  # handle text as string
292
  if isinstance(text, list):
293
  if exists(vocab_char_map):
294
  text = list_str_to_idx(text, vocab_char_map).to(device)
295
  else:
296
  text = list_str_to_tensor(text).to(device)
297
- assert text.shape[0] == batch
298
-
299
  time = torch.rand((batch,), device=device)
300
  text_mask = lens_to_mask(text_lens).to(device)
301
 
@@ -311,7 +319,7 @@ if __name__ == "__main__":
311
  # ).to('cuda')
312
  # hidden_states = text_encoder(text, time_embed(time), mask)
313
  # print(hidden_states.shape) # [bsz, seq_len, text_dim]
314
-
315
  # test MMDiT
316
  mel_dim = 80
317
  model = MMDiT(
@@ -323,14 +331,23 @@ if __name__ == "__main__":
323
  dropout=0.1,
324
  ff_mult=4,
325
  text_num_embeds=vocab_size,
326
- mel_dim=mel_dim
327
  ).to(device)
328
-
329
  x = torch.rand((batch, 100, mel_dim), device=device)
330
  cond = torch.rand((batch, 100, mel_dim), device=device)
331
  lens = torch.ones((batch,), dtype=torch.long) * 100
332
  mask = lens_to_mask(lens).to(device)
333
-
334
- output = model(x, cond, text, time, drop_audio_cond=False, drop_text=False, mask=mask, text_mask=text_mask)
335
-
336
- print(output.shape) # [bsz, seq_len, mel_dim]
 
 
 
 
 
 
 
 
 
 
10
  from __future__ import annotations
11
 
12
  import torch
 
13
  import torch.nn.functional as F
14
+ from torch import nn
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
+ from f5_tts.model.modules import (AdaLayerNormZero_Final,
18
+ ConvPositionEmbedding, DiTBlock, MMDiTBlock,
19
+ TimestepEmbedding, get_pos_embed_indices,
20
+ precompute_freqs_cis)
21
+ from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
22
+ list_str_to_tensor, mask_from_frac_lengths)
23
+
 
 
 
 
 
 
 
 
 
 
 
24
  # text embedding
25
 
26
 
27
  class TextEmbedding(nn.Module):
28
  def __init__(self, out_dim, text_num_embeds):
29
  super().__init__()
30
+ self.text_embed = nn.Embedding(
31
+ text_num_embeds + 1, out_dim
32
+ ) # will use 0 as filler token
33
 
34
  self.precompute_max_pos = 1024
35
+ self.register_buffer(
36
+ "freqs_cis",
37
+ precompute_freqs_cis(out_dim, self.precompute_max_pos),
38
+ persistent=False,
39
+ )
40
 
41
+ def forward(
42
+ self, text: int["b nt"], drop_text=False
43
+ ) -> int["b nt d"]: # noqa: F722
44
  text = text + 1
45
  if drop_text:
46
  text = torch.zeros_like(text)
 
49
  # sinus pos emb
50
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
51
  batch_text_len = text.shape[1]
52
+ pos_idx = get_pos_embed_indices(
53
+ batch_start, batch_text_len, max_pos=self.precompute_max_pos
54
+ )
55
  text_pos_embed = self.freqs_cis[pos_idx]
56
 
57
  text = text + text_pos_embed
 
68
  self.linear = nn.Linear(2 * in_dim, out_dim)
69
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
70
 
71
+ def forward(
72
+ self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False
73
+ ): # noqa: F722
74
  if drop_audio_cond:
75
  cond = torch.zeros_like(cond)
76
  x = torch.cat((x, cond), dim=-1)
 
97
  mel_dim=100,
98
  checkpoint_activations=False,
99
  text_encoder=True,
 
100
  ):
101
  super().__init__()
102
 
103
  self.time_embed = TimestepEmbedding(dim)
104
  if text_encoder:
105
+ self.text_encoder = TextEncoder(
106
+ text_num_embeds=text_num_embeds,
107
+ text_dim=dim,
108
+ depth=text_depth,
109
+ heads=heads,
110
+ dim_head=dim_head,
111
+ ff_mult=ff_mult,
112
+ dropout=dropout,
113
+ )
114
  else:
115
  self.text_encoder = None
116
  self.text_embed = TextEmbedding(dim, text_num_embeds)
117
+
118
  self.audio_embed = AudioEmbedding(mel_dim, dim)
119
 
120
  self.rotary_embed = RotaryEmbedding(dim_head)
 
137
  )
138
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
139
  self.proj_out = nn.Linear(dim, mel_dim)
 
 
140
 
141
+ self.checkpoint_activations = checkpoint_activations
142
 
143
  def forward(
144
  self,
 
161
  c = self.text_encoder(text, t, mask=text_mask, drop_text=drop_text)
162
  else:
163
  c = self.text_embed(text, drop_text=drop_text)
164
+
165
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
166
 
167
  seq_len = x.shape[1]
168
  text_len = text.shape[1]
169
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
170
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
171
+
172
  # if mask is not None:
173
  # rope_audio = self.rotary_embed.forward_from_seq_len(seq_len + 1)
174
+
175
  # dummy_token = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype)
176
  # x = torch.cat([x, dummy_token], dim=1) # shape is now [b, nw+1, d]
177
+
178
  # # pad the mask so that new dummy token is always masked out
179
  # # mask: [b, nw] -> [b, nw+1]
180
  # false_col = torch.zeros((x.shape[0], 1), dtype=torch.bool, device=x.device)
181
  # mask = torch.cat([mask, false_col], dim=1)
182
+
183
  # if text_mask is not None:
184
  # rope_text = self.rotary_embed.forward_from_seq_len(text_len + 1)
185
 
186
  # dummy_token = torch.zeros((c.shape[0], 1, c.shape[-1]), device=c.device, dtype=c.dtype)
187
  # c = torch.cat([c, dummy_token], dim=1) # shape is now [b, nt+1, d]
188
+
189
  # # pad the text mask so that new dummy token is always masked out
190
  # # text_mask: [b, nt] -> [b, nt+1]
191
  # false_col = torch.zeros((c.shape[0], 1), dtype=torch.bool, device=c.device)
192
  # text_mask = torch.cat([text_mask, false_col], dim=1)
193
+
194
  for block in self.transformer_blocks:
195
+ c, x = block(
196
+ x,
197
+ c,
198
+ t,
199
+ mask=mask,
200
+ src_mask=text_mask,
201
+ rope=rope_audio,
202
+ c_rope=rope_text,
203
+ )
204
 
205
  x = self.norm_out(x, t)
206
  output = self.proj_out(x)
 
207
 
208
  return output
209
 
210
+
211
  class TextEncoder(nn.Module):
212
  def __init__(
213
  self,
 
227
  # Embeddings
228
  self.text_embed = TextEmbedding(text_dim, text_num_embeds)
229
  self.rotary_embed = RotaryEmbedding(dim_head)
230
+
231
  # Example stack of DiTBlocks or any custom blocks
232
  self.transformer_blocks = nn.ModuleList(
233
  [
 
247
  text: int["b nt"], # noqa: F821
248
  time: float["b"] | float[""], # time step # noqa: F821 F722
249
  mask: bool["b nt"] | None = None, # noqa: F821 F722
250
+ drop_text: bool = False,
251
  ):
252
  """
253
  Encode text into hidden states of shape [b, nt, d].
 
259
 
260
  # Basic embedding
261
  hidden_states = self.text_embed(text, seq_len) # [b, nt, d]
262
+
263
  # lens and mask
264
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
265
 
 
268
  # Here, you likely want standard self-attn, so no cross-attn
269
  hidden_states = block(
270
  x=hidden_states,
271
+ t=time, # no time embedding for the text encoder by default
272
+ mask=mask, # or pass a text mask if needed
273
+ rope=rope, # pass a rope if you want rotary embeddings for text
274
  )
275
  return hidden_states
276
 
277
+
278
  if __name__ == "__main__":
279
  from f5_tts.model.utils import get_tokenizer
280
 
281
  bsz = 16
282
+
283
  tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
284
  tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
285
  dataset_name = "Emilia_ZH_EN"
 
288
  else:
289
  tokenizer_path = dataset_name
290
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
291
+
292
  text = ["hello world"] * bsz
293
+ text_lens = torch.ones((bsz,), dtype=torch.long) * len("hello world")
294
  text_lens[-1] = 5
295
  device = "cuda"
296
  batch = bsz
297
  time_embed = TimestepEmbedding(512).to(device)
298
+
 
299
  # handle text as string
300
  if isinstance(text, list):
301
  if exists(vocab_char_map):
302
  text = list_str_to_idx(text, vocab_char_map).to(device)
303
  else:
304
  text = list_str_to_tensor(text).to(device)
305
+ assert text.shape[0] == batch
306
+
307
  time = torch.rand((batch,), device=device)
308
  text_mask = lens_to_mask(text_lens).to(device)
309
 
 
319
  # ).to('cuda')
320
  # hidden_states = text_encoder(text, time_embed(time), mask)
321
  # print(hidden_states.shape) # [bsz, seq_len, text_dim]
322
+
323
  # test MMDiT
324
  mel_dim = 80
325
  model = MMDiT(
 
331
  dropout=0.1,
332
  ff_mult=4,
333
  text_num_embeds=vocab_size,
334
+ mel_dim=mel_dim,
335
  ).to(device)
336
+
337
  x = torch.rand((batch, 100, mel_dim), device=device)
338
  cond = torch.rand((batch, 100, mel_dim), device=device)
339
  lens = torch.ones((batch,), dtype=torch.long) * 100
340
  mask = lens_to_mask(lens).to(device)
341
+
342
+ output = model(
343
+ x,
344
+ cond,
345
+ text,
346
+ time,
347
+ drop_audio_cond=False,
348
+ drop_text=False,
349
+ mask=mask,
350
+ text_mask=text_mask,
351
+ )
352
+
353
+ print(output.shape) # [bsz, seq_len, mel_dim]
f5_tts/model/backbones/unett.py CHANGED
@@ -8,26 +8,19 @@ d - dimension
8
  """
9
 
10
  from __future__ import annotations
 
11
  from typing import Literal
12
 
13
  import torch
14
- from torch import nn
15
  import torch.nn.functional as F
16
-
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
- from f5_tts.model.modules import (
21
- TimestepEmbedding,
22
- ConvNeXtV2Block,
23
- ConvPositionEmbedding,
24
- Attention,
25
- AttnProcessor,
26
- FeedForward,
27
- precompute_freqs_cis,
28
- get_pos_embed_indices,
29
- )
30
-
31
 
32
  # Text embedding
33
 
@@ -35,21 +28,34 @@ from f5_tts.model.modules import (
35
  class TextEmbedding(nn.Module):
36
  def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
 
 
39
 
40
  if conv_layers > 0:
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
 
 
 
 
44
  self.text_blocks = nn.Sequential(
45
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
 
 
 
46
  )
47
  else:
48
  self.extra_modeling = False
49
 
50
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
 
 
 
53
  batch, text_len = text.shape[0], text.shape[1]
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
@@ -62,7 +68,9 @@ class TextEmbedding(nn.Module):
62
  if self.extra_modeling:
63
  # sinus pos emb
64
  batch_start = torch.zeros((batch,), dtype=torch.long)
65
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
 
 
66
  text_pos_embed = self.freqs_cis[pos_idx]
67
  text = text + text_pos_embed
68
 
@@ -81,7 +89,13 @@ class InputEmbedding(nn.Module):
81
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
 
84
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
 
 
 
 
 
 
85
  if drop_audio_cond: # cfg for cond audio
86
  cond = torch.zeros_like(cond)
87
 
@@ -115,7 +129,9 @@ class UNetT(nn.Module):
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -144,7 +160,11 @@ class UNetT(nn.Module):
144
  ff_norm = RMSNorm(dim)
145
  ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
 
147
- skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
 
 
 
 
148
 
149
  self.layers.append(
150
  nn.ModuleList(
@@ -190,7 +210,9 @@ class UNetT(nn.Module):
190
  # flat unet transformer
191
  skip_connect_type = self.skip_connect_type
192
  skips = []
193
- for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
 
 
194
  layer = idx + 1
195
 
196
  # skip connection logic
 
8
  """
9
 
10
  from __future__ import annotations
11
+
12
  from typing import Literal
13
 
14
  import torch
 
15
  import torch.nn.functional as F
16
+ from torch import nn
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
+ from f5_tts.model.modules import (Attention, AttnProcessor, ConvNeXtV2Block,
21
+ ConvPositionEmbedding, FeedForward,
22
+ TimestepEmbedding, get_pos_embed_indices,
23
+ precompute_freqs_cis)
 
 
 
 
 
 
 
24
 
25
  # Text embedding
26
 
 
28
  class TextEmbedding(nn.Module):
29
  def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
30
  super().__init__()
31
+ self.text_embed = nn.Embedding(
32
+ text_num_embeds + 1, text_dim
33
+ ) # use 0 as filler token
34
 
35
  if conv_layers > 0:
36
  self.extra_modeling = True
37
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
38
+ self.register_buffer(
39
+ "freqs_cis",
40
+ precompute_freqs_cis(text_dim, self.precompute_max_pos),
41
+ persistent=False,
42
+ )
43
  self.text_blocks = nn.Sequential(
44
+ *[
45
+ ConvNeXtV2Block(text_dim, text_dim * conv_mult)
46
+ for _ in range(conv_layers)
47
+ ]
48
  )
49
  else:
50
  self.extra_modeling = False
51
 
52
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
53
+ text = (
54
+ text + 1
55
+ ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
56
+ text = text[
57
+ :, :seq_len
58
+ ] # curtail if character tokens are more than the mel spec tokens
59
  batch, text_len = text.shape[0], text.shape[1]
60
  text = F.pad(text, (0, seq_len - text_len), value=0)
61
 
 
68
  if self.extra_modeling:
69
  # sinus pos emb
70
  batch_start = torch.zeros((batch,), dtype=torch.long)
71
+ pos_idx = get_pos_embed_indices(
72
+ batch_start, seq_len, max_pos=self.precompute_max_pos
73
+ )
74
  text_pos_embed = self.freqs_cis[pos_idx]
75
  text = text + text_pos_embed
76
 
 
89
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
90
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
91
 
92
+ def forward(
93
+ self,
94
+ x: float["b n d"],
95
+ cond: float["b n d"],
96
+ text_embed: float["b n d"],
97
+ drop_audio_cond=False,
98
+ ): # noqa: F722
99
  if drop_audio_cond: # cfg for cond audio
100
  cond = torch.zeros_like(cond)
101
 
 
129
  self.time_embed = TimestepEmbedding(dim)
130
  if text_dim is None:
131
  text_dim = mel_dim
132
+ self.text_embed = TextEmbedding(
133
+ text_num_embeds, text_dim, conv_layers=conv_layers
134
+ )
135
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
136
 
137
  self.rotary_embed = RotaryEmbedding(dim_head)
 
160
  ff_norm = RMSNorm(dim)
161
  ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
162
 
163
+ skip_proj = (
164
+ nn.Linear(dim * 2, dim, bias=False)
165
+ if needs_skip_proj and is_later_half
166
+ else None
167
+ )
168
 
169
  self.layers.append(
170
  nn.ModuleList(
 
210
  # flat unet transformer
211
  skip_connect_type = self.skip_connect_type
212
  skips = []
213
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(
214
+ self.layers
215
+ ):
216
  layer = idx + 1
217
 
218
  # skip connection logic
f5_tts/model/cfm.py CHANGED
@@ -19,14 +19,8 @@ from torch.nn.utils.rnn import pad_sequence
19
  from torchdiffeq import odeint
20
 
21
  from f5_tts.model.modules import MelSpec
22
- from f5_tts.model.utils import (
23
- default,
24
- exists,
25
- lens_to_mask,
26
- list_str_to_idx,
27
- list_str_to_tensor,
28
- mask_from_frac_lengths,
29
- )
30
 
31
 
32
  class CFM(nn.Module):
@@ -74,7 +68,7 @@ class CFM(nn.Module):
74
 
75
  # vocab map for tokenization
76
  self.vocab_char_map = vocab_char_map
77
-
78
  self.scale = scale
79
 
80
  @property
@@ -109,11 +103,11 @@ class CFM(nn.Module):
109
  assert cond.shape[-1] == self.num_channels
110
 
111
  cond = cond.to(next(self.parameters()).dtype)
112
-
113
  print(self.scale)
114
 
115
  cond = cond / self.scale
116
-
117
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
118
  if not exists(lens):
119
  lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
@@ -129,7 +123,9 @@ class CFM(nn.Module):
129
 
130
  if exists(text):
131
  text_lens = (text != -1).sum(dim=-1)
132
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
 
 
133
 
134
  # duration
135
 
@@ -140,19 +136,25 @@ class CFM(nn.Module):
140
  if isinstance(duration, int):
141
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
142
 
143
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
 
 
144
  duration = duration.clamp(max=max_duration)
145
  max_duration = duration.amax()
146
 
147
  # duplicate test corner for inner time step oberservation
148
  if duplicate_test:
149
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
 
 
150
 
151
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
152
  if no_ref_audio:
153
  cond = torch.zeros_like(cond)
154
 
155
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
 
 
156
  cond_mask = cond_mask.unsqueeze(-1)
157
  step_cond = torch.where(
158
  cond_mask, cond, torch.zeros_like(cond)
@@ -171,13 +173,25 @@ class CFM(nn.Module):
171
 
172
  # predict flow
173
  pred = self.transformer(
174
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
 
 
 
 
 
 
175
  )
176
  if cfg_strength < 1e-5:
177
  return pred
178
 
179
  null_pred = self.transformer(
180
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
 
 
 
 
 
 
181
  )
182
  return pred + (pred - null_pred) * cfg_strength
183
 
@@ -188,7 +202,11 @@ class CFM(nn.Module):
188
  for dur in duration:
189
  if exists(seed):
190
  torch.manual_seed(seed)
191
- y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
 
 
 
 
192
  y0 = pad_sequence(y0, padding_value=0, batch_first=True)
193
 
194
  t_start = 0
@@ -199,7 +217,9 @@ class CFM(nn.Module):
199
  y0 = (1 - t_start) * y0 + t_start * test_cond
200
  steps = int(steps * (1 - t_start))
201
 
202
- t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
 
 
203
  if sway_sampling_coef is not None:
204
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
205
 
@@ -210,7 +230,7 @@ class CFM(nn.Module):
210
  out = torch.where(cond_mask, cond, out)
211
 
212
  out = out * self.scale
213
-
214
  if exists(vocoder):
215
  out = out.permute(0, 2, 1)
216
  out = vocoder(out)
@@ -231,7 +251,12 @@ class CFM(nn.Module):
231
  inp = inp.permute(0, 2, 1)
232
  assert inp.shape[-1] == self.num_channels
233
 
234
- batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
 
 
 
 
 
235
 
236
  # handle text as string
237
  if isinstance(text, list):
@@ -245,10 +270,16 @@ class CFM(nn.Module):
245
  if not exists(lens):
246
  lens = torch.full((batch,), seq_len, device=device)
247
 
248
- mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
 
 
249
 
250
  # get a random span to mask out for training conditionally
251
- frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
 
 
 
 
252
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
253
 
254
  if exists(mask):
@@ -283,11 +314,16 @@ class CFM(nn.Module):
283
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
284
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
285
  pred = self.transformer(
286
- x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
 
 
 
 
 
287
  )
288
 
289
  # flow matching loss
290
  loss = F.mse_loss(pred, flow, reduction="none")
291
  loss = loss[rand_span_mask]
292
 
293
- return loss.mean(), cond, pred, t
 
19
  from torchdiffeq import odeint
20
 
21
  from f5_tts.model.modules import MelSpec
22
+ from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
23
+ list_str_to_tensor, mask_from_frac_lengths)
 
 
 
 
 
 
24
 
25
 
26
  class CFM(nn.Module):
 
68
 
69
  # vocab map for tokenization
70
  self.vocab_char_map = vocab_char_map
71
+
72
  self.scale = scale
73
 
74
  @property
 
103
  assert cond.shape[-1] == self.num_channels
104
 
105
  cond = cond.to(next(self.parameters()).dtype)
106
+
107
  print(self.scale)
108
 
109
  cond = cond / self.scale
110
+
111
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
112
  if not exists(lens):
113
  lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
 
123
 
124
  if exists(text):
125
  text_lens = (text != -1).sum(dim=-1)
126
+ lens = torch.maximum(
127
+ text_lens, lens
128
+ ) # make sure lengths are at least those of the text characters
129
 
130
  # duration
131
 
 
136
  if isinstance(duration, int):
137
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
138
 
139
+ duration = torch.maximum(
140
+ lens + 1, duration
141
+ ) # just add one token so something is generated
142
  duration = duration.clamp(max=max_duration)
143
  max_duration = duration.amax()
144
 
145
  # duplicate test corner for inner time step oberservation
146
  if duplicate_test:
147
+ test_cond = F.pad(
148
+ cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
149
+ )
150
 
151
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
152
  if no_ref_audio:
153
  cond = torch.zeros_like(cond)
154
 
155
+ cond_mask = F.pad(
156
+ cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
157
+ )
158
  cond_mask = cond_mask.unsqueeze(-1)
159
  step_cond = torch.where(
160
  cond_mask, cond, torch.zeros_like(cond)
 
173
 
174
  # predict flow
175
  pred = self.transformer(
176
+ x=x,
177
+ cond=step_cond,
178
+ text=text,
179
+ time=t,
180
+ mask=mask,
181
+ drop_audio_cond=False,
182
+ drop_text=False,
183
  )
184
  if cfg_strength < 1e-5:
185
  return pred
186
 
187
  null_pred = self.transformer(
188
+ x=x,
189
+ cond=step_cond,
190
+ text=text,
191
+ time=t,
192
+ mask=mask,
193
+ drop_audio_cond=True,
194
+ drop_text=True,
195
  )
196
  return pred + (pred - null_pred) * cfg_strength
197
 
 
202
  for dur in duration:
203
  if exists(seed):
204
  torch.manual_seed(seed)
205
+ y0.append(
206
+ torch.randn(
207
+ dur, self.num_channels, device=self.device, dtype=step_cond.dtype
208
+ )
209
+ )
210
  y0 = pad_sequence(y0, padding_value=0, batch_first=True)
211
 
212
  t_start = 0
 
217
  y0 = (1 - t_start) * y0 + t_start * test_cond
218
  steps = int(steps * (1 - t_start))
219
 
220
+ t = torch.linspace(
221
+ t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
222
+ )
223
  if sway_sampling_coef is not None:
224
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
225
 
 
230
  out = torch.where(cond_mask, cond, out)
231
 
232
  out = out * self.scale
233
+
234
  if exists(vocoder):
235
  out = out.permute(0, 2, 1)
236
  out = vocoder(out)
 
251
  inp = inp.permute(0, 2, 1)
252
  assert inp.shape[-1] == self.num_channels
253
 
254
+ batch, seq_len, dtype, device, _σ1 = (
255
+ *inp.shape[:2],
256
+ inp.dtype,
257
+ self.device,
258
+ self.sigma,
259
+ )
260
 
261
  # handle text as string
262
  if isinstance(text, list):
 
270
  if not exists(lens):
271
  lens = torch.full((batch,), seq_len, device=device)
272
 
273
+ mask = lens_to_mask(
274
+ lens, length=seq_len
275
+ ) # useless here, as collate_fn will pad to max length in batch
276
 
277
  # get a random span to mask out for training conditionally
278
+ frac_lengths = (
279
+ torch.zeros((batch,), device=self.device)
280
+ .float()
281
+ .uniform_(*self.frac_lengths_mask)
282
+ )
283
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
284
 
285
  if exists(mask):
 
314
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
315
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
316
  pred = self.transformer(
317
+ x=φ,
318
+ cond=cond,
319
+ text=text,
320
+ time=time,
321
+ drop_audio_cond=drop_audio_cond,
322
+ drop_text=drop_text,
323
  )
324
 
325
  # flow matching loss
326
  loss = F.mse_loss(pred, flow, reduction="none")
327
  loss = loss[rand_span_mask]
328
 
329
+ return loss.mean(), cond, pred, t
f5_tts/model/dataset.py CHANGED
@@ -1,6 +1,6 @@
1
- import re
2
  import json
3
  import random
 
4
  from importlib.resources import files
5
 
6
  import torch
@@ -15,8 +15,9 @@ from tqdm import tqdm
15
  from f5_tts.model.modules import MelSpec
16
  from f5_tts.model.utils import default
17
 
 
18
  def get_speaker_id(path):
19
- parts = path.split('/')
20
  speaker_id = parts[-3]
21
  return speaker_id
22
 
@@ -40,7 +41,7 @@ class CustomDataset(Dataset):
40
  return_wavform=False,
41
  remove_starting_space=True,
42
  need_prompt_speech=False,
43
- prompt_repository: dict=None,
44
  ):
45
  self.data = custom_dataset
46
  self.durations = durations
@@ -63,42 +64,32 @@ class CustomDataset(Dataset):
63
  mel_spec_type=mel_spec_type,
64
  ),
65
  )
66
-
67
  self.validation = validation
68
  self.validation_num = validation_num
69
 
70
  if (not validation) and data_augmentation:
71
- print('Using data augmentation.')
72
- self.augment = Compose([
73
- AddBackgroundNoise(
74
- sounds_path="/data5/ESC-50-master",
75
- min_snr_db=3.0,
76
- max_snr_db=30.0,
77
- noise_transform=PolarityInversion(),
78
- p=0.5
79
- ),
80
- AddGaussianNoise(
81
- min_amplitude=0.001,
82
- max_amplitude=0.015,
83
- p=0.5
84
- ),
85
- PitchShift(
86
- min_semitones=-12.0,
87
- max_semitones=12.0,
88
- p=0.8
89
- ),
90
- ApplyImpulseResponse(ir_path="/data5/Audio", p=1.0),
91
- Aliasing(min_sample_rate=4000, max_sample_rate=30000, p=0.3),
92
- BandPassFilter(min_center_freq=100.0, max_center_freq=6000, p=0.2),
93
- SevenBandParametricEQ(p=0.2),
94
- TanhDistortion(
95
- min_distortion=0.01,
96
- max_distortion=0.7,
97
- p=0.2
98
- ),
99
- ])
100
  else:
101
- print('No data augmentation.')
102
  self.augment = None
103
 
104
  self.return_wavform = return_wavform
@@ -112,7 +103,7 @@ class CustomDataset(Dataset):
112
  text = row["text"]
113
  duration = row["duration"]
114
  spk_id = get_speaker_id(audio_path)
115
- assert spk_id != None and spk_id != 'mp3'
116
  if spk_id not in self.prompt_repository:
117
  self.prompt_repository[spk_id] = [row]
118
  else:
@@ -120,13 +111,14 @@ class CustomDataset(Dataset):
120
  else:
121
  self.prompt_repository = prompt_repository
122
 
123
- print(f'Grouped samples into {len(self.prompt_repository.keys())} speakers.')
 
 
124
  self.need_prompt_speech = True
125
 
126
  else:
127
  self.need_prompt_speech = False
128
 
129
-
130
  def get_frame_len(self, index):
131
  if self.validation:
132
  index += len(self.data) - self.validation_num
@@ -164,9 +156,9 @@ class CustomDataset(Dataset):
164
  index = (index + 1) % len(self.data)
165
 
166
  if self.remove_starting_space:
167
- while len(text) > 1 and text[0] == ' ':
168
  text = text[1:]
169
-
170
  if self.preprocessed_mel:
171
  mel_spec = torch.tensor(row["mel_spec"])
172
  else:
@@ -178,31 +170,37 @@ class CustomDataset(Dataset):
178
 
179
  # resample if necessary
180
  if source_sample_rate != self.target_sample_rate:
181
- resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
 
 
182
  audio = resampler(audio)
183
 
184
  if not self.validation:
185
  if self.augment != None:
186
- audio = self.augment(audio.squeeze().numpy(), sample_rate=self.target_sample_rate)
 
 
187
  audio = torch.from_numpy(audio).float().unsqueeze(0)
188
 
189
  # to mel spectrogram
190
  mel_spec = self.mel_spectrogram(audio)
191
  mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
192
 
193
- out['mel_spec'] = mel_spec
194
- out['text'] = text
195
- out['duration'] = duration
196
- out['target_text'] = self.data[(index + len(self.data) // 2) % len(self.data)]["text"]
 
 
197
 
198
  if self.return_wavform:
199
- out['wav'] = audio
200
 
201
  if return_path:
202
- out['path'] = audio_path
203
 
204
  if return_row:
205
- out['row'] = row
206
 
207
  # Sample a prompt speech of the same speaker
208
  # From prompt_repository
@@ -212,9 +210,9 @@ class CustomDataset(Dataset):
212
  _count = 100
213
  while True:
214
  pmt_row = random.choice(spk_repository)
215
- pmt_audio_path = pmt_row['audio_path']
216
- pmt_text = pmt_row['text']
217
- pmt_duration = pmt_row['duration']
218
 
219
  if not isinstance(pmt_text, list):
220
  pmt_text = list(pmt_text)
@@ -223,14 +221,14 @@ class CustomDataset(Dataset):
223
  if 0.3 <= pmt_duration <= 30 and (0 < len(pmt_text) < 2048):
224
  if pmt_text != text:
225
  break
226
- _count = _count - 1
227
  if _count <= 0:
228
  break
229
 
230
  if self.remove_starting_space:
231
- while len(pmt_text) > 1 and pmt_text[0] == ' ':
232
  pmt_text = pmt_text[1:]
233
-
234
  if self.preprocessed_mel:
235
  pmt_mel_spec = torch.tensor(pmt_row["mel_spec"])
236
  else:
@@ -242,30 +240,35 @@ class CustomDataset(Dataset):
242
 
243
  # resample if necessary
244
  if source_sample_rate != self.target_sample_rate:
245
- resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
 
 
246
  pmt_audio = resampler(pmt_audio)
247
 
248
  if not self.validation:
249
  if self.augment != None:
250
- pmt_audio = self.augment(pmt_audio.squeeze().numpy(), sample_rate=self.target_sample_rate)
 
 
 
251
  pmt_audio = torch.from_numpy(pmt_audio).float().unsqueeze(0)
252
 
253
  # to mel spectrogram
254
  pmt_mel_spec = self.mel_spectrogram(pmt_audio)
255
  pmt_mel_spec = pmt_mel_spec.squeeze(0) # '1 d t -> d t'
256
 
257
- out['pmt_mel_spec'] = pmt_mel_spec
258
- out['pmt_text'] = pmt_text
259
- out['pmt_duration'] = pmt_duration
260
 
261
  if self.return_wavform:
262
- out['pmt_wav'] = pmt_audio
263
 
264
  if return_path:
265
- out['pmt_path'] = pmt_audio_path
266
 
267
  if return_row:
268
- out['pmt_row'] = pmt_row
269
 
270
  return out
271
 
@@ -280,7 +283,12 @@ class DynamicBatchSampler(Sampler[list[int]]):
280
  """
281
 
282
  def __init__(
283
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
 
 
 
 
 
284
  ):
285
  self.sampler = sampler
286
  self.frames_threshold = frames_threshold
@@ -302,7 +310,9 @@ class DynamicBatchSampler(Sampler[list[int]]):
302
  # indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
303
  # ):
304
  for idx, frame_len in indices:
305
- if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
 
 
306
  batch.append(idx)
307
  batch_frames += frame_len
308
  else:
@@ -337,6 +347,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
337
 
338
  # Load dataset
339
 
 
340
  def load_dataset(
341
  dataset_name: str,
342
  tokenizer: str = "pinyin",
@@ -349,7 +360,7 @@ def load_dataset(
349
  return_wavform: bool = False,
350
  remove_starting_space: bool = True,
351
  need_prompt_speech: bool = False,
352
- prompt_repository: dict = None
353
  ) -> CustomDataset:
354
  """
355
  dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
@@ -359,9 +370,13 @@ def load_dataset(
359
  print("Loading dataset ...")
360
 
361
  if dataset_type == "CustomDataset":
362
- rel_data_path = str(f'/home/yl4579/F5-TTS-diff/F5-TTS-DMD-flow-ds/data/{dataset_name}_{tokenizer}')
363
- if 'LibriTTS_100_360_500_char_pinyin' in rel_data_path:
364
- rel_data_path = rel_data_path.replace('LibriTTS_100_360_500_char_pinyin', 'LibriTTS_100_360_500_char')
 
 
 
 
365
  if audio_type == "raw":
366
  try:
367
  train_dataset = load_from_disk(f"{rel_data_path}/raw")
@@ -385,7 +400,7 @@ def load_dataset(
385
  return_wavform=return_wavform,
386
  remove_starting_space=remove_starting_space,
387
  need_prompt_speech=need_prompt_speech,
388
- prompt_repository=prompt_repository
389
  )
390
 
391
  elif dataset_type == "CustomDatasetPath":
@@ -398,7 +413,10 @@ def load_dataset(
398
  data_dict = json.load(f)
399
  durations = data_dict["duration"]
400
  train_dataset = CustomDataset(
401
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
 
 
 
402
  )
403
 
404
  return train_dataset
@@ -410,7 +428,7 @@ def collate_fn(batch):
410
  mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
411
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
412
  max_mel_length = mel_lengths.amax()
413
-
414
  # Pad mel_specs
415
  padded_mel_specs = []
416
  for spec in mel_specs: # TODO. maybe records mask for attention here
@@ -419,8 +437,8 @@ def collate_fn(batch):
419
  padded_mel_specs.append(padded_spec)
420
  mel_specs = torch.stack(padded_mel_specs)
421
 
422
- text = [item['text'] for item in batch]
423
- target_text = [item['target_text'] for item in batch]
424
 
425
  text_lengths = torch.LongTensor([len(item) for item in text])
426
 
@@ -432,26 +450,26 @@ def collate_fn(batch):
432
  target_text=target_text,
433
  )
434
 
435
- if 'pmt_mel_spec' in batch[0]:
436
  pmt_mel_specs = [item["pmt_mel_spec"].squeeze(0) for item in batch]
437
  pmt_mel_lengths = torch.LongTensor([spec.shape[-1] for spec in pmt_mel_specs])
438
  max_pmt_mel_length = pmt_mel_lengths.amax()
439
-
440
  # Pad mel_specs
441
  padded_pmt_mel_specs = []
442
- for spec in pmt_mel_specs:
443
  padding = (0, max_pmt_mel_length - spec.size(-1))
444
  padded_spec = F.pad(spec, padding, value=0)
445
  padded_pmt_mel_specs.append(padded_spec)
446
  pmt_mel_specs = torch.stack(padded_pmt_mel_specs)
447
 
448
- out['pmt_mel_specs'] = pmt_mel_specs
449
 
450
- if 'pmt_text' in batch[0]:
451
- pmt_text = [item['pmt_text'] for item in batch]
452
  pmt_text_lengths = torch.LongTensor([len(item) for item in pmt_text])
453
 
454
- out['pmt_text'] = pmt_text
455
- out['pmt_text_lengths'] = pmt_text_lengths
456
 
457
- return out
 
 
1
  import json
2
  import random
3
+ import re
4
  from importlib.resources import files
5
 
6
  import torch
 
15
  from f5_tts.model.modules import MelSpec
16
  from f5_tts.model.utils import default
17
 
18
+
19
  def get_speaker_id(path):
20
+ parts = path.split("/")
21
  speaker_id = parts[-3]
22
  return speaker_id
23
 
 
41
  return_wavform=False,
42
  remove_starting_space=True,
43
  need_prompt_speech=False,
44
+ prompt_repository: dict = None,
45
  ):
46
  self.data = custom_dataset
47
  self.durations = durations
 
64
  mel_spec_type=mel_spec_type,
65
  ),
66
  )
67
+
68
  self.validation = validation
69
  self.validation_num = validation_num
70
 
71
  if (not validation) and data_augmentation:
72
+ print("Using data augmentation.")
73
+ self.augment = Compose(
74
+ [
75
+ AddBackgroundNoise(
76
+ sounds_path="/data5/ESC-50-master",
77
+ min_snr_db=3.0,
78
+ max_snr_db=30.0,
79
+ noise_transform=PolarityInversion(),
80
+ p=0.5,
81
+ ),
82
+ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
83
+ PitchShift(min_semitones=-12.0, max_semitones=12.0, p=0.8),
84
+ ApplyImpulseResponse(ir_path="/data5/Audio", p=1.0),
85
+ Aliasing(min_sample_rate=4000, max_sample_rate=30000, p=0.3),
86
+ BandPassFilter(min_center_freq=100.0, max_center_freq=6000, p=0.2),
87
+ SevenBandParametricEQ(p=0.2),
88
+ TanhDistortion(min_distortion=0.01, max_distortion=0.7, p=0.2),
89
+ ]
90
+ )
 
 
 
 
 
 
 
 
 
 
91
  else:
92
+ print("No data augmentation.")
93
  self.augment = None
94
 
95
  self.return_wavform = return_wavform
 
103
  text = row["text"]
104
  duration = row["duration"]
105
  spk_id = get_speaker_id(audio_path)
106
+ assert spk_id != None and spk_id != "mp3"
107
  if spk_id not in self.prompt_repository:
108
  self.prompt_repository[spk_id] = [row]
109
  else:
 
111
  else:
112
  self.prompt_repository = prompt_repository
113
 
114
+ print(
115
+ f"Grouped samples into {len(self.prompt_repository.keys())} speakers."
116
+ )
117
  self.need_prompt_speech = True
118
 
119
  else:
120
  self.need_prompt_speech = False
121
 
 
122
  def get_frame_len(self, index):
123
  if self.validation:
124
  index += len(self.data) - self.validation_num
 
156
  index = (index + 1) % len(self.data)
157
 
158
  if self.remove_starting_space:
159
+ while len(text) > 1 and text[0] == " ":
160
  text = text[1:]
161
+
162
  if self.preprocessed_mel:
163
  mel_spec = torch.tensor(row["mel_spec"])
164
  else:
 
170
 
171
  # resample if necessary
172
  if source_sample_rate != self.target_sample_rate:
173
+ resampler = torchaudio.transforms.Resample(
174
+ source_sample_rate, self.target_sample_rate
175
+ )
176
  audio = resampler(audio)
177
 
178
  if not self.validation:
179
  if self.augment != None:
180
+ audio = self.augment(
181
+ audio.squeeze().numpy(), sample_rate=self.target_sample_rate
182
+ )
183
  audio = torch.from_numpy(audio).float().unsqueeze(0)
184
 
185
  # to mel spectrogram
186
  mel_spec = self.mel_spectrogram(audio)
187
  mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
188
 
189
+ out["mel_spec"] = mel_spec
190
+ out["text"] = text
191
+ out["duration"] = duration
192
+ out["target_text"] = self.data[(index + len(self.data) // 2) % len(self.data)][
193
+ "text"
194
+ ]
195
 
196
  if self.return_wavform:
197
+ out["wav"] = audio
198
 
199
  if return_path:
200
+ out["path"] = audio_path
201
 
202
  if return_row:
203
+ out["row"] = row
204
 
205
  # Sample a prompt speech of the same speaker
206
  # From prompt_repository
 
210
  _count = 100
211
  while True:
212
  pmt_row = random.choice(spk_repository)
213
+ pmt_audio_path = pmt_row["audio_path"]
214
+ pmt_text = pmt_row["text"]
215
+ pmt_duration = pmt_row["duration"]
216
 
217
  if not isinstance(pmt_text, list):
218
  pmt_text = list(pmt_text)
 
221
  if 0.3 <= pmt_duration <= 30 and (0 < len(pmt_text) < 2048):
222
  if pmt_text != text:
223
  break
224
+ _count = _count - 1
225
  if _count <= 0:
226
  break
227
 
228
  if self.remove_starting_space:
229
+ while len(pmt_text) > 1 and pmt_text[0] == " ":
230
  pmt_text = pmt_text[1:]
231
+
232
  if self.preprocessed_mel:
233
  pmt_mel_spec = torch.tensor(pmt_row["mel_spec"])
234
  else:
 
240
 
241
  # resample if necessary
242
  if source_sample_rate != self.target_sample_rate:
243
+ resampler = torchaudio.transforms.Resample(
244
+ source_sample_rate, self.target_sample_rate
245
+ )
246
  pmt_audio = resampler(pmt_audio)
247
 
248
  if not self.validation:
249
  if self.augment != None:
250
+ pmt_audio = self.augment(
251
+ pmt_audio.squeeze().numpy(),
252
+ sample_rate=self.target_sample_rate,
253
+ )
254
  pmt_audio = torch.from_numpy(pmt_audio).float().unsqueeze(0)
255
 
256
  # to mel spectrogram
257
  pmt_mel_spec = self.mel_spectrogram(pmt_audio)
258
  pmt_mel_spec = pmt_mel_spec.squeeze(0) # '1 d t -> d t'
259
 
260
+ out["pmt_mel_spec"] = pmt_mel_spec
261
+ out["pmt_text"] = pmt_text
262
+ out["pmt_duration"] = pmt_duration
263
 
264
  if self.return_wavform:
265
+ out["pmt_wav"] = pmt_audio
266
 
267
  if return_path:
268
+ out["pmt_path"] = pmt_audio_path
269
 
270
  if return_row:
271
+ out["pmt_row"] = pmt_row
272
 
273
  return out
274
 
 
283
  """
284
 
285
  def __init__(
286
+ self,
287
+ sampler: Sampler[int],
288
+ frames_threshold: int,
289
+ max_samples=0,
290
+ random_seed=None,
291
+ drop_last: bool = False,
292
  ):
293
  self.sampler = sampler
294
  self.frames_threshold = frames_threshold
 
310
  # indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
311
  # ):
312
  for idx, frame_len in indices:
313
+ if batch_frames + frame_len <= self.frames_threshold and (
314
+ max_samples == 0 or len(batch) < max_samples
315
+ ):
316
  batch.append(idx)
317
  batch_frames += frame_len
318
  else:
 
347
 
348
  # Load dataset
349
 
350
+
351
  def load_dataset(
352
  dataset_name: str,
353
  tokenizer: str = "pinyin",
 
360
  return_wavform: bool = False,
361
  remove_starting_space: bool = True,
362
  need_prompt_speech: bool = False,
363
+ prompt_repository: dict = None,
364
  ) -> CustomDataset:
365
  """
366
  dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
 
370
  print("Loading dataset ...")
371
 
372
  if dataset_type == "CustomDataset":
373
+ rel_data_path = str(
374
+ f"/home/yl4579/F5-TTS-diff/F5-TTS-DMD-flow-ds/data/{dataset_name}_{tokenizer}"
375
+ )
376
+ if "LibriTTS_100_360_500_char_pinyin" in rel_data_path:
377
+ rel_data_path = rel_data_path.replace(
378
+ "LibriTTS_100_360_500_char_pinyin", "LibriTTS_100_360_500_char"
379
+ )
380
  if audio_type == "raw":
381
  try:
382
  train_dataset = load_from_disk(f"{rel_data_path}/raw")
 
400
  return_wavform=return_wavform,
401
  remove_starting_space=remove_starting_space,
402
  need_prompt_speech=need_prompt_speech,
403
+ prompt_repository=prompt_repository,
404
  )
405
 
406
  elif dataset_type == "CustomDatasetPath":
 
413
  data_dict = json.load(f)
414
  durations = data_dict["duration"]
415
  train_dataset = CustomDataset(
416
+ train_dataset,
417
+ durations=durations,
418
+ preprocessed_mel=preprocessed_mel,
419
+ **mel_spec_kwargs,
420
  )
421
 
422
  return train_dataset
 
428
  mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
429
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
430
  max_mel_length = mel_lengths.amax()
431
+
432
  # Pad mel_specs
433
  padded_mel_specs = []
434
  for spec in mel_specs: # TODO. maybe records mask for attention here
 
437
  padded_mel_specs.append(padded_spec)
438
  mel_specs = torch.stack(padded_mel_specs)
439
 
440
+ text = [item["text"] for item in batch]
441
+ target_text = [item["target_text"] for item in batch]
442
 
443
  text_lengths = torch.LongTensor([len(item) for item in text])
444
 
 
450
  target_text=target_text,
451
  )
452
 
453
+ if "pmt_mel_spec" in batch[0]:
454
  pmt_mel_specs = [item["pmt_mel_spec"].squeeze(0) for item in batch]
455
  pmt_mel_lengths = torch.LongTensor([spec.shape[-1] for spec in pmt_mel_specs])
456
  max_pmt_mel_length = pmt_mel_lengths.amax()
457
+
458
  # Pad mel_specs
459
  padded_pmt_mel_specs = []
460
+ for spec in pmt_mel_specs:
461
  padding = (0, max_pmt_mel_length - spec.size(-1))
462
  padded_spec = F.pad(spec, padding, value=0)
463
  padded_pmt_mel_specs.append(padded_spec)
464
  pmt_mel_specs = torch.stack(padded_pmt_mel_specs)
465
 
466
+ out["pmt_mel_specs"] = pmt_mel_specs
467
 
468
+ if "pmt_text" in batch[0]:
469
+ pmt_text = [item["pmt_text"] for item in batch]
470
  pmt_text_lengths = torch.LongTensor([len(item) for item in pmt_text])
471
 
472
+ out["pmt_text"] = pmt_text
473
+ out["pmt_text_lengths"] = pmt_text_lengths
474
 
475
+ return out
f5_tts/model/modules.py CHANGED
@@ -19,7 +19,6 @@ from librosa.filters import mel as librosa_mel_fn
19
  from torch import nn
20
  from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
22
-
23
  # raw wav to mel spec
24
 
25
 
@@ -42,15 +41,25 @@ def get_bigvgan_mel_spectrogram(
42
  key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
43
 
44
  if key not in mel_basis_cache:
45
- mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
46
- mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
 
 
 
 
 
 
 
 
47
  hann_window_cache[key] = torch.hann_window(win_length).to(device)
48
 
49
  mel_basis = mel_basis_cache[key]
50
  hann_window = hann_window_cache[key]
51
 
52
  padding = (n_fft - hop_length) // 2
53
- waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
 
 
54
 
55
  spec = torch.stft(
56
  waveform,
@@ -112,7 +121,9 @@ class MelSpec(nn.Module):
112
  mel_spec_type="vocos",
113
  ):
114
  super().__init__()
115
- assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
 
 
116
 
117
  self.n_fft = n_fft
118
  self.hop_length = hop_length
@@ -193,7 +204,9 @@ class ConvPositionEmbedding(nn.Module):
193
  # rotary positional embedding related
194
 
195
 
196
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
 
 
197
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
198
  # has some connection to NTK literature
199
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -209,10 +222,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
209
 
210
  def get_pos_embed_indices(start, length, max_pos, scale=1.0):
211
  # length = length if isinstance(length, int) else length.max()
212
- scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
 
 
213
  pos = (
214
  start.unsqueeze(1)
215
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
 
 
 
216
  )
217
  # avoid extra long error.
218
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
@@ -251,7 +269,9 @@ class ConvNeXtV2Block(nn.Module):
251
  dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
252
  ) # depthwise conv
253
  self.norm = nn.LayerNorm(dim, eps=1e-6)
254
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
 
 
255
  self.act = nn.GELU()
256
  self.grn = GRN(intermediate_dim)
257
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
@@ -284,7 +304,9 @@ class AdaLayerNormZero(nn.Module):
284
 
285
  def forward(self, x, emb=None):
286
  emb = self.linear(self.silu(emb))
287
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
 
 
288
 
289
  x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
290
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
@@ -315,14 +337,18 @@ class AdaLayerNormZero_Final(nn.Module):
315
 
316
 
317
  class FeedForward(nn.Module):
318
- def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
 
 
319
  super().__init__()
320
  inner_dim = int(dim * mult)
321
  dim_out = dim_out if dim_out is not None else dim
322
 
323
  activation = nn.GELU(approximate=approximate)
324
  project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
325
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
326
 
327
  def forward(self, x):
328
  return self.ff(x)
@@ -346,7 +372,9 @@ class Attention(nn.Module):
346
  super().__init__()
347
 
348
  if not hasattr(F, "scaled_dot_product_attention"):
349
- raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
 
 
350
 
351
  self.processor = processor
352
 
@@ -385,7 +413,9 @@ class Attention(nn.Module):
385
  c_rope=None, # rotary position embedding for c
386
  ) -> torch.Tensor:
387
  if c is not None:
388
- return self.processor(self, x, c=c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope)
 
 
389
  else:
390
  return self.processor(self, x, mask=mask, rope=rope)
391
 
@@ -414,7 +444,9 @@ class AttnProcessor:
414
  # apply rotary position embedding
415
  if rope is not None:
416
  freqs, xpos_scale = rope
417
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
 
 
418
 
419
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
420
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
@@ -430,11 +462,15 @@ class AttnProcessor:
430
  if mask is not None:
431
  attn_mask = mask
432
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
433
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
 
 
434
  else:
435
  attn_mask = None
436
 
437
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
 
 
438
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
439
  x = x.to(query.dtype)
440
 
@@ -461,12 +497,12 @@ class JointAttnProcessor:
461
  def __call__(
462
  self,
463
  attn: Attention,
464
- x: float["b n d"], # noised input x
465
- c: float["b nt d"] = None, # context c, here text
466
  mask: bool["b n"] | None = None,
467
  src_mask: bool["b nt"] | None = None,
468
- rope=None, # rotary position embedding for x
469
- c_rope=None, # rotary position embedding for c
470
  ) -> torch.FloatTensor:
471
  residual = x
472
  batch_size = c.shape[0]
@@ -484,14 +520,18 @@ class JointAttnProcessor:
484
  # apply rope for x
485
  if rope is not None:
486
  freqs, xpos_scale = rope
487
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
 
 
488
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
489
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
490
 
491
  # apply rope for c
492
  if c_rope is not None:
493
  freqs, xpos_scale = c_rope
494
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
 
 
495
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
496
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
497
 
@@ -515,17 +555,23 @@ class JointAttnProcessor:
515
  attn_mask_c = F.pad(src_mask, (x.shape[1], 0), value=True)
516
  attn_mask = attn_mask & attn_mask_c
517
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
518
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
 
 
519
  else:
520
  if src_mask is not None:
521
  # if there's no mask for x but there's src_mask
522
  attn_mask = F.pad(src_mask, (x.shape[1], 0), value=True)
523
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
524
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
 
 
525
  else:
526
  attn_mask = None
527
 
528
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
 
 
529
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
530
  x = x.to(query.dtype)
531
 
@@ -546,7 +592,6 @@ class JointAttnProcessor:
546
  return x, c
547
 
548
 
549
-
550
  # DiT Block
551
 
552
 
@@ -564,7 +609,9 @@ class DiTBlock(nn.Module):
564
  )
565
 
566
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
567
- self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
 
 
568
 
569
  def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
570
  # pre-norm & modulation for attention input
@@ -596,12 +643,16 @@ class MMDiTBlock(nn.Module):
596
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
597
  """
598
 
599
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
 
 
600
  super().__init__()
601
 
602
  self.context_pre_only = context_pre_only
603
 
604
- self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
 
 
605
  self.attn_norm_x = AdaLayerNormZero(dim)
606
  self.attn = Attention(
607
  processor=JointAttnProcessor(),
@@ -615,23 +666,35 @@ class MMDiTBlock(nn.Module):
615
 
616
  if not context_pre_only:
617
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
618
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
 
 
619
  else:
620
  self.ff_norm_c = None
621
  self.ff_c = None
622
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
623
- self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
 
 
624
 
625
- def forward(self, x, c, t, mask=None, src_mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
 
 
626
  # pre-norm & modulation for attention input
627
  if self.context_pre_only:
628
  norm_c = self.attn_norm_c(c, t)
629
  else:
630
- norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
631
- norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
 
 
 
 
632
 
633
  # attention
634
- x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope)
 
 
635
 
636
  # process attention output for context c
637
  if self.context_pre_only:
@@ -639,7 +702,9 @@ class MMDiTBlock(nn.Module):
639
  else: # if not last layer
640
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
641
 
642
- norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
 
 
643
  c_ff_output = self.ff_c(norm_c)
644
  c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
645
 
@@ -660,7 +725,9 @@ class TimestepEmbedding(nn.Module):
660
  def __init__(self, dim, freq_embed_dim=256):
661
  super().__init__()
662
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
663
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 
 
664
 
665
  def forward(self, timestep: float["b"]): # noqa: F821
666
  time_hidden = self.time_embed(timestep)
 
19
  from torch import nn
20
  from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
 
22
  # raw wav to mel spec
23
 
24
 
 
41
  key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
42
 
43
  if key not in mel_basis_cache:
44
+ mel = librosa_mel_fn(
45
+ sr=target_sample_rate,
46
+ n_fft=n_fft,
47
+ n_mels=n_mel_channels,
48
+ fmin=fmin,
49
+ fmax=fmax,
50
+ )
51
+ mel_basis_cache[key] = (
52
+ torch.from_numpy(mel).float().to(device)
53
+ ) # TODO: why they need .float()?
54
  hann_window_cache[key] = torch.hann_window(win_length).to(device)
55
 
56
  mel_basis = mel_basis_cache[key]
57
  hann_window = hann_window_cache[key]
58
 
59
  padding = (n_fft - hop_length) // 2
60
+ waveform = torch.nn.functional.pad(
61
+ waveform.unsqueeze(1), (padding, padding), mode="reflect"
62
+ ).squeeze(1)
63
 
64
  spec = torch.stft(
65
  waveform,
 
121
  mel_spec_type="vocos",
122
  ):
123
  super().__init__()
124
+ assert mel_spec_type in ["vocos", "bigvgan"], print(
125
+ "We only support two extract mel backend: vocos or bigvgan"
126
+ )
127
 
128
  self.n_fft = n_fft
129
  self.hop_length = hop_length
 
204
  # rotary positional embedding related
205
 
206
 
207
+ def precompute_freqs_cis(
208
+ dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
209
+ ):
210
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
211
  # has some connection to NTK literature
212
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
 
222
 
223
  def get_pos_embed_indices(start, length, max_pos, scale=1.0):
224
  # length = length if isinstance(length, int) else length.max()
225
+ scale = scale * torch.ones_like(
226
+ start, dtype=torch.float32
227
+ ) # in case scale is a scalar
228
  pos = (
229
  start.unsqueeze(1)
230
+ + (
231
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
232
+ * scale.unsqueeze(1)
233
+ ).long()
234
  )
235
  # avoid extra long error.
236
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
 
269
  dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
270
  ) # depthwise conv
271
  self.norm = nn.LayerNorm(dim, eps=1e-6)
272
+ self.pwconv1 = nn.Linear(
273
+ dim, intermediate_dim
274
+ ) # pointwise/1x1 convs, implemented with linear layers
275
  self.act = nn.GELU()
276
  self.grn = GRN(intermediate_dim)
277
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
 
304
 
305
  def forward(self, x, emb=None):
306
  emb = self.linear(self.silu(emb))
307
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
308
+ emb, 6, dim=1
309
+ )
310
 
311
  x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
312
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
 
337
 
338
 
339
  class FeedForward(nn.Module):
340
+ def __init__(
341
+ self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
342
+ ):
343
  super().__init__()
344
  inner_dim = int(dim * mult)
345
  dim_out = dim_out if dim_out is not None else dim
346
 
347
  activation = nn.GELU(approximate=approximate)
348
  project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
349
+ self.ff = nn.Sequential(
350
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
351
+ )
352
 
353
  def forward(self, x):
354
  return self.ff(x)
 
372
  super().__init__()
373
 
374
  if not hasattr(F, "scaled_dot_product_attention"):
375
+ raise ImportError(
376
+ "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
377
+ )
378
 
379
  self.processor = processor
380
 
 
413
  c_rope=None, # rotary position embedding for c
414
  ) -> torch.Tensor:
415
  if c is not None:
416
+ return self.processor(
417
+ self, x, c=c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope
418
+ )
419
  else:
420
  return self.processor(self, x, mask=mask, rope=rope)
421
 
 
444
  # apply rotary position embedding
445
  if rope is not None:
446
  freqs, xpos_scale = rope
447
+ q_xpos_scale, k_xpos_scale = (
448
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
449
+ )
450
 
451
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
452
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
 
462
  if mask is not None:
463
  attn_mask = mask
464
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
465
+ attn_mask = attn_mask.expand(
466
+ batch_size, attn.heads, query.shape[-2], key.shape[-2]
467
+ )
468
  else:
469
  attn_mask = None
470
 
471
+ x = F.scaled_dot_product_attention(
472
+ query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
473
+ )
474
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
475
  x = x.to(query.dtype)
476
 
 
497
  def __call__(
498
  self,
499
  attn: Attention,
500
+ x: float["b n d"], # noised input x
501
+ c: float["b nt d"] = None, # context c, here text
502
  mask: bool["b n"] | None = None,
503
  src_mask: bool["b nt"] | None = None,
504
+ rope=None, # rotary position embedding for x
505
+ c_rope=None, # rotary position embedding for c
506
  ) -> torch.FloatTensor:
507
  residual = x
508
  batch_size = c.shape[0]
 
520
  # apply rope for x
521
  if rope is not None:
522
  freqs, xpos_scale = rope
523
+ q_xpos_scale, k_xpos_scale = (
524
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
525
+ )
526
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
527
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
528
 
529
  # apply rope for c
530
  if c_rope is not None:
531
  freqs, xpos_scale = c_rope
532
+ q_xpos_scale, k_xpos_scale = (
533
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
534
+ )
535
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
536
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
537
 
 
555
  attn_mask_c = F.pad(src_mask, (x.shape[1], 0), value=True)
556
  attn_mask = attn_mask & attn_mask_c
557
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
558
+ attn_mask = attn_mask.expand(
559
+ batch_size, attn.heads, query.shape[-2], key.shape[-2]
560
+ )
561
  else:
562
  if src_mask is not None:
563
  # if there's no mask for x but there's src_mask
564
  attn_mask = F.pad(src_mask, (x.shape[1], 0), value=True)
565
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
566
+ attn_mask = attn_mask.expand(
567
+ batch_size, attn.heads, query.shape[-2], key.shape[-2]
568
+ )
569
  else:
570
  attn_mask = None
571
 
572
+ x = F.scaled_dot_product_attention(
573
+ query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
574
+ )
575
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
576
  x = x.to(query.dtype)
577
 
 
592
  return x, c
593
 
594
 
 
595
  # DiT Block
596
 
597
 
 
609
  )
610
 
611
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
612
+ self.ff = FeedForward(
613
+ dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
614
+ )
615
 
616
  def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
617
  # pre-norm & modulation for attention input
 
643
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
644
  """
645
 
646
+ def __init__(
647
+ self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False
648
+ ):
649
  super().__init__()
650
 
651
  self.context_pre_only = context_pre_only
652
 
653
+ self.attn_norm_c = (
654
+ AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
655
+ )
656
  self.attn_norm_x = AdaLayerNormZero(dim)
657
  self.attn = Attention(
658
  processor=JointAttnProcessor(),
 
666
 
667
  if not context_pre_only:
668
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
669
+ self.ff_c = FeedForward(
670
+ dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
671
+ )
672
  else:
673
  self.ff_norm_c = None
674
  self.ff_c = None
675
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
676
+ self.ff_x = FeedForward(
677
+ dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
678
+ )
679
 
680
+ def forward(
681
+ self, x, c, t, mask=None, src_mask=None, rope=None, c_rope=None
682
+ ): # x: noised input, c: context, t: time embedding
683
  # pre-norm & modulation for attention input
684
  if self.context_pre_only:
685
  norm_c = self.attn_norm_c(c, t)
686
  else:
687
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
688
+ c, emb=t
689
+ )
690
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
691
+ x, emb=t
692
+ )
693
 
694
  # attention
695
+ x_attn_output, c_attn_output = self.attn(
696
+ x=norm_x, c=norm_c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope
697
+ )
698
 
699
  # process attention output for context c
700
  if self.context_pre_only:
 
702
  else: # if not last layer
703
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
704
 
705
+ norm_c = (
706
+ self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
707
+ )
708
  c_ff_output = self.ff_c(norm_c)
709
  c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
710
 
 
725
  def __init__(self, dim, freq_embed_dim=256):
726
  super().__init__()
727
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
728
+ self.time_mlp = nn.Sequential(
729
+ nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
730
+ )
731
 
732
  def forward(self, timestep: float["b"]): # noqa: F821
733
  time_hidden = self.time_embed(timestep)
f5_tts/model/trainer.py CHANGED
@@ -67,7 +67,13 @@ class Trainer:
67
  self.logger = logger
68
  if self.logger == "wandb":
69
  if exists(wandb_resume_id):
70
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
 
 
 
 
 
 
71
  else:
72
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
73
 
@@ -102,7 +108,9 @@ class Trainer:
102
  self.epochs = epochs
103
  self.num_warmup_updates = num_warmup_updates
104
  self.save_per_updates = save_per_updates
105
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
 
 
106
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
107
 
108
  self.batch_size = batch_size
@@ -126,8 +134,10 @@ class Trainer:
126
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
127
  else:
128
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
129
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
130
-
 
 
131
  self.scale = None
132
  self.count = 0
133
 
@@ -137,10 +147,12 @@ class Trainer:
137
 
138
  def save_checkpoint(self, step, last=False):
139
  self.accelerator.wait_for_everyone()
140
- if self.is_main:
141
  checkpoint = dict(
142
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
143
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
 
 
144
  ema_model_state_dict=self.ema_model.state_dict(),
145
  scheduler_state_dict=self.scheduler.state_dict(),
146
  step=step,
@@ -150,16 +162,23 @@ class Trainer:
150
  if not os.path.exists(self.checkpoint_path):
151
  os.makedirs(self.checkpoint_path)
152
  if last:
153
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
 
 
154
  print(f"Saved last checkpoint at step {step}")
155
  else:
156
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
 
 
157
 
158
  def load_checkpoint(self):
159
  if (
160
  not exists(self.checkpoint_path)
161
  or not os.path.exists(self.checkpoint_path)
162
- or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
 
 
 
163
  ):
164
  return 0
165
 
@@ -172,10 +191,17 @@ class Trainer:
172
  key=lambda x: int("".join(filter(str.isdigit, x))),
173
  )[-1]
174
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
175
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
 
 
176
 
177
  # patch for backward compatibility, 305e3ea
178
- for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
 
 
 
179
  if key in checkpoint["ema_model_state_dict"]:
180
  del checkpoint["ema_model_state_dict"][key]
181
 
@@ -184,12 +210,19 @@ class Trainer:
184
 
185
  if "step" in checkpoint:
186
  # patch for backward compatibility, 305e3ea
187
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
 
 
 
188
  if key in checkpoint["model_state_dict"]:
189
  del checkpoint["model_state_dict"][key]
190
 
191
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
192
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
 
 
 
 
193
  if self.scheduler:
194
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
195
  step = checkpoint["step"]
@@ -199,28 +232,37 @@ class Trainer:
199
  for k, v in checkpoint["ema_model_state_dict"].items()
200
  if k not in ["initted", "step"]
201
  }
202
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
 
 
203
  step = 0
204
 
205
  if "scale" in checkpoint:
206
  self.scale = float(checkpoint["scale"])
207
  self.model.scale = self.scale
208
-
209
  if "count" in checkpoint:
210
  self.count = int(checkpoint["count"])
211
-
212
  del checkpoint
213
  gc.collect()
214
  return step
215
 
216
- def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
 
217
  if self.log_samples:
218
- from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
 
219
 
220
  vocoder = load_vocoder(
221
- vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
 
 
222
  )
223
- target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
 
 
224
  log_samples_path = f"{self.checkpoint_path}/samples"
225
  os.makedirs(log_samples_path, exist_ok=True)
226
 
@@ -245,7 +287,11 @@ class Trainer:
245
  self.accelerator.even_batches = False
246
  sampler = SequentialSampler(train_dataset)
247
  batch_sampler = DynamicBatchSampler(
248
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
 
 
 
 
249
  )
250
  train_dataloader = DataLoader(
251
  train_dataset,
@@ -256,7 +302,9 @@ class Trainer:
256
  batch_sampler=batch_sampler,
257
  )
258
  else:
259
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
 
 
260
 
261
  # accelerator.prepare() dispatches batches to devices;
262
  # which means the length of dataloader calculated before, should consider the number of devices
@@ -266,10 +314,16 @@ class Trainer:
266
  # otherwise by default with split_batches=False, warmup steps change with num_processes
267
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
268
  decay_steps = total_steps - warmup_steps
269
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
270
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
 
 
 
 
271
  self.scheduler = SequentialLR(
272
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
 
 
273
  )
274
  train_dataloader, self.scheduler = self.accelerator.prepare(
275
  train_dataloader, self.scheduler
@@ -281,7 +335,9 @@ class Trainer:
281
  orig_epoch_step = len(train_dataloader)
282
  skipped_epoch = int(start_step // orig_epoch_step)
283
  skipped_batch = start_step % orig_epoch_step
284
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
 
 
285
  else:
286
  skipped_epoch = 0
287
 
@@ -309,28 +365,40 @@ class Trainer:
309
  text_inputs = batch["text"]
310
  mel_spec = batch["mel"].permute(0, 2, 1)
311
  mel_lengths = batch["mel_lengths"]
312
-
313
  self.count += 1
314
-
315
  if self.scale is None:
316
  self.scale = mel_spec.std()
317
  else:
318
  self.scale += (mel_spec.std() - self.scale) / self.count
319
-
320
- mel_spec = mel_spec / self.scale # normalize mel spectrogram
321
-
322
  # TODO. add duration predictor training
323
- if self.duration_predictor is not None and self.accelerator.is_local_main_process:
324
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
325
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
 
 
 
 
 
 
 
326
 
327
  loss, cond, pred, t = self.model(
328
- mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
 
 
 
329
  )
330
  self.accelerator.backward(loss)
331
 
332
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
333
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
 
 
334
 
335
  self.optimizer.step()
336
  self.scheduler.step()
@@ -342,18 +410,30 @@ class Trainer:
342
  global_step += 1
343
 
344
  if self.accelerator.is_local_main_process:
345
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
 
 
 
346
  if self.logger == "tensorboard":
347
  self.writer.add_scalar("loss", loss.item(), global_step)
348
- self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
 
 
349
 
350
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
351
 
352
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
 
 
 
353
  self.save_checkpoint(global_step)
354
  if self.log_samples and self.accelerator.is_local_main_process:
355
- gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1) * self.scale
356
- ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1) * self.scale
 
 
 
 
357
  with torch.inference_mode():
358
  if self.vocoder_name == "vocos":
359
  gen_audio = vocoder.decode(gen_mel_spec).cpu()
@@ -361,51 +441,56 @@ class Trainer:
361
  elif self.vocoder_name == "bigvgan":
362
  gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
363
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
364
-
365
  gen_audio = wandb.Audio(
366
  gen_audio.float().numpy().squeeze(),
367
  sample_rate=24000,
368
- caption="time: " + str(t[0].squeeze().float().cpu().numpy())
 
369
  )
370
  ref_audio = wandb.Audio(
371
  ref_audio.float().numpy().squeeze(),
372
  sample_rate=24000,
373
- caption="time: " + str(t[0].squeeze().float().cpu().numpy())
 
 
 
 
 
 
 
 
 
374
  )
375
 
376
- self.accelerator.log({"gen_audio": gen_audio,
377
- "ref_audio": ref_audio,
378
- }, step=global_step)
379
-
380
-
381
- # if self.log_samples and self.accelerator.is_local_main_process:
382
- # ref_audio_len = mel_lengths[0]
383
- # infer_text = [
384
- # text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
385
- # ]
386
- # with torch.inference_mode():
387
- # # generated, _ = self.accelerator.unwrap_model(self.model).sample(
388
- # # cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
389
- # # text=infer_text,
390
- # # duration=ref_audio_len * 2,
391
- # # steps=nfe_step,
392
- # # cfg_strength=cfg_strength,
393
- # # sway_sampling_coef=sway_sampling_coef,
394
- # # )
395
- # # generated = generated.to(torch.float32)
396
- # # gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
397
- # # ref_mel_spec = batch["mel"][0].unsqueeze(0)
398
- # gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1)
399
- # ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1)
400
- # if self.vocoder_name == "vocos":
401
- # gen_audio = vocoder.decode(gen_mel_spec).cpu()
402
- # ref_audio = vocoder.decode(ref_mel_spec).cpu()
403
- # elif self.vocoder_name == "bigvgan":
404
- # gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
405
- # ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
406
-
407
- # torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
408
- # torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
409
 
410
  if global_step % self.last_per_steps == 0:
411
  self.save_checkpoint(global_step, last=True)
 
67
  self.logger = logger
68
  if self.logger == "wandb":
69
  if exists(wandb_resume_id):
70
+ init_kwargs = {
71
+ "wandb": {
72
+ "resume": "allow",
73
+ "name": wandb_run_name,
74
+ "id": wandb_resume_id,
75
+ }
76
+ }
77
  else:
78
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
79
 
 
108
  self.epochs = epochs
109
  self.num_warmup_updates = num_warmup_updates
110
  self.save_per_updates = save_per_updates
111
+ self.last_per_steps = default(
112
+ last_per_steps, save_per_updates * grad_accumulation_steps
113
+ )
114
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
115
 
116
  self.batch_size = batch_size
 
134
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
135
  else:
136
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
137
+ self.model, self.optimizer = self.accelerator.prepare(
138
+ self.model, self.optimizer
139
+ )
140
+
141
  self.scale = None
142
  self.count = 0
143
 
 
147
 
148
  def save_checkpoint(self, step, last=False):
149
  self.accelerator.wait_for_everyone()
150
+ if self.is_main:
151
  checkpoint = dict(
152
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
153
+ optimizer_state_dict=self.accelerator.unwrap_model(
154
+ self.optimizer
155
+ ).state_dict(),
156
  ema_model_state_dict=self.ema_model.state_dict(),
157
  scheduler_state_dict=self.scheduler.state_dict(),
158
  step=step,
 
162
  if not os.path.exists(self.checkpoint_path):
163
  os.makedirs(self.checkpoint_path)
164
  if last:
165
+ self.accelerator.save(
166
+ checkpoint, f"{self.checkpoint_path}/model_last.pt"
167
+ )
168
  print(f"Saved last checkpoint at step {step}")
169
  else:
170
+ self.accelerator.save(
171
+ checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
172
+ )
173
 
174
  def load_checkpoint(self):
175
  if (
176
  not exists(self.checkpoint_path)
177
  or not os.path.exists(self.checkpoint_path)
178
+ or not any(
179
+ filename.endswith(".pt")
180
+ for filename in os.listdir(self.checkpoint_path)
181
+ )
182
  ):
183
  return 0
184
 
 
191
  key=lambda x: int("".join(filter(str.isdigit, x))),
192
  )[-1]
193
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
194
+ checkpoint = torch.load(
195
+ f"{self.checkpoint_path}/{latest_checkpoint}",
196
+ weights_only=True,
197
+ map_location="cpu",
198
+ )
199
 
200
  # patch for backward compatibility, 305e3ea
201
+ for key in [
202
+ "ema_model.mel_spec.mel_stft.mel_scale.fb",
203
+ "ema_model.mel_spec.mel_stft.spectrogram.window",
204
+ ]:
205
  if key in checkpoint["ema_model_state_dict"]:
206
  del checkpoint["ema_model_state_dict"][key]
207
 
 
210
 
211
  if "step" in checkpoint:
212
  # patch for backward compatibility, 305e3ea
213
+ for key in [
214
+ "mel_spec.mel_stft.mel_scale.fb",
215
+ "mel_spec.mel_stft.spectrogram.window",
216
+ ]:
217
  if key in checkpoint["model_state_dict"]:
218
  del checkpoint["model_state_dict"][key]
219
 
220
+ self.accelerator.unwrap_model(self.model).load_state_dict(
221
+ checkpoint["model_state_dict"]
222
+ )
223
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(
224
+ checkpoint["optimizer_state_dict"]
225
+ )
226
  if self.scheduler:
227
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
228
  step = checkpoint["step"]
 
232
  for k, v in checkpoint["ema_model_state_dict"].items()
233
  if k not in ["initted", "step"]
234
  }
235
+ self.accelerator.unwrap_model(self.model).load_state_dict(
236
+ checkpoint["model_state_dict"]
237
+ )
238
  step = 0
239
 
240
  if "scale" in checkpoint:
241
  self.scale = float(checkpoint["scale"])
242
  self.model.scale = self.scale
243
+
244
  if "count" in checkpoint:
245
  self.count = int(checkpoint["count"])
246
+
247
  del checkpoint
248
  gc.collect()
249
  return step
250
 
251
+ def train(
252
+ self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None
253
+ ):
254
  if self.log_samples:
255
+ from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
256
+ nfe_step, sway_sampling_coef)
257
 
258
  vocoder = load_vocoder(
259
+ vocoder_name=self.vocoder_name,
260
+ is_local=self.is_local_vocoder,
261
+ local_path=self.local_vocoder_path,
262
  )
263
+ target_sample_rate = self.accelerator.unwrap_model(
264
+ self.model
265
+ ).mel_spec.target_sample_rate
266
  log_samples_path = f"{self.checkpoint_path}/samples"
267
  os.makedirs(log_samples_path, exist_ok=True)
268
 
 
287
  self.accelerator.even_batches = False
288
  sampler = SequentialSampler(train_dataset)
289
  batch_sampler = DynamicBatchSampler(
290
+ sampler,
291
+ self.batch_size,
292
+ max_samples=self.max_samples,
293
+ random_seed=resumable_with_seed,
294
+ drop_last=False,
295
  )
296
  train_dataloader = DataLoader(
297
  train_dataset,
 
302
  batch_sampler=batch_sampler,
303
  )
304
  else:
305
+ raise ValueError(
306
+ f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
307
+ )
308
 
309
  # accelerator.prepare() dispatches batches to devices;
310
  # which means the length of dataloader calculated before, should consider the number of devices
 
314
  # otherwise by default with split_batches=False, warmup steps change with num_processes
315
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
316
  decay_steps = total_steps - warmup_steps
317
+ warmup_scheduler = LinearLR(
318
+ self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps
319
+ )
320
+ decay_scheduler = LinearLR(
321
+ self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps
322
+ )
323
  self.scheduler = SequentialLR(
324
+ self.optimizer,
325
+ schedulers=[warmup_scheduler, decay_scheduler],
326
+ milestones=[warmup_steps],
327
  )
328
  train_dataloader, self.scheduler = self.accelerator.prepare(
329
  train_dataloader, self.scheduler
 
335
  orig_epoch_step = len(train_dataloader)
336
  skipped_epoch = int(start_step // orig_epoch_step)
337
  skipped_batch = start_step % orig_epoch_step
338
+ skipped_dataloader = self.accelerator.skip_first_batches(
339
+ train_dataloader, num_batches=skipped_batch
340
+ )
341
  else:
342
  skipped_epoch = 0
343
 
 
365
  text_inputs = batch["text"]
366
  mel_spec = batch["mel"].permute(0, 2, 1)
367
  mel_lengths = batch["mel_lengths"]
368
+
369
  self.count += 1
370
+
371
  if self.scale is None:
372
  self.scale = mel_spec.std()
373
  else:
374
  self.scale += (mel_spec.std() - self.scale) / self.count
375
+
376
+ mel_spec = mel_spec / self.scale # normalize mel spectrogram
377
+
378
  # TODO. add duration predictor training
379
+ if (
380
+ self.duration_predictor is not None
381
+ and self.accelerator.is_local_main_process
382
+ ):
383
+ dur_loss = self.duration_predictor(
384
+ mel_spec, lens=batch.get("durations")
385
+ )
386
+ self.accelerator.log(
387
+ {"duration loss": dur_loss.item()}, step=global_step
388
+ )
389
 
390
  loss, cond, pred, t = self.model(
391
+ mel_spec,
392
+ text=text_inputs,
393
+ lens=mel_lengths,
394
+ noise_scheduler=self.noise_scheduler,
395
  )
396
  self.accelerator.backward(loss)
397
 
398
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
399
+ self.accelerator.clip_grad_norm_(
400
+ self.model.parameters(), self.max_grad_norm
401
+ )
402
 
403
  self.optimizer.step()
404
  self.scheduler.step()
 
410
  global_step += 1
411
 
412
  if self.accelerator.is_local_main_process:
413
+ self.accelerator.log(
414
+ {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]},
415
+ step=global_step,
416
+ )
417
  if self.logger == "tensorboard":
418
  self.writer.add_scalar("loss", loss.item(), global_step)
419
+ self.writer.add_scalar(
420
+ "lr", self.scheduler.get_last_lr()[0], global_step
421
+ )
422
 
423
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
424
 
425
+ if (
426
+ global_step % (self.save_per_updates * self.grad_accumulation_steps)
427
+ == 0
428
+ ):
429
  self.save_checkpoint(global_step)
430
  if self.log_samples and self.accelerator.is_local_main_process:
431
+ gen_mel_spec = (
432
+ pred[0].unsqueeze(0).permute(0, 2, 1) * self.scale
433
+ )
434
+ ref_mel_spec = (
435
+ cond[0].unsqueeze(0).permute(0, 2, 1) * self.scale
436
+ )
437
  with torch.inference_mode():
438
  if self.vocoder_name == "vocos":
439
  gen_audio = vocoder.decode(gen_mel_spec).cpu()
 
441
  elif self.vocoder_name == "bigvgan":
442
  gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
443
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
444
+
445
  gen_audio = wandb.Audio(
446
  gen_audio.float().numpy().squeeze(),
447
  sample_rate=24000,
448
+ caption="time: "
449
+ + str(t[0].squeeze().float().cpu().numpy()),
450
  )
451
  ref_audio = wandb.Audio(
452
  ref_audio.float().numpy().squeeze(),
453
  sample_rate=24000,
454
+ caption="time: "
455
+ + str(t[0].squeeze().float().cpu().numpy()),
456
+ )
457
+
458
+ self.accelerator.log(
459
+ {
460
+ "gen_audio": gen_audio,
461
+ "ref_audio": ref_audio,
462
+ },
463
+ step=global_step,
464
  )
465
 
466
+ # if self.log_samples and self.accelerator.is_local_main_process:
467
+ # ref_audio_len = mel_lengths[0]
468
+ # infer_text = [
469
+ # text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
470
+ # ]
471
+ # with torch.inference_mode():
472
+ # # generated, _ = self.accelerator.unwrap_model(self.model).sample(
473
+ # # cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
474
+ # # text=infer_text,
475
+ # # duration=ref_audio_len * 2,
476
+ # # steps=nfe_step,
477
+ # # cfg_strength=cfg_strength,
478
+ # # sway_sampling_coef=sway_sampling_coef,
479
+ # # )
480
+ # # generated = generated.to(torch.float32)
481
+ # # gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
482
+ # # ref_mel_spec = batch["mel"][0].unsqueeze(0)
483
+ # gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1)
484
+ # ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1)
485
+ # if self.vocoder_name == "vocos":
486
+ # gen_audio = vocoder.decode(gen_mel_spec).cpu()
487
+ # ref_audio = vocoder.decode(ref_mel_spec).cpu()
488
+ # elif self.vocoder_name == "bigvgan":
489
+ # gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
490
+ # ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
491
+
492
+ # torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
493
+ # torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
 
 
 
 
 
494
 
495
  if global_step % self.last_per_steps == 0:
496
  self.save_checkpoint(global_step, last=True)
f5_tts/model/utils.py CHANGED
@@ -5,13 +5,11 @@ import random
5
  from collections import defaultdict
6
  from importlib.resources import files
7
 
 
8
  import torch
 
9
  from torch.nn.utils.rnn import pad_sequence
10
 
11
- import jieba
12
- from pypinyin import lazy_pinyin, Style
13
-
14
-
15
  # seed everything
16
 
17
 
@@ -39,7 +37,9 @@ def default(v, d):
39
  # tensor helpers
40
 
41
 
42
- def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
 
 
43
  if not exists(length):
44
  length = t.amax()
45
 
@@ -47,7 +47,9 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa
47
  return seq[None, :] < t[:, None]
48
 
49
 
50
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
 
 
51
  max_seq_len = seq_len.max().item()
52
  seq = torch.arange(max_seq_len, device=start.device).long()
53
  start_mask = seq[None, :] >= start[:, None]
@@ -55,7 +57,9 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"
55
  return start_mask & end_mask
56
 
57
 
58
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
 
 
59
  lengths = (frac_lengths * seq_len).long()
60
  max_start = seq_len - lengths
61
 
@@ -66,7 +70,9 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa
66
  return mask_from_start_end_indices(seq_len, start, end)
67
 
68
 
69
- def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
 
 
70
  if not exists(mask):
71
  return t.mean(dim=1)
72
 
@@ -90,7 +96,9 @@ def list_str_to_idx(
90
  vocab_char_map: dict[str, int], # {char: idx}
91
  padding_value=-1,
92
  ) -> int["b nt"]: # noqa: F722
93
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
 
 
94
  text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
  return text
96
 
@@ -109,13 +117,17 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
109
  - if use "byte", set to 256 (unicode byte range)
110
  """
111
  if tokenizer in ["pinyin", "char"]:
112
- tokenizer_path = os.path.join(files("f5_tts").joinpath("../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
 
 
113
  with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
116
  vocab_char_map[char[:-1]] = i
117
  vocab_size = len(vocab_char_map)
118
- assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
 
 
119
 
120
  elif tokenizer == "byte":
121
  vocab_char_map = None
@@ -131,7 +143,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
131
  return vocab_char_map, vocab_size
132
 
133
 
134
-
135
  # convert char to pinyin
136
 
137
  jieba.initialize()
@@ -145,9 +156,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
145
  ) # add custom trans here, to address oov
146
 
147
  def is_chinese(c):
148
- return (
149
- "\u3100" <= c <= "\u9fff" # common chinese characters
150
- )
151
 
152
  for text in text_list:
153
  char_list = []
@@ -158,7 +167,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
158
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
159
  char_list.append(" ")
160
  char_list.extend(seg)
161
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
 
 
162
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
163
  for i, c in enumerate(seg):
164
  if is_chinese(c):
@@ -170,7 +181,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
170
  char_list.extend(c)
171
  elif is_chinese(c):
172
  char_list.append(" ")
173
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
 
 
174
  else:
175
  char_list.append(c)
176
  final_text_list.append(char_list)
@@ -224,7 +237,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema=True):
224
  def sample_consecutive_steps(float_list):
225
  idx = torch.randint(0, len(float_list), size=(1,))
226
  next_idx = idx - 1
227
-
228
  if next_idx < 0:
229
  next_idx = 0
230
  else:
 
5
  from collections import defaultdict
6
  from importlib.resources import files
7
 
8
+ import jieba
9
  import torch
10
+ from pypinyin import Style, lazy_pinyin
11
  from torch.nn.utils.rnn import pad_sequence
12
 
 
 
 
 
13
  # seed everything
14
 
15
 
 
37
  # tensor helpers
38
 
39
 
40
+ def lens_to_mask(
41
+ t: int["b"], length: int | None = None
42
+ ) -> bool["b n"]: # noqa: F722 F821
43
  if not exists(length):
44
  length = t.amax()
45
 
 
47
  return seq[None, :] < t[:, None]
48
 
49
 
50
+ def mask_from_start_end_indices(
51
+ seq_len: int["b"], start: int["b"], end: int["b"]
52
+ ): # noqa: F722 F821
53
  max_seq_len = seq_len.max().item()
54
  seq = torch.arange(max_seq_len, device=start.device).long()
55
  start_mask = seq[None, :] >= start[:, None]
 
57
  return start_mask & end_mask
58
 
59
 
60
+ def mask_from_frac_lengths(
61
+ seq_len: int["b"], frac_lengths: float["b"]
62
+ ): # noqa: F722 F821
63
  lengths = (frac_lengths * seq_len).long()
64
  max_start = seq_len - lengths
65
 
 
70
  return mask_from_start_end_indices(seq_len, start, end)
71
 
72
 
73
+ def maybe_masked_mean(
74
+ t: float["b n d"], mask: bool["b n"] = None
75
+ ) -> float["b d"]: # noqa: F722
76
  if not exists(mask):
77
  return t.mean(dim=1)
78
 
 
96
  vocab_char_map: dict[str, int], # {char: idx}
97
  padding_value=-1,
98
  ) -> int["b nt"]: # noqa: F722
99
+ list_idx_tensors = [
100
+ torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
101
+ ] # pinyin or char style
102
  text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
103
  return text
104
 
 
117
  - if use "byte", set to 256 (unicode byte range)
118
  """
119
  if tokenizer in ["pinyin", "char"]:
120
+ tokenizer_path = os.path.join(
121
+ files("f5_tts").joinpath("../data"), f"{dataset_name}_{tokenizer}/vocab.txt"
122
+ )
123
  with open(tokenizer_path, "r", encoding="utf-8") as f:
124
  vocab_char_map = {}
125
  for i, char in enumerate(f):
126
  vocab_char_map[char[:-1]] = i
127
  vocab_size = len(vocab_char_map)
128
+ assert (
129
+ vocab_char_map[" "] == 0
130
+ ), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
131
 
132
  elif tokenizer == "byte":
133
  vocab_char_map = None
 
143
  return vocab_char_map, vocab_size
144
 
145
 
 
146
  # convert char to pinyin
147
 
148
  jieba.initialize()
 
156
  ) # add custom trans here, to address oov
157
 
158
  def is_chinese(c):
159
+ return "\u3100" <= c <= "\u9fff" # common chinese characters
 
 
160
 
161
  for text in text_list:
162
  char_list = []
 
167
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
168
  char_list.append(" ")
169
  char_list.extend(seg)
170
+ elif polyphone and seg_byte_len == 3 * len(
171
+ seg
172
+ ): # if pure east asian characters
173
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
174
  for i, c in enumerate(seg):
175
  if is_chinese(c):
 
181
  char_list.extend(c)
182
  elif is_chinese(c):
183
  char_list.append(" ")
184
+ char_list.extend(
185
+ lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
186
+ )
187
  else:
188
  char_list.append(c)
189
  final_text_list.append(char_list)
 
237
  def sample_consecutive_steps(float_list):
238
  idx = torch.randint(0, len(float_list), size=(1,))
239
  next_idx = idx - 1
240
+
241
  if next_idx < 0:
242
  next_idx = 0
243
  else:
f5_tts/model_new/__init__.py CHANGED
@@ -4,5 +4,4 @@ from f5_tts.model_new.backbones.unett import UNetT
4
  from f5_tts.model_new.cfm import CFM
5
  from f5_tts.model_new.trainer import Trainer
6
 
7
-
8
  __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
 
4
  from f5_tts.model_new.cfm import CFM
5
  from f5_tts.model_new.trainer import Trainer
6
 
 
7
  __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
f5_tts/model_new/backbones/dit.py CHANGED
@@ -14,40 +14,49 @@ import torch.nn.functional as F
14
  from torch import nn
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
- from f5_tts.model_new.modules import (
18
- AdaLayerNorm_Final,
19
- ConvNeXtV2Block,
20
- ConvPositionEmbedding,
21
- DiTBlock,
22
- TimestepEmbedding,
23
- get_pos_embed_indices,
24
- precompute_freqs_cis,
25
- )
26
-
27
 
28
  # Text embedding
29
 
30
 
31
  class TextEmbedding(nn.Module):
32
- def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
 
 
33
  super().__init__()
34
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
 
 
35
 
36
  self.mask_padding = mask_padding # mask filler and batch padding tokens or not
37
 
38
  if conv_layers > 0:
39
  self.extra_modeling = True
40
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
41
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
 
 
 
 
42
  self.text_blocks = nn.Sequential(
43
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
 
 
 
44
  )
45
  else:
46
  self.extra_modeling = False
47
 
48
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
49
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
50
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
 
 
 
51
  batch, text_len = text.shape[0], text.shape[1]
52
  text = F.pad(text, (0, seq_len - text_len), value=0)
53
  if self.mask_padding:
@@ -62,16 +71,22 @@ class TextEmbedding(nn.Module):
62
  if self.extra_modeling:
63
  # sinus pos emb
64
  batch_start = torch.zeros((batch,), dtype=torch.long)
65
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
 
 
66
  text_pos_embed = self.freqs_cis[pos_idx]
67
  text = text + text_pos_embed
68
 
69
  # convnextv2 blocks
70
  if self.mask_padding:
71
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
 
 
72
  for block in self.text_blocks:
73
  text = block(text)
74
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
 
 
75
  else:
76
  text = self.text_blocks(text)
77
 
@@ -87,7 +102,13 @@ class InputEmbedding(nn.Module):
87
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
88
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
89
 
90
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
 
 
 
 
 
 
91
  if drop_audio_cond: # cfg for cond audio
92
  cond = torch.zeros_like(cond)
93
 
@@ -127,7 +148,10 @@ class DiT(nn.Module):
127
  if text_dim is None:
128
  text_dim = mel_dim
129
  self.text_embed = TextEmbedding(
130
- text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
 
 
 
131
  )
132
  self.text_cond, self.text_uncond = None, None # text cache
133
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
@@ -153,7 +177,9 @@ class DiT(nn.Module):
153
  for _ in range(depth)
154
  ]
155
  )
156
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
 
 
157
 
158
  self.norm_out = AdaLayerNorm_Final(dim) # final modulation
159
  self.proj_out = nn.Linear(dim, mel_dim)
@@ -230,13 +256,24 @@ class DiT(nn.Module):
230
  # t: conditioning time, text: text, x: noised audio + cond audio + text
231
  t = self.time_embed(time)
232
  if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
233
- x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
234
- x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
 
 
 
 
235
  x = torch.cat((x_cond, x_uncond), dim=0)
236
  t = torch.cat((t, t), dim=0)
237
  mask = torch.cat((mask, mask), dim=0) if mask is not None else None
238
  else:
239
- x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
 
 
 
 
 
 
 
240
 
241
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
242
 
@@ -246,7 +283,9 @@ class DiT(nn.Module):
246
  for block in self.transformer_blocks:
247
  if self.checkpoint_activations:
248
  # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
249
- x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
 
 
250
  else:
251
  x = block(x, t, mask=mask, rope=rope)
252
 
 
14
  from torch import nn
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
+ from f5_tts.model_new.modules import (AdaLayerNorm_Final, ConvNeXtV2Block,
18
+ ConvPositionEmbedding, DiTBlock,
19
+ TimestepEmbedding, get_pos_embed_indices,
20
+ precompute_freqs_cis)
 
 
 
 
 
 
21
 
22
  # Text embedding
23
 
24
 
25
  class TextEmbedding(nn.Module):
26
+ def __init__(
27
+ self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2
28
+ ):
29
  super().__init__()
30
+ self.text_embed = nn.Embedding(
31
+ text_num_embeds + 1, text_dim
32
+ ) # use 0 as filler token
33
 
34
  self.mask_padding = mask_padding # mask filler and batch padding tokens or not
35
 
36
  if conv_layers > 0:
37
  self.extra_modeling = True
38
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
39
+ self.register_buffer(
40
+ "freqs_cis",
41
+ precompute_freqs_cis(text_dim, self.precompute_max_pos),
42
+ persistent=False,
43
+ )
44
  self.text_blocks = nn.Sequential(
45
+ *[
46
+ ConvNeXtV2Block(text_dim, text_dim * conv_mult)
47
+ for _ in range(conv_layers)
48
+ ]
49
  )
50
  else:
51
  self.extra_modeling = False
52
 
53
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
54
+ text = (
55
+ text + 1
56
+ ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
57
+ text = text[
58
+ :, :seq_len
59
+ ] # curtail if character tokens are more than the mel spec tokens
60
  batch, text_len = text.shape[0], text.shape[1]
61
  text = F.pad(text, (0, seq_len - text_len), value=0)
62
  if self.mask_padding:
 
71
  if self.extra_modeling:
72
  # sinus pos emb
73
  batch_start = torch.zeros((batch,), dtype=torch.long)
74
+ pos_idx = get_pos_embed_indices(
75
+ batch_start, seq_len, max_pos=self.precompute_max_pos
76
+ )
77
  text_pos_embed = self.freqs_cis[pos_idx]
78
  text = text + text_pos_embed
79
 
80
  # convnextv2 blocks
81
  if self.mask_padding:
82
+ text = text.masked_fill(
83
+ text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
84
+ )
85
  for block in self.text_blocks:
86
  text = block(text)
87
+ text = text.masked_fill(
88
+ text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
89
+ )
90
  else:
91
  text = self.text_blocks(text)
92
 
 
102
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
103
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
104
 
105
+ def forward(
106
+ self,
107
+ x: float["b n d"],
108
+ cond: float["b n d"],
109
+ text_embed: float["b n d"],
110
+ drop_audio_cond=False,
111
+ ): # noqa: F722
112
  if drop_audio_cond: # cfg for cond audio
113
  cond = torch.zeros_like(cond)
114
 
 
148
  if text_dim is None:
149
  text_dim = mel_dim
150
  self.text_embed = TextEmbedding(
151
+ text_num_embeds,
152
+ text_dim,
153
+ mask_padding=text_mask_padding,
154
+ conv_layers=conv_layers,
155
  )
156
  self.text_cond, self.text_uncond = None, None # text cache
157
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
 
177
  for _ in range(depth)
178
  ]
179
  )
180
+ self.long_skip_connection = (
181
+ nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
182
+ )
183
 
184
  self.norm_out = AdaLayerNorm_Final(dim) # final modulation
185
  self.proj_out = nn.Linear(dim, mel_dim)
 
256
  # t: conditioning time, text: text, x: noised audio + cond audio + text
257
  t = self.time_embed(time)
258
  if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
259
+ x_cond = self.get_input_embed(
260
+ x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache
261
+ )
262
+ x_uncond = self.get_input_embed(
263
+ x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache
264
+ )
265
  x = torch.cat((x_cond, x_uncond), dim=0)
266
  t = torch.cat((t, t), dim=0)
267
  mask = torch.cat((mask, mask), dim=0) if mask is not None else None
268
  else:
269
+ x = self.get_input_embed(
270
+ x,
271
+ cond,
272
+ text,
273
+ drop_audio_cond=drop_audio_cond,
274
+ drop_text=drop_text,
275
+ cache=cache,
276
+ )
277
 
278
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
279
 
 
283
  for block in self.transformer_blocks:
284
  if self.checkpoint_activations:
285
  # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
286
+ x = torch.utils.checkpoint.checkpoint(
287
+ self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False
288
+ )
289
  else:
290
  x = block(x, t, mask=mask, rope=rope)
291
 
f5_tts/model_new/backbones/mmdit.py CHANGED
@@ -13,15 +13,10 @@ import torch
13
  from torch import nn
14
  from x_transformers.x_transformers import RotaryEmbedding
15
 
16
- from f5_tts.model_new.modules import (
17
- AdaLayerNorm_Final,
18
- ConvPositionEmbedding,
19
- MMDiTBlock,
20
- TimestepEmbedding,
21
- get_pos_embed_indices,
22
- precompute_freqs_cis,
23
- )
24
-
25
 
26
  # text embedding
27
 
@@ -29,15 +24,25 @@ from f5_tts.model_new.modules import (
29
  class TextEmbedding(nn.Module):
30
  def __init__(self, out_dim, text_num_embeds, mask_padding=True):
31
  super().__init__()
32
- self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
 
 
33
 
34
  self.mask_padding = mask_padding # mask filler and batch padding tokens or not
35
 
36
  self.precompute_max_pos = 1024
37
- self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
 
 
 
 
38
 
39
- def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
40
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
 
 
 
 
41
  if self.mask_padding:
42
  text_mask = text == 0
43
 
@@ -49,13 +54,17 @@ class TextEmbedding(nn.Module):
49
  # sinus pos emb
50
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
51
  batch_text_len = text.shape[1]
52
- pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
 
 
53
  text_pos_embed = self.freqs_cis[pos_idx]
54
 
55
  text = text + text_pos_embed
56
 
57
  if self.mask_padding:
58
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
 
 
59
 
60
  return text
61
 
@@ -69,7 +78,9 @@ class AudioEmbedding(nn.Module):
69
  self.linear = nn.Linear(2 * in_dim, out_dim)
70
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
71
 
72
- def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
 
 
73
  if drop_audio_cond:
74
  cond = torch.zeros_like(cond)
75
  x = torch.cat((x, cond), dim=-1)
@@ -99,7 +110,9 @@ class MMDiT(nn.Module):
99
  super().__init__()
100
 
101
  self.time_embed = TimestepEmbedding(dim)
102
- self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
 
 
103
  self.text_cond, self.text_uncond = None, None # text cache
104
  self.audio_embed = AudioEmbedding(mel_dim, dim)
105
 
@@ -187,15 +200,24 @@ class MMDiT(nn.Module):
187
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
188
  t = self.time_embed(time)
189
  if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
190
- x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
191
- x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
 
 
 
 
192
  x = torch.cat((x_cond, x_uncond), dim=0)
193
  c = torch.cat((c_cond, c_uncond), dim=0)
194
  t = torch.cat((t, t), dim=0)
195
  mask = torch.cat((mask, mask), dim=0) if mask is not None else None
196
  else:
197
  x, c = self.get_input_embed(
198
- x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
 
 
 
 
 
199
  )
200
 
201
  seq_len = x.shape[1]
 
13
  from torch import nn
14
  from x_transformers.x_transformers import RotaryEmbedding
15
 
16
+ from f5_tts.model_new.modules import (AdaLayerNorm_Final,
17
+ ConvPositionEmbedding, MMDiTBlock,
18
+ TimestepEmbedding, get_pos_embed_indices,
19
+ precompute_freqs_cis)
 
 
 
 
 
20
 
21
  # text embedding
22
 
 
24
  class TextEmbedding(nn.Module):
25
  def __init__(self, out_dim, text_num_embeds, mask_padding=True):
26
  super().__init__()
27
+ self.text_embed = nn.Embedding(
28
+ text_num_embeds + 1, out_dim
29
+ ) # will use 0 as filler token
30
 
31
  self.mask_padding = mask_padding # mask filler and batch padding tokens or not
32
 
33
  self.precompute_max_pos = 1024
34
+ self.register_buffer(
35
+ "freqs_cis",
36
+ precompute_freqs_cis(out_dim, self.precompute_max_pos),
37
+ persistent=False,
38
+ )
39
 
40
+ def forward(
41
+ self, text: int["b nt"], drop_text=False
42
+ ) -> int["b nt d"]: # noqa: F722
43
+ text = (
44
+ text + 1
45
+ ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
46
  if self.mask_padding:
47
  text_mask = text == 0
48
 
 
54
  # sinus pos emb
55
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
56
  batch_text_len = text.shape[1]
57
+ pos_idx = get_pos_embed_indices(
58
+ batch_start, batch_text_len, max_pos=self.precompute_max_pos
59
+ )
60
  text_pos_embed = self.freqs_cis[pos_idx]
61
 
62
  text = text + text_pos_embed
63
 
64
  if self.mask_padding:
65
+ text = text.masked_fill(
66
+ text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
67
+ )
68
 
69
  return text
70
 
 
78
  self.linear = nn.Linear(2 * in_dim, out_dim)
79
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
80
 
81
+ def forward(
82
+ self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False
83
+ ): # noqa: F722
84
  if drop_audio_cond:
85
  cond = torch.zeros_like(cond)
86
  x = torch.cat((x, cond), dim=-1)
 
110
  super().__init__()
111
 
112
  self.time_embed = TimestepEmbedding(dim)
113
+ self.text_embed = TextEmbedding(
114
+ dim, text_num_embeds, mask_padding=text_mask_padding
115
+ )
116
  self.text_cond, self.text_uncond = None, None # text cache
117
  self.audio_embed = AudioEmbedding(mel_dim, dim)
118
 
 
200
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
201
  t = self.time_embed(time)
202
  if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
203
+ x_cond, c_cond = self.get_input_embed(
204
+ x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache
205
+ )
206
+ x_uncond, c_uncond = self.get_input_embed(
207
+ x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache
208
+ )
209
  x = torch.cat((x_cond, x_uncond), dim=0)
210
  c = torch.cat((c_cond, c_uncond), dim=0)
211
  t = torch.cat((t, t), dim=0)
212
  mask = torch.cat((mask, mask), dim=0) if mask is not None else None
213
  else:
214
  x, c = self.get_input_embed(
215
+ x,
216
+ cond,
217
+ text,
218
+ drop_audio_cond=drop_audio_cond,
219
+ drop_text=drop_text,
220
+ cache=cache,
221
  )
222
 
223
  seq_len = x.shape[1]
f5_tts/model_new/backbones/unett.py CHANGED
@@ -17,41 +17,50 @@ from torch import nn
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
- from f5_tts.model_new.modules import (
21
- Attention,
22
- AttnProcessor,
23
- ConvNeXtV2Block,
24
- ConvPositionEmbedding,
25
- FeedForward,
26
- TimestepEmbedding,
27
- get_pos_embed_indices,
28
- precompute_freqs_cis,
29
- )
30
-
31
 
32
  # Text embedding
33
 
34
 
35
  class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
 
 
37
  super().__init__()
38
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
 
 
39
 
40
  self.mask_padding = mask_padding # mask filler and batch padding tokens or not
41
 
42
  if conv_layers > 0:
43
  self.extra_modeling = True
44
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
45
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
 
 
 
 
46
  self.text_blocks = nn.Sequential(
47
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
 
 
 
48
  )
49
  else:
50
  self.extra_modeling = False
51
 
52
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
53
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
54
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
 
 
 
55
  batch, text_len = text.shape[0], text.shape[1]
56
  text = F.pad(text, (0, seq_len - text_len), value=0)
57
  if self.mask_padding:
@@ -66,16 +75,22 @@ class TextEmbedding(nn.Module):
66
  if self.extra_modeling:
67
  # sinus pos emb
68
  batch_start = torch.zeros((batch,), dtype=torch.long)
69
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
 
 
70
  text_pos_embed = self.freqs_cis[pos_idx]
71
  text = text + text_pos_embed
72
 
73
  # convnextv2 blocks
74
  if self.mask_padding:
75
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
 
 
76
  for block in self.text_blocks:
77
  text = block(text)
78
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
 
 
79
  else:
80
  text = self.text_blocks(text)
81
 
@@ -91,7 +106,13 @@ class InputEmbedding(nn.Module):
91
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
92
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
93
 
94
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
 
 
 
 
 
 
95
  if drop_audio_cond: # cfg for cond audio
96
  cond = torch.zeros_like(cond)
97
 
@@ -129,7 +150,10 @@ class UNetT(nn.Module):
129
  if text_dim is None:
130
  text_dim = mel_dim
131
  self.text_embed = TextEmbedding(
132
- text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
 
 
 
133
  )
134
  self.text_cond, self.text_uncond = None, None # text cache
135
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
@@ -161,7 +185,11 @@ class UNetT(nn.Module):
161
  ff_norm = RMSNorm(dim)
162
  ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
163
 
164
- skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
 
 
 
 
165
 
166
  self.layers.append(
167
  nn.ModuleList(
@@ -226,13 +254,24 @@ class UNetT(nn.Module):
226
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
227
  t = self.time_embed(time)
228
  if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
229
- x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
230
- x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
 
 
 
 
231
  x = torch.cat((x_cond, x_uncond), dim=0)
232
  t = torch.cat((t, t), dim=0)
233
  mask = torch.cat((mask, mask), dim=0) if mask is not None else None
234
  else:
235
- x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
 
 
 
 
 
 
 
236
 
237
  # postfix time t to input x, [b n d] -> [b n+1 d]
238
  x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
@@ -244,7 +283,9 @@ class UNetT(nn.Module):
244
  # flat unet transformer
245
  skip_connect_type = self.skip_connect_type
246
  skips = []
247
- for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
 
 
248
  layer = idx + 1
249
 
250
  # skip connection logic
 
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
+ from f5_tts.model_new.modules import (Attention, AttnProcessor,
21
+ ConvNeXtV2Block, ConvPositionEmbedding,
22
+ FeedForward, TimestepEmbedding,
23
+ get_pos_embed_indices,
24
+ precompute_freqs_cis)
 
 
 
 
 
 
25
 
26
  # Text embedding
27
 
28
 
29
  class TextEmbedding(nn.Module):
30
+ def __init__(
31
+ self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2
32
+ ):
33
  super().__init__()
34
+ self.text_embed = nn.Embedding(
35
+ text_num_embeds + 1, text_dim
36
+ ) # use 0 as filler token
37
 
38
  self.mask_padding = mask_padding # mask filler and batch padding tokens or not
39
 
40
  if conv_layers > 0:
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer(
44
+ "freqs_cis",
45
+ precompute_freqs_cis(text_dim, self.precompute_max_pos),
46
+ persistent=False,
47
+ )
48
  self.text_blocks = nn.Sequential(
49
+ *[
50
+ ConvNeXtV2Block(text_dim, text_dim * conv_mult)
51
+ for _ in range(conv_layers)
52
+ ]
53
  )
54
  else:
55
  self.extra_modeling = False
56
 
57
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
58
+ text = (
59
+ text + 1
60
+ ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
61
+ text = text[
62
+ :, :seq_len
63
+ ] # curtail if character tokens are more than the mel spec tokens
64
  batch, text_len = text.shape[0], text.shape[1]
65
  text = F.pad(text, (0, seq_len - text_len), value=0)
66
  if self.mask_padding:
 
75
  if self.extra_modeling:
76
  # sinus pos emb
77
  batch_start = torch.zeros((batch,), dtype=torch.long)
78
+ pos_idx = get_pos_embed_indices(
79
+ batch_start, seq_len, max_pos=self.precompute_max_pos
80
+ )
81
  text_pos_embed = self.freqs_cis[pos_idx]
82
  text = text + text_pos_embed
83
 
84
  # convnextv2 blocks
85
  if self.mask_padding:
86
+ text = text.masked_fill(
87
+ text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
88
+ )
89
  for block in self.text_blocks:
90
  text = block(text)
91
+ text = text.masked_fill(
92
+ text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
93
+ )
94
  else:
95
  text = self.text_blocks(text)
96
 
 
106
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
107
  self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
108
 
109
+ def forward(
110
+ self,
111
+ x: float["b n d"],
112
+ cond: float["b n d"],
113
+ text_embed: float["b n d"],
114
+ drop_audio_cond=False,
115
+ ): # noqa: F722
116
  if drop_audio_cond: # cfg for cond audio
117
  cond = torch.zeros_like(cond)
118
 
 
150
  if text_dim is None:
151
  text_dim = mel_dim
152
  self.text_embed = TextEmbedding(
153
+ text_num_embeds,
154
+ text_dim,
155
+ mask_padding=text_mask_padding,
156
+ conv_layers=conv_layers,
157
  )
158
  self.text_cond, self.text_uncond = None, None # text cache
159
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
 
185
  ff_norm = RMSNorm(dim)
186
  ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
187
 
188
+ skip_proj = (
189
+ nn.Linear(dim * 2, dim, bias=False)
190
+ if needs_skip_proj and is_later_half
191
+ else None
192
+ )
193
 
194
  self.layers.append(
195
  nn.ModuleList(
 
254
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
255
  t = self.time_embed(time)
256
  if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
257
+ x_cond = self.get_input_embed(
258
+ x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache
259
+ )
260
+ x_uncond = self.get_input_embed(
261
+ x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache
262
+ )
263
  x = torch.cat((x_cond, x_uncond), dim=0)
264
  t = torch.cat((t, t), dim=0)
265
  mask = torch.cat((mask, mask), dim=0) if mask is not None else None
266
  else:
267
+ x = self.get_input_embed(
268
+ x,
269
+ cond,
270
+ text,
271
+ drop_audio_cond=drop_audio_cond,
272
+ drop_text=drop_text,
273
+ cache=cache,
274
+ )
275
 
276
  # postfix time t to input x, [b n d] -> [b n+1 d]
277
  x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
 
283
  # flat unet transformer
284
  skip_connect_type = self.skip_connect_type
285
  skips = []
286
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(
287
+ self.layers
288
+ ):
289
  layer = idx + 1
290
 
291
  # skip connection logic
f5_tts/model_new/cfm.py CHANGED
@@ -19,15 +19,9 @@ from torch.nn.utils.rnn import pad_sequence
19
  from torchdiffeq import odeint
20
 
21
  from f5_tts.model_new.modules import MelSpec
22
- from f5_tts.model_new.utils import (
23
- default,
24
- exists,
25
- get_epss_timesteps,
26
- lens_to_mask,
27
- list_str_to_idx,
28
- list_str_to_tensor,
29
- mask_from_frac_lengths,
30
- )
31
 
32
 
33
  class CFM(nn.Module):
@@ -139,13 +133,17 @@ class CFM(nn.Module):
139
 
140
  # duplicate test corner for inner time step oberservation
141
  if duplicate_test:
142
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
 
 
143
 
144
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
145
  if no_ref_audio:
146
  cond = torch.zeros_like(cond)
147
 
148
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
 
 
149
  cond_mask = cond_mask.unsqueeze(-1)
150
  step_cond = torch.where(
151
  cond_mask, cond, torch.zeros_like(cond)
@@ -196,7 +194,11 @@ class CFM(nn.Module):
196
  for dur in duration:
197
  if exists(seed):
198
  torch.manual_seed(seed)
199
- y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
 
 
 
 
200
  y0 = pad_sequence(y0, padding_value=0, batch_first=True)
201
 
202
  t_start = 0
@@ -207,10 +209,14 @@ class CFM(nn.Module):
207
  y0 = (1 - t_start) * y0 + t_start * test_cond
208
  steps = int(steps * (1 - t_start))
209
 
210
- if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
 
 
211
  t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
212
  else:
213
- t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
 
 
214
  if sway_sampling_coef is not None:
215
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
216
 
@@ -241,7 +247,12 @@ class CFM(nn.Module):
241
  inp = inp.permute(0, 2, 1)
242
  assert inp.shape[-1] == self.num_channels
243
 
244
- batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
 
 
 
 
 
245
 
246
  # handle text as string
247
  if isinstance(text, list):
@@ -255,10 +266,16 @@ class CFM(nn.Module):
255
  if not exists(lens):
256
  lens = torch.full((batch,), seq_len, device=device)
257
 
258
- mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
 
 
259
 
260
  # get a random span to mask out for training conditionally
261
- frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
 
 
 
 
262
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
263
 
264
  if exists(mask):
@@ -292,7 +309,13 @@ class CFM(nn.Module):
292
 
293
  # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
294
  pred = self.transformer(
295
- x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
 
 
 
 
 
 
296
  )
297
 
298
  # flow matching loss
 
19
  from torchdiffeq import odeint
20
 
21
  from f5_tts.model_new.modules import MelSpec
22
+ from f5_tts.model_new.utils import (default, exists, get_epss_timesteps,
23
+ lens_to_mask, list_str_to_idx,
24
+ list_str_to_tensor, mask_from_frac_lengths)
 
 
 
 
 
 
25
 
26
 
27
  class CFM(nn.Module):
 
133
 
134
  # duplicate test corner for inner time step oberservation
135
  if duplicate_test:
136
+ test_cond = F.pad(
137
+ cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
138
+ )
139
 
140
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
141
  if no_ref_audio:
142
  cond = torch.zeros_like(cond)
143
 
144
+ cond_mask = F.pad(
145
+ cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
146
+ )
147
  cond_mask = cond_mask.unsqueeze(-1)
148
  step_cond = torch.where(
149
  cond_mask, cond, torch.zeros_like(cond)
 
194
  for dur in duration:
195
  if exists(seed):
196
  torch.manual_seed(seed)
197
+ y0.append(
198
+ torch.randn(
199
+ dur, self.num_channels, device=self.device, dtype=step_cond.dtype
200
+ )
201
+ )
202
  y0 = pad_sequence(y0, padding_value=0, batch_first=True)
203
 
204
  t_start = 0
 
209
  y0 = (1 - t_start) * y0 + t_start * test_cond
210
  steps = int(steps * (1 - t_start))
211
 
212
+ if (
213
+ t_start == 0 and use_epss
214
+ ): # use Empirically Pruned Step Sampling for low NFE
215
  t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
216
  else:
217
+ t = torch.linspace(
218
+ t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
219
+ )
220
  if sway_sampling_coef is not None:
221
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
222
 
 
247
  inp = inp.permute(0, 2, 1)
248
  assert inp.shape[-1] == self.num_channels
249
 
250
+ batch, seq_len, dtype, device, _σ1 = (
251
+ *inp.shape[:2],
252
+ inp.dtype,
253
+ self.device,
254
+ self.sigma,
255
+ )
256
 
257
  # handle text as string
258
  if isinstance(text, list):
 
266
  if not exists(lens):
267
  lens = torch.full((batch,), seq_len, device=device)
268
 
269
+ mask = lens_to_mask(
270
+ lens, length=seq_len
271
+ ) # useless here, as collate_fn will pad to max length in batch
272
 
273
  # get a random span to mask out for training conditionally
274
+ frac_lengths = (
275
+ torch.zeros((batch,), device=self.device)
276
+ .float()
277
+ .uniform_(*self.frac_lengths_mask)
278
+ )
279
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
280
 
281
  if exists(mask):
 
309
 
310
  # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
311
  pred = self.transformer(
312
+ x=φ,
313
+ cond=cond,
314
+ text=text,
315
+ time=time,
316
+ drop_audio_cond=drop_audio_cond,
317
+ drop_text=drop_text,
318
+ mask=mask,
319
  )
320
 
321
  # flow matching loss
f5_tts/model_new/dataset.py CHANGED
@@ -62,7 +62,9 @@ class HFDataset(Dataset):
62
  audio_tensor = torch.from_numpy(audio).float()
63
 
64
  if sample_rate != self.target_sample_rate:
65
- resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
 
 
66
  audio_tensor = resampler(audio_tensor)
67
 
68
  audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
@@ -149,7 +151,9 @@ class CustomDataset(Dataset):
149
 
150
  # resample if necessary
151
  if source_sample_rate != self.target_sample_rate:
152
- resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
 
 
153
  audio = resampler(audio)
154
 
155
  # to mel spectrogram
@@ -173,7 +177,12 @@ class DynamicBatchSampler(Sampler[list[int]]):
173
  """
174
 
175
  def __init__(
176
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
 
 
 
 
 
177
  ):
178
  self.sampler = sampler
179
  self.frames_threshold = frames_threshold
@@ -185,7 +194,8 @@ class DynamicBatchSampler(Sampler[list[int]]):
185
  data_source = self.sampler.data_source
186
 
187
  for idx in tqdm(
188
- self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
 
189
  ):
190
  indices.append((idx, data_source.get_frame_len(idx)))
191
  indices.sort(key=lambda elem: elem[1])
@@ -193,9 +203,12 @@ class DynamicBatchSampler(Sampler[list[int]]):
193
  batch = []
194
  batch_frames = 0
195
  for idx, frame_len in tqdm(
196
- indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
 
197
  ):
198
- if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
 
 
199
  batch.append(idx)
200
  batch_frames += frame_len
201
  else:
@@ -256,7 +269,9 @@ def load_dataset(
256
  print("Loading dataset ...")
257
 
258
  if dataset_type == "CustomDataset":
259
- rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
 
 
260
  if audio_type == "raw":
261
  try:
262
  train_dataset = load_from_disk(f"{rel_data_path}/raw")
@@ -287,7 +302,10 @@ def load_dataset(
287
  data_dict = json.load(f)
288
  durations = data_dict["duration"]
289
  train_dataset = CustomDataset(
290
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
 
 
 
291
  )
292
 
293
  elif dataset_type == "HFDataset":
@@ -297,7 +315,11 @@ def load_dataset(
297
  )
298
  pre, post = dataset_name.split("_")
299
  train_dataset = HFDataset(
300
- load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
 
 
 
 
301
  )
302
 
303
  return train_dataset
 
62
  audio_tensor = torch.from_numpy(audio).float()
63
 
64
  if sample_rate != self.target_sample_rate:
65
+ resampler = torchaudio.transforms.Resample(
66
+ sample_rate, self.target_sample_rate
67
+ )
68
  audio_tensor = resampler(audio_tensor)
69
 
70
  audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
 
151
 
152
  # resample if necessary
153
  if source_sample_rate != self.target_sample_rate:
154
+ resampler = torchaudio.transforms.Resample(
155
+ source_sample_rate, self.target_sample_rate
156
+ )
157
  audio = resampler(audio)
158
 
159
  # to mel spectrogram
 
177
  """
178
 
179
  def __init__(
180
+ self,
181
+ sampler: Sampler[int],
182
+ frames_threshold: int,
183
+ max_samples=0,
184
+ random_seed=None,
185
+ drop_residual: bool = False,
186
  ):
187
  self.sampler = sampler
188
  self.frames_threshold = frames_threshold
 
194
  data_source = self.sampler.data_source
195
 
196
  for idx in tqdm(
197
+ self.sampler,
198
+ desc="Sorting with sampler... if slow, check whether dataset is provided with duration",
199
  ):
200
  indices.append((idx, data_source.get_frame_len(idx)))
201
  indices.sort(key=lambda elem: elem[1])
 
203
  batch = []
204
  batch_frames = 0
205
  for idx, frame_len in tqdm(
206
+ indices,
207
+ desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu",
208
  ):
209
+ if batch_frames + frame_len <= self.frames_threshold and (
210
+ max_samples == 0 or len(batch) < max_samples
211
+ ):
212
  batch.append(idx)
213
  batch_frames += frame_len
214
  else:
 
269
  print("Loading dataset ...")
270
 
271
  if dataset_type == "CustomDataset":
272
+ rel_data_path = str(
273
+ files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}")
274
+ )
275
  if audio_type == "raw":
276
  try:
277
  train_dataset = load_from_disk(f"{rel_data_path}/raw")
 
302
  data_dict = json.load(f)
303
  durations = data_dict["duration"]
304
  train_dataset = CustomDataset(
305
+ train_dataset,
306
+ durations=durations,
307
+ preprocessed_mel=preprocessed_mel,
308
+ **mel_spec_kwargs,
309
  )
310
 
311
  elif dataset_type == "HFDataset":
 
315
  )
316
  pre, post = dataset_name.split("_")
317
  train_dataset = HFDataset(
318
+ load_dataset(
319
+ f"{pre}/{pre}",
320
+ split=f"train.{post}",
321
+ cache_dir=str(files("f5_tts").joinpath("../../data")),
322
+ ),
323
  )
324
 
325
  return train_dataset
f5_tts/model_new/modules.py CHANGED
@@ -6,6 +6,7 @@ nt - text sequence
6
  nw - raw wave length
7
  d - dimension
8
  """
 
9
  # flake8: noqa
10
 
11
  from __future__ import annotations
@@ -22,7 +23,6 @@ from x_transformers.x_transformers import apply_rotary_pos_emb
22
 
23
  from f5_tts.model_new.utils import is_package_available
24
 
25
-
26
  # raw wav to mel spec
27
 
28
 
@@ -45,15 +45,25 @@ def get_bigvgan_mel_spectrogram(
45
  key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
46
 
47
  if key not in mel_basis_cache:
48
- mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
49
- mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
 
 
 
 
 
 
 
 
50
  hann_window_cache[key] = torch.hann_window(win_length).to(device)
51
 
52
  mel_basis = mel_basis_cache[key]
53
  hann_window = hann_window_cache[key]
54
 
55
  padding = (n_fft - hop_length) // 2
56
- waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
 
 
57
 
58
  spec = torch.stft(
59
  waveform,
@@ -115,7 +125,9 @@ class MelSpec(nn.Module):
115
  mel_spec_type="vocos",
116
  ):
117
  super().__init__()
118
- assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
 
 
119
 
120
  self.n_fft = n_fft
121
  self.hop_length = hop_length
@@ -196,7 +208,9 @@ class ConvPositionEmbedding(nn.Module):
196
  # rotary positional embedding related
197
 
198
 
199
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
 
 
200
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
201
  # has some connection to NTK literature
202
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -212,10 +226,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
212
 
213
  def get_pos_embed_indices(start, length, max_pos, scale=1.0):
214
  # length = length if isinstance(length, int) else length.max()
215
- scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
 
 
216
  pos = (
217
  start.unsqueeze(1)
218
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
 
 
 
219
  )
220
  # avoid extra long error.
221
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
@@ -254,7 +273,9 @@ class ConvNeXtV2Block(nn.Module):
254
  dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
255
  ) # depthwise conv
256
  self.norm = nn.LayerNorm(dim, eps=1e-6)
257
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
 
 
258
  self.act = nn.GELU()
259
  self.grn = GRN(intermediate_dim)
260
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
@@ -286,7 +307,9 @@ class RMSNorm(nn.Module):
286
  if self.native_rms_norm:
287
  if self.weight.dtype in [torch.float16, torch.bfloat16]:
288
  x = x.to(self.weight.dtype)
289
- x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
 
 
290
  else:
291
  variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
292
  x = x * torch.rsqrt(variance + self.eps)
@@ -312,7 +335,9 @@ class AdaLayerNorm(nn.Module):
312
 
313
  def forward(self, x, emb=None):
314
  emb = self.linear(self.silu(emb))
315
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
 
 
316
 
317
  x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
318
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
@@ -343,14 +368,18 @@ class AdaLayerNorm_Final(nn.Module):
343
 
344
 
345
  class FeedForward(nn.Module):
346
- def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
 
 
347
  super().__init__()
348
  inner_dim = int(dim * mult)
349
  dim_out = dim_out if dim_out is not None else dim
350
 
351
  activation = nn.GELU(approximate=approximate)
352
  project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
353
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
354
 
355
  def forward(self, x):
356
  return self.ff(x)
@@ -375,7 +404,9 @@ class Attention(nn.Module):
375
  super().__init__()
376
 
377
  if not hasattr(F, "scaled_dot_product_attention"):
378
- raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
 
 
379
 
380
  self.processor = processor
381
 
@@ -435,19 +466,23 @@ class Attention(nn.Module):
435
  # Attention processor
436
 
437
  if is_package_available("flash_attn"):
 
438
  from flash_attn.bert_padding import pad_input, unpad_input
439
- from flash_attn import flash_attn_varlen_func, flash_attn_func
440
 
441
 
442
  class AttnProcessor:
443
  def __init__(
444
  self,
445
- pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
 
 
446
  attn_backend: str = "torch", # "torch" or "flash_attn"
447
  attn_mask_enabled: bool = True,
448
  ):
449
  if attn_backend == "flash_attn":
450
- assert is_package_available("flash_attn"), "Please install flash-attn first."
 
 
451
 
452
  self.pe_attn_head = pe_attn_head
453
  self.attn_backend = attn_backend
@@ -483,12 +518,18 @@ class AttnProcessor:
483
  # apply rotary position embedding
484
  if rope is not None:
485
  freqs, xpos_scale = rope
486
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
 
 
487
 
488
  if self.pe_attn_head is not None:
489
  pn = self.pe_attn_head
490
- query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
491
- key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
 
 
 
 
492
  else:
493
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
494
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
@@ -498,10 +539,14 @@ class AttnProcessor:
498
  if self.attn_mask_enabled and mask is not None:
499
  attn_mask = mask
500
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
501
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
 
 
502
  else:
503
  attn_mask = None
504
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
 
 
505
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
506
 
507
  elif self.attn_backend == "flash_attn":
@@ -509,7 +554,9 @@ class AttnProcessor:
509
  key = key.transpose(1, 2)
510
  value = value.transpose(1, 2)
511
  if self.attn_mask_enabled and mask is not None:
512
- query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
 
 
513
  key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
514
  value, _, _, _, _ = unpad_input(value, mask)
515
  x = flash_attn_varlen_func(
@@ -595,12 +642,16 @@ class JointAttnProcessor:
595
  # apply rope for context and noised input independently
596
  if rope is not None:
597
  freqs, xpos_scale = rope
598
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
 
 
599
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
600
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
601
  if c_rope is not None:
602
  freqs, xpos_scale = c_rope
603
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
 
 
604
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
605
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
606
 
@@ -613,11 +664,15 @@ class JointAttnProcessor:
613
  if mask is not None:
614
  attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
615
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
616
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
 
 
617
  else:
618
  attn_mask = None
619
 
620
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
 
 
621
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
622
  x = x.to(query.dtype)
623
 
@@ -675,7 +730,9 @@ class DiTBlock(nn.Module):
675
  )
676
 
677
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
678
- self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
 
 
679
 
680
  def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
681
  # pre-norm & modulation for attention input
@@ -708,14 +765,26 @@ class MMDiTBlock(nn.Module):
708
  """
709
 
710
  def __init__(
711
- self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
 
 
 
 
 
 
 
 
712
  ):
713
  super().__init__()
714
  if context_dim is None:
715
  context_dim = dim
716
  self.context_pre_only = context_pre_only
717
 
718
- self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
 
 
 
 
719
  self.attn_norm_x = AdaLayerNorm(dim)
720
  self.attn = Attention(
721
  processor=JointAttnProcessor(),
@@ -729,24 +798,38 @@ class MMDiTBlock(nn.Module):
729
  )
730
 
731
  if not context_pre_only:
732
- self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
733
- self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
 
 
 
 
734
  else:
735
  self.ff_norm_c = None
736
  self.ff_c = None
737
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
738
- self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
 
 
739
 
740
- def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
 
 
741
  # pre-norm & modulation for attention input
742
  if self.context_pre_only:
743
  norm_c = self.attn_norm_c(c, t)
744
  else:
745
- norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
746
- norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
 
 
 
 
747
 
748
  # attention
749
- x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
 
 
750
 
751
  # process attention output for context c
752
  if self.context_pre_only:
@@ -754,7 +837,9 @@ class MMDiTBlock(nn.Module):
754
  else: # if not last layer
755
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
756
 
757
- norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
 
 
758
  c_ff_output = self.ff_c(norm_c)
759
  c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
760
 
@@ -775,7 +860,9 @@ class TimestepEmbedding(nn.Module):
775
  def __init__(self, dim, freq_embed_dim=256):
776
  super().__init__()
777
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
778
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 
 
779
 
780
  def forward(self, timestep: float["b"]):
781
  time_hidden = self.time_embed(timestep)
 
6
  nw - raw wave length
7
  d - dimension
8
  """
9
+
10
  # flake8: noqa
11
 
12
  from __future__ import annotations
 
23
 
24
  from f5_tts.model_new.utils import is_package_available
25
 
 
26
  # raw wav to mel spec
27
 
28
 
 
45
  key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
46
 
47
  if key not in mel_basis_cache:
48
+ mel = librosa_mel_fn(
49
+ sr=target_sample_rate,
50
+ n_fft=n_fft,
51
+ n_mels=n_mel_channels,
52
+ fmin=fmin,
53
+ fmax=fmax,
54
+ )
55
+ mel_basis_cache[key] = (
56
+ torch.from_numpy(mel).float().to(device)
57
+ ) # TODO: why they need .float()?
58
  hann_window_cache[key] = torch.hann_window(win_length).to(device)
59
 
60
  mel_basis = mel_basis_cache[key]
61
  hann_window = hann_window_cache[key]
62
 
63
  padding = (n_fft - hop_length) // 2
64
+ waveform = torch.nn.functional.pad(
65
+ waveform.unsqueeze(1), (padding, padding), mode="reflect"
66
+ ).squeeze(1)
67
 
68
  spec = torch.stft(
69
  waveform,
 
125
  mel_spec_type="vocos",
126
  ):
127
  super().__init__()
128
+ assert mel_spec_type in ["vocos", "bigvgan"], print(
129
+ "We only support two extract mel backend: vocos or bigvgan"
130
+ )
131
 
132
  self.n_fft = n_fft
133
  self.hop_length = hop_length
 
208
  # rotary positional embedding related
209
 
210
 
211
+ def precompute_freqs_cis(
212
+ dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
213
+ ):
214
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
215
  # has some connection to NTK literature
216
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
 
226
 
227
  def get_pos_embed_indices(start, length, max_pos, scale=1.0):
228
  # length = length if isinstance(length, int) else length.max()
229
+ scale = scale * torch.ones_like(
230
+ start, dtype=torch.float32
231
+ ) # in case scale is a scalar
232
  pos = (
233
  start.unsqueeze(1)
234
+ + (
235
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
236
+ * scale.unsqueeze(1)
237
+ ).long()
238
  )
239
  # avoid extra long error.
240
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
 
273
  dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
274
  ) # depthwise conv
275
  self.norm = nn.LayerNorm(dim, eps=1e-6)
276
+ self.pwconv1 = nn.Linear(
277
+ dim, intermediate_dim
278
+ ) # pointwise/1x1 convs, implemented with linear layers
279
  self.act = nn.GELU()
280
  self.grn = GRN(intermediate_dim)
281
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
 
307
  if self.native_rms_norm:
308
  if self.weight.dtype in [torch.float16, torch.bfloat16]:
309
  x = x.to(self.weight.dtype)
310
+ x = F.rms_norm(
311
+ x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps
312
+ )
313
  else:
314
  variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
315
  x = x * torch.rsqrt(variance + self.eps)
 
335
 
336
  def forward(self, x, emb=None):
337
  emb = self.linear(self.silu(emb))
338
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
339
+ emb, 6, dim=1
340
+ )
341
 
342
  x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
343
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
 
368
 
369
 
370
  class FeedForward(nn.Module):
371
+ def __init__(
372
+ self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
373
+ ):
374
  super().__init__()
375
  inner_dim = int(dim * mult)
376
  dim_out = dim_out if dim_out is not None else dim
377
 
378
  activation = nn.GELU(approximate=approximate)
379
  project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
380
+ self.ff = nn.Sequential(
381
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
382
+ )
383
 
384
  def forward(self, x):
385
  return self.ff(x)
 
404
  super().__init__()
405
 
406
  if not hasattr(F, "scaled_dot_product_attention"):
407
+ raise ImportError(
408
+ "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
409
+ )
410
 
411
  self.processor = processor
412
 
 
466
  # Attention processor
467
 
468
  if is_package_available("flash_attn"):
469
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
470
  from flash_attn.bert_padding import pad_input, unpad_input
 
471
 
472
 
473
  class AttnProcessor:
474
  def __init__(
475
  self,
476
+ pe_attn_head: (
477
+ int | None
478
+ ) = None, # number of attention head to apply rope, None for all
479
  attn_backend: str = "torch", # "torch" or "flash_attn"
480
  attn_mask_enabled: bool = True,
481
  ):
482
  if attn_backend == "flash_attn":
483
+ assert is_package_available(
484
+ "flash_attn"
485
+ ), "Please install flash-attn first."
486
 
487
  self.pe_attn_head = pe_attn_head
488
  self.attn_backend = attn_backend
 
518
  # apply rotary position embedding
519
  if rope is not None:
520
  freqs, xpos_scale = rope
521
+ q_xpos_scale, k_xpos_scale = (
522
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
523
+ )
524
 
525
  if self.pe_attn_head is not None:
526
  pn = self.pe_attn_head
527
+ query[:, :pn, :, :] = apply_rotary_pos_emb(
528
+ query[:, :pn, :, :], freqs, q_xpos_scale
529
+ )
530
+ key[:, :pn, :, :] = apply_rotary_pos_emb(
531
+ key[:, :pn, :, :], freqs, k_xpos_scale
532
+ )
533
  else:
534
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
535
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
 
539
  if self.attn_mask_enabled and mask is not None:
540
  attn_mask = mask
541
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
542
+ attn_mask = attn_mask.expand(
543
+ batch_size, attn.heads, query.shape[-2], key.shape[-2]
544
+ )
545
  else:
546
  attn_mask = None
547
+ x = F.scaled_dot_product_attention(
548
+ query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
549
+ )
550
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
551
 
552
  elif self.attn_backend == "flash_attn":
 
554
  key = key.transpose(1, 2)
555
  value = value.transpose(1, 2)
556
  if self.attn_mask_enabled and mask is not None:
557
+ query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(
558
+ query, mask
559
+ )
560
  key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
561
  value, _, _, _, _ = unpad_input(value, mask)
562
  x = flash_attn_varlen_func(
 
642
  # apply rope for context and noised input independently
643
  if rope is not None:
644
  freqs, xpos_scale = rope
645
+ q_xpos_scale, k_xpos_scale = (
646
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
647
+ )
648
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
649
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
650
  if c_rope is not None:
651
  freqs, xpos_scale = c_rope
652
+ q_xpos_scale, k_xpos_scale = (
653
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
654
+ )
655
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
656
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
657
 
 
664
  if mask is not None:
665
  attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
666
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
667
+ attn_mask = attn_mask.expand(
668
+ batch_size, attn.heads, query.shape[-2], key.shape[-2]
669
+ )
670
  else:
671
  attn_mask = None
672
 
673
+ x = F.scaled_dot_product_attention(
674
+ query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
675
+ )
676
  x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
677
  x = x.to(query.dtype)
678
 
 
730
  )
731
 
732
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
733
+ self.ff = FeedForward(
734
+ dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
735
+ )
736
 
737
  def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
738
  # pre-norm & modulation for attention input
 
765
  """
766
 
767
  def __init__(
768
+ self,
769
+ dim,
770
+ heads,
771
+ dim_head,
772
+ ff_mult=4,
773
+ dropout=0.1,
774
+ context_dim=None,
775
+ context_pre_only=False,
776
+ qk_norm=None,
777
  ):
778
  super().__init__()
779
  if context_dim is None:
780
  context_dim = dim
781
  self.context_pre_only = context_pre_only
782
 
783
+ self.attn_norm_c = (
784
+ AdaLayerNorm_Final(context_dim)
785
+ if context_pre_only
786
+ else AdaLayerNorm(context_dim)
787
+ )
788
  self.attn_norm_x = AdaLayerNorm(dim)
789
  self.attn = Attention(
790
  processor=JointAttnProcessor(),
 
798
  )
799
 
800
  if not context_pre_only:
801
+ self.ff_norm_c = nn.LayerNorm(
802
+ context_dim, elementwise_affine=False, eps=1e-6
803
+ )
804
+ self.ff_c = FeedForward(
805
+ dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh"
806
+ )
807
  else:
808
  self.ff_norm_c = None
809
  self.ff_c = None
810
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
811
+ self.ff_x = FeedForward(
812
+ dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
813
+ )
814
 
815
+ def forward(
816
+ self, x, c, t, mask=None, rope=None, c_rope=None
817
+ ): # x: noised input, c: context, t: time embedding
818
  # pre-norm & modulation for attention input
819
  if self.context_pre_only:
820
  norm_c = self.attn_norm_c(c, t)
821
  else:
822
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
823
+ c, emb=t
824
+ )
825
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
826
+ x, emb=t
827
+ )
828
 
829
  # attention
830
+ x_attn_output, c_attn_output = self.attn(
831
+ x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
832
+ )
833
 
834
  # process attention output for context c
835
  if self.context_pre_only:
 
837
  else: # if not last layer
838
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
839
 
840
+ norm_c = (
841
+ self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
842
+ )
843
  c_ff_output = self.ff_c(norm_c)
844
  c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
845
 
 
860
  def __init__(self, dim, freq_embed_dim=256):
861
  super().__init__()
862
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
863
+ self.time_mlp = nn.Sequential(
864
+ nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
865
+ )
866
 
867
  def forward(self, timestep: float["b"]):
868
  time_hidden = self.time_embed(timestep)
f5_tts/model_new/trainer.py CHANGED
@@ -19,7 +19,6 @@ from f5_tts.model import CFM
19
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
20
  from f5_tts.model.utils import default, exists
21
 
22
-
23
  # trainer
24
 
25
 
@@ -70,7 +69,13 @@ class Trainer:
70
  self.logger = logger
71
  if self.logger == "wandb":
72
  if exists(wandb_resume_id):
73
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
 
 
 
 
 
 
74
  else:
75
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
76
 
@@ -138,7 +143,9 @@ class Trainer:
138
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
139
  else:
140
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
141
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
 
 
142
 
143
  @property
144
  def is_main(self):
@@ -157,12 +164,16 @@ class Trainer:
157
  if not os.path.exists(self.checkpoint_path):
158
  os.makedirs(self.checkpoint_path)
159
  if last:
160
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
 
 
161
  print(f"Saved last checkpoint at update {update}")
162
  else:
163
  if self.keep_last_n_checkpoints == 0:
164
  return
165
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
 
 
166
  if self.keep_last_n_checkpoints > 0:
167
  # Updated logic to exclude pretrained model from rotation
168
  checkpoints = [
@@ -183,7 +194,10 @@ class Trainer:
183
  if (
184
  not exists(self.checkpoint_path)
185
  or not os.path.exists(self.checkpoint_path)
186
- or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
 
 
 
187
  ):
188
  return 0
189
 
@@ -195,11 +209,16 @@ class Trainer:
195
  all_checkpoints = [
196
  f
197
  for f in os.listdir(self.checkpoint_path)
198
- if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
 
199
  ]
200
 
201
  # First try to find regular training checkpoints
202
- training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
 
 
 
 
203
  if training_checkpoints:
204
  latest_checkpoint = sorted(
205
  training_checkpoints,
@@ -207,21 +226,30 @@ class Trainer:
207
  )[-1]
208
  else:
209
  # If no training checkpoints, use pretrained model
210
- latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
 
 
211
 
212
  if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
213
  from safetensors.torch import load_file
214
 
215
- checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
 
 
216
  checkpoint = {"ema_model_state_dict": checkpoint}
217
  elif latest_checkpoint.endswith(".pt"):
218
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
219
  checkpoint = torch.load(
220
- f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
 
 
221
  )
222
 
223
  # patch for backward compatibility, 305e3ea
224
- for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
 
 
 
225
  if key in checkpoint["ema_model_state_dict"]:
226
  del checkpoint["ema_model_state_dict"][key]
227
 
@@ -231,17 +259,24 @@ class Trainer:
231
  if "update" in checkpoint or "step" in checkpoint:
232
  # patch for backward compatibility, with before f992c4e
233
  if "step" in checkpoint:
234
- checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
 
 
235
  if self.grad_accumulation_steps > 1 and self.is_main:
236
  print(
237
  "F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
238
  )
239
  # patch for backward compatibility, 305e3ea
240
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
 
 
 
241
  if key in checkpoint["model_state_dict"]:
242
  del checkpoint["model_state_dict"][key]
243
 
244
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
 
 
245
  self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
246
  if self.scheduler:
247
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
@@ -252,21 +287,30 @@ class Trainer:
252
  for k, v in checkpoint["ema_model_state_dict"].items()
253
  if k not in ["initted", "update", "step"]
254
  }
255
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
 
 
256
  update = 0
257
 
258
  del checkpoint
259
  gc.collect()
260
  return update
261
 
262
- def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
 
263
  if self.log_samples:
264
- from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
 
265
 
266
  vocoder = load_vocoder(
267
- vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
 
 
268
  )
269
- target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
 
 
270
  log_samples_path = f"{self.checkpoint_path}/samples"
271
  os.makedirs(log_samples_path, exist_ok=True)
272
 
@@ -306,7 +350,9 @@ class Trainer:
306
  batch_sampler=batch_sampler,
307
  )
308
  else:
309
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
 
 
310
 
311
  # accelerator.prepare() dispatches batches to devices;
312
  # which means the length of dataloader calculated before, should consider the number of devices
@@ -314,12 +360,24 @@ class Trainer:
314
  self.num_warmup_updates * self.accelerator.num_processes
315
  ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
316
  # otherwise by default with split_batches=False, warmup steps change with num_processes
317
- total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
 
 
 
318
  decay_updates = total_updates - warmup_updates
319
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
320
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
 
 
 
 
 
 
 
321
  self.scheduler = SequentialLR(
322
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
 
 
323
  )
324
  train_dataloader, self.scheduler = self.accelerator.prepare(
325
  train_dataloader, self.scheduler
@@ -332,21 +390,27 @@ class Trainer:
332
  start_step = start_update * self.grad_accumulation_steps
333
  skipped_epoch = int(start_step // orig_epoch_step)
334
  skipped_batch = start_step % orig_epoch_step
335
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
 
 
336
  else:
337
  skipped_epoch = 0
338
 
339
  for epoch in range(skipped_epoch, self.epochs):
340
  self.model.train()
341
  if exists(resumable_with_seed) and epoch == skipped_epoch:
342
- progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
 
 
343
  current_dataloader = skipped_dataloader
344
  else:
345
  progress_bar_initial = 0
346
  current_dataloader = train_dataloader
347
 
348
  # Set epoch for the batch sampler if it exists
349
- if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
 
 
350
  train_dataloader.batch_sampler.set_epoch(epoch)
351
 
352
  progress_bar = tqdm(
@@ -364,17 +428,29 @@ class Trainer:
364
  mel_lengths = batch["mel_lengths"]
365
 
366
  # TODO. add duration predictor training
367
- if self.duration_predictor is not None and self.accelerator.is_local_main_process:
368
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
369
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
 
 
 
 
 
 
 
370
 
371
  loss, cond, pred = self.model(
372
- mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
 
 
 
373
  )
374
  self.accelerator.backward(loss)
375
 
376
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
377
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
 
 
378
 
379
  self.optimizer.step()
380
  self.scheduler.step()
@@ -386,29 +462,44 @@ class Trainer:
386
 
387
  global_update += 1
388
  progress_bar.update(1)
389
- progress_bar.set_postfix(update=str(global_update), loss=loss.item())
 
 
390
 
391
  if self.accelerator.is_local_main_process:
392
  self.accelerator.log(
393
- {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
 
394
  )
395
  if self.logger == "tensorboard":
396
  self.writer.add_scalar("loss", loss.item(), global_update)
397
- self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
 
 
398
 
399
- if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
 
 
 
400
  self.save_checkpoint(global_update, last=True)
401
 
402
- if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
 
 
 
403
  self.save_checkpoint(global_update)
404
 
405
  if self.log_samples and self.accelerator.is_local_main_process:
406
  ref_audio_len = mel_lengths[0]
407
  infer_text = [
408
- text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
 
 
409
  ]
410
  with torch.inference_mode():
411
- generated, _ = self.accelerator.unwrap_model(self.model).sample(
 
 
412
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
413
  text=infer_text,
414
  duration=ref_audio_len * 2,
@@ -417,7 +508,11 @@ class Trainer:
417
  sway_sampling_coef=sway_sampling_coef,
418
  )
419
  generated = generated.to(torch.float32)
420
- gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
 
 
 
 
421
  ref_mel_spec = batch["mel"][0].unsqueeze(0)
422
  if self.vocoder_name == "vocos":
423
  gen_audio = vocoder.decode(gen_mel_spec).cpu()
@@ -427,10 +522,14 @@ class Trainer:
427
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
428
 
429
  torchaudio.save(
430
- f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
 
 
431
  )
432
  torchaudio.save(
433
- f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
 
 
434
  )
435
  self.model.train()
436
 
 
19
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
20
  from f5_tts.model.utils import default, exists
21
 
 
22
  # trainer
23
 
24
 
 
69
  self.logger = logger
70
  if self.logger == "wandb":
71
  if exists(wandb_resume_id):
72
+ init_kwargs = {
73
+ "wandb": {
74
+ "resume": "allow",
75
+ "name": wandb_run_name,
76
+ "id": wandb_resume_id,
77
+ }
78
+ }
79
  else:
80
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
81
 
 
143
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
144
  else:
145
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
146
+ self.model, self.optimizer = self.accelerator.prepare(
147
+ self.model, self.optimizer
148
+ )
149
 
150
  @property
151
  def is_main(self):
 
164
  if not os.path.exists(self.checkpoint_path):
165
  os.makedirs(self.checkpoint_path)
166
  if last:
167
+ self.accelerator.save(
168
+ checkpoint, f"{self.checkpoint_path}/model_last.pt"
169
+ )
170
  print(f"Saved last checkpoint at update {update}")
171
  else:
172
  if self.keep_last_n_checkpoints == 0:
173
  return
174
+ self.accelerator.save(
175
+ checkpoint, f"{self.checkpoint_path}/model_{update}.pt"
176
+ )
177
  if self.keep_last_n_checkpoints > 0:
178
  # Updated logic to exclude pretrained model from rotation
179
  checkpoints = [
 
194
  if (
195
  not exists(self.checkpoint_path)
196
  or not os.path.exists(self.checkpoint_path)
197
+ or not any(
198
+ filename.endswith((".pt", ".safetensors"))
199
+ for filename in os.listdir(self.checkpoint_path)
200
+ )
201
  ):
202
  return 0
203
 
 
209
  all_checkpoints = [
210
  f
211
  for f in os.listdir(self.checkpoint_path)
212
+ if (f.startswith("model_") or f.startswith("pretrained_"))
213
+ and f.endswith((".pt", ".safetensors"))
214
  ]
215
 
216
  # First try to find regular training checkpoints
217
+ training_checkpoints = [
218
+ f
219
+ for f in all_checkpoints
220
+ if f.startswith("model_") and f != "model_last.pt"
221
+ ]
222
  if training_checkpoints:
223
  latest_checkpoint = sorted(
224
  training_checkpoints,
 
226
  )[-1]
227
  else:
228
  # If no training checkpoints, use pretrained model
229
+ latest_checkpoint = next(
230
+ f for f in all_checkpoints if f.startswith("pretrained_")
231
+ )
232
 
233
  if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
234
  from safetensors.torch import load_file
235
 
236
+ checkpoint = load_file(
237
+ f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu"
238
+ )
239
  checkpoint = {"ema_model_state_dict": checkpoint}
240
  elif latest_checkpoint.endswith(".pt"):
241
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
242
  checkpoint = torch.load(
243
+ f"{self.checkpoint_path}/{latest_checkpoint}",
244
+ weights_only=True,
245
+ map_location="cpu",
246
  )
247
 
248
  # patch for backward compatibility, 305e3ea
249
+ for key in [
250
+ "ema_model.mel_spec.mel_stft.mel_scale.fb",
251
+ "ema_model.mel_spec.mel_stft.spectrogram.window",
252
+ ]:
253
  if key in checkpoint["ema_model_state_dict"]:
254
  del checkpoint["ema_model_state_dict"][key]
255
 
 
259
  if "update" in checkpoint or "step" in checkpoint:
260
  # patch for backward compatibility, with before f992c4e
261
  if "step" in checkpoint:
262
+ checkpoint["update"] = (
263
+ checkpoint["step"] // self.grad_accumulation_steps
264
+ )
265
  if self.grad_accumulation_steps > 1 and self.is_main:
266
  print(
267
  "F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
268
  )
269
  # patch for backward compatibility, 305e3ea
270
+ for key in [
271
+ "mel_spec.mel_stft.mel_scale.fb",
272
+ "mel_spec.mel_stft.spectrogram.window",
273
+ ]:
274
  if key in checkpoint["model_state_dict"]:
275
  del checkpoint["model_state_dict"][key]
276
 
277
+ self.accelerator.unwrap_model(self.model).load_state_dict(
278
+ checkpoint["model_state_dict"]
279
+ )
280
  self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
281
  if self.scheduler:
282
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
 
287
  for k, v in checkpoint["ema_model_state_dict"].items()
288
  if k not in ["initted", "update", "step"]
289
  }
290
+ self.accelerator.unwrap_model(self.model).load_state_dict(
291
+ checkpoint["model_state_dict"]
292
+ )
293
  update = 0
294
 
295
  del checkpoint
296
  gc.collect()
297
  return update
298
 
299
+ def train(
300
+ self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None
301
+ ):
302
  if self.log_samples:
303
+ from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
304
+ nfe_step, sway_sampling_coef)
305
 
306
  vocoder = load_vocoder(
307
+ vocoder_name=self.vocoder_name,
308
+ is_local=self.is_local_vocoder,
309
+ local_path=self.local_vocoder_path,
310
  )
311
+ target_sample_rate = self.accelerator.unwrap_model(
312
+ self.model
313
+ ).mel_spec.target_sample_rate
314
  log_samples_path = f"{self.checkpoint_path}/samples"
315
  os.makedirs(log_samples_path, exist_ok=True)
316
 
 
350
  batch_sampler=batch_sampler,
351
  )
352
  else:
353
+ raise ValueError(
354
+ f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
355
+ )
356
 
357
  # accelerator.prepare() dispatches batches to devices;
358
  # which means the length of dataloader calculated before, should consider the number of devices
 
360
  self.num_warmup_updates * self.accelerator.num_processes
361
  ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
362
  # otherwise by default with split_batches=False, warmup steps change with num_processes
363
+ total_updates = (
364
+ math.ceil(len(train_dataloader) / self.grad_accumulation_steps)
365
+ * self.epochs
366
+ )
367
  decay_updates = total_updates - warmup_updates
368
+ warmup_scheduler = LinearLR(
369
+ self.optimizer,
370
+ start_factor=1e-8,
371
+ end_factor=1.0,
372
+ total_iters=warmup_updates,
373
+ )
374
+ decay_scheduler = LinearLR(
375
+ self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates
376
+ )
377
  self.scheduler = SequentialLR(
378
+ self.optimizer,
379
+ schedulers=[warmup_scheduler, decay_scheduler],
380
+ milestones=[warmup_updates],
381
  )
382
  train_dataloader, self.scheduler = self.accelerator.prepare(
383
  train_dataloader, self.scheduler
 
390
  start_step = start_update * self.grad_accumulation_steps
391
  skipped_epoch = int(start_step // orig_epoch_step)
392
  skipped_batch = start_step % orig_epoch_step
393
+ skipped_dataloader = self.accelerator.skip_first_batches(
394
+ train_dataloader, num_batches=skipped_batch
395
+ )
396
  else:
397
  skipped_epoch = 0
398
 
399
  for epoch in range(skipped_epoch, self.epochs):
400
  self.model.train()
401
  if exists(resumable_with_seed) and epoch == skipped_epoch:
402
+ progress_bar_initial = math.ceil(
403
+ skipped_batch / self.grad_accumulation_steps
404
+ )
405
  current_dataloader = skipped_dataloader
406
  else:
407
  progress_bar_initial = 0
408
  current_dataloader = train_dataloader
409
 
410
  # Set epoch for the batch sampler if it exists
411
+ if hasattr(train_dataloader, "batch_sampler") and hasattr(
412
+ train_dataloader.batch_sampler, "set_epoch"
413
+ ):
414
  train_dataloader.batch_sampler.set_epoch(epoch)
415
 
416
  progress_bar = tqdm(
 
428
  mel_lengths = batch["mel_lengths"]
429
 
430
  # TODO. add duration predictor training
431
+ if (
432
+ self.duration_predictor is not None
433
+ and self.accelerator.is_local_main_process
434
+ ):
435
+ dur_loss = self.duration_predictor(
436
+ mel_spec, lens=batch.get("durations")
437
+ )
438
+ self.accelerator.log(
439
+ {"duration loss": dur_loss.item()}, step=global_update
440
+ )
441
 
442
  loss, cond, pred = self.model(
443
+ mel_spec,
444
+ text=text_inputs,
445
+ lens=mel_lengths,
446
+ noise_scheduler=self.noise_scheduler,
447
  )
448
  self.accelerator.backward(loss)
449
 
450
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
451
+ self.accelerator.clip_grad_norm_(
452
+ self.model.parameters(), self.max_grad_norm
453
+ )
454
 
455
  self.optimizer.step()
456
  self.scheduler.step()
 
462
 
463
  global_update += 1
464
  progress_bar.update(1)
465
+ progress_bar.set_postfix(
466
+ update=str(global_update), loss=loss.item()
467
+ )
468
 
469
  if self.accelerator.is_local_main_process:
470
  self.accelerator.log(
471
+ {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]},
472
+ step=global_update,
473
  )
474
  if self.logger == "tensorboard":
475
  self.writer.add_scalar("loss", loss.item(), global_update)
476
+ self.writer.add_scalar(
477
+ "lr", self.scheduler.get_last_lr()[0], global_update
478
+ )
479
 
480
+ if (
481
+ global_update % self.last_per_updates == 0
482
+ and self.accelerator.sync_gradients
483
+ ):
484
  self.save_checkpoint(global_update, last=True)
485
 
486
+ if (
487
+ global_update % self.save_per_updates == 0
488
+ and self.accelerator.sync_gradients
489
+ ):
490
  self.save_checkpoint(global_update)
491
 
492
  if self.log_samples and self.accelerator.is_local_main_process:
493
  ref_audio_len = mel_lengths[0]
494
  infer_text = [
495
+ text_inputs[0]
496
+ + ([" "] if isinstance(text_inputs[0], list) else " ")
497
+ + text_inputs[0]
498
  ]
499
  with torch.inference_mode():
500
+ generated, _ = self.accelerator.unwrap_model(
501
+ self.model
502
+ ).sample(
503
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
504
  text=infer_text,
505
  duration=ref_audio_len * 2,
 
508
  sway_sampling_coef=sway_sampling_coef,
509
  )
510
  generated = generated.to(torch.float32)
511
+ gen_mel_spec = (
512
+ generated[:, ref_audio_len:, :]
513
+ .permute(0, 2, 1)
514
+ .to(self.accelerator.device)
515
+ )
516
  ref_mel_spec = batch["mel"][0].unsqueeze(0)
517
  if self.vocoder_name == "vocos":
518
  gen_audio = vocoder.decode(gen_mel_spec).cpu()
 
522
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
523
 
524
  torchaudio.save(
525
+ f"{log_samples_path}/update_{global_update}_gen.wav",
526
+ gen_audio,
527
+ target_sample_rate,
528
  )
529
  torchaudio.save(
530
+ f"{log_samples_path}/update_{global_update}_ref.wav",
531
+ ref_audio,
532
+ target_sample_rate,
533
  )
534
  self.model.train()
535
 
f5_tts/model_new/utils.py CHANGED
@@ -10,7 +10,6 @@ import torch
10
  from pypinyin import Style, lazy_pinyin
11
  from torch.nn.utils.rnn import pad_sequence
12
 
13
-
14
  # seed everything
15
 
16
 
@@ -48,7 +47,9 @@ def is_package_available(package_name: str) -> bool:
48
  # tensor helpers
49
 
50
 
51
- def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
 
 
52
  if not exists(length):
53
  length = t.amax()
54
 
@@ -56,7 +57,9 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa
56
  return seq[None, :] < t[:, None]
57
 
58
 
59
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
 
 
60
  max_seq_len = seq_len.max().item()
61
  seq = torch.arange(max_seq_len, device=start.device).long()
62
  start_mask = seq[None, :] >= start[:, None]
@@ -64,7 +67,9 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"
64
  return start_mask & end_mask
65
 
66
 
67
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
 
 
68
  lengths = (frac_lengths * seq_len).long()
69
  max_start = seq_len - lengths
70
 
@@ -75,7 +80,9 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa
75
  return mask_from_start_end_indices(seq_len, start, end)
76
 
77
 
78
- def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
 
 
79
  if not exists(mask):
80
  return t.mean(dim=1)
81
 
@@ -99,7 +106,9 @@ def list_str_to_idx(
99
  vocab_char_map: dict[str, int], # {char: idx}
100
  padding_value=-1,
101
  ) -> int["b nt"]: # noqa: F722
102
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
 
 
103
  text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
104
  return text
105
 
@@ -118,13 +127,18 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
118
  - if use "byte", set to 256 (unicode byte range)
119
  """
120
  if tokenizer in ["pinyin", "char"]:
121
- tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
 
 
 
122
  with open(tokenizer_path, "r", encoding="utf-8") as f:
123
  vocab_char_map = {}
124
  for i, char in enumerate(f):
125
  vocab_char_map[char[:-1]] = i
126
  vocab_size = len(vocab_char_map)
127
- assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
 
 
128
 
129
  elif tokenizer == "byte":
130
  vocab_char_map = None
@@ -154,9 +168,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
154
  ) # add custom trans here, to address oov
155
 
156
  def is_chinese(c):
157
- return (
158
- "\u3100" <= c <= "\u9fff" # common chinese characters
159
- )
160
 
161
  for text in text_list:
162
  char_list = []
@@ -167,7 +179,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
167
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
168
  char_list.append(" ")
169
  char_list.extend(seg)
170
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
 
 
171
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
172
  for i, c in enumerate(seg):
173
  if is_chinese(c):
@@ -179,7 +193,9 @@ def convert_char_to_pinyin(text_list, polyphone=True):
179
  char_list.extend(c)
180
  elif is_chinese(c):
181
  char_list.append(" ")
182
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
 
 
183
  else:
184
  char_list.append(c)
185
  final_text_list.append(char_list)
 
10
  from pypinyin import Style, lazy_pinyin
11
  from torch.nn.utils.rnn import pad_sequence
12
 
 
13
  # seed everything
14
 
15
 
 
47
  # tensor helpers
48
 
49
 
50
+ def lens_to_mask(
51
+ t: int["b"], length: int | None = None
52
+ ) -> bool["b n"]: # noqa: F722 F821
53
  if not exists(length):
54
  length = t.amax()
55
 
 
57
  return seq[None, :] < t[:, None]
58
 
59
 
60
+ def mask_from_start_end_indices(
61
+ seq_len: int["b"], start: int["b"], end: int["b"]
62
+ ): # noqa: F722 F821
63
  max_seq_len = seq_len.max().item()
64
  seq = torch.arange(max_seq_len, device=start.device).long()
65
  start_mask = seq[None, :] >= start[:, None]
 
67
  return start_mask & end_mask
68
 
69
 
70
+ def mask_from_frac_lengths(
71
+ seq_len: int["b"], frac_lengths: float["b"]
72
+ ): # noqa: F722 F821
73
  lengths = (frac_lengths * seq_len).long()
74
  max_start = seq_len - lengths
75
 
 
80
  return mask_from_start_end_indices(seq_len, start, end)
81
 
82
 
83
+ def maybe_masked_mean(
84
+ t: float["b n d"], mask: bool["b n"] = None
85
+ ) -> float["b d"]: # noqa: F722
86
  if not exists(mask):
87
  return t.mean(dim=1)
88
 
 
106
  vocab_char_map: dict[str, int], # {char: idx}
107
  padding_value=-1,
108
  ) -> int["b nt"]: # noqa: F722
109
+ list_idx_tensors = [
110
+ torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
111
+ ] # pinyin or char style
112
  text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
113
  return text
114
 
 
127
  - if use "byte", set to 256 (unicode byte range)
128
  """
129
  if tokenizer in ["pinyin", "char"]:
130
+ tokenizer_path = os.path.join(
131
+ files("f5_tts").joinpath("../../data"),
132
+ f"{dataset_name}_{tokenizer}/vocab.txt",
133
+ )
134
  with open(tokenizer_path, "r", encoding="utf-8") as f:
135
  vocab_char_map = {}
136
  for i, char in enumerate(f):
137
  vocab_char_map[char[:-1]] = i
138
  vocab_size = len(vocab_char_map)
139
+ assert (
140
+ vocab_char_map[" "] == 0
141
+ ), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
142
 
143
  elif tokenizer == "byte":
144
  vocab_char_map = None
 
168
  ) # add custom trans here, to address oov
169
 
170
  def is_chinese(c):
171
+ return "\u3100" <= c <= "\u9fff" # common chinese characters
 
 
172
 
173
  for text in text_list:
174
  char_list = []
 
179
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
180
  char_list.append(" ")
181
  char_list.extend(seg)
182
+ elif polyphone and seg_byte_len == 3 * len(
183
+ seg
184
+ ): # if pure east asian characters
185
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
186
  for i, c in enumerate(seg):
187
  if is_chinese(c):
 
193
  char_list.extend(c)
194
  elif is_chinese(c):
195
  char_list.append(" ")
196
+ char_list.extend(
197
+ lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
198
+ )
199
  else:
200
  char_list.append(c)
201
  final_text_list.append(char_list)
f5_tts/runtime/triton_trtllm/benchmark.py CHANGED
@@ -51,7 +51,6 @@ from torch.utils.data import DataLoader, DistributedSampler
51
  from tqdm import tqdm
52
  from vocos import Vocos
53
 
54
-
55
  torch.manual_seed(0)
56
 
57
 
@@ -64,7 +63,9 @@ def get_args():
64
  choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
65
  help="huggingface dataset split name",
66
  )
67
- parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
 
 
68
  parser.add_argument(
69
  "--vocab-file",
70
  required=True,
@@ -89,8 +90,12 @@ def get_args():
89
  type=int,
90
  help="batch size (per-device) for inference",
91
  )
92
- parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
93
- parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
 
 
 
 
94
  parser.add_argument(
95
  "--vocoder",
96
  default="vocos",
@@ -105,8 +110,16 @@ def get_args():
105
  )
106
  parser.add_argument("--enable-warmup", action="store_true")
107
  parser.add_argument("--remove-input-padding", action="store_true")
108
- parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
109
- parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
 
 
 
 
 
 
 
 
110
  args = parser.parse_args()
111
  return args
112
 
@@ -126,7 +139,13 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
126
  torch.cuda.nvtx.range_push("data_collator")
127
  target_sample_rate = 24000
128
  target_rms = 0.1
129
- ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
 
 
 
 
 
 
130
  [],
131
  [],
132
  [],
@@ -170,7 +189,14 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
170
  ref_mel_len_list.append(ref_mel_len)
171
 
172
  estimated_reference_target_mel_len.append(
173
- int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
 
 
 
 
 
 
 
174
  )
175
 
176
  max_seq_len = max(estimated_reference_target_mel_len)
@@ -182,12 +208,22 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
182
 
183
  for i, item in enumerate(text_pad_sequence):
184
  text_pad_sequence[i] = F.pad(
185
- item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
 
 
 
186
  )
187
- text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
188
- text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
 
 
 
 
189
  text_pad_sequence = F.pad(
190
- text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
 
 
 
191
  )
192
  if use_perf:
193
  torch.cuda.nvtx.range_pop()
@@ -252,7 +288,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
252
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
253
  char_list.append(" ")
254
  char_list.extend(seg)
255
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
 
 
256
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
257
  for i, c in enumerate(seg):
258
  if is_chinese(c):
@@ -264,7 +302,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
264
  char_list.extend(c)
265
  elif is_chinese(c):
266
  char_list.append(" ")
267
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
 
 
268
  else:
269
  char_list.append(c)
270
  final_reference_target_texts_list.append(char_list)
@@ -277,13 +317,20 @@ def list_str_to_idx(
277
  vocab_char_map: Dict[str, int], # {char: idx}
278
  padding_value=-1,
279
  ):
280
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
 
 
281
  # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
282
  return list_idx_tensors
283
 
284
 
285
  def load_vocoder(
286
- vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
 
 
 
 
 
287
  ):
288
  if vocoder_name == "vocos":
289
  if vocoder_trt_engine_path is not None:
@@ -297,8 +344,14 @@ def load_vocoder(
297
  else:
298
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
299
  repo_id = "charactr/vocos-mel-24khz"
300
- config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
301
- model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
 
 
 
 
 
 
302
  vocoder = Vocos.from_hparams(config_path)
303
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
304
  from vocos.feature_extractors import EncodecFeatures
@@ -343,14 +396,21 @@ class VocosTensorRT:
343
  with open(engine_path, "rb") as f:
344
  engine_buffer = f.read()
345
  self.session = Session.from_serialized_engine(engine_buffer)
346
- self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
 
 
347
 
348
  def decode(self, mels):
349
  mels = mels.contiguous()
350
  inputs = {"mel": mels}
351
- output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
 
 
352
  outputs = {
353
- t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
 
 
 
354
  }
355
  ok = self.session.run(inputs, outputs, self.stream)
356
 
@@ -376,12 +436,18 @@ def main():
376
  config = json.load(f)
377
  if args.backend_type == "trt":
378
  model = F5TTS(
379
- config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
 
 
 
 
380
  )
381
  elif args.backend_type == "pytorch":
382
  import sys
383
 
384
- sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
 
 
385
  from f5_tts.infer.utils_infer import load_model
386
  from f5_tts.model import DiT
387
 
@@ -398,7 +464,9 @@ def main():
398
  model = load_model(DiT, F5TTS_model_cfg, args.model_path)
399
 
400
  vocoder = load_vocoder(
401
- vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
 
 
402
  )
403
 
404
  dataset = load_dataset(
@@ -411,7 +479,9 @@ def main():
411
  prompt_audio_len = example["prompt_audio"]["array"].shape[0]
412
  scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
413
  estimated_duration = prompt_audio_len * scale_factor
414
- example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
 
 
415
  return example
416
 
417
  dataset = dataset.map(add_estimated_duration)
@@ -442,12 +512,18 @@ def main():
442
 
443
  if args.enable_warmup:
444
  for batch in dataloader:
445
- ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
 
 
446
  text_pad_seq = batch["text_pad_sequence"].to(device)
447
  total_mel_lens = batch["estimated_reference_target_mel_len"]
448
  if args.backend_type == "trt":
449
  _ = model.sample(
450
- text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
 
 
 
 
451
  )
452
  elif args.backend_type == "pytorch":
453
  with torch.inference_mode():
@@ -475,7 +551,9 @@ def main():
475
  for batch in dataloader:
476
  if args.use_perf:
477
  torch.cuda.nvtx.range_push("data sample")
478
- ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
 
 
479
  text_pad_seq = batch["text_pad_sequence"].to(device)
480
  total_mel_lens = batch["estimated_reference_target_mel_len"]
481
 
 
51
  from tqdm import tqdm
52
  from vocos import Vocos
53
 
 
54
  torch.manual_seed(0)
55
 
56
 
 
63
  choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
64
  help="huggingface dataset split name",
65
  )
66
+ parser.add_argument(
67
+ "--output-dir", required=True, type=str, help="dir to save result"
68
+ )
69
  parser.add_argument(
70
  "--vocab-file",
71
  required=True,
 
90
  type=int,
91
  help="batch size (per-device) for inference",
92
  )
93
+ parser.add_argument(
94
+ "--num-workers", type=int, default=0, help="workers for dataloader"
95
+ )
96
+ parser.add_argument(
97
+ "--prefetch", type=int, default=None, help="prefetch for dataloader"
98
+ )
99
  parser.add_argument(
100
  "--vocoder",
101
  default="vocos",
 
110
  )
111
  parser.add_argument("--enable-warmup", action="store_true")
112
  parser.add_argument("--remove-input-padding", action="store_true")
113
+ parser.add_argument(
114
+ "--use-perf", action="store_true", help="use nvtx to record performance"
115
+ )
116
+ parser.add_argument(
117
+ "--backend-type",
118
+ type=str,
119
+ default="triton",
120
+ choices=["trt", "pytorch"],
121
+ help="backend type",
122
+ )
123
  args = parser.parse_args()
124
  return args
125
 
 
139
  torch.cuda.nvtx.range_push("data_collator")
140
  target_sample_rate = 24000
141
  target_rms = 0.1
142
+ (
143
+ ids,
144
+ ref_mel_list,
145
+ ref_mel_len_list,
146
+ estimated_reference_target_mel_len,
147
+ reference_target_texts_list,
148
+ ) = (
149
  [],
150
  [],
151
  [],
 
189
  ref_mel_len_list.append(ref_mel_len)
190
 
191
  estimated_reference_target_mel_len.append(
192
+ int(
193
+ ref_mel.shape[0]
194
+ * (
195
+ 1
196
+ + len(target_text.encode("utf-8"))
197
+ / len(prompt_text.encode("utf-8"))
198
+ )
199
+ )
200
  )
201
 
202
  max_seq_len = max(estimated_reference_target_mel_len)
 
208
 
209
  for i, item in enumerate(text_pad_sequence):
210
  text_pad_sequence[i] = F.pad(
211
+ item,
212
+ (0, estimated_reference_target_mel_len[i] - len(item)),
213
+ mode="constant",
214
+ value=-1,
215
  )
216
+ text_pad_sequence[
217
+ i
218
+ ] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
219
+ text_pad_sequence = pad_sequence(
220
+ text_pad_sequence, padding_value=-1, batch_first=True
221
+ ).to(device)
222
  text_pad_sequence = F.pad(
223
+ text_pad_sequence,
224
+ (0, max_seq_len - text_pad_sequence.shape[1]),
225
+ mode="constant",
226
+ value=-1,
227
  )
228
  if use_perf:
229
  torch.cuda.nvtx.range_pop()
 
288
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
289
  char_list.append(" ")
290
  char_list.extend(seg)
291
+ elif polyphone and seg_byte_len == 3 * len(
292
+ seg
293
+ ): # if pure east asian characters
294
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
295
  for i, c in enumerate(seg):
296
  if is_chinese(c):
 
302
  char_list.extend(c)
303
  elif is_chinese(c):
304
  char_list.append(" ")
305
+ char_list.extend(
306
+ lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
307
+ )
308
  else:
309
  char_list.append(c)
310
  final_reference_target_texts_list.append(char_list)
 
317
  vocab_char_map: Dict[str, int], # {char: idx}
318
  padding_value=-1,
319
  ):
320
+ list_idx_tensors = [
321
+ torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
322
+ ] # pinyin or char style
323
  # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
324
  return list_idx_tensors
325
 
326
 
327
  def load_vocoder(
328
+ vocoder_name="vocos",
329
+ is_local=False,
330
+ local_path="",
331
+ device="cuda",
332
+ hf_cache_dir=None,
333
+ vocoder_trt_engine_path=None,
334
  ):
335
  if vocoder_name == "vocos":
336
  if vocoder_trt_engine_path is not None:
 
344
  else:
345
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
346
  repo_id = "charactr/vocos-mel-24khz"
347
+ config_path = hf_hub_download(
348
+ repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml"
349
+ )
350
+ model_path = hf_hub_download(
351
+ repo_id=repo_id,
352
+ cache_dir=hf_cache_dir,
353
+ filename="pytorch_model.bin",
354
+ )
355
  vocoder = Vocos.from_hparams(config_path)
356
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
357
  from vocos.feature_extractors import EncodecFeatures
 
396
  with open(engine_path, "rb") as f:
397
  engine_buffer = f.read()
398
  self.session = Session.from_serialized_engine(engine_buffer)
399
+ self.stream = (
400
+ stream if stream is not None else torch.cuda.current_stream().cuda_stream
401
+ )
402
 
403
  def decode(self, mels):
404
  mels = mels.contiguous()
405
  inputs = {"mel": mels}
406
+ output_info = self.session.infer_shapes(
407
+ [TensorInfo("mel", trt.DataType.FLOAT, mels.shape)]
408
+ )
409
  outputs = {
410
+ t.name: torch.empty(
411
+ tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda"
412
+ )
413
+ for t in output_info
414
  }
415
  ok = self.session.run(inputs, outputs, self.stream)
416
 
 
436
  config = json.load(f)
437
  if args.backend_type == "trt":
438
  model = F5TTS(
439
+ config,
440
+ debug_mode=False,
441
+ tllm_model_dir=tllm_model_dir,
442
+ model_path=args.model_path,
443
+ vocab_size=vocab_size,
444
  )
445
  elif args.backend_type == "pytorch":
446
  import sys
447
 
448
+ sys.path.append(
449
+ f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/"
450
+ )
451
  from f5_tts.infer.utils_infer import load_model
452
  from f5_tts.model import DiT
453
 
 
464
  model = load_model(DiT, F5TTS_model_cfg, args.model_path)
465
 
466
  vocoder = load_vocoder(
467
+ vocoder_name=args.vocoder,
468
+ device=device,
469
+ vocoder_trt_engine_path=args.vocoder_trt_engine_path,
470
  )
471
 
472
  dataset = load_dataset(
 
479
  prompt_audio_len = example["prompt_audio"]["array"].shape[0]
480
  scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
481
  estimated_duration = prompt_audio_len * scale_factor
482
+ example["estimated_duration"] = (
483
+ estimated_duration / example["prompt_audio"]["sampling_rate"]
484
+ )
485
  return example
486
 
487
  dataset = dataset.map(add_estimated_duration)
 
512
 
513
  if args.enable_warmup:
514
  for batch in dataloader:
515
+ ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
516
+ "ref_mel_len_batch"
517
+ ].to(device)
518
  text_pad_seq = batch["text_pad_sequence"].to(device)
519
  total_mel_lens = batch["estimated_reference_target_mel_len"]
520
  if args.backend_type == "trt":
521
  _ = model.sample(
522
+ text_pad_seq,
523
+ ref_mels,
524
+ ref_mel_lens,
525
+ total_mel_lens,
526
+ remove_input_padding=args.remove_input_padding,
527
  )
528
  elif args.backend_type == "pytorch":
529
  with torch.inference_mode():
 
551
  for batch in dataloader:
552
  if args.use_perf:
553
  torch.cuda.nvtx.range_push("data sample")
554
+ ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
555
+ "ref_mel_len_batch"
556
+ ].to(device)
557
  text_pad_seq = batch["text_pad_sequence"].to(device)
558
  total_mel_lens = batch["estimated_reference_target_mel_len"]
559
 
f5_tts/runtime/triton_trtllm/client_grpc.py CHANGED
@@ -64,8 +64,12 @@ def write_triton_stats(stats, summary_file):
64
  "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
65
  )
66
  summary_f.write("To learn more about the log, please refer to: \n")
67
- summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
68
- summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
 
 
 
 
69
  summary_f.write(
70
  "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
71
  )
@@ -86,7 +90,9 @@ def write_triton_stats(stats, summary_file):
86
  total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
87
  total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
88
  total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
89
- total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
 
 
90
  summary_f.write(
91
  f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
92
  )
@@ -97,7 +103,11 @@ def write_triton_stats(stats, summary_file):
97
  compute_output = batch["compute_output"]
98
  compute_infer = batch["compute_infer"]
99
  batch_count = int(compute_infer["count"])
100
- assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
 
 
 
 
101
  compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
102
  compute_input_time_ms = int(compute_input["ns"]) / 1e6
103
  compute_output_time_ms = int(compute_output["ns"]) / 1e6
@@ -113,7 +123,9 @@ def write_triton_stats(stats, summary_file):
113
 
114
 
115
  def get_args():
116
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
 
117
 
118
  parser.add_argument(
119
  "--server-addr",
@@ -254,7 +266,9 @@ async def send(
254
  for i, item in enumerate(manifest_item_list):
255
  if i % log_interval == 0:
256
  print(f"{name}: {i}/{len(manifest_item_list)}")
257
- waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
 
 
258
  duration = len(waveform) / sample_rate
259
  lengths = np.array([[len(waveform)]], dtype=np.int32)
260
 
@@ -269,7 +283,10 @@ async def send(
269
  1,
270
  padding_duration
271
  * sample_rate
272
- * ((int(estimated_target_duration + duration) // padding_duration) + 1),
 
 
 
273
  ),
274
  dtype=np.float32,
275
  )
@@ -281,8 +298,12 @@ async def send(
281
  samples = samples.reshape(1, -1).astype(np.float32)
282
 
283
  inputs = [
284
- protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
285
- protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
 
 
 
 
286
  protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
287
  protocol_client.InferInput("target_text", [1, 1], "BYTES"),
288
  ]
@@ -301,13 +322,17 @@ async def send(
301
 
302
  sequence_id = 100000000 + i + task_id * 10
303
  start = time.time()
304
- response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
 
 
305
 
306
  audio = response.as_numpy("waveform").reshape(-1)
307
 
308
  end = time.time() - start
309
 
310
- audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
 
 
311
  sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
312
 
313
  actual_duration = len(audio) / save_sample_rate
@@ -341,7 +366,9 @@ def load_manifests(manifest_path):
341
  def split_data(data, k):
342
  n = len(data)
343
  if n < k:
344
- print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
 
 
345
  k = n
346
 
347
  quotient = n // k
@@ -461,7 +488,9 @@ async def main():
461
  stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
462
  write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
463
 
464
- metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
 
 
465
  with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
466
  json.dump(metadata, f, indent=4)
467
 
 
64
  "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
65
  )
66
  summary_f.write("To learn more about the log, please refer to: \n")
67
+ summary_f.write(
68
+ "1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n"
69
+ )
70
+ summary_f.write(
71
+ "2. https://github.com/triton-inference-server/server/issues/5374 \n\n"
72
+ )
73
  summary_f.write(
74
  "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
75
  )
 
90
  total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
91
  total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
92
  total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
93
+ total_output_time_s = (
94
+ int(model_inference_stats["compute_output"]["ns"]) / 1e9
95
+ )
96
  summary_f.write(
97
  f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
98
  )
 
103
  compute_output = batch["compute_output"]
104
  compute_infer = batch["compute_infer"]
105
  batch_count = int(compute_infer["count"])
106
+ assert (
107
+ compute_infer["count"]
108
+ == compute_output["count"]
109
+ == compute_input["count"]
110
+ )
111
  compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
112
  compute_input_time_ms = int(compute_input["ns"]) / 1e6
113
  compute_output_time_ms = int(compute_output["ns"]) / 1e6
 
123
 
124
 
125
  def get_args():
126
+ parser = argparse.ArgumentParser(
127
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
128
+ )
129
 
130
  parser.add_argument(
131
  "--server-addr",
 
266
  for i, item in enumerate(manifest_item_list):
267
  if i % log_interval == 0:
268
  print(f"{name}: {i}/{len(manifest_item_list)}")
269
+ waveform, sample_rate = load_audio(
270
+ item["audio_filepath"], target_sample_rate=24000
271
+ )
272
  duration = len(waveform) / sample_rate
273
  lengths = np.array([[len(waveform)]], dtype=np.int32)
274
 
 
283
  1,
284
  padding_duration
285
  * sample_rate
286
+ * (
287
+ (int(estimated_target_duration + duration) // padding_duration)
288
+ + 1
289
+ ),
290
  ),
291
  dtype=np.float32,
292
  )
 
298
  samples = samples.reshape(1, -1).astype(np.float32)
299
 
300
  inputs = [
301
+ protocol_client.InferInput(
302
+ "reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)
303
+ ),
304
+ protocol_client.InferInput(
305
+ "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
306
+ ),
307
  protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
308
  protocol_client.InferInput("target_text", [1, 1], "BYTES"),
309
  ]
 
322
 
323
  sequence_id = 100000000 + i + task_id * 10
324
  start = time.time()
325
+ response = await triton_client.infer(
326
+ model_name, inputs, request_id=str(sequence_id), outputs=outputs
327
+ )
328
 
329
  audio = response.as_numpy("waveform").reshape(-1)
330
 
331
  end = time.time() - start
332
 
333
+ audio_save_path = os.path.join(
334
+ audio_save_dir, f"{item['target_audio_path']}.wav"
335
+ )
336
  sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
337
 
338
  actual_duration = len(audio) / save_sample_rate
 
366
  def split_data(data, k):
367
  n = len(data)
368
  if n < k:
369
+ print(
370
+ f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}."
371
+ )
372
  k = n
373
 
374
  quotient = n // k
 
488
  stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
489
  write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
490
 
491
+ metadata = await triton_client.get_model_config(
492
+ model_name=args.model_name, as_json=True
493
+ )
494
  with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
495
  json.dump(metadata, f, indent=4)
496
 
f5_tts/runtime/triton_trtllm/client_http.py CHANGED
@@ -31,7 +31,9 @@ import soundfile as sf
31
 
32
 
33
  def get_args():
34
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
 
35
 
36
  parser.add_argument(
37
  "--server-url",
@@ -91,15 +93,30 @@ def prepare_request(
91
 
92
  data = {
93
  "inputs": [
94
- {"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
 
 
 
 
 
95
  {
96
  "name": "reference_wav_len",
97
  "shape": lengths.shape,
98
  "datatype": "INT32",
99
  "data": lengths.tolist(),
100
  },
101
- {"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
102
- {"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
 
 
 
 
 
 
 
 
 
 
103
  ]
104
  }
105
 
@@ -135,7 +152,11 @@ if __name__ == "__main__":
135
  data = prepare_request(samples, args.reference_text, args.target_text)
136
 
137
  rsp = requests.post(
138
- url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
 
 
 
 
139
  )
140
  result = rsp.json()
141
  audio = result["outputs"][0]["data"]
 
31
 
32
 
33
  def get_args():
34
+ parser = argparse.ArgumentParser(
35
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
36
+ )
37
 
38
  parser.add_argument(
39
  "--server-url",
 
93
 
94
  data = {
95
  "inputs": [
96
+ {
97
+ "name": "reference_wav",
98
+ "shape": samples.shape,
99
+ "datatype": "FP32",
100
+ "data": samples.tolist(),
101
+ },
102
  {
103
  "name": "reference_wav_len",
104
  "shape": lengths.shape,
105
  "datatype": "INT32",
106
  "data": lengths.tolist(),
107
  },
108
+ {
109
+ "name": "reference_text",
110
+ "shape": [1, 1],
111
+ "datatype": "BYTES",
112
+ "data": [reference_text],
113
+ },
114
+ {
115
+ "name": "target_text",
116
+ "shape": [1, 1],
117
+ "datatype": "BYTES",
118
+ "data": [target_text],
119
+ },
120
  ]
121
  }
122
 
 
152
  data = prepare_request(samples, args.reference_text, args.target_text)
153
 
154
  rsp = requests.post(
155
+ url,
156
+ headers={"Content-Type": "application/json"},
157
+ json=data,
158
+ verify=False,
159
+ params={"request_id": "0"},
160
  )
161
  result = rsp.json()
162
  audio = result["outputs"][0]["data"]
f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py CHANGED
@@ -17,7 +17,9 @@ from tensorrt_llm.runtime.session import Session
17
  def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
18
  # Audio tensor case: batch, seq_len, feature_len
19
  # position_ids case: batch, seq_len
20
- assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
 
 
21
 
22
  # Initialize a list to collect valid sequences
23
  valid_sequences = []
@@ -32,11 +34,29 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
32
 
33
 
34
  class TextEmbedding(nn.Module):
35
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
 
 
 
 
 
 
 
36
  super().__init__()
37
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
38
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
39
- self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def forward(self, text):
42
  # only keep tensors with value not -1
@@ -80,7 +100,9 @@ class ConvNeXtV2Block(nn.Module):
80
  dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
81
  ) # depthwise conv
82
  self.norm = nn.LayerNorm(dim, eps=1e-6)
83
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
 
 
84
  self.act = nn.GELU()
85
  self.grn = GRN(intermediate_dim)
86
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
@@ -98,7 +120,9 @@ class ConvNeXtV2Block(nn.Module):
98
  return residual + x
99
 
100
 
101
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
 
 
102
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
103
  # has some connection to NTK literature
104
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -125,7 +149,9 @@ def load_checkpoint(ckpt_path, use_ema=True):
125
  for key in dict_state.keys():
126
  # transformer.text_embed.text_embed.weight -> text_embed.weight
127
  if "text_embed" in key:
128
- text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
 
 
129
  return text_embed_dict
130
 
131
 
@@ -148,7 +174,12 @@ class F5TTS(object):
148
  pp_size = config["pretrained_config"]["mapping"]["pp_size"]
149
  assert pp_size == 1
150
  self.mapping = tensorrt_llm.Mapping(
151
- world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
 
 
 
 
 
152
  )
153
 
154
  local_rank = rank % self.mapping.gpus_per_node
@@ -176,10 +207,23 @@ class F5TTS(object):
176
  self.outputs = {}
177
  self.buffer_allocated = False
178
 
179
- expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
180
-
181
- found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
182
- if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  logger.error(
184
  f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
185
  )
@@ -190,11 +234,16 @@ class F5TTS(object):
190
  logger.error(f"Found tensor names: {found_tensor_names}")
191
  raise RuntimeError("Tensor names in engine are not the same as expected.")
192
  if self.debug_mode:
193
- self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
 
 
194
 
195
  self.max_mel_len = 4096
196
  self.text_embedding = TextEmbedding(
197
- text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
 
 
 
198
  ).to(self.device)
199
  self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
200
 
@@ -208,9 +257,16 @@ class F5TTS(object):
208
  self.head_dim = 64
209
  self.base_rescale_factor = 1.0
210
  self.interpolation_factor = 1.0
211
- base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
212
- inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
213
- freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
 
 
 
 
 
 
 
214
  self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
215
  self.rope_cos = self.freqs.cos().half()
216
  self.rope_sin = self.freqs.sin().half()
@@ -223,7 +279,9 @@ class F5TTS(object):
223
  time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
224
  half_dim = tmp_dim // 2
225
  emb_factor = math.log(10000) / (half_dim - 1)
226
- emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
 
 
227
  for i in range(self.nfe_steps):
228
  emb = time_step[i] * emb_factor
229
  time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
@@ -242,7 +300,9 @@ class F5TTS(object):
242
  shape = list(self.session.engine.get_tensor_shape(name))
243
  shape[0] = batch_size
244
  shape[1] = seq_len
245
- self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
 
 
246
 
247
  self.buffer_allocated = True
248
 
@@ -356,17 +416,29 @@ class F5TTS(object):
356
  max_seq_len = ref_mel_batch.shape[1]
357
 
358
  text_pad_sequence_drop = torch.cat(
359
- (text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
 
 
 
 
 
 
360
  )
361
 
362
  text_embedding_drop_list = []
363
  for i in range(batch + 1):
364
- text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
 
 
 
 
365
  text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
366
 
367
  text_embedding = text_embedding_drop_condition[:-1]
368
  # text_embedding_drop B,T,C batch should be the same
369
- text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
 
 
370
 
371
  noise = torch.randn_like(ref_mel_batch).to(self.device)
372
  rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
@@ -375,7 +447,9 @@ class F5TTS(object):
375
  cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
376
  cat_mel_text_drop = torch.cat(
377
  (
378
- torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
 
 
379
  text_embedding_drop,
380
  ),
381
  dim=-1,
@@ -384,7 +458,9 @@ class F5TTS(object):
384
  time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
385
 
386
  # Convert estimated_reference_target_mel_len to tensor
387
- input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
 
 
388
 
389
  # combine above along the batch dimension
390
  inputs = {
@@ -393,20 +469,34 @@ class F5TTS(object):
393
  "time_expand": time_expand,
394
  "rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
395
  "rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
396
- "input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
 
 
397
  "delta_t": self.delta_t,
398
  }
399
  if use_perf and remove_input_padding:
400
  torch.cuda.nvtx.range_push("remove input padding")
401
  if remove_input_padding:
402
  max_seq_len = inputs["cond"].shape[1]
403
- inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
404
- inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
 
 
 
 
405
  # for time_expand, convert from B,D to B,T,D by repeat
406
- inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
407
- inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
408
- inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
409
- inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
 
 
 
 
 
 
 
 
410
  if use_perf and remove_input_padding:
411
  torch.cuda.nvtx.range_pop()
412
  for key in inputs:
@@ -422,7 +512,9 @@ class F5TTS(object):
422
  denoised_list = []
423
  start_idx = 0
424
  for i in range(batch):
425
- denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
 
 
426
  start_idx += inputs["input_lengths"][i]
427
  if use_perf and remove_input_padding:
428
  torch.cuda.nvtx.range_pop()
 
17
  def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
18
  # Audio tensor case: batch, seq_len, feature_len
19
  # position_ids case: batch, seq_len
20
+ assert (
21
+ input_tensor_lengths is not None
22
+ ), "input_tensor_lengths must be provided for 3D input_tensor"
23
 
24
  # Initialize a list to collect valid sequences
25
  valid_sequences = []
 
34
 
35
 
36
  class TextEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ text_num_embeds,
40
+ text_dim,
41
+ conv_layers=0,
42
+ conv_mult=2,
43
+ precompute_max_pos=4096,
44
+ ):
45
  super().__init__()
46
+ self.text_embed = nn.Embedding(
47
+ text_num_embeds + 1, text_dim
48
+ ) # use 0 as filler token
49
+ self.register_buffer(
50
+ "freqs_cis",
51
+ precompute_freqs_cis(text_dim, precompute_max_pos),
52
+ persistent=False,
53
+ )
54
+ self.text_blocks = nn.Sequential(
55
+ *[
56
+ ConvNeXtV2Block(text_dim, text_dim * conv_mult)
57
+ for _ in range(conv_layers)
58
+ ]
59
+ )
60
 
61
  def forward(self, text):
62
  # only keep tensors with value not -1
 
100
  dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
101
  ) # depthwise conv
102
  self.norm = nn.LayerNorm(dim, eps=1e-6)
103
+ self.pwconv1 = nn.Linear(
104
+ dim, intermediate_dim
105
+ ) # pointwise/1x1 convs, implemented with linear layers
106
  self.act = nn.GELU()
107
  self.grn = GRN(intermediate_dim)
108
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
 
120
  return residual + x
121
 
122
 
123
+ def precompute_freqs_cis(
124
+ dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
125
+ ):
126
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
127
  # has some connection to NTK literature
128
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
 
149
  for key in dict_state.keys():
150
  # transformer.text_embed.text_embed.weight -> text_embed.weight
151
  if "text_embed" in key:
152
+ text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[
153
+ key
154
+ ]
155
  return text_embed_dict
156
 
157
 
 
174
  pp_size = config["pretrained_config"]["mapping"]["pp_size"]
175
  assert pp_size == 1
176
  self.mapping = tensorrt_llm.Mapping(
177
+ world_size=world_size,
178
+ rank=rank,
179
+ cp_size=cp_size,
180
+ tp_size=tp_size,
181
+ pp_size=1,
182
+ gpus_per_node=1,
183
  )
184
 
185
  local_rank = rank % self.mapping.gpus_per_node
 
207
  self.outputs = {}
208
  self.buffer_allocated = False
209
 
210
+ expected_tensor_names = [
211
+ "noise",
212
+ "cond",
213
+ "time",
214
+ "rope_cos",
215
+ "rope_sin",
216
+ "input_lengths",
217
+ "denoised",
218
+ ]
219
+
220
+ found_tensor_names = [
221
+ self.session.engine.get_tensor_name(i)
222
+ for i in range(self.session.engine.num_io_tensors)
223
+ ]
224
+ if not self.debug_mode and set(expected_tensor_names) != set(
225
+ found_tensor_names
226
+ ):
227
  logger.error(
228
  f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
229
  )
 
234
  logger.error(f"Found tensor names: {found_tensor_names}")
235
  raise RuntimeError("Tensor names in engine are not the same as expected.")
236
  if self.debug_mode:
237
+ self.debug_tensors = list(
238
+ set(found_tensor_names) - set(expected_tensor_names)
239
+ )
240
 
241
  self.max_mel_len = 4096
242
  self.text_embedding = TextEmbedding(
243
+ text_num_embeds=vocab_size,
244
+ text_dim=512,
245
+ conv_layers=4,
246
+ precompute_max_pos=self.max_mel_len,
247
  ).to(self.device)
248
  self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
249
 
 
257
  self.head_dim = 64
258
  self.base_rescale_factor = 1.0
259
  self.interpolation_factor = 1.0
260
+ base = 10000.0 * self.base_rescale_factor ** (
261
+ self.head_dim / (self.head_dim - 2)
262
+ )
263
+ inv_freq = 1.0 / (
264
+ base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)
265
+ )
266
+ freqs = (
267
+ torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq)
268
+ / self.interpolation_factor
269
+ )
270
  self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
271
  self.rope_cos = self.freqs.cos().half()
272
  self.rope_sin = self.freqs.sin().half()
 
279
  time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
280
  half_dim = tmp_dim // 2
281
  emb_factor = math.log(10000) / (half_dim - 1)
282
+ emb_factor = 1000.0 * torch.exp(
283
+ torch.arange(half_dim, dtype=torch.float32) * -emb_factor
284
+ )
285
  for i in range(self.nfe_steps):
286
  emb = time_step[i] * emb_factor
287
  time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
 
300
  shape = list(self.session.engine.get_tensor_shape(name))
301
  shape[0] = batch_size
302
  shape[1] = seq_len
303
+ self.outputs[name] = torch.empty(
304
+ shape, dtype=self._tensor_dtype(name), device=self.device
305
+ )
306
 
307
  self.buffer_allocated = True
308
 
 
416
  max_seq_len = ref_mel_batch.shape[1]
417
 
418
  text_pad_sequence_drop = torch.cat(
419
+ (
420
+ text_pad_sequence,
421
+ torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(
422
+ self.device
423
+ ),
424
+ ),
425
+ dim=0,
426
  )
427
 
428
  text_embedding_drop_list = []
429
  for i in range(batch + 1):
430
+ text_embedding_drop_list.append(
431
+ self.text_embedding(
432
+ text_pad_sequence_drop[i].unsqueeze(0).to(self.device)
433
+ )
434
+ )
435
  text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
436
 
437
  text_embedding = text_embedding_drop_condition[:-1]
438
  # text_embedding_drop B,T,C batch should be the same
439
+ text_embedding_drop = (
440
+ text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
441
+ )
442
 
443
  noise = torch.randn_like(ref_mel_batch).to(self.device)
444
  rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
 
447
  cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
448
  cat_mel_text_drop = torch.cat(
449
  (
450
+ torch.zeros(
451
+ (batch, max_seq_len, self.n_mel_channels), dtype=torch.float32
452
+ ).to(self.device),
453
  text_embedding_drop,
454
  ),
455
  dim=-1,
 
458
  time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
459
 
460
  # Convert estimated_reference_target_mel_len to tensor
461
+ input_lengths = torch.tensor(
462
+ estimated_reference_target_mel_len, dtype=torch.int32
463
+ )
464
 
465
  # combine above along the batch dimension
466
  inputs = {
 
469
  "time_expand": time_expand,
470
  "rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
471
  "rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
472
+ "input_lengths": torch.cat(
473
+ (input_lengths, input_lengths), dim=0
474
+ ).contiguous(),
475
  "delta_t": self.delta_t,
476
  }
477
  if use_perf and remove_input_padding:
478
  torch.cuda.nvtx.range_push("remove input padding")
479
  if remove_input_padding:
480
  max_seq_len = inputs["cond"].shape[1]
481
+ inputs["noise"] = remove_tensor_padding(
482
+ inputs["noise"], inputs["input_lengths"]
483
+ )
484
+ inputs["cond"] = remove_tensor_padding(
485
+ inputs["cond"], inputs["input_lengths"]
486
+ )
487
  # for time_expand, convert from B,D to B,T,D by repeat
488
+ inputs["time_expand"] = (
489
+ inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
490
+ )
491
+ inputs["time_expand"] = remove_tensor_padding(
492
+ inputs["time_expand"], inputs["input_lengths"]
493
+ )
494
+ inputs["rope_cos"] = remove_tensor_padding(
495
+ inputs["rope_cos"], inputs["input_lengths"]
496
+ )
497
+ inputs["rope_sin"] = remove_tensor_padding(
498
+ inputs["rope_sin"], inputs["input_lengths"]
499
+ )
500
  if use_perf and remove_input_padding:
501
  torch.cuda.nvtx.range_pop()
502
  for key in inputs:
 
512
  denoised_list = []
513
  start_idx = 0
514
  for i in range(batch):
515
+ denoised_list.append(
516
+ denoised[start_idx : start_idx + inputs["input_lengths"][i]]
517
+ )
518
  start_idx += inputs["input_lengths"][i]
519
  if use_perf and remove_input_padding:
520
  torch.cuda.nvtx.range_pop()
f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py CHANGED
@@ -73,7 +73,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
73
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
74
  char_list.append(" ")
75
  char_list.extend(seg)
76
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
 
 
77
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
78
  for i, c in enumerate(seg):
79
  if is_chinese(c):
@@ -85,7 +87,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
85
  char_list.extend(c)
86
  elif is_chinese(c):
87
  char_list.append(" ")
88
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
 
 
89
  else:
90
  char_list.append(c)
91
  final_reference_target_texts_list.append(char_list)
@@ -98,7 +102,9 @@ def list_str_to_idx(
98
  vocab_char_map: dict[str, int], # {char: idx}
99
  padding_value=-1,
100
  ): # noqa: F722
101
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
 
 
102
  return list_idx_tensors
103
 
104
 
@@ -121,7 +127,9 @@ class TritonPythonModel:
121
 
122
  self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
123
  self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
124
- self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
 
 
125
 
126
  self.tllm_model_dir = parameters["tllm_model_dir"]
127
  config_file = os.path.join(self.tllm_model_dir, "config.json")
@@ -163,13 +171,17 @@ class TritonPythonModel:
163
  input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
164
 
165
  inference_request = pb_utils.InferenceRequest(
166
- model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
 
 
167
  )
168
  inference_response = inference_request.exec()
169
  if inference_response.has_error():
170
  raise pb_utils.TritonModelException(inference_response.error().message())
171
  else:
172
- waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
 
 
173
  waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
174
 
175
  return waveform
@@ -181,7 +193,13 @@ class TritonPythonModel:
181
  reference_target_texts_list,
182
  estimated_reference_target_mel_len,
183
  reference_mel_len,
184
- ) = [], [], [], [], []
 
 
 
 
 
 
185
  mel_features_list = []
186
  if self.use_perf:
187
  torch.cuda.nvtx.range_push("preprocess")
@@ -189,10 +207,14 @@ class TritonPythonModel:
189
  wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
190
  wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
191
 
192
- reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
 
 
193
  reference_text = reference_text[0][0].decode("utf-8")
194
  reference_text_list.append(reference_text)
195
- target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
 
 
196
  target_text = target_text[0][0].decode("utf-8")
197
  target_text_list.append(target_text)
198
 
@@ -221,30 +243,49 @@ class TritonPythonModel:
221
  reference_mel_len.append(mel_features.shape[1])
222
  estimated_reference_target_mel_len.append(
223
  int(
224
- mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
 
 
 
 
 
225
  )
226
  )
227
 
228
  max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
229
 
230
  batch = len(requests)
231
- mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
 
 
232
  for i, mel in enumerate(mel_features_list):
233
  mel_features[i, : mel.shape[1], :] = mel
234
 
235
  reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
236
 
237
- pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
 
 
238
  text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
239
 
240
  for i, item in enumerate(text_pad_sequence):
241
  text_pad_sequence[i] = F.pad(
242
- item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
 
 
 
243
  )
244
- text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
245
- text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
 
 
 
 
246
  text_pad_sequence = F.pad(
247
- text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
 
 
 
248
  )
249
  if self.use_perf:
250
  torch.cuda.nvtx.range_pop()
@@ -264,7 +305,11 @@ class TritonPythonModel:
264
  for i in range(batch):
265
  ref_me_len = reference_mel_len[i]
266
  estimated_mel_len = estimated_reference_target_mel_len[i]
267
- denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
 
 
 
 
268
  audio = self.forward_vocoder(denoised_one_item)
269
  rms = torch.sqrt(torch.mean(torch.square(audio)))
270
  if rms < self.target_rms:
 
73
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
74
  char_list.append(" ")
75
  char_list.extend(seg)
76
+ elif polyphone and seg_byte_len == 3 * len(
77
+ seg
78
+ ): # if pure east asian characters
79
  seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
80
  for i, c in enumerate(seg):
81
  if is_chinese(c):
 
87
  char_list.extend(c)
88
  elif is_chinese(c):
89
  char_list.append(" ")
90
+ char_list.extend(
91
+ lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
92
+ )
93
  else:
94
  char_list.append(c)
95
  final_reference_target_texts_list.append(char_list)
 
102
  vocab_char_map: dict[str, int], # {char: idx}
103
  padding_value=-1,
104
  ): # noqa: F722
105
+ list_idx_tensors = [
106
+ torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
107
+ ] # pinyin or char style
108
  return list_idx_tensors
109
 
110
 
 
127
 
128
  self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
129
  self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
130
+ self.resampler = torchaudio.transforms.Resample(
131
+ self.reference_sample_rate, self.target_audio_sample_rate
132
+ )
133
 
134
  self.tllm_model_dir = parameters["tllm_model_dir"]
135
  config_file = os.path.join(self.tllm_model_dir, "config.json")
 
171
  input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
172
 
173
  inference_request = pb_utils.InferenceRequest(
174
+ model_name="vocoder",
175
+ requested_output_names=["waveform"],
176
+ inputs=[input_tensor_0],
177
  )
178
  inference_response = inference_request.exec()
179
  if inference_response.has_error():
180
  raise pb_utils.TritonModelException(inference_response.error().message())
181
  else:
182
+ waveform = pb_utils.get_output_tensor_by_name(
183
+ inference_response, "waveform"
184
+ )
185
  waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
186
 
187
  return waveform
 
193
  reference_target_texts_list,
194
  estimated_reference_target_mel_len,
195
  reference_mel_len,
196
+ ) = (
197
+ [],
198
+ [],
199
+ [],
200
+ [],
201
+ [],
202
+ )
203
  mel_features_list = []
204
  if self.use_perf:
205
  torch.cuda.nvtx.range_push("preprocess")
 
207
  wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
208
  wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
209
 
210
+ reference_text = pb_utils.get_input_tensor_by_name(
211
+ request, "reference_text"
212
+ ).as_numpy()
213
  reference_text = reference_text[0][0].decode("utf-8")
214
  reference_text_list.append(reference_text)
215
+ target_text = pb_utils.get_input_tensor_by_name(
216
+ request, "target_text"
217
+ ).as_numpy()
218
  target_text = target_text[0][0].decode("utf-8")
219
  target_text_list.append(target_text)
220
 
 
243
  reference_mel_len.append(mel_features.shape[1])
244
  estimated_reference_target_mel_len.append(
245
  int(
246
+ mel_features.shape[1]
247
+ * (
248
+ 1
249
+ + len(target_text.encode("utf-8"))
250
+ / len(reference_text.encode("utf-8"))
251
+ )
252
  )
253
  )
254
 
255
  max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
256
 
257
  batch = len(requests)
258
+ mel_features = torch.zeros(
259
+ (batch, max_seq_len, self.n_mel_channels), dtype=torch.float16
260
+ ).to(self.device)
261
  for i, mel in enumerate(mel_features_list):
262
  mel_features[i, : mel.shape[1], :] = mel
263
 
264
  reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
265
 
266
+ pinyin_list = convert_char_to_pinyin(
267
+ reference_target_texts_list, polyphone=True
268
+ )
269
  text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
270
 
271
  for i, item in enumerate(text_pad_sequence):
272
  text_pad_sequence[i] = F.pad(
273
+ item,
274
+ (0, estimated_reference_target_mel_len[i] - len(item)),
275
+ mode="constant",
276
+ value=-1,
277
  )
278
+ text_pad_sequence[
279
+ i
280
+ ] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
281
+ text_pad_sequence = pad_sequence(
282
+ text_pad_sequence, padding_value=-1, batch_first=True
283
+ ).to(self.device)
284
  text_pad_sequence = F.pad(
285
+ text_pad_sequence,
286
+ (0, max_seq_len - text_pad_sequence.shape[1]),
287
+ mode="constant",
288
+ value=-1,
289
  )
290
  if self.use_perf:
291
  torch.cuda.nvtx.range_pop()
 
305
  for i in range(batch):
306
  ref_me_len = reference_mel_len[i]
307
  estimated_mel_len = estimated_reference_target_mel_len[i]
308
+ denoised_one_item = (
309
+ denoised[i, ref_me_len:estimated_mel_len, :]
310
+ .unsqueeze(0)
311
+ .transpose(1, 2)
312
+ )
313
  audio = self.forward_vocoder(denoised_one_item)
314
  rms = torch.sqrt(torch.mean(torch.square(audio)))
315
  if rms < self.target_rms:
f5_tts/runtime/triton_trtllm/patch/__init__.py CHANGED
@@ -13,14 +13,10 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  from .baichuan.model import BaichuanForCausalLM
16
- from .bert.model import (
17
- BertForQuestionAnswering,
18
- BertForSequenceClassification,
19
- BertModel,
20
- RobertaForQuestionAnswering,
21
- RobertaForSequenceClassification,
22
- RobertaModel,
23
- )
24
  from .bloom.model import BloomForCausalLM, BloomModel
25
  from .chatglm.config import ChatGLMConfig
26
  from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
@@ -51,17 +47,17 @@ from .mamba.model import MambaForCausalLM
51
  from .medusa.config import MedusaConfig
52
  from .medusa.model import MedusaForCausalLm
53
  from .mllama.model import MLLaMAModel
54
- from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode
 
55
  from .mpt.model import MPTForCausalLM, MPTModel
56
  from .nemotron_nas.model import DeciLMForCausalLM
57
  from .opt.model import OPTForCausalLM, OPTModel
58
- from .phi.model import PhiForCausalLM, PhiModel
59
  from .phi3.model import Phi3ForCausalLM, Phi3Model
 
60
  from .qwen.model import QWenForCausalLM
61
  from .recurrentgemma.model import RecurrentGemmaForCausalLM
62
  from .redrafter.model import ReDrafterForCausalLM
63
 
64
-
65
  __all__ = [
66
  "BertModel",
67
  "BertForQuestionAnswering",
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  from .baichuan.model import BaichuanForCausalLM
16
+ from .bert.model import (BertForQuestionAnswering,
17
+ BertForSequenceClassification, BertModel,
18
+ RobertaForQuestionAnswering,
19
+ RobertaForSequenceClassification, RobertaModel)
 
 
 
 
20
  from .bloom.model import BloomForCausalLM, BloomModel
21
  from .chatglm.config import ChatGLMConfig
22
  from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
 
47
  from .medusa.config import MedusaConfig
48
  from .medusa.model import MedusaForCausalLm
49
  from .mllama.model import MLLaMAModel
50
+ from .modeling_utils import (PretrainedConfig, PretrainedModel,
51
+ SpeculativeDecodingMode)
52
  from .mpt.model import MPTForCausalLM, MPTModel
53
  from .nemotron_nas.model import DeciLMForCausalLM
54
  from .opt.model import OPTForCausalLM, OPTModel
 
55
  from .phi3.model import Phi3ForCausalLM, Phi3Model
56
+ from .phi.model import PhiForCausalLM, PhiModel
57
  from .qwen.model import QWenForCausalLM
58
  from .recurrentgemma.model import RecurrentGemmaForCausalLM
59
  from .redrafter.model import ReDrafterForCausalLM
60
 
 
61
  __all__ = [
62
  "BertModel",
63
  "BertForQuestionAnswering",
f5_tts/runtime/triton_trtllm/patch/f5tts/model.py CHANGED
@@ -13,8 +13,8 @@ from ...layers import Linear
13
  from ...module import Module, ModuleList
14
  from ...plugin import current_all_reduce_helper
15
  from ..modeling_utils import PretrainedConfig, PretrainedModel
16
- from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
17
-
18
 
19
  current_file_path = os.path.abspath(__file__)
20
  parent_dir = os.path.dirname(current_file_path)
@@ -38,7 +38,9 @@ class F5TTS(PretrainedModel):
38
  self.dtype = str_dtype_to_trt(config.dtype)
39
 
40
  self.time_embed = TimestepEmbedding(config.hidden_size)
41
- self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
 
 
42
 
43
  self.dim = config.hidden_size
44
  self.depth = config.num_hidden_layers
@@ -71,7 +73,14 @@ class F5TTS(PretrainedModel):
71
  t = self.time_embed(time)
72
  x = self.input_embed(noise, cond)
73
  for block in self.transformer_blocks:
74
- x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
 
 
 
 
 
 
 
75
  denoise = self.proj_out(self.norm_out(x, t))
76
  denoise.mark_output("denoised", self.dtype)
77
  return denoise
 
13
  from ...module import Module, ModuleList
14
  from ...plugin import current_all_reduce_helper
15
  from ..modeling_utils import PretrainedConfig, PretrainedModel
16
+ from .modules import (AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock,
17
+ TimestepEmbedding)
18
 
19
  current_file_path = os.path.abspath(__file__)
20
  parent_dir = os.path.dirname(current_file_path)
 
38
  self.dtype = str_dtype_to_trt(config.dtype)
39
 
40
  self.time_embed = TimestepEmbedding(config.hidden_size)
41
+ self.input_embed = InputEmbedding(
42
+ config.mel_dim, config.text_dim, config.hidden_size
43
+ )
44
 
45
  self.dim = config.hidden_size
46
  self.depth = config.num_hidden_layers
 
73
  t = self.time_embed(time)
74
  x = self.input_embed(noise, cond)
75
  for block in self.transformer_blocks:
76
+ x = block(
77
+ x,
78
+ t,
79
+ rope_cos=rope_cos,
80
+ rope_sin=rope_sin,
81
+ input_lengths=input_lengths,
82
+ scale=scale,
83
+ )
84
  denoise = self.proj_out(self.norm_out(x, t))
85
  denoise.mark_output("denoised", self.dtype)
86
  return denoise
f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py CHANGED
@@ -9,28 +9,10 @@ import torch.nn.functional as F
9
  from tensorrt_llm._common import default_net
10
 
11
  from ..._utils import str_dtype_to_trt, trt_dtype_to_np
12
- from ...functional import (
13
- Tensor,
14
- bert_attention,
15
- cast,
16
- chunk,
17
- concat,
18
- constant,
19
- expand,
20
- expand_dims,
21
- expand_dims_like,
22
- expand_mask,
23
- gelu,
24
- matmul,
25
- permute,
26
- shape,
27
- silu,
28
- slice,
29
- softmax,
30
- squeeze,
31
- unsqueeze,
32
- view,
33
- )
34
  from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
35
  from ...module import Module
36
 
@@ -57,7 +39,9 @@ class AdaLayerNormZero(Module):
57
 
58
  def forward(self, x, emb=None):
59
  emb = self.linear(silu(emb))
60
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
 
 
61
  x = self.norm(x)
62
  ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
63
  if default_net().plugin_config.remove_input_padding:
@@ -91,8 +75,12 @@ class ConvPositionEmbedding(Module):
91
  def __init__(self, dim, kernel_size=31, groups=16):
92
  super().__init__()
93
  assert kernel_size % 2 != 0
94
- self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
95
- self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
 
 
 
 
96
  self.mish = Mish()
97
 
98
  def forward(self, x, mask=None): # noqa: F722
@@ -120,7 +108,9 @@ class Attention(Module):
120
  super().__init__()
121
 
122
  if not hasattr(F, "scaled_dot_product_attention"):
123
- raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
 
 
124
 
125
  self.processor = processor
126
 
@@ -191,16 +181,32 @@ class Attention(Module):
191
  c_rope=None, # rotary position embedding for c
192
  ) -> torch.Tensor:
193
  if c is not None:
194
- return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
 
 
 
 
 
 
 
 
195
  else:
196
  return self.processor(
197
- self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
 
 
 
 
 
198
  )
199
 
200
 
201
  def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
202
  shape_tensor = concat(
203
- [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
 
 
 
204
  )
205
  if default_net().plugin_config.remove_input_padding:
206
  assert tensor.ndim() == 2
@@ -208,7 +214,9 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
208
  x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
209
  x1 = expand_dims(x1, 2)
210
  x2 = expand_dims(x2, 2)
211
- zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
 
 
212
  x2 = zero - x2
213
  x = concat([x2, x1], 2)
214
  out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
@@ -219,7 +227,9 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
219
  x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
220
  x1 = expand_dims(x1, 3)
221
  x2 = expand_dims(x2, 3)
222
- zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
 
 
223
  x2 = zero - x2
224
  x = concat([x2, x1], 3)
225
  out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
@@ -235,15 +245,23 @@ def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
235
  end_dim = shape(x, -1) - shape(rope_cos, -1)
236
  new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
237
  x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
238
- out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
 
 
239
  else:
240
  rot_dim = shape(rope_cos, 2) # 64
241
  new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
242
  x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
243
  end_dim = shape(x, 2) - shape(rope_cos, 2)
244
- new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
245
- x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
246
- out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
 
 
 
 
 
 
247
  return out
248
 
249
 
@@ -279,8 +297,12 @@ class AttnProcessor:
279
  seq_len_2d = concat([1, N])
280
  max_position_embeddings = 4096
281
  # create position ids
282
- position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
283
- tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
 
 
 
 
284
  tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
285
  tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
286
  tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
@@ -315,14 +337,28 @@ class AttnProcessor:
315
  assert not default_net().plugin_config.remove_input_padding
316
 
317
  def transpose_for_scores(x):
318
- new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
 
 
 
 
 
 
 
319
 
320
  y = x.view(new_x_shape)
321
  y = y.transpose(1, 2)
322
  return y
323
 
324
  def transpose_for_scores_k(x):
325
- new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
 
 
 
 
 
 
 
326
 
327
  y = x.view(new_x_shape)
328
  y = y.permute([0, 2, 3, 1])
@@ -342,7 +378,11 @@ class AttnProcessor:
342
  attention_probs = softmax(attention_scores, dim=-1)
343
 
344
  context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
345
- context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
 
 
 
 
346
  context = attn.to_out(context)
347
  if mask is not None:
348
  mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
@@ -370,13 +410,26 @@ class DiTBlock(Module):
370
  self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
371
 
372
  def forward(
373
- self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
 
 
 
 
 
 
 
374
  ): # x: noised input, t: time embedding
375
  # pre-norm & modulation for attention input
376
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
377
  # attention
378
  # norm ----> (2,1226,1024)
379
- attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
 
 
 
 
 
 
380
 
381
  # process attention output for input x
382
  if default_net().plugin_config.remove_input_padding:
@@ -387,7 +440,9 @@ class DiTBlock(Module):
387
  if default_net().plugin_config.remove_input_padding:
388
  norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
389
  else:
390
- norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
 
 
391
  # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
392
  ff_output = self.ff(norm)
393
  if default_net().plugin_config.remove_input_padding:
 
9
  from tensorrt_llm._common import default_net
10
 
11
  from ..._utils import str_dtype_to_trt, trt_dtype_to_np
12
+ from ...functional import (Tensor, bert_attention, cast, chunk, concat,
13
+ constant, expand, expand_dims, expand_dims_like,
14
+ expand_mask, gelu, matmul, permute, shape, silu,
15
+ slice, softmax, squeeze, unsqueeze, view)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
17
  from ...module import Module
18
 
 
39
 
40
  def forward(self, x, emb=None):
41
  emb = self.linear(silu(emb))
42
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(
43
+ emb, 6, dim=1
44
+ )
45
  x = self.norm(x)
46
  ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
47
  if default_net().plugin_config.remove_input_padding:
 
75
  def __init__(self, dim, kernel_size=31, groups=16):
76
  super().__init__()
77
  assert kernel_size % 2 != 0
78
+ self.conv1d1 = Conv1d(
79
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
80
+ )
81
+ self.conv1d2 = Conv1d(
82
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
83
+ )
84
  self.mish = Mish()
85
 
86
  def forward(self, x, mask=None): # noqa: F722
 
108
  super().__init__()
109
 
110
  if not hasattr(F, "scaled_dot_product_attention"):
111
+ raise ImportError(
112
+ "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
113
+ )
114
 
115
  self.processor = processor
116
 
 
181
  c_rope=None, # rotary position embedding for c
182
  ) -> torch.Tensor:
183
  if c is not None:
184
+ return self.processor(
185
+ self,
186
+ x,
187
+ c=c,
188
+ input_lengths=input_lengths,
189
+ scale=scale,
190
+ rope=rope,
191
+ c_rope=c_rope,
192
+ )
193
  else:
194
  return self.processor(
195
+ self,
196
+ x,
197
+ rope_cos=rope_cos,
198
+ rope_sin=rope_sin,
199
+ input_lengths=input_lengths,
200
+ scale=scale,
201
  )
202
 
203
 
204
  def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
205
  shape_tensor = concat(
206
+ [
207
+ shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i)
208
+ for i in range(tensor.ndim())
209
+ ]
210
  )
211
  if default_net().plugin_config.remove_input_padding:
212
  assert tensor.ndim() == 2
 
214
  x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
215
  x1 = expand_dims(x1, 2)
216
  x2 = expand_dims(x2, 2)
217
+ zero = constant(
218
+ np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))
219
+ )
220
  x2 = zero - x2
221
  x = concat([x2, x1], 2)
222
  out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
 
227
  x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
228
  x1 = expand_dims(x1, 3)
229
  x2 = expand_dims(x2, 3)
230
+ zero = constant(
231
+ np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))
232
+ )
233
  x2 = zero - x2
234
  x = concat([x2, x1], 3)
235
  out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
 
245
  end_dim = shape(x, -1) - shape(rope_cos, -1)
246
  new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
247
  x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
248
+ out = concat(
249
+ [x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1
250
+ )
251
  else:
252
  rot_dim = shape(rope_cos, 2) # 64
253
  new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
254
  x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
255
  end_dim = shape(x, 2) - shape(rope_cos, 2)
256
+ new_t_unrotated_shape = concat(
257
+ [shape(x, 0), shape(x, 1), end_dim]
258
+ ) # (2, -1, 960)
259
+ x_unrotated = slice(
260
+ x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1]
261
+ )
262
+ out = concat(
263
+ [x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1
264
+ )
265
  return out
266
 
267
 
 
297
  seq_len_2d = concat([1, N])
298
  max_position_embeddings = 4096
299
  # create position ids
300
+ position_ids_buffer = constant(
301
+ np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0)
302
+ )
303
+ tmp_position_ids = slice(
304
+ position_ids_buffer, starts=[0, 0], sizes=seq_len_2d
305
+ )
306
  tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
307
  tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
308
  tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
 
337
  assert not default_net().plugin_config.remove_input_padding
338
 
339
  def transpose_for_scores(x):
340
+ new_x_shape = concat(
341
+ [
342
+ shape(x, 0),
343
+ shape(x, 1),
344
+ attn.num_attention_heads,
345
+ attn.attention_head_size,
346
+ ]
347
+ )
348
 
349
  y = x.view(new_x_shape)
350
  y = y.transpose(1, 2)
351
  return y
352
 
353
  def transpose_for_scores_k(x):
354
+ new_x_shape = concat(
355
+ [
356
+ shape(x, 0),
357
+ shape(x, 1),
358
+ attn.num_attention_heads,
359
+ attn.attention_head_size,
360
+ ]
361
+ )
362
 
363
  y = x.view(new_x_shape)
364
  y = y.permute([0, 2, 3, 1])
 
378
  attention_probs = softmax(attention_scores, dim=-1)
379
 
380
  context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
381
+ context = context.view(
382
+ concat(
383
+ [shape(context, 0), shape(context, 1), attn.attention_hidden_size]
384
+ )
385
+ )
386
  context = attn.to_out(context)
387
  if mask is not None:
388
  mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
 
410
  self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
411
 
412
  def forward(
413
+ self,
414
+ x,
415
+ t,
416
+ rope_cos,
417
+ rope_sin,
418
+ input_lengths,
419
+ scale=1.0,
420
+ rope=ModuleNotFoundError,
421
  ): # x: noised input, t: time embedding
422
  # pre-norm & modulation for attention input
423
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
424
  # attention
425
  # norm ----> (2,1226,1024)
426
+ attn_output = self.attn(
427
+ x=norm,
428
+ rope_cos=rope_cos,
429
+ rope_sin=rope_sin,
430
+ input_lengths=input_lengths,
431
+ scale=scale,
432
+ )
433
 
434
  # process attention output for input x
435
  if default_net().plugin_config.remove_input_padding:
 
440
  if default_net().plugin_config.remove_input_padding:
441
  norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
442
  else:
443
+ norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(
444
+ shift_mlp, 1
445
+ )
446
  # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
447
  ff_output = self.ff(norm)
448
  if default_net().plugin_config.remove_input_padding:
f5_tts/runtime/triton_trtllm/scripts/conv_stft.py CHANGED
@@ -40,7 +40,6 @@ import torch as th
40
  import torch.nn.functional as F
41
  from scipy.signal import check_COLA, get_window
42
 
43
-
44
  support_clp_op = None
45
  if th.__version__ >= "1.7.0":
46
  from torch.fft import rfft as fft
@@ -124,7 +123,9 @@ class STFT(th.nn.Module):
124
  ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
125
  window = get_window(self.win_type, self.win_len)
126
 
127
- self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
 
 
128
  window = th.FloatTensor(window)
129
  if self.mode == "continue":
130
  left_pad = (self.fft_len - self.win_len) // 2
 
40
  import torch.nn.functional as F
41
  from scipy.signal import check_COLA, get_window
42
 
 
43
  support_clp_op = None
44
  if th.__version__ >= "1.7.0":
45
  from torch.fft import rfft as fft
 
123
  ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
124
  window = get_window(self.win_type, self.win_len)
125
 
126
+ self.perfect_reconstruct = check_COLA(
127
+ window, self.win_len, self.win_len - self.win_hop
128
+ )
129
  window = th.FloatTensor(window)
130
  if self.mode == "continue":
131
  left_pad = (self.fft_len - self.win_len) // 2
f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py CHANGED
@@ -179,19 +179,47 @@ def parse_arguments():
179
  ) # TODO: support F5TTS_v1_Base
180
  parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
181
  parser.add_argument(
182
- "--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  )
184
- parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
185
- parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
186
- parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
187
  parser.add_argument("--cfg_scale", type=float, default=4.0)
188
- parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
189
- parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
190
- parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
191
- parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
192
- parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
193
  parser.add_argument(
194
- "--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
  args = parser.parse_args()
197
  return args
@@ -205,10 +233,15 @@ def convert_timm_dit(args, mapping, dtype="float32"):
205
 
206
  model_params = dict(torch.load(args.timm_ckpt))
207
  model_params = {
208
- k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
 
 
209
  }
210
  prefix = "ema_model.transformer."
211
- model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
 
 
 
212
 
213
  timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
214
 
@@ -223,8 +256,13 @@ def convert_timm_dit(args, mapping, dtype="float32"):
223
 
224
  weights = dict()
225
  for name, param in model_params.items():
226
- if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
227
- weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
 
 
 
 
 
228
  else:
229
  weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
230
 
@@ -239,25 +277,37 @@ def convert_timm_dit(args, mapping, dtype="float32"):
239
  for k, v in weights.items():
240
  if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
241
  weights[k] *= scale_factor
242
- weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
 
 
243
 
244
  elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
245
  weights[k] *= scale_factor
246
- weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
 
 
247
 
248
  elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
249
- weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
 
 
250
  weights[k] *= scale_factor
251
 
252
  elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
253
- weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
 
 
254
  weights[k] *= scale_factor
255
 
256
  elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
257
- weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
 
 
258
 
259
  elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
260
- weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
 
 
261
 
262
  elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
263
  weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
@@ -317,7 +367,9 @@ def covert_and_save(args, rank):
317
 
318
  weights = convert_timm_dit(args, mapping, dtype=args.dtype)
319
 
320
- safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
 
 
321
 
322
 
323
  def execute(workers, func, args):
@@ -334,7 +386,9 @@ def execute(workers, func, args):
334
  except Exception as e:
335
  traceback.print_exc()
336
  exceptions.append(e)
337
- assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
 
 
338
 
339
 
340
  def main():
 
179
  ) # TODO: support F5TTS_v1_Base
180
  parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
181
  parser.add_argument(
182
+ "--output_dir",
183
+ type=str,
184
+ default="./tllm_checkpoint",
185
+ help="The path to save the TensorRT-LLM checkpoint",
186
+ )
187
+ parser.add_argument(
188
+ "--hidden_size", type=int, default=1024, help="The hidden size of DiT"
189
+ )
190
+ parser.add_argument(
191
+ "--depth", type=int, default=22, help="The number of DiTBlock layers"
192
+ )
193
+ parser.add_argument(
194
+ "--num_heads",
195
+ type=int,
196
+ default=16,
197
+ help="The number of heads of attention module",
198
  )
 
 
 
199
  parser.add_argument("--cfg_scale", type=float, default=4.0)
 
 
 
 
 
200
  parser.add_argument(
201
+ "--tp_size", type=int, default=1, help="N-way tensor parallelism size"
202
+ )
203
+ parser.add_argument(
204
+ "--cp_size", type=int, default=1, help="Context parallelism size"
205
+ )
206
+ parser.add_argument(
207
+ "--pp_size", type=int, default=1, help="N-way pipeline parallelism size"
208
+ )
209
+ parser.add_argument(
210
+ "--dtype",
211
+ type=str,
212
+ default="float16",
213
+ choices=["float32", "bfloat16", "float16"],
214
+ )
215
+ parser.add_argument(
216
+ "--fp8_linear", action="store_true", help="Whether use FP8 for linear layers"
217
+ )
218
+ parser.add_argument(
219
+ "--workers",
220
+ type=int,
221
+ default=1,
222
+ help="The number of workers for converting checkpoint in parallel",
223
  )
224
  args = parser.parse_args()
225
  return args
 
233
 
234
  model_params = dict(torch.load(args.timm_ckpt))
235
  model_params = {
236
+ k: v
237
+ for k, v in model_params["ema_model_state_dict"].items()
238
+ if k.startswith("ema_model.transformer")
239
  }
240
  prefix = "ema_model.transformer."
241
+ model_params = {
242
+ key[len(prefix) :] if key.startswith(prefix) else key: value
243
+ for key, value in model_params.items()
244
+ }
245
 
246
  timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
247
 
 
256
 
257
  weights = dict()
258
  for name, param in model_params.items():
259
+ if (
260
+ name == "input_embed.conv_pos_embed.conv1d.0.weight"
261
+ or name == "input_embed.conv_pos_embed.conv1d.2.weight"
262
+ ):
263
+ weights[get_trtllm_name(name)] = (
264
+ param.contiguous().to(torch_dtype).unsqueeze(-1)
265
+ )
266
  else:
267
  weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
268
 
 
277
  for k, v in weights.items():
278
  if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
279
  weights[k] *= scale_factor
280
+ weights[k] = split_q_tp(
281
+ v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
282
+ )
283
 
284
  elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
285
  weights[k] *= scale_factor
286
+ weights[k] = split_q_bias_tp(
287
+ v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
288
+ )
289
 
290
  elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
291
+ weights[k] = split_q_tp(
292
+ v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
293
+ )
294
  weights[k] *= scale_factor
295
 
296
  elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
297
+ weights[k] = split_q_bias_tp(
298
+ v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
299
+ )
300
  weights[k] *= scale_factor
301
 
302
  elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
303
+ weights[k] = split_q_tp(
304
+ v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
305
+ )
306
 
307
  elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
308
+ weights[k] = split_q_bias_tp(
309
+ v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank
310
+ )
311
 
312
  elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
313
  weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
 
367
 
368
  weights = convert_timm_dit(args, mapping, dtype=args.dtype)
369
 
370
+ safetensors.torch.save_file(
371
+ weights, os.path.join(args.output_dir, f"rank{rank}.safetensors")
372
+ )
373
 
374
 
375
  def execute(workers, func, args):
 
386
  except Exception as e:
387
  traceback.print_exc()
388
  exceptions.append(e)
389
+ assert (
390
+ len(exceptions) == 0
391
+ ), "Checkpoint conversion failed, please check error log."
392
 
393
 
394
  def main():
f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py CHANGED
@@ -20,12 +20,13 @@ from conv_stft import STFT
20
  from huggingface_hub import hf_hub_download
21
  from vocos import Vocos
22
 
23
-
24
  opset_version = 17
25
 
26
 
27
  def get_args():
28
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
 
29
  parser.add_argument(
30
  "--vocoder",
31
  type=str,
@@ -108,7 +109,9 @@ def export_VocosVocoder(vocos_vocoder, output_path, verbose):
108
  print("Exported to {}".format(output_path))
109
 
110
 
111
- def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
 
 
112
  if vocoder_name == "vocos":
113
  # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
114
  if is_local:
@@ -118,8 +121,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cp
118
  else:
119
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
120
  repo_id = "charactr/vocos-mel-24khz"
121
- config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
122
- model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
 
 
 
 
123
  vocoder = Vocos.from_hparams(config_path)
124
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
125
  vocoder.load_state_dict(state_dict)
 
20
  from huggingface_hub import hf_hub_download
21
  from vocos import Vocos
22
 
 
23
  opset_version = 17
24
 
25
 
26
  def get_args():
27
+ parser = argparse.ArgumentParser(
28
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
29
+ )
30
  parser.add_argument(
31
  "--vocoder",
32
  type=str,
 
109
  print("Exported to {}".format(output_path))
110
 
111
 
112
+ def load_vocoder(
113
+ vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None
114
+ ):
115
  if vocoder_name == "vocos":
116
  # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
117
  if is_local:
 
121
  else:
122
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
123
  repo_id = "charactr/vocos-mel-24khz"
124
+ config_path = hf_hub_download(
125
+ repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml"
126
+ )
127
+ model_path = hf_hub_download(
128
+ repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin"
129
+ )
130
  vocoder = Vocos.from_hparams(config_path)
131
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
132
  vocoder.load_state_dict(state_dict)
f5_tts/runtime/triton_trtllm/scripts/fill_template.py CHANGED
@@ -29,8 +29,12 @@ if __name__ == "__main__":
29
  "substitutions",
30
  help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
31
  )
32
- parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
33
- parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
 
 
 
 
34
  args = parser.parse_args()
35
 
36
  main(**vars(args))
 
29
  "substitutions",
30
  help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
31
  )
32
+ parser.add_argument(
33
+ "--in_place", "-i", action="store_true", help="do the operation in-place"
34
+ )
35
+ parser.add_argument(
36
+ "--participant_ids", help="Participant IDs for the model", default=""
37
+ )
38
  args = parser.parse_args()
39
 
40
  main(**vars(args))
f5_tts/scripts/count_max_epoch.py CHANGED
@@ -24,10 +24,14 @@ updates_per_epoch = total_hours / mini_batch_hours
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
- print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
 
 
28
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
  # print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
 
31
  # others
32
  print(f"total {total_hours:.0f} hours")
33
- print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
 
 
 
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
+ print(
28
+ f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})"
29
+ )
30
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
31
  # print(f" or approx. 0/{steps_per_epoch:.0f} steps")
32
 
33
  # others
34
  print(f"total {total_hours:.0f} hours")
35
+ print(
36
+ f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch"
37
+ )