alex commited on
Commit
c20d00d
·
1 Parent(s): a8a56b0

do faster CPU based tts

Browse files
Files changed (36) hide show
  1. app.py +32 -50
  2. higgs_audio/__init__.py +0 -1
  3. higgs_audio/audio_processing/LICENSE +0 -51
  4. higgs_audio/audio_processing/descriptaudiocodec/__init__.py +0 -0
  5. higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py +0 -286
  6. higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py +0 -365
  7. higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py +0 -33
  8. higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py +0 -251
  9. higgs_audio/audio_processing/higgs_audio_tokenizer.py +0 -341
  10. higgs_audio/audio_processing/quantization/__init__.py +0 -8
  11. higgs_audio/audio_processing/quantization/ac.py +0 -301
  12. higgs_audio/audio_processing/quantization/core_vq.py +0 -360
  13. higgs_audio/audio_processing/quantization/core_vq_lsx_version.py +0 -431
  14. higgs_audio/audio_processing/quantization/ddp_utils.py +0 -197
  15. higgs_audio/audio_processing/quantization/distrib.py +0 -123
  16. higgs_audio/audio_processing/quantization/vq.py +0 -116
  17. higgs_audio/audio_processing/semantic_module.py +0 -310
  18. higgs_audio/constants.py +0 -3
  19. higgs_audio/data_collator/__init__.py +0 -0
  20. higgs_audio/data_collator/higgs_audio_collator.py +0 -583
  21. higgs_audio/data_types.py +0 -38
  22. higgs_audio/dataset/__init__.py +0 -0
  23. higgs_audio/dataset/chatml_dataset.py +0 -554
  24. higgs_audio/model/__init__.py +0 -9
  25. higgs_audio/model/audio_head.py +0 -139
  26. higgs_audio/model/common.py +0 -27
  27. higgs_audio/model/configuration_higgs_audio.py +0 -235
  28. higgs_audio/model/cuda_graph_runner.py +0 -129
  29. higgs_audio/model/custom_modules.py +0 -155
  30. higgs_audio/model/modeling_higgs_audio.py +0 -0
  31. higgs_audio/model/utils.py +0 -778
  32. higgs_audio/serve/serve_engine.py +0 -474
  33. higgs_audio/serve/utils.py +0 -254
  34. higgs_audio_utils.py +0 -280
  35. requirements.txt +2 -13
  36. supertonic.py +364 -0
app.py CHANGED
@@ -16,40 +16,11 @@ from tqdm import tqdm
16
  import importlib, site, sys
17
  from huggingface_hub import hf_hub_download, snapshot_download
18
 
19
- # Re-discover all .pth/.egg-link files
20
- for sitedir in site.getsitepackages():
21
- site.addsitedir(sitedir)
22
-
23
- # Clear caches so importlib will pick up new modules
24
- importlib.invalidate_caches()
25
-
26
  def sh(cmd): subprocess.check_call(cmd, shell=True)
27
 
28
- flash_attention_installed = False
29
-
30
- try:
31
- print("Attempting to download and install FlashAttention wheel...")
32
- flash_attention_wheel = hf_hub_download(
33
- repo_id="alexnasa/flash-attn-3",
34
- repo_type="model",
35
- filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
36
- )
37
-
38
- sh(f"pip install {flash_attention_wheel}")
39
-
40
- # tell Python to re-scan site-packages now that the egg-link exists
41
- import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
42
-
43
- flash_attention_installed = True
44
- print("FlashAttention installed successfully.")
45
-
46
- except Exception as e:
47
- print(f"⚠️ Could not install FlashAttention: {e}")
48
- print("Continuing without FlashAttention...")
49
-
50
  import torch
51
  print(f"Torch version: {torch.__version__}")
52
- print(f"FlashAttention available: {flash_attention_installed}")
53
 
54
 
55
  import torch.nn as nn
@@ -82,41 +53,38 @@ from transformers import Wav2Vec2FeatureExtractor
82
  import torchvision.transforms as transforms
83
  import torch.nn.functional as F
84
  from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
85
- from higgs_audio_utils import text_to_speech, initialize_engine
86
-
87
 
88
- DEFAULT_TTS_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
89
- DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
90
- engine = initialize_engine(DEFAULT_TTS_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
91
 
92
  os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
93
 
94
- @spaces.GPU
95
- def tts_from_text(text):
96
- _, output = text_to_speech(engine, text)
97
  return output
98
 
99
  def speak_to_me(session_id, evt: gr.EventData):
100
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
101
 
102
- current_text = detail.get("text", "")
103
-
104
- output = tts_from_text(current_text)
105
-
106
  if session_id is None:
107
  session_id = uuid.uuid4().hex
108
 
 
 
 
 
 
 
 
 
 
109
  output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
110
 
111
  tts_dir = output_dir + '/tts'
112
  os.makedirs(tts_dir, exist_ok=True)
113
- speech_to_text_path = os.path.join(tts_dir, f"speech_to_text.wav")
114
 
115
- sampling_rate = output[0]
116
- audio_data = output[1]
117
-
118
- torchaudio.save(speech_to_text_path, torch.from_numpy(audio_data)[None, :], output[0])
119
-
120
  return speech_to_text_path
121
 
122
  def tensor_to_pil(tensor):
@@ -814,6 +782,20 @@ css = """
814
  .stateful textarea[readonly]{
815
  color: var(--body-text-color-subdued) !important; /* subdued in both modes */
816
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
  """
818
 
819
  with gr.Blocks(css=css) as demo:
@@ -859,7 +841,7 @@ with gr.Blocks(css=css) as demo:
859
  with gr.Column():
860
 
861
  image_input = extendedimage(label="Reference Image", type="filepath", height=512)
862
- audio_input = ExtendedAudio(label="Input Audio", type="filepath", options=["EMPTY"], show_download_button=True)
863
  gr.Markdown("*A 5-second limit is applied to audio files to shorten generation time. You can turn this off in Advanced Settings*")
864
 
865
 
@@ -869,7 +851,7 @@ with gr.Blocks(css=css) as demo:
869
  num_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
870
 
871
  time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
872
- infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
873
  with gr.Accordion("Advanced Settings", open=False):
874
  raw_img_text = gr.Text(show_label=False, label="", value='', visible=False)
875
  limit_on = gr.Checkbox(label="Limit Audio files to 5 seconds", value=True)
 
16
  import importlib, site, sys
17
  from huggingface_hub import hf_hub_download, snapshot_download
18
 
 
 
 
 
 
 
 
19
  def sh(cmd): subprocess.check_call(cmd, shell=True)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  import torch
22
  print(f"Torch version: {torch.__version__}")
23
+
24
 
25
 
26
  import torch.nn as nn
 
53
  import torchvision.transforms as transforms
54
  import torch.nn.functional as F
55
  from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
 
 
56
 
57
+ from supertonic import generate_speech
 
 
58
 
59
  os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
60
 
61
+ def tts_from_text(text, tts_dir, voice_choice):
62
+
63
+ output = generate_speech([text], tts_dir, voice_choice)[0]
64
  return output
65
 
66
  def speak_to_me(session_id, evt: gr.EventData):
67
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
68
 
 
 
 
 
69
  if session_id is None:
70
  session_id = uuid.uuid4().hex
71
 
72
+ current_text = detail.get("text", "")
73
+ voice = detail.get("choice")
74
+ voice_choice = "M1"
75
+
76
+ if voice == "Person1":
77
+ voice_choice = "M1"
78
+ elif voice == "Person2":
79
+ voice_choice = "F1"
80
+
81
  output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
82
 
83
  tts_dir = output_dir + '/tts'
84
  os.makedirs(tts_dir, exist_ok=True)
 
85
 
86
+ speech_to_text_path = tts_from_text(current_text, tts_dir, voice_choice)
87
+
 
 
 
88
  return speech_to_text_path
89
 
90
  def tensor_to_pil(tensor):
 
782
  .stateful textarea[readonly]{
783
  color: var(--body-text-color-subdued) !important; /* subdued in both modes */
784
  }
785
+
786
+ .button-gradient {
787
+ background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%;
788
+ border: none;
789
+ padding: 14px 28px;
790
+ font-size: 16px;
791
+ font-weight: bold;
792
+ color: white;
793
+ border-radius: 10px;
794
+ cursor: pointer;
795
+ transition: 0.3s ease-in-out;
796
+ animation: 2s linear 0s infinite normal none running gradientAnimation;
797
+ box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px;
798
+ }
799
  """
800
 
801
  with gr.Blocks(css=css) as demo:
 
841
  with gr.Column():
842
 
843
  image_input = extendedimage(label="Reference Image", type="filepath", height=512)
844
+ audio_input = ExtendedAudio(label="Input Audio", type="filepath", options=["Person1", "Person2"], show_download_button=True)
845
  gr.Markdown("*A 5-second limit is applied to audio files to shorten generation time. You can turn this off in Advanced Settings*")
846
 
847
 
 
851
  num_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
852
 
853
  time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
854
+ infer_btn = gr.Button("🦜 Avatar Me", variant='primary', elem_classes="button-gradient")
855
  with gr.Accordion("Advanced Settings", open=False):
856
  raw_img_text = gr.Text(show_label=False, label="", value='', visible=False)
857
  limit_on = gr.Checkbox(label="Limit Audio files to 5 seconds", value=True)
higgs_audio/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .model import HiggsAudioConfig, HiggsAudioModel
 
 
higgs_audio/audio_processing/LICENSE DELETED
@@ -1,51 +0,0 @@
1
- Third-Party License Attribution for Audio Processing Module
2
- ===========================================================
3
-
4
- This directory contains code derived from multiple open-source projects.
5
- The following sections detail the licenses and attributions for third-party code.
6
-
7
- ## XCodec Repository
8
- The code in this directory is derived from:
9
- https://github.com/zhenye234/xcodec
10
-
11
- ## Individual File Attributions
12
-
13
- ### Quantization Module (quantization/)
14
- - Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
15
- - Individual files contain their own license headers where applicable
16
- - The vector-quantize-pytorch portions are licensed under the MIT License
17
-
18
- ## License Terms
19
-
20
- ### MIT License (for applicable portions)
21
- Permission is hereby granted, free of charge, to any person obtaining a copy
22
- of this software and associated documentation files (the "Software"), to deal
23
- in the Software without restriction, including without limitation the rights
24
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
- copies of the Software, and to permit persons to whom the Software is
26
- furnished to do so, subject to the following conditions:
27
-
28
- The above copyright notice and this permission notice shall be included in all
29
- copies or substantial portions of the Software.
30
-
31
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
- SOFTWARE.
38
-
39
- ## Attribution Requirements
40
- When using this code, please ensure proper attribution to:
41
- 1. The original xcodec repository: https://github.com/zhenye234/xcodec
42
- 2. Any other repositories mentioned in individual file headers
43
- 3. This derivative work and its modifications
44
-
45
- ## Disclaimer
46
- This directory contains modified versions of the original code. Please refer to
47
- the original repositories for the canonical implementations and their specific
48
- license terms.
49
-
50
- For any questions about licensing or attribution, please check the individual
51
- file headers and the original source repositories.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/descriptaudiocodec/__init__.py DELETED
File without changes
higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py DELETED
@@ -1,286 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from pathlib import Path
4
- from typing import Union
5
-
6
- import numpy as np
7
- import torch
8
- import tqdm
9
- from audiotools import AudioSignal
10
- from torch import nn
11
-
12
- SUPPORTED_VERSIONS = ["1.0.0"]
13
-
14
-
15
- @dataclass
16
- class DACFile:
17
- codes: torch.Tensor
18
-
19
- # Metadata
20
- chunk_length: int
21
- original_length: int
22
- input_db: float
23
- channels: int
24
- sample_rate: int
25
- padding: bool
26
- dac_version: str
27
-
28
- def save(self, path):
29
- artifacts = {
30
- "codes": self.codes.numpy().astype(np.uint16),
31
- "metadata": {
32
- "input_db": self.input_db.numpy().astype(np.float32),
33
- "original_length": self.original_length,
34
- "sample_rate": self.sample_rate,
35
- "chunk_length": self.chunk_length,
36
- "channels": self.channels,
37
- "padding": self.padding,
38
- "dac_version": SUPPORTED_VERSIONS[-1],
39
- },
40
- }
41
- path = Path(path).with_suffix(".dac")
42
- with open(path, "wb") as f:
43
- np.save(f, artifacts)
44
- return path
45
-
46
- @classmethod
47
- def load(cls, path):
48
- artifacts = np.load(path, allow_pickle=True)[()]
49
- codes = torch.from_numpy(artifacts["codes"].astype(int))
50
- if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
- raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
52
- return cls(codes=codes, **artifacts["metadata"])
53
-
54
-
55
- class CodecMixin:
56
- @property
57
- def padding(self):
58
- if not hasattr(self, "_padding"):
59
- self._padding = True
60
- return self._padding
61
-
62
- @padding.setter
63
- def padding(self, value):
64
- assert isinstance(value, bool)
65
-
66
- layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
67
-
68
- for layer in layers:
69
- if value:
70
- if hasattr(layer, "original_padding"):
71
- layer.padding = layer.original_padding
72
- else:
73
- layer.original_padding = layer.padding
74
- layer.padding = tuple(0 for _ in range(len(layer.padding)))
75
-
76
- self._padding = value
77
-
78
- def get_delay(self):
79
- # Any number works here, delay is invariant to input length
80
- l_out = self.get_output_length(0)
81
- L = l_out
82
-
83
- layers = []
84
- for layer in self.modules():
85
- if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
86
- layers.append(layer)
87
-
88
- for layer in reversed(layers):
89
- d = layer.dilation[0]
90
- k = layer.kernel_size[0]
91
- s = layer.stride[0]
92
-
93
- if isinstance(layer, nn.ConvTranspose1d):
94
- L = ((L - d * (k - 1) - 1) / s) + 1
95
- elif isinstance(layer, nn.Conv1d):
96
- L = (L - 1) * s + d * (k - 1) + 1
97
-
98
- L = math.ceil(L)
99
-
100
- l_in = L
101
-
102
- return (l_in - l_out) // 2
103
-
104
- def get_output_length(self, input_length):
105
- L = input_length
106
- # Calculate output length
107
- for layer in self.modules():
108
- if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
109
- d = layer.dilation[0]
110
- k = layer.kernel_size[0]
111
- s = layer.stride[0]
112
-
113
- if isinstance(layer, nn.Conv1d):
114
- L = ((L - d * (k - 1) - 1) / s) + 1
115
- elif isinstance(layer, nn.ConvTranspose1d):
116
- L = (L - 1) * s + d * (k - 1) + 1
117
-
118
- L = math.floor(L)
119
- return L
120
-
121
- @torch.no_grad()
122
- def compress(
123
- self,
124
- audio_path_or_signal: Union[str, Path, AudioSignal],
125
- win_duration: float = 1.0,
126
- verbose: bool = False,
127
- normalize_db: float = -16,
128
- n_quantizers: int = None,
129
- ) -> DACFile:
130
- """Processes an audio signal from a file or AudioSignal object into
131
- discrete codes. This function processes the signal in short windows,
132
- using constant GPU memory.
133
-
134
- Parameters
135
- ----------
136
- audio_path_or_signal : Union[str, Path, AudioSignal]
137
- audio signal to reconstruct
138
- win_duration : float, optional
139
- window duration in seconds, by default 5.0
140
- verbose : bool, optional
141
- by default False
142
- normalize_db : float, optional
143
- normalize db, by default -16
144
-
145
- Returns
146
- -------
147
- DACFile
148
- Object containing compressed codes and metadata
149
- required for decompression
150
- """
151
- audio_signal = audio_path_or_signal
152
- if isinstance(audio_signal, (str, Path)):
153
- audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
154
-
155
- self.eval()
156
- original_padding = self.padding
157
- original_device = audio_signal.device
158
-
159
- audio_signal = audio_signal.clone()
160
- original_sr = audio_signal.sample_rate
161
-
162
- resample_fn = audio_signal.resample
163
- loudness_fn = audio_signal.loudness
164
-
165
- # If audio is > 10 minutes long, use the ffmpeg versions
166
- if audio_signal.signal_duration >= 10 * 60 * 60:
167
- resample_fn = audio_signal.ffmpeg_resample
168
- loudness_fn = audio_signal.ffmpeg_loudness
169
-
170
- original_length = audio_signal.signal_length
171
- resample_fn(self.sample_rate)
172
- input_db = loudness_fn()
173
-
174
- if normalize_db is not None:
175
- audio_signal.normalize(normalize_db)
176
- audio_signal.ensure_max_of_audio()
177
-
178
- nb, nac, nt = audio_signal.audio_data.shape
179
- audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
180
- win_duration = audio_signal.signal_duration if win_duration is None else win_duration
181
-
182
- if audio_signal.signal_duration <= win_duration:
183
- # Unchunked compression (used if signal length < win duration)
184
- self.padding = True
185
- n_samples = nt
186
- hop = nt
187
- else:
188
- # Chunked inference
189
- self.padding = False
190
- # Zero-pad signal on either side by the delay
191
- audio_signal.zero_pad(self.delay, self.delay)
192
- n_samples = int(win_duration * self.sample_rate)
193
- # Round n_samples to nearest hop length multiple
194
- n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
195
- hop = self.get_output_length(n_samples)
196
-
197
- codes = []
198
- range_fn = range if not verbose else tqdm.trange
199
-
200
- for i in range_fn(0, nt, hop):
201
- x = audio_signal[..., i : i + n_samples]
202
- x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
203
-
204
- audio_data = x.audio_data.to(self.device)
205
- audio_data = self.preprocess(audio_data, self.sample_rate)
206
- _, c, _, _, _ = self.encode(audio_data, n_quantizers)
207
- codes.append(c.to(original_device))
208
- chunk_length = c.shape[-1]
209
-
210
- codes = torch.cat(codes, dim=-1)
211
-
212
- dac_file = DACFile(
213
- codes=codes,
214
- chunk_length=chunk_length,
215
- original_length=original_length,
216
- input_db=input_db,
217
- channels=nac,
218
- sample_rate=original_sr,
219
- padding=self.padding,
220
- dac_version=SUPPORTED_VERSIONS[-1],
221
- )
222
-
223
- if n_quantizers is not None:
224
- codes = codes[:, :n_quantizers, :]
225
-
226
- self.padding = original_padding
227
- return dac_file
228
-
229
- @torch.no_grad()
230
- def decompress(
231
- self,
232
- obj: Union[str, Path, DACFile],
233
- verbose: bool = False,
234
- ) -> AudioSignal:
235
- """Reconstruct audio from a given .dac file
236
-
237
- Parameters
238
- ----------
239
- obj : Union[str, Path, DACFile]
240
- .dac file location or corresponding DACFile object.
241
- verbose : bool, optional
242
- Prints progress if True, by default False
243
-
244
- Returns
245
- -------
246
- AudioSignal
247
- Object with the reconstructed audio
248
- """
249
- self.eval()
250
- if isinstance(obj, (str, Path)):
251
- obj = DACFile.load(obj)
252
-
253
- original_padding = self.padding
254
- self.padding = obj.padding
255
-
256
- range_fn = range if not verbose else tqdm.trange
257
- codes = obj.codes
258
- original_device = codes.device
259
- chunk_length = obj.chunk_length
260
- recons = []
261
-
262
- for i in range_fn(0, codes.shape[-1], chunk_length):
263
- c = codes[..., i : i + chunk_length].to(self.device)
264
- z = self.quantizer.from_codes(c)[0]
265
- r = self.decode(z)
266
- recons.append(r.to(original_device))
267
-
268
- recons = torch.cat(recons, dim=-1)
269
- recons = AudioSignal(recons, self.sample_rate)
270
-
271
- resample_fn = recons.resample
272
- loudness_fn = recons.loudness
273
-
274
- # If audio is > 10 minutes long, use the ffmpeg versions
275
- if recons.signal_duration >= 10 * 60 * 60:
276
- resample_fn = recons.ffmpeg_resample
277
- loudness_fn = recons.ffmpeg_loudness
278
-
279
- recons.normalize(obj.input_db)
280
- resample_fn(obj.sample_rate)
281
- recons = recons[..., : obj.original_length]
282
- loudness_fn()
283
- recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
284
-
285
- self.padding = original_padding
286
- return recons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py DELETED
@@ -1,365 +0,0 @@
1
- import math
2
- from typing import List
3
- from typing import Union
4
-
5
- import numpy as np
6
- import torch
7
- from audiotools import AudioSignal
8
- from audiotools.ml import BaseModel
9
- from torch import nn
10
-
11
- from .base import CodecMixin
12
- from dac.nn.layers import Snake1d
13
- from dac.nn.layers import WNConv1d
14
- from dac.nn.layers import WNConvTranspose1d
15
- from dac.nn.quantize import ResidualVectorQuantize
16
-
17
-
18
- def init_weights(m):
19
- if isinstance(m, nn.Conv1d):
20
- nn.init.trunc_normal_(m.weight, std=0.02)
21
- nn.init.constant_(m.bias, 0)
22
-
23
-
24
- class ResidualUnit(nn.Module):
25
- def __init__(self, dim: int = 16, dilation: int = 1):
26
- super().__init__()
27
- pad = ((7 - 1) * dilation) // 2
28
- self.block = nn.Sequential(
29
- Snake1d(dim),
30
- WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
- Snake1d(dim),
32
- WNConv1d(dim, dim, kernel_size=1),
33
- )
34
-
35
- def forward(self, x):
36
- y = self.block(x)
37
- pad = (x.shape[-1] - y.shape[-1]) // 2
38
- if pad > 0:
39
- x = x[..., pad:-pad]
40
- return x + y
41
-
42
-
43
- class EncoderBlock(nn.Module):
44
- def __init__(self, dim: int = 16, stride: int = 1):
45
- super().__init__()
46
- self.block = nn.Sequential(
47
- ResidualUnit(dim // 2, dilation=1),
48
- ResidualUnit(dim // 2, dilation=3),
49
- ResidualUnit(dim // 2, dilation=9),
50
- Snake1d(dim // 2),
51
- WNConv1d(
52
- dim // 2,
53
- dim,
54
- kernel_size=2 * stride,
55
- stride=stride,
56
- padding=math.ceil(stride / 2),
57
- ),
58
- )
59
-
60
- def forward(self, x):
61
- return self.block(x)
62
-
63
-
64
- class Encoder(nn.Module):
65
- def __init__(
66
- self,
67
- d_model: int = 64,
68
- strides: list = [2, 4, 8, 8],
69
- d_latent: int = 256,
70
- ):
71
- super().__init__()
72
- # Create first convolution
73
- self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
-
75
- # Create EncoderBlocks that double channels as they downsample by `stride`
76
- for stride in strides:
77
- d_model *= 2
78
- self.block += [EncoderBlock(d_model, stride=stride)]
79
-
80
- # Create last convolution
81
- self.block += [
82
- Snake1d(d_model),
83
- WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
- ]
85
-
86
- # Wrap black into nn.Sequential
87
- self.block = nn.Sequential(*self.block)
88
- self.enc_dim = d_model
89
-
90
- def forward(self, x):
91
- return self.block(x)
92
-
93
-
94
- class DecoderBlock(nn.Module):
95
- def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
96
- super().__init__()
97
- self.block = nn.Sequential(
98
- Snake1d(input_dim),
99
- WNConvTranspose1d(
100
- input_dim,
101
- output_dim,
102
- kernel_size=2 * stride,
103
- stride=stride,
104
- padding=math.ceil(stride / 2),
105
- output_padding=stride % 2, # out_pad,
106
- ),
107
- ResidualUnit(output_dim, dilation=1),
108
- ResidualUnit(output_dim, dilation=3),
109
- ResidualUnit(output_dim, dilation=9),
110
- )
111
-
112
- def forward(self, x):
113
- return self.block(x)
114
-
115
-
116
- class Decoder(nn.Module):
117
- def __init__(
118
- self,
119
- input_channel,
120
- channels,
121
- rates,
122
- d_out: int = 1,
123
- ):
124
- super().__init__()
125
-
126
- # Add first conv layer
127
- layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
128
-
129
- # Add upsampling + MRF blocks
130
- for i, stride in enumerate(rates):
131
- input_dim = channels // 2**i
132
- output_dim = channels // 2 ** (i + 1)
133
- if i == 1:
134
- out_pad = 1
135
- else:
136
- out_pad = 0
137
- layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
138
-
139
- # Add final conv layer
140
- layers += [
141
- Snake1d(output_dim),
142
- WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
143
- # nn.Tanh(),
144
- ]
145
-
146
- self.model = nn.Sequential(*layers)
147
-
148
- def forward(self, x):
149
- return self.model(x)
150
-
151
-
152
- class DAC(BaseModel, CodecMixin):
153
- def __init__(
154
- self,
155
- encoder_dim: int = 64,
156
- encoder_rates: List[int] = [2, 4, 8, 8],
157
- latent_dim: int = None,
158
- decoder_dim: int = 1536,
159
- decoder_rates: List[int] = [8, 8, 4, 2],
160
- n_codebooks: int = 9,
161
- codebook_size: int = 1024,
162
- codebook_dim: Union[int, list] = 8,
163
- quantizer_dropout: bool = False,
164
- sample_rate: int = 44100,
165
- ):
166
- super().__init__()
167
-
168
- self.encoder_dim = encoder_dim
169
- self.encoder_rates = encoder_rates
170
- self.decoder_dim = decoder_dim
171
- self.decoder_rates = decoder_rates
172
- self.sample_rate = sample_rate
173
-
174
- if latent_dim is None:
175
- latent_dim = encoder_dim * (2 ** len(encoder_rates))
176
-
177
- self.latent_dim = latent_dim
178
-
179
- self.hop_length = np.prod(encoder_rates)
180
- self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
181
-
182
- self.n_codebooks = n_codebooks
183
- self.codebook_size = codebook_size
184
- self.codebook_dim = codebook_dim
185
- self.quantizer = ResidualVectorQuantize(
186
- input_dim=latent_dim,
187
- n_codebooks=n_codebooks,
188
- codebook_size=codebook_size,
189
- codebook_dim=codebook_dim,
190
- quantizer_dropout=quantizer_dropout,
191
- )
192
-
193
- self.decoder = Decoder(
194
- latent_dim,
195
- decoder_dim,
196
- decoder_rates,
197
- )
198
- self.sample_rate = sample_rate
199
- self.apply(init_weights)
200
-
201
- self.delay = self.get_delay()
202
-
203
- def preprocess(self, audio_data, sample_rate):
204
- if sample_rate is None:
205
- sample_rate = self.sample_rate
206
- assert sample_rate == self.sample_rate
207
-
208
- length = audio_data.shape[-1]
209
- right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
210
- audio_data = nn.functional.pad(audio_data, (0, right_pad))
211
-
212
- return audio_data
213
-
214
- def encode(
215
- self,
216
- audio_data: torch.Tensor,
217
- n_quantizers: int = None,
218
- ):
219
- """Encode given audio data and return quantized latent codes
220
-
221
- Parameters
222
- ----------
223
- audio_data : Tensor[B x 1 x T]
224
- Audio data to encode
225
- n_quantizers : int, optional
226
- Number of quantizers to use, by default None
227
- If None, all quantizers are used.
228
-
229
- Returns
230
- -------
231
- dict
232
- A dictionary with the following keys:
233
- "z" : Tensor[B x D x T]
234
- Quantized continuous representation of input
235
- "codes" : Tensor[B x N x T]
236
- Codebook indices for each codebook
237
- (quantized discrete representation of input)
238
- "latents" : Tensor[B x N*D x T]
239
- Projected latents (continuous representation of input before quantization)
240
- "vq/commitment_loss" : Tensor[1]
241
- Commitment loss to train encoder to predict vectors closer to codebook
242
- entries
243
- "vq/codebook_loss" : Tensor[1]
244
- Codebook loss to update the codebook
245
- "length" : int
246
- Number of samples in input audio
247
- """
248
- z = self.encoder(audio_data)
249
- z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
250
- return z, codes, latents, commitment_loss, codebook_loss
251
-
252
- def decode(self, z: torch.Tensor):
253
- """Decode given latent codes and return audio data
254
-
255
- Parameters
256
- ----------
257
- z : Tensor[B x D x T]
258
- Quantized continuous representation of input
259
- length : int, optional
260
- Number of samples in output audio, by default None
261
-
262
- Returns
263
- -------
264
- dict
265
- A dictionary with the following keys:
266
- "audio" : Tensor[B x 1 x length]
267
- Decoded audio data.
268
- """
269
- return self.decoder(z)
270
-
271
- def forward(
272
- self,
273
- audio_data: torch.Tensor,
274
- sample_rate: int = None,
275
- n_quantizers: int = None,
276
- ):
277
- """Model forward pass
278
-
279
- Parameters
280
- ----------
281
- audio_data : Tensor[B x 1 x T]
282
- Audio data to encode
283
- sample_rate : int, optional
284
- Sample rate of audio data in Hz, by default None
285
- If None, defaults to `self.sample_rate`
286
- n_quantizers : int, optional
287
- Number of quantizers to use, by default None.
288
- If None, all quantizers are used.
289
-
290
- Returns
291
- -------
292
- dict
293
- A dictionary with the following keys:
294
- "z" : Tensor[B x D x T]
295
- Quantized continuous representation of input
296
- "codes" : Tensor[B x N x T]
297
- Codebook indices for each codebook
298
- (quantized discrete representation of input)
299
- "latents" : Tensor[B x N*D x T]
300
- Projected latents (continuous representation of input before quantization)
301
- "vq/commitment_loss" : Tensor[1]
302
- Commitment loss to train encoder to predict vectors closer to codebook
303
- entries
304
- "vq/codebook_loss" : Tensor[1]
305
- Codebook loss to update the codebook
306
- "length" : int
307
- Number of samples in input audio
308
- "audio" : Tensor[B x 1 x length]
309
- Decoded audio data.
310
- """
311
- length = audio_data.shape[-1]
312
- audio_data = self.preprocess(audio_data, sample_rate)
313
- z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
314
-
315
- x = self.decode(z)
316
- return {
317
- "audio": x[..., :length],
318
- "z": z,
319
- "codes": codes,
320
- "latents": latents,
321
- "vq/commitment_loss": commitment_loss,
322
- "vq/codebook_loss": codebook_loss,
323
- }
324
-
325
-
326
- if __name__ == "__main__":
327
- import numpy as np
328
- from functools import partial
329
-
330
- model = DAC().to("cpu")
331
-
332
- for n, m in model.named_modules():
333
- o = m.extra_repr()
334
- p = sum([np.prod(p.size()) for p in m.parameters()])
335
- fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
336
- setattr(m, "extra_repr", partial(fn, o=o, p=p))
337
- print(model)
338
- print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
339
-
340
- length = 88200 * 2
341
- x = torch.randn(1, 1, length).to(model.device)
342
- x.requires_grad_(True)
343
- x.retain_grad()
344
-
345
- # Make a forward pass
346
- out = model(x)["audio"]
347
- print("Input shape:", x.shape)
348
- print("Output shape:", out.shape)
349
-
350
- # Create gradient variable
351
- grad = torch.zeros_like(out)
352
- grad[:, :, grad.shape[-1] // 2] = 1
353
-
354
- # Make a backward pass
355
- out.backward(grad)
356
-
357
- # Check non-zero values
358
- gradmap = x.grad.squeeze(0)
359
- gradmap = (gradmap != 0).sum(0) # sum across features
360
- rf = (gradmap != 0).sum()
361
-
362
- print(f"Receptive field: {rf.item()}")
363
-
364
- x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
365
- model.decompress(model.compress(x, verbose=True), verbose=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py DELETED
@@ -1,33 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from einops import rearrange
6
- from torch.nn.utils import weight_norm
7
-
8
-
9
- def WNConv1d(*args, **kwargs):
10
- return weight_norm(nn.Conv1d(*args, **kwargs))
11
-
12
-
13
- def WNConvTranspose1d(*args, **kwargs):
14
- return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
-
16
-
17
- # Scripting this brings model speed up 1.4x
18
- @torch.jit.script
19
- def snake(x, alpha):
20
- shape = x.shape
21
- x = x.reshape(shape[0], shape[1], -1)
22
- x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
- x = x.reshape(shape)
24
- return x
25
-
26
-
27
- class Snake1d(nn.Module):
28
- def __init__(self, channels):
29
- super().__init__()
30
- self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
-
32
- def forward(self, x):
33
- return snake(x, self.alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py DELETED
@@ -1,251 +0,0 @@
1
- from typing import Union
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from einops import rearrange
8
- from torch.nn.utils import weight_norm
9
-
10
- from dac.nn.layers import WNConv1d
11
-
12
-
13
- class VectorQuantize(nn.Module):
14
- """
15
- Implementation of VQ similar to Karpathy's repo:
16
- https://github.com/karpathy/deep-vector-quantization
17
- Additionally uses following tricks from Improved VQGAN
18
- (https://arxiv.org/pdf/2110.04627.pdf):
19
- 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
- for improved codebook usage
21
- 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
- improves training stability
23
- """
24
-
25
- def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
- super().__init__()
27
- self.codebook_size = codebook_size
28
- self.codebook_dim = codebook_dim
29
-
30
- self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
- self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
-
34
- def forward(self, z):
35
- """Quantized the input tensor using a fixed codebook and returns
36
- the corresponding codebook vectors
37
-
38
- Parameters
39
- ----------
40
- z : Tensor[B x D x T]
41
-
42
- Returns
43
- -------
44
- Tensor[B x D x T]
45
- Quantized continuous representation of input
46
- Tensor[1]
47
- Commitment loss to train encoder to predict vectors closer to codebook
48
- entries
49
- Tensor[1]
50
- Codebook loss to update the codebook
51
- Tensor[B x T]
52
- Codebook indices (quantized discrete representation of input)
53
- Tensor[B x D x T]
54
- Projected latents (continuous representation of input before quantization)
55
- """
56
-
57
- # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
- z_e = self.in_proj(z) # z_e : (B x D x T)
59
- z_q, indices = self.decode_latents(z_e)
60
-
61
- commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
- codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
-
64
- z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
65
-
66
- z_q = self.out_proj(z_q)
67
-
68
- return z_q, commitment_loss, codebook_loss, indices, z_e
69
-
70
- def embed_code(self, embed_id):
71
- return F.embedding(embed_id, self.codebook.weight)
72
-
73
- def decode_code(self, embed_id):
74
- return self.embed_code(embed_id).transpose(1, 2)
75
-
76
- def decode_latents(self, latents):
77
- encodings = rearrange(latents, "b d t -> (b t) d")
78
- codebook = self.codebook.weight # codebook: (N x D)
79
-
80
- # L2 normalize encodings and codebook (ViT-VQGAN)
81
- encodings = F.normalize(encodings)
82
- codebook = F.normalize(codebook)
83
-
84
- # Compute euclidean distance with codebook
85
- dist = (
86
- encodings.pow(2).sum(1, keepdim=True)
87
- - 2 * encodings @ codebook.t()
88
- + codebook.pow(2).sum(1, keepdim=True).t()
89
- )
90
- indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
91
- z_q = self.decode_code(indices)
92
- return z_q, indices
93
-
94
-
95
- class ResidualVectorQuantize(nn.Module):
96
- """
97
- Introduced in SoundStream: An end2end neural audio codec
98
- https://arxiv.org/abs/2107.03312
99
- """
100
-
101
- def __init__(
102
- self,
103
- input_dim: int = 512,
104
- n_codebooks: int = 9,
105
- codebook_size: int = 1024,
106
- codebook_dim: Union[int, list] = 8,
107
- quantizer_dropout: float = 0.0,
108
- ):
109
- super().__init__()
110
- if isinstance(codebook_dim, int):
111
- codebook_dim = [codebook_dim for _ in range(n_codebooks)]
112
-
113
- self.n_codebooks = n_codebooks
114
- self.codebook_dim = codebook_dim
115
- self.codebook_size = codebook_size
116
-
117
- self.quantizers = nn.ModuleList(
118
- [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
119
- )
120
- self.quantizer_dropout = quantizer_dropout
121
-
122
- def forward(self, z, n_quantizers: int = None):
123
- """Quantized the input tensor using a fixed set of `n` codebooks and returns
124
- the corresponding codebook vectors
125
- Parameters
126
- ----------
127
- z : Tensor[B x D x T]
128
- n_quantizers : int, optional
129
- No. of quantizers to use
130
- (n_quantizers < self.n_codebooks ex: for quantizer dropout)
131
- Note: if `self.quantizer_dropout` is True, this argument is ignored
132
- when in training mode, and a random number of quantizers is used.
133
- Returns
134
- -------
135
- dict
136
- A dictionary with the following keys:
137
-
138
- "z" : Tensor[B x D x T]
139
- Quantized continuous representation of input
140
- "codes" : Tensor[B x N x T]
141
- Codebook indices for each codebook
142
- (quantized discrete representation of input)
143
- "latents" : Tensor[B x N*D x T]
144
- Projected latents (continuous representation of input before quantization)
145
- "vq/commitment_loss" : Tensor[1]
146
- Commitment loss to train encoder to predict vectors closer to codebook
147
- entries
148
- "vq/codebook_loss" : Tensor[1]
149
- Codebook loss to update the codebook
150
- """
151
- z_q = 0
152
- residual = z
153
- commitment_loss = 0
154
- codebook_loss = 0
155
-
156
- codebook_indices = []
157
- latents = []
158
-
159
- if n_quantizers is None:
160
- n_quantizers = self.n_codebooks
161
- if self.training:
162
- n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
163
- dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
164
- n_dropout = int(z.shape[0] * self.quantizer_dropout)
165
- n_quantizers[:n_dropout] = dropout[:n_dropout]
166
- n_quantizers = n_quantizers.to(z.device)
167
-
168
- for i, quantizer in enumerate(self.quantizers):
169
- if self.training is False and i >= n_quantizers:
170
- break
171
-
172
- z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
173
-
174
- # Create mask to apply quantizer dropout
175
- mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
176
- z_q = z_q + z_q_i * mask[:, None, None]
177
- residual = residual - z_q_i
178
-
179
- # Sum losses
180
- commitment_loss += (commitment_loss_i * mask).mean()
181
- codebook_loss += (codebook_loss_i * mask).mean()
182
-
183
- codebook_indices.append(indices_i)
184
- latents.append(z_e_i)
185
-
186
- codes = torch.stack(codebook_indices, dim=1)
187
- latents = torch.cat(latents, dim=1)
188
-
189
- return z_q, codes, latents, commitment_loss, codebook_loss
190
-
191
- def from_codes(self, codes: torch.Tensor):
192
- """Given the quantized codes, reconstruct the continuous representation
193
- Parameters
194
- ----------
195
- codes : Tensor[B x N x T]
196
- Quantized discrete representation of input
197
- Returns
198
- -------
199
- Tensor[B x D x T]
200
- Quantized continuous representation of input
201
- """
202
- z_q = 0.0
203
- z_p = []
204
- n_codebooks = codes.shape[1]
205
- for i in range(n_codebooks):
206
- z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
207
- z_p.append(z_p_i)
208
-
209
- z_q_i = self.quantizers[i].out_proj(z_p_i)
210
- z_q = z_q + z_q_i
211
- return z_q, torch.cat(z_p, dim=1), codes
212
-
213
- def from_latents(self, latents: torch.Tensor):
214
- """Given the unquantized latents, reconstruct the
215
- continuous representation after quantization.
216
-
217
- Parameters
218
- ----------
219
- latents : Tensor[B x N x T]
220
- Continuous representation of input after projection
221
-
222
- Returns
223
- -------
224
- Tensor[B x D x T]
225
- Quantized representation of full-projected space
226
- Tensor[B x D x T]
227
- Quantized representation of latent space
228
- """
229
- z_q = 0
230
- z_p = []
231
- codes = []
232
- dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
233
-
234
- n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
235
- for i in range(n_codebooks):
236
- j, k = dims[i], dims[i + 1]
237
- z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
238
- z_p.append(z_p_i)
239
- codes.append(codes_i)
240
-
241
- z_q_i = self.quantizers[i].out_proj(z_p_i)
242
- z_q = z_q + z_q_i
243
-
244
- return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
245
-
246
-
247
- if __name__ == "__main__":
248
- rvq = ResidualVectorQuantize(quantizer_dropout=True)
249
- x = torch.randn(16, 512, 80)
250
- y = rvq(x)
251
- print(y["latents"].shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/higgs_audio_tokenizer.py DELETED
@@ -1,341 +0,0 @@
1
- # Based on code from: https://github.com/zhenye234/xcodec
2
- # Licensed under MIT License
3
- # Modifications by BosonAI
4
-
5
- import math
6
- import os
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from typing import Optional, Union, Sequence
11
- import numpy as np
12
- from transformers import AutoModel
13
- import torchaudio
14
- import json
15
- import librosa
16
- from huggingface_hub import snapshot_download
17
-
18
- from vector_quantize_pytorch import ResidualFSQ
19
- from .descriptaudiocodec.dac.model import dac as dac2
20
- from .quantization.vq import ResidualVectorQuantizer
21
- from .semantic_module import Encoder, Decoder
22
-
23
-
24
- class EncodedResult:
25
- def __init__(self, audio_codes):
26
- self.audio_codes = audio_codes
27
-
28
-
29
- class HiggsAudioFeatureExtractor(nn.Module):
30
- def __init__(self, sampling_rate=16000):
31
- super().__init__()
32
- self.sampling_rate = sampling_rate
33
-
34
- def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
35
- # Convert from librosa to torch
36
- audio_signal = torch.tensor(raw_audio)
37
- audio_signal = audio_signal.unsqueeze(0)
38
- if len(audio_signal.shape) < 3:
39
- audio_signal = audio_signal.unsqueeze(0)
40
- return {"input_values": audio_signal}
41
-
42
-
43
- class HiggsAudioTokenizer(nn.Module):
44
- def __init__(
45
- self,
46
- n_filters: int = 32,
47
- D: int = 128,
48
- target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
49
- ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
50
- sample_rate: int = 16000,
51
- bins: int = 1024,
52
- n_q: int = 8,
53
- codebook_dim: int = None,
54
- normalize: bool = False,
55
- causal: bool = False,
56
- semantic_techer: str = "hubert_base_general",
57
- last_layer_semantic: bool = True,
58
- merge_mode: str = "concat",
59
- downsample_mode: str = "step_down",
60
- semantic_mode: str = "classic",
61
- vq_scale: int = 1,
62
- semantic_sample_rate: int = None,
63
- device: str = "cuda",
64
- ):
65
- super().__init__()
66
- self.hop_length = np.prod(ratios)
67
- self.semantic_techer = semantic_techer
68
-
69
- self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
70
-
71
- self.target_bandwidths = target_bandwidths
72
- self.n_q = n_q
73
- self.sample_rate = sample_rate
74
- self.encoder = dac2.Encoder(64, ratios, D)
75
-
76
- self.decoder_2 = dac2.Decoder(D, 1024, ratios)
77
- self.last_layer_semantic = last_layer_semantic
78
- self.device = device
79
- if semantic_techer == "hubert_base":
80
- self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
81
- self.semantic_sample_rate = 16000
82
- self.semantic_dim = 768
83
- self.encoder_semantic_dim = 768
84
-
85
- elif semantic_techer == "wavlm_base_plus":
86
- self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
87
- self.semantic_sample_rate = 16000
88
- self.semantic_dim = 768
89
- self.encoder_semantic_dim = 768
90
-
91
- elif semantic_techer == "hubert_base_general":
92
- self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio")
93
- self.semantic_sample_rate = 16000
94
- self.semantic_dim = 768
95
- self.encoder_semantic_dim = 768
96
-
97
- # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
98
- if semantic_sample_rate is not None:
99
- self.semantic_sample_rate = semantic_sample_rate
100
-
101
- self.semantic_model.eval()
102
-
103
- # make the semantic model parameters do not need gradient
104
- for param in self.semantic_model.parameters():
105
- param.requires_grad = False
106
-
107
- self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
108
-
109
- self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
110
- self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
111
- self.decoder_semantic = Decoder(
112
- code_dim=self.encoder_semantic_dim,
113
- output_channels=self.semantic_dim,
114
- decode_channels=self.semantic_dim,
115
- )
116
-
117
- # out_D=D+768
118
- if isinstance(bins, int): # RVQ
119
- self.quantizer = ResidualVectorQuantizer(
120
- dimension=self.quantizer_dim,
121
- codebook_dim=codebook_dim,
122
- n_q=n_q,
123
- bins=bins,
124
- )
125
- self.quantizer_type = "RVQ"
126
- else: # RFSQ
127
- self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
128
- self.quantizer_type = "RFSQ"
129
-
130
- self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
131
- self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
132
- self.fc_post2 = nn.Linear(self.quantizer_dim, D)
133
-
134
- self.downsample_mode = downsample_mode
135
- if downsample_mode == "avg":
136
- self.semantic_pooling = nn.AvgPool1d(
137
- kernel_size=self.semantic_downsample_factor,
138
- stride=self.semantic_downsample_factor,
139
- )
140
-
141
- self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
142
-
143
- @property
144
- def tps(self):
145
- return self.frame_rate
146
-
147
- @property
148
- def sampling_rate(self):
149
- return self.sample_rate
150
-
151
- @property
152
- def num_codebooks(self):
153
- return self.n_q
154
-
155
- @property
156
- def codebook_size(self):
157
- return self.quantizer_dim
158
-
159
- def get_last_layer(self):
160
- return self.decoder.layers[-1].weight
161
-
162
- def calculate_rec_loss(self, rec, target):
163
- target = target / target.norm(dim=-1, keepdim=True)
164
- rec = rec / rec.norm(dim=-1, keepdim=True)
165
- rec_loss = (1 - (target * rec).sum(-1)).mean()
166
-
167
- return rec_loss
168
-
169
- @torch.no_grad()
170
- def get_regress_target(self, x):
171
- x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
172
-
173
- if (
174
- self.semantic_techer == "hubert_base"
175
- or self.semantic_techer == "hubert_base_general"
176
- or self.semantic_techer == "wavlm_base_plus"
177
- ):
178
- x = x[:, 0, :]
179
- x = F.pad(x, (160, 160))
180
- target = self.semantic_model(x, output_hidden_states=True).hidden_states
181
- target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
182
-
183
- # average for all layers
184
- target = target.mean(1)
185
- # target = target[9]
186
- # if self.hop_length > 320:
187
- # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
188
-
189
- elif self.semantic_techer == "w2v_bert2":
190
- target = self.semantic_model(x)
191
-
192
- elif self.semantic_techer.startswith("whisper"):
193
- if self.last_layer_semantic:
194
- target = self.semantic_model(x, avg_layers=False)
195
- else:
196
- target = self.semantic_model(x, avg_layers=True)
197
-
198
- elif self.semantic_techer.startswith("mert_music"):
199
- if self.last_layer_semantic:
200
- target = self.semantic_model(x, avg_layers=False)
201
- else:
202
- target = self.semantic_model(x, avg_layers=True)
203
-
204
- elif self.semantic_techer.startswith("qwen_audio_omni"):
205
- target = self.semantic_model(x)
206
-
207
- if self.downsample_mode == "step_down":
208
- if self.semantic_downsample_factor > 1:
209
- target = target[:, :: self.semantic_downsample_factor, :]
210
-
211
- elif self.downsample_mode == "avg":
212
- target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
213
- return target
214
-
215
- def forward(self, x: torch.Tensor, bw: int):
216
- e_semantic_input = self.get_regress_target(x).detach()
217
-
218
- e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
219
- e_acoustic = self.encoder(x)
220
-
221
- e = torch.cat([e_acoustic, e_semantic], dim=1)
222
-
223
- e = self.fc_prior(e.transpose(1, 2))
224
-
225
- if self.quantizer_type == "RVQ":
226
- e = e.transpose(1, 2)
227
- quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
228
- quantized = quantized.transpose(1, 2)
229
- else:
230
- quantized, codes = self.quantizer(e)
231
- commit_loss = torch.tensor(0.0)
232
-
233
- quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
234
- quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
235
-
236
- o = self.decoder_2(quantized_acoustic)
237
-
238
- o_semantic = self.decoder_semantic(quantized_semantic)
239
- semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
240
-
241
- return o, commit_loss, semantic_recon_loss, None
242
-
243
- def encode(
244
- self,
245
- audio_path_or_wv,
246
- sr=None,
247
- loudness_normalize=False,
248
- loudness_threshold=-23.0,
249
- ):
250
- if isinstance(audio_path_or_wv, str):
251
- wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
252
- else:
253
- wv = audio_path_or_wv
254
- assert sr is not None
255
- if loudness_normalize:
256
- import pyloudnorm as pyln
257
-
258
- meter = pyln.Meter(sr)
259
- l = meter.integrated_loudness(wv)
260
- wv = pyln.normalize.loudness(wv, l, loudness_threshold)
261
- if sr != self.sampling_rate:
262
- wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
263
- if self.audio_tokenizer_feature_extractor is not None:
264
- inputs = self.audio_tokenizer_feature_extractor(
265
- raw_audio=wv,
266
- sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate,
267
- return_tensors="pt",
268
- )
269
- input_values = inputs["input_values"].to(self.device)
270
- else:
271
- input_values = torch.from_numpy(wv).float().unsqueeze(0)
272
- with torch.no_grad():
273
- encoder_outputs = self._xcodec_encode(input_values)
274
- vq_code = encoder_outputs.audio_codes[0]
275
- return vq_code
276
-
277
- def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
278
- bw = target_bw
279
-
280
- e_semantic_input = self.get_regress_target(x).detach()
281
-
282
- e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
283
- e_acoustic = self.encoder(x)
284
-
285
- if e_acoustic.shape[2] != e_semantic.shape[2]:
286
- pad_size = 160 * self.semantic_downsample_factor
287
- e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
288
-
289
- if e_acoustic.shape[2] != e_semantic.shape[2]:
290
- if e_acoustic.shape[2] > e_semantic.shape[2]:
291
- e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
292
- else:
293
- e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
294
-
295
- e = torch.cat([e_acoustic, e_semantic], dim=1)
296
-
297
- e = self.fc_prior(e.transpose(1, 2))
298
-
299
- if self.quantizer_type == "RVQ":
300
- e = e.transpose(1, 2)
301
- quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
302
- codes = codes.permute(1, 0, 2)
303
- else:
304
- quantized, codes = self.quantizer(e)
305
- codes = codes.permute(0, 2, 1)
306
-
307
- # return codes
308
- return EncodedResult(codes)
309
-
310
- def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
311
- if self.quantizer_type == "RVQ":
312
- vq_code = vq_code.permute(1, 0, 2)
313
- quantized = self.quantizer.decode(vq_code)
314
- quantized = quantized.transpose(1, 2)
315
- else:
316
- vq_code = vq_code.permute(0, 2, 1)
317
- quantized = self.quantizer.get_output_from_indices(vq_code)
318
- quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
319
-
320
- o = self.decoder_2(quantized_acoustic)
321
- return o.cpu().numpy()
322
-
323
-
324
- def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
325
- is_local = os.path.exists(tokenizer_name_or_path)
326
- if not is_local:
327
- tokenizer_path = snapshot_download(tokenizer_name_or_path)
328
- else:
329
- tokenizer_path = tokenizer_name_or_path
330
- config_path = os.path.join(tokenizer_path, "config.json")
331
- model_path = os.path.join(tokenizer_path, "model.pth")
332
- config = json.load(open(config_path))
333
- model = HiggsAudioTokenizer(
334
- **config,
335
- device=device,
336
- )
337
- parameter_dict = torch.load(model_path, map_location=device)
338
- model.load_state_dict(parameter_dict, strict=False)
339
- model.to(device)
340
- model.eval()
341
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/quantization/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # flake8: noqa
8
- from .vq import QuantizedResult, ResidualVectorQuantizer
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/quantization/ac.py DELETED
@@ -1,301 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Arithmetic coder."""
8
-
9
- import io
10
- import math
11
- import random
12
- import typing as tp
13
- import torch
14
-
15
- from ..binary import BitPacker, BitUnpacker
16
-
17
-
18
- def build_stable_quantized_cdf(
19
- pdf: torch.Tensor,
20
- total_range_bits: int,
21
- roundoff: float = 1e-8,
22
- min_range: int = 2,
23
- check: bool = True,
24
- ) -> torch.Tensor:
25
- """Turn the given PDF into a quantized CDF that splits
26
- [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
27
- to the PDF.
28
-
29
- Args:
30
- pdf (torch.Tensor): probability distribution, shape should be `[N]`.
31
- total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
32
- during the coding process is `[0, 2 ** total_range_bits - 1]`.
33
- roundoff (float): will round the pdf up to that level to remove difference coming
34
- from e.g. evaluating the Language Model on different architectures.
35
- min_range (int): minimum range width. Should always be at least 2 for numerical
36
- stability. Use this to avoid pathological behavior is a value
37
- that is expected to be rare actually happens in real life.
38
- check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
39
- """
40
- pdf = pdf.detach()
41
- if roundoff:
42
- pdf = (pdf / roundoff).floor() * roundoff
43
- # interpolate with uniform distribution to achieve desired minimum probability.
44
- total_range = 2**total_range_bits
45
- cardinality = len(pdf)
46
- alpha = min_range * cardinality / total_range
47
- assert alpha <= 1, "you must reduce min_range"
48
- ranges = (((1 - alpha) * total_range) * pdf).floor().long()
49
- ranges += min_range
50
- quantized_cdf = torch.cumsum(ranges, dim=-1)
51
- if min_range < 2:
52
- raise ValueError("min_range must be at least 2.")
53
- if check:
54
- assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
55
- if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
56
- raise ValueError("You must increase your total_range_bits.")
57
- return quantized_cdf
58
-
59
-
60
- class ArithmeticCoder:
61
- """ArithmeticCoder,
62
- Let us take a distribution `p` over `N` symbols, and assume we have a stream
63
- of random variables `s_t` sampled from `p`. Let us assume that we have a budget
64
- of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
65
- corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
66
- sequence `(s_t)` by doing the following:
67
-
68
- 1) Initialize the current range to` [0 ** 2 B - 1]`.
69
- 2) For each time step t, split the current range into contiguous chunks,
70
- one for each possible outcome, with size roughly proportional to `p`.
71
- For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
72
- would be `{[0, 2], [3, 3]}`.
73
- 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
74
- 4) When done encoding all the values, just select any value remaining in the range.
75
-
76
- You will notice that this procedure can fail: for instance if at any point in time
77
- the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
78
- possible outcome. Intuitively, the more likely a value is, the less the range width
79
- will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
80
- coding scheme, likely outcomes would take less bits, and more of them can be coded
81
- with a fixed budget.
82
-
83
- In practice, we do not know `B` ahead of time, but we have a way to inject new bits
84
- when the current range decreases below a given limit (given by `total_range_bits`), without
85
- having to redo all the computations. If we encode mostly likely values, we will seldom
86
- need to inject new bits, but a single rare value can deplete our stock of entropy!
87
-
88
- In this explanation, we assumed that the distribution `p` was constant. In fact, the present
89
- code works for any sequence `(p_t)` possibly different for each timestep.
90
- We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
91
- the KL between the true distribution and `p_t`, the most efficient the coding will be.
92
-
93
- Args:
94
- fo (IO[bytes]): file-like object to which the bytes will be written to.
95
- total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
96
- Any time the current range width fall under this limit, new bits will
97
- be injected to rescale the initial range.
98
- """
99
-
100
- def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
101
- assert total_range_bits <= 30
102
- self.total_range_bits = total_range_bits
103
- self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
104
- self.low: int = 0
105
- self.high: int = 0
106
- self.max_bit: int = -1
107
- self._dbg: tp.List[tp.Any] = []
108
- self._dbg2: tp.List[tp.Any] = []
109
-
110
- @property
111
- def delta(self) -> int:
112
- """Return the current range width."""
113
- return self.high - self.low + 1
114
-
115
- def _flush_common_prefix(self):
116
- # If self.low and self.high start with the sames bits,
117
- # those won't change anymore as we always just increase the range
118
- # by powers of 2, and we can flush them out to the bit stream.
119
- assert self.high >= self.low, (self.low, self.high)
120
- assert self.high < 2 ** (self.max_bit + 1)
121
- while self.max_bit >= 0:
122
- b1 = self.low >> self.max_bit
123
- b2 = self.high >> self.max_bit
124
- if b1 == b2:
125
- self.low -= b1 << self.max_bit
126
- self.high -= b1 << self.max_bit
127
- assert self.high >= self.low, (self.high, self.low, self.max_bit)
128
- assert self.low >= 0
129
- self.max_bit -= 1
130
- self.packer.push(b1)
131
- else:
132
- break
133
-
134
- def push(self, symbol: int, quantized_cdf: torch.Tensor):
135
- """Push the given symbol on the stream, flushing out bits
136
- if possible.
137
-
138
- Args:
139
- symbol (int): symbol to encode with the AC.
140
- quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
141
- to build this from your pdf estimate.
142
- """
143
- while self.delta < 2**self.total_range_bits:
144
- self.low *= 2
145
- self.high = self.high * 2 + 1
146
- self.max_bit += 1
147
-
148
- range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
149
- range_high = quantized_cdf[symbol].item() - 1
150
- effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
151
- effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
152
- assert self.low <= self.high
153
- self.high = self.low + effective_high
154
- self.low = self.low + effective_low
155
- assert self.low <= self.high, (
156
- effective_low,
157
- effective_high,
158
- range_low,
159
- range_high,
160
- )
161
- self._dbg.append((self.low, self.high))
162
- self._dbg2.append((self.low, self.high))
163
- outs = self._flush_common_prefix()
164
- assert self.low <= self.high
165
- assert self.max_bit >= -1
166
- assert self.max_bit <= 61, self.max_bit
167
- return outs
168
-
169
- def flush(self):
170
- """Flush the remaining information to the stream."""
171
- while self.max_bit >= 0:
172
- b1 = (self.low >> self.max_bit) & 1
173
- self.packer.push(b1)
174
- self.max_bit -= 1
175
- self.packer.flush()
176
-
177
-
178
- class ArithmeticDecoder:
179
- """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
180
-
181
- Note that this must be called with **exactly** the same parameters and sequence
182
- of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
183
-
184
- If the AC encoder current range is [L, H], with `L` and `H` having the some common
185
- prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
186
- For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
187
- `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
188
- for a specific sequence of symbols and a binary-search allows us to decode those symbols.
189
- At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
190
- and we will need to read new bits from the stream and repeat the process.
191
-
192
- """
193
-
194
- def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
195
- self.total_range_bits = total_range_bits
196
- self.low: int = 0
197
- self.high: int = 0
198
- self.current: int = 0
199
- self.max_bit: int = -1
200
- self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
201
- # Following is for debugging
202
- self._dbg: tp.List[tp.Any] = []
203
- self._dbg2: tp.List[tp.Any] = []
204
- self._last: tp.Any = None
205
-
206
- @property
207
- def delta(self) -> int:
208
- return self.high - self.low + 1
209
-
210
- def _flush_common_prefix(self):
211
- # Given the current range [L, H], if both have a common prefix,
212
- # we know we can remove it from our representation to avoid handling large numbers.
213
- while self.max_bit >= 0:
214
- b1 = self.low >> self.max_bit
215
- b2 = self.high >> self.max_bit
216
- if b1 == b2:
217
- self.low -= b1 << self.max_bit
218
- self.high -= b1 << self.max_bit
219
- self.current -= b1 << self.max_bit
220
- assert self.high >= self.low
221
- assert self.low >= 0
222
- self.max_bit -= 1
223
- else:
224
- break
225
-
226
- def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
227
- """Pull a symbol, reading as many bits from the stream as required.
228
- This returns `None` when the stream has been exhausted.
229
-
230
- Args:
231
- quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
232
- to build this from your pdf estimate. This must be **exatly**
233
- the same cdf as the one used at encoding time.
234
- """
235
- while self.delta < 2**self.total_range_bits:
236
- bit = self.unpacker.pull()
237
- if bit is None:
238
- return None
239
- self.low *= 2
240
- self.high = self.high * 2 + 1
241
- self.current = self.current * 2 + bit
242
- self.max_bit += 1
243
-
244
- def bin_search(low_idx: int, high_idx: int):
245
- # Binary search is not just for coding interviews :)
246
- if high_idx < low_idx:
247
- raise RuntimeError("Binary search failed")
248
- mid = (low_idx + high_idx) // 2
249
- range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
250
- range_high = quantized_cdf[mid].item() - 1
251
- effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
252
- effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
253
- low = effective_low + self.low
254
- high = effective_high + self.low
255
- if self.current >= low:
256
- if self.current <= high:
257
- return (mid, low, high, self.current)
258
- else:
259
- return bin_search(mid + 1, high_idx)
260
- else:
261
- return bin_search(low_idx, mid - 1)
262
-
263
- self._last = (self.low, self.high, self.current, self.max_bit)
264
- sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
265
- self._dbg.append((self.low, self.high, self.current))
266
- self._flush_common_prefix()
267
- self._dbg2.append((self.low, self.high, self.current))
268
-
269
- return sym
270
-
271
-
272
- def test():
273
- torch.manual_seed(1234)
274
- random.seed(1234)
275
- for _ in range(4):
276
- pdfs = []
277
- cardinality = random.randrange(4000)
278
- steps = random.randrange(100, 500)
279
- fo = io.BytesIO()
280
- encoder = ArithmeticCoder(fo)
281
- symbols = []
282
- for step in range(steps):
283
- pdf = torch.softmax(torch.randn(cardinality), dim=0)
284
- pdfs.append(pdf)
285
- q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
286
- symbol = torch.multinomial(pdf, 1).item()
287
- symbols.append(symbol)
288
- encoder.push(symbol, q_cdf)
289
- encoder.flush()
290
-
291
- fo.seek(0)
292
- decoder = ArithmeticDecoder(fo)
293
- for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
294
- q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
295
- decoded_symbol = decoder.pull(q_cdf)
296
- assert decoded_symbol == symbol, idx
297
- assert decoder.pull(torch.zeros(1)) is None
298
-
299
-
300
- if __name__ == "__main__":
301
- test()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/quantization/core_vq.py DELETED
@@ -1,360 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- #
7
- # This implementation is inspired from
8
- # https://github.com/lucidrains/vector-quantize-pytorch
9
- # which is released under MIT License. Hereafter, the original license:
10
- # MIT License
11
- #
12
- # Copyright (c) 2020 Phil Wang
13
- #
14
- # Permission is hereby granted, free of charge, to any person obtaining a copy
15
- # of this software and associated documentation files (the "Software"), to deal
16
- # in the Software without restriction, including without limitation the rights
17
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
- # copies of the Software, and to permit persons to whom the Software is
19
- # furnished to do so, subject to the following conditions:
20
- #
21
- # The above copyright notice and this permission notice shall be included in all
22
- # copies or substantial portions of the Software.
23
- #
24
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
- # SOFTWARE.
31
-
32
- """Core vector quantization implementation."""
33
-
34
- import typing as tp
35
-
36
- from einops import rearrange, repeat
37
- import torch
38
- from torch import nn
39
- import torch.nn.functional as F
40
-
41
- from xcodec.quantization.distrib import broadcast_tensors, rank
42
-
43
-
44
- def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
- return val if val is not None else d
46
-
47
-
48
- def ema_inplace(moving_avg, new, decay: float):
49
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
-
51
-
52
- def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
- return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
-
55
-
56
- def uniform_init(*shape: int):
57
- t = torch.empty(shape)
58
- nn.init.kaiming_uniform_(t)
59
- return t
60
-
61
-
62
- def sample_vectors(samples, num: int):
63
- num_samples, device = samples.shape[0], samples.device
64
-
65
- if num_samples >= num:
66
- indices = torch.randperm(num_samples, device=device)[:num]
67
- else:
68
- indices = torch.randint(0, num_samples, (num,), device=device)
69
-
70
- return samples[indices]
71
-
72
-
73
- def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
- dim, dtype = samples.shape[-1], samples.dtype
75
-
76
- means = sample_vectors(samples, num_clusters)
77
-
78
- for _ in range(num_iters):
79
- diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
- dists = -(diffs**2).sum(dim=-1)
81
-
82
- buckets = dists.max(dim=-1).indices
83
- bins = torch.bincount(buckets, minlength=num_clusters)
84
- zero_mask = bins == 0
85
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
-
87
- new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
- new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
- new_means = new_means / bins_min_clamped[..., None]
90
-
91
- means = torch.where(zero_mask[..., None], means, new_means)
92
-
93
- return means, bins
94
-
95
-
96
- class EuclideanCodebook(nn.Module):
97
- """Codebook with Euclidean distance.
98
- Args:
99
- dim (int): Dimension.
100
- codebook_size (int): Codebook size.
101
- kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
- If set to true, run the k-means algorithm on the first training batch and use
103
- the learned centroids as initialization.
104
- kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
- decay (float): Decay for exponential moving average over the codebooks.
106
- epsilon (float): Epsilon value for numerical stability.
107
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
- that have an exponential moving average cluster size less than the specified threshold with
109
- randomly selected vector from the current batch.
110
- """
111
-
112
- def __init__(
113
- self,
114
- dim: int,
115
- codebook_size: int,
116
- kmeans_init: int = False,
117
- kmeans_iters: int = 10,
118
- decay: float = 0.99,
119
- epsilon: float = 1e-5,
120
- threshold_ema_dead_code: int = 2,
121
- ):
122
- super().__init__()
123
- self.decay = decay
124
- init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
125
- embed = init_fn(codebook_size, dim)
126
-
127
- self.codebook_size = codebook_size
128
-
129
- self.kmeans_iters = kmeans_iters
130
- self.epsilon = epsilon
131
- self.threshold_ema_dead_code = threshold_ema_dead_code
132
-
133
- self.register_buffer("inited", torch.Tensor([not kmeans_init]))
134
- self.register_buffer("cluster_size", torch.zeros(codebook_size))
135
- self.register_buffer("embed", embed)
136
- self.register_buffer("embed_avg", embed.clone())
137
-
138
- @torch.jit.ignore
139
- def init_embed_(self, data):
140
- if self.inited:
141
- return
142
-
143
- embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
144
- self.embed.data.copy_(embed)
145
- self.embed_avg.data.copy_(embed.clone())
146
- self.cluster_size.data.copy_(cluster_size)
147
- self.inited.data.copy_(torch.Tensor([True]))
148
- # Make sure all buffers across workers are in sync after initialization
149
- broadcast_tensors(self.buffers())
150
-
151
- def replace_(self, samples, mask):
152
- modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
153
- self.embed.data.copy_(modified_codebook)
154
-
155
- def expire_codes_(self, batch_samples):
156
- if self.threshold_ema_dead_code == 0:
157
- return
158
-
159
- expired_codes = self.cluster_size < self.threshold_ema_dead_code
160
- if not torch.any(expired_codes):
161
- return
162
-
163
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
164
- self.replace_(batch_samples, mask=expired_codes)
165
- broadcast_tensors(self.buffers())
166
-
167
- def preprocess(self, x):
168
- x = rearrange(x, "... d -> (...) d")
169
- return x
170
-
171
- def quantize(self, x):
172
- embed = self.embed.t()
173
- dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
174
- embed_ind = dist.max(dim=-1).indices
175
- return embed_ind
176
-
177
- def postprocess_emb(self, embed_ind, shape):
178
- return embed_ind.view(*shape[:-1])
179
-
180
- def dequantize(self, embed_ind):
181
- quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
182
- return quantize
183
-
184
- def encode(self, x):
185
- shape = x.shape
186
- # pre-process
187
- x = self.preprocess(x)
188
- # quantize
189
- embed_ind = self.quantize(x) # get index based on Euclidean distance
190
- # post-process
191
- embed_ind = self.postprocess_emb(embed_ind, shape)
192
- return embed_ind
193
-
194
- def decode(self, embed_ind):
195
- quantize = self.dequantize(embed_ind)
196
- return quantize
197
-
198
- def forward(self, x):
199
- shape, dtype = x.shape, x.dtype
200
- x = self.preprocess(x)
201
-
202
- self.init_embed_(x)
203
-
204
- embed_ind = self.quantize(x)
205
- embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
206
- embed_ind = self.postprocess_emb(embed_ind, shape)
207
- quantize = self.dequantize(embed_ind)
208
-
209
- if self.training:
210
- # We do the expiry of code at that point as buffers are in sync
211
- # and all the workers will take the same decision.
212
- self.expire_codes_(x)
213
- ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
214
- embed_sum = x.t() @ embed_onehot
215
- ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
216
- cluster_size = (
217
- laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
218
- )
219
- embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
220
- self.embed.data.copy_(embed_normalized)
221
-
222
- return quantize, embed_ind
223
-
224
-
225
- class VectorQuantization(nn.Module):
226
- """Vector quantization implementation.
227
- Currently supports only euclidean distance.
228
- Args:
229
- dim (int): Dimension
230
- codebook_size (int): Codebook size
231
- codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
232
- decay (float): Decay for exponential moving average over the codebooks.
233
- epsilon (float): Epsilon value for numerical stability.
234
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
235
- kmeans_iters (int): Number of iterations used for kmeans initialization.
236
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
237
- that have an exponential moving average cluster size less than the specified threshold with
238
- randomly selected vector from the current batch.
239
- commitment_weight (float): Weight for commitment loss.
240
- """
241
-
242
- def __init__(
243
- self,
244
- dim: int,
245
- codebook_size: int,
246
- codebook_dim: tp.Optional[int] = None,
247
- decay: float = 0.99,
248
- epsilon: float = 1e-5,
249
- kmeans_init: bool = True,
250
- kmeans_iters: int = 50,
251
- threshold_ema_dead_code: int = 2,
252
- commitment_weight: float = 1.0,
253
- ):
254
- super().__init__()
255
- _codebook_dim: int = default(codebook_dim, dim)
256
-
257
- requires_projection = _codebook_dim != dim
258
- self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
259
- self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
260
-
261
- self.epsilon = epsilon
262
- self.commitment_weight = commitment_weight
263
-
264
- self._codebook = EuclideanCodebook(
265
- dim=_codebook_dim,
266
- codebook_size=codebook_size,
267
- kmeans_init=kmeans_init,
268
- kmeans_iters=kmeans_iters,
269
- decay=decay,
270
- epsilon=epsilon,
271
- threshold_ema_dead_code=threshold_ema_dead_code,
272
- )
273
- self.codebook_size = codebook_size
274
-
275
- @property
276
- def codebook(self):
277
- return self._codebook.embed
278
-
279
- def encode(self, x):
280
- x = rearrange(x, "b d n -> b n d")
281
- x = self.project_in(x)
282
- embed_in = self._codebook.encode(x)
283
- return embed_in
284
-
285
- def decode(self, embed_ind):
286
- quantize = self._codebook.decode(embed_ind)
287
- quantize = self.project_out(quantize)
288
- quantize = rearrange(quantize, "b n d -> b d n")
289
- return quantize
290
-
291
- def forward(self, x):
292
- device = x.device
293
- x = rearrange(x, "b d n -> b n d")
294
- x = self.project_in(x)
295
-
296
- quantize, embed_ind = self._codebook(x)
297
-
298
- if self.training:
299
- quantize = x + (quantize - x).detach()
300
-
301
- loss = torch.tensor([0.0], device=device, requires_grad=self.training)
302
-
303
- if self.training:
304
- if self.commitment_weight > 0:
305
- commit_loss = F.mse_loss(quantize.detach(), x)
306
- loss = loss + commit_loss * self.commitment_weight
307
-
308
- quantize = self.project_out(quantize)
309
- quantize = rearrange(quantize, "b n d -> b d n")
310
- return quantize, embed_ind, loss
311
-
312
-
313
- class ResidualVectorQuantization(nn.Module):
314
- """Residual vector quantization implementation.
315
- Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
316
- """
317
-
318
- def __init__(self, *, num_quantizers, **kwargs):
319
- super().__init__()
320
- self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
321
-
322
- def forward(self, x, n_q: tp.Optional[int] = None):
323
- quantized_out = 0.0
324
- residual = x
325
-
326
- all_losses = []
327
- all_indices = []
328
-
329
- n_q = n_q or len(self.layers)
330
-
331
- for layer in self.layers[:n_q]:
332
- quantized, indices, loss = layer(residual)
333
- residual = residual - quantized
334
- quantized_out = quantized_out + quantized
335
-
336
- all_indices.append(indices)
337
- all_losses.append(loss)
338
-
339
- out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
340
- return quantized_out, out_indices, out_losses
341
-
342
- def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
343
- residual = x
344
- all_indices = []
345
- n_q = n_q or len(self.layers)
346
- for layer in self.layers[:n_q]:
347
- indices = layer.encode(residual)
348
- quantized = layer.decode(indices)
349
- residual = residual - quantized
350
- all_indices.append(indices)
351
- out_indices = torch.stack(all_indices)
352
- return out_indices
353
-
354
- def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
355
- quantized_out = torch.tensor(0.0, device=q_indices.device)
356
- for i, indices in enumerate(q_indices):
357
- layer = self.layers[i]
358
- quantized = layer.decode(indices)
359
- quantized_out = quantized_out + quantized
360
- return quantized_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/quantization/core_vq_lsx_version.py DELETED
@@ -1,431 +0,0 @@
1
- # Copyright (c)
2
- #
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # This implementation is inspired from
6
- # https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
7
- # https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
8
- #
9
- # Copyright (c) Meta Platforms, Inc. and affiliates.
10
- # All rights reserved.
11
- #
12
- # This source code is licensed under the license found in the
13
- # LICENSE file in the root directory of this source tree.
14
- #
15
- # This implementation is inspired from
16
- # https://github.com/lucidrains/vector-quantize-pytorch
17
- # which is released under MIT License. Hereafter, the original license:
18
- # MIT License
19
- #
20
- # Copyright (c) 2020 Phil Wang
21
- #
22
- # Permission is hereby granted, free of charge, to any person obtaining a copy
23
- # of this software and associated documentation files (the "Software"), to deal
24
- # in the Software without restriction, including without limitation the rights
25
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26
- # copies of the Software, and to permit persons to whom the Software is
27
- # furnished to do so, subject to the following conditions:
28
- #
29
- # The above copyright notice and this permission notice shall be included in all
30
- # copies or substantial portions of the Software.
31
- #
32
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
- # SOFTWARE.
39
-
40
- """Core vector quantization implementation."""
41
-
42
- import typing as tp
43
-
44
- from einops import rearrange
45
- import torch
46
- from torch import nn
47
- import torch.nn.functional as F
48
- import torch.distributed as dist
49
-
50
- from .distrib import broadcast_tensors, is_distributed
51
- from .ddp_utils import SyncFunction
52
-
53
-
54
- def default(val: tp.Any, d: tp.Any) -> tp.Any:
55
- return val if val is not None else d
56
-
57
-
58
- def ema_inplace(moving_avg, new, decay: float):
59
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
60
-
61
-
62
- def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
63
- return (x + epsilon) / (x.sum() + n_categories * epsilon)
64
-
65
-
66
- def uniform_init(*shape: int):
67
- t = torch.empty(shape)
68
- nn.init.kaiming_uniform_(t)
69
- return t
70
-
71
-
72
- def sample_vectors(samples, num: int):
73
- num_samples, device = samples.shape[0], samples.device
74
-
75
- if num_samples >= num:
76
- indices = torch.randperm(num_samples, device=device)[:num]
77
- else:
78
- indices = torch.randint(0, num_samples, (num,), device=device)
79
-
80
- return samples[indices]
81
-
82
-
83
- def kmeans(
84
- samples,
85
- num_clusters: int,
86
- num_iters: int = 10,
87
- frames_to_use: int = 10_000,
88
- batch_size: int = 64,
89
- ):
90
- """
91
- Memory-efficient K-means clustering.
92
- Args:
93
- samples (tensor): shape [N, D]
94
- num_clusters (int): number of centroids.
95
- num_iters (int): number of iterations.
96
- frames_to_use (int): subsample size from total samples.
97
- batch_size (int): batch size used in distance computation.
98
- Returns:
99
- means: [num_clusters, D]
100
- bins: [num_clusters] (number of points per cluster)
101
- """
102
- N, D = samples.shape
103
- dtype, device = samples.dtype, samples.device
104
-
105
- if frames_to_use < N:
106
- indices = torch.randperm(N, device=device)[:frames_to_use]
107
- samples = samples[indices]
108
-
109
- means = sample_vectors(samples, num_clusters)
110
-
111
- for _ in range(num_iters):
112
- # Store cluster assignments
113
- all_assignments = []
114
-
115
- for i in range(0, samples.shape[0], batch_size):
116
- batch = samples[i : i + batch_size] # [B, D]
117
- dists = torch.cdist(batch, means, p=2) # [B, C]
118
- assignments = dists.argmin(dim=1) # [B]
119
- all_assignments.append(assignments)
120
-
121
- buckets = torch.cat(all_assignments, dim=0) # [N]
122
- bins = torch.bincount(buckets, minlength=num_clusters)
123
- zero_mask = bins == 0
124
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
125
-
126
- # Compute new means
127
- new_means = torch.zeros_like(means)
128
- for i in range(num_clusters):
129
- mask = buckets == i
130
- if mask.any():
131
- new_means[i] = samples[mask].mean(dim=0)
132
-
133
- means = torch.where(zero_mask[:, None], means, new_means)
134
-
135
- return means, bins
136
-
137
-
138
- class EuclideanCodebook(nn.Module):
139
- """Codebook with Euclidean distance.
140
- Args:
141
- dim (int): Dimension.
142
- codebook_size (int): Codebook size.
143
- kmeans_init (bool): Whether to use k-means to initialize the codebooks.
144
- If set to true, run the k-means algorithm on the first training batch and use
145
- the learned centroids as initialization.
146
- kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
147
- decay (float): Decay for exponential moving average over the codebooks.
148
- epsilon (float): Epsilon value for numerical stability.
149
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
150
- that have an exponential moving average cluster size less than the specified threshold with
151
- randomly selected vector from the current batch.
152
- """
153
-
154
- def __init__(
155
- self,
156
- dim: int,
157
- codebook_size: int,
158
- kmeans_init: int = False,
159
- kmeans_iters: int = 10,
160
- decay: float = 0.99,
161
- epsilon: float = 1e-5,
162
- threshold_ema_dead_code: int = 2,
163
- ):
164
- super().__init__()
165
- self.decay = decay
166
- init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
167
- embed = init_fn(codebook_size, dim)
168
-
169
- self.codebook_size = codebook_size
170
-
171
- self.kmeans_iters = kmeans_iters
172
- self.epsilon = epsilon
173
- self.threshold_ema_dead_code = threshold_ema_dead_code
174
-
175
- # Flag variable to indicate whether the codebook is initialized
176
- self.register_buffer("inited", torch.Tensor([not kmeans_init]))
177
- # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
178
- self.register_buffer("cluster_size", torch.zeros(codebook_size))
179
- # Codebook
180
- self.register_buffer("embed", embed)
181
- # EMA codebook: eq. (7) in vqvae paper
182
- self.register_buffer("embed_avg", embed.clone())
183
-
184
- @torch.jit.ignore
185
- def init_embed_(self, data):
186
- """Initialize codebook.
187
- Args:
188
- data (tensor): [B * T, D].
189
- """
190
- if self.inited:
191
- return
192
-
193
- ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
194
- if dist.is_available() and dist.is_initialized():
195
- # [B * T * world_size, D]
196
- data = SyncFunction.apply(data)
197
-
198
- embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
199
- self.embed.data.copy_(embed)
200
- self.embed_avg.data.copy_(embed.clone())
201
- self.cluster_size.data.copy_(cluster_size)
202
- self.inited.data.copy_(torch.Tensor([True]))
203
- # Make sure all buffers across workers are in sync after initialization
204
- broadcast_tensors(self.buffers())
205
-
206
- def replace_(self, samples, mask):
207
- modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
208
- self.embed.data.copy_(modified_codebook)
209
-
210
- def expire_codes_(self, batch_samples):
211
- if self.threshold_ema_dead_code == 0:
212
- return
213
-
214
- expired_codes = self.cluster_size < self.threshold_ema_dead_code
215
- if not torch.any(expired_codes):
216
- return
217
-
218
- ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
219
- if is_distributed():
220
- # [B * T * world_size, D]
221
- batch_samples = SyncFunction.apply(batch_samples)
222
-
223
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
224
- self.replace_(batch_samples, mask=expired_codes)
225
- broadcast_tensors(self.buffers())
226
-
227
- def preprocess(self, x):
228
- x = rearrange(x, "... d -> (...) d")
229
- return x
230
-
231
- def quantize(self, x):
232
- embed = self.embed.t()
233
- dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
234
- embed_ind = dist.max(dim=-1).indices
235
- return embed_ind
236
-
237
- def postprocess_emb(self, embed_ind, shape):
238
- return embed_ind.view(*shape[:-1])
239
-
240
- def dequantize(self, embed_ind):
241
- quantize = F.embedding(embed_ind, self.embed)
242
- return quantize
243
-
244
- def encode(self, x):
245
- shape = x.shape
246
- # pre-process
247
- x = self.preprocess(x) # [B, T, D] -> [B*T, D]
248
- # quantize
249
- embed_ind = self.quantize(x)
250
- # post-process
251
- embed_ind = self.postprocess_emb(embed_ind, shape)
252
- return embed_ind
253
-
254
- def decode(self, embed_ind):
255
- quantize = self.dequantize(embed_ind)
256
- return quantize
257
-
258
- def forward(self, x):
259
- # shape: [B, T, D]
260
- shape, dtype = x.shape, x.dtype
261
- x = self.preprocess(x) # [B, T, D] -> [B*T, D]
262
-
263
- # Initialize codebook
264
- self.init_embed_(x)
265
-
266
- embed_ind = self.quantize(x) # [B*T,]
267
- embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
268
- embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
269
- quantize = self.dequantize(embed_ind) # [B, T, D]
270
-
271
- if self.training:
272
- ### Update codebook by EMA
273
- embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
274
- embed_sum = x.t() @ embed_onehot # [D, cb-size]
275
- if is_distributed():
276
- dist.all_reduce(embed_onehot_sum)
277
- dist.all_reduce(embed_sum)
278
- # Update ema cluster count N_i^t, eq. (6) in vqvae paper
279
- self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
280
- # Update ema embed: eq. (7) in vqvae paper
281
- self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
282
- # apply laplace smoothing
283
- n = self.cluster_size.sum()
284
- cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
285
- # Update ema embed: eq. (8) in vqvae paper
286
- embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
287
- self.embed.data.copy_(embed_normalized)
288
-
289
- # We do the expiry of code at that point as buffers are in sync
290
- # and all the workers will take the same decision.
291
- self.expire_codes_(x)
292
-
293
- return quantize, embed_ind
294
-
295
-
296
- class VectorQuantization(nn.Module):
297
- """Vector quantization implementation.
298
- Currently supports only euclidean distance.
299
- Args:
300
- dim (int): Dimension
301
- codebook_size (int): Codebook size
302
- codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
303
- decay (float): Decay for exponential moving average over the codebooks.
304
- epsilon (float): Epsilon value for numerical stability.
305
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
306
- kmeans_iters (int): Number of iterations used for kmeans initialization.
307
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
308
- that have an exponential moving average cluster size less than the specified threshold with
309
- randomly selected vector from the current batch.
310
- commitment_weight (float): Weight for commitment loss.
311
- """
312
-
313
- def __init__(
314
- self,
315
- dim: int,
316
- codebook_size: int,
317
- codebook_dim: tp.Optional[int] = None,
318
- decay: float = 0.99,
319
- epsilon: float = 1e-5,
320
- kmeans_init: bool = True,
321
- kmeans_iters: int = 50,
322
- threshold_ema_dead_code: int = 2,
323
- commitment_weight: float = 1.0,
324
- ):
325
- super().__init__()
326
- _codebook_dim: int = default(codebook_dim, dim)
327
-
328
- requires_projection = _codebook_dim != dim
329
- self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
330
- self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
331
-
332
- self.epsilon = epsilon
333
- self.commitment_weight = commitment_weight
334
-
335
- self._codebook = EuclideanCodebook(
336
- dim=_codebook_dim,
337
- codebook_size=codebook_size,
338
- kmeans_init=kmeans_init,
339
- kmeans_iters=kmeans_iters,
340
- decay=decay,
341
- epsilon=epsilon,
342
- threshold_ema_dead_code=threshold_ema_dead_code,
343
- )
344
- self.codebook_size = codebook_size
345
-
346
- @property
347
- def codebook(self):
348
- return self._codebook.embed
349
-
350
- def encode(self, x):
351
- x = rearrange(x, "b d n -> b n d")
352
- x = self.project_in(x)
353
- embed_in = self._codebook.encode(x)
354
- return embed_in
355
-
356
- def decode(self, embed_ind):
357
- quantize = self._codebook.decode(embed_ind)
358
- quantize = self.project_out(quantize)
359
- quantize = rearrange(quantize, "b n d -> b d n")
360
- return quantize
361
-
362
- def forward(self, x):
363
- device = x.device
364
- x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
365
- x = self.project_in(x)
366
-
367
- quantize, embed_ind = self._codebook(x)
368
-
369
- if self.training:
370
- quantize = x + (quantize - x).detach()
371
-
372
- loss = torch.tensor([0.0], device=device, requires_grad=self.training)
373
-
374
- if self.training:
375
- if self.commitment_weight > 0:
376
- commit_loss = F.mse_loss(quantize.detach(), x)
377
- loss = loss + commit_loss * self.commitment_weight
378
-
379
- quantize = self.project_out(quantize)
380
- quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
381
- return quantize, embed_ind, loss
382
-
383
-
384
- class ResidualVectorQuantization(nn.Module):
385
- """Residual vector quantization implementation.
386
- Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
387
- """
388
-
389
- def __init__(self, *, num_quantizers, **kwargs):
390
- super().__init__()
391
- self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
392
-
393
- def forward(self, x, n_q: tp.Optional[int] = None):
394
- quantized_out = 0.0
395
- residual = x
396
-
397
- all_losses = []
398
- all_indices = []
399
-
400
- n_q = n_q or len(self.layers)
401
-
402
- for layer in self.layers[:n_q]:
403
- quantized, indices, loss = layer(residual)
404
- residual = residual - quantized
405
- quantized_out = quantized_out + quantized
406
-
407
- all_indices.append(indices)
408
- all_losses.append(loss)
409
-
410
- out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
411
- return quantized_out, out_indices, out_losses
412
-
413
- def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
414
- residual = x
415
- all_indices = []
416
- n_q = n_q or len(self.layers)
417
- for layer in self.layers[:n_q]:
418
- indices = layer.encode(residual)
419
- quantized = layer.decode(indices)
420
- residual = residual - quantized
421
- all_indices.append(indices)
422
- out_indices = torch.stack(all_indices)
423
- return out_indices
424
-
425
- def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
426
- quantized_out = torch.tensor(0.0, device=q_indices.device)
427
- for i, indices in enumerate(q_indices):
428
- layer = self.layers[i]
429
- quantized = layer.decode(indices)
430
- quantized_out = quantized_out + quantized
431
- return quantized_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/quantization/ddp_utils.py DELETED
@@ -1,197 +0,0 @@
1
- import logging
2
- import random
3
- import subprocess
4
- from datetime import datetime
5
-
6
- import numpy as np
7
- import torch
8
- import torch.distributed as dist
9
- from torch.nn.parallel import DistributedDataParallel
10
- from torch.nn.parallel.distributed import _find_tensors
11
- import torch.optim
12
- import torch.utils.data
13
- from packaging import version
14
- from omegaconf import OmegaConf
15
-
16
-
17
- def set_random_seed(seed):
18
- random.seed(seed)
19
- np.random.seed(seed)
20
- torch.manual_seed(seed)
21
- torch.cuda.manual_seed_all(seed)
22
-
23
-
24
- def is_logging_process():
25
- return not dist.is_initialized() or dist.get_rank() == 0
26
-
27
-
28
- def get_logger(cfg, name=None):
29
- # log_file_path is used when unit testing
30
- if is_logging_process():
31
- logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
32
- return logging.getLogger(name)
33
-
34
-
35
- # from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
36
- class SyncFunction(torch.autograd.Function):
37
- @staticmethod
38
- # @torch.no_grad()
39
- def forward(ctx, tensor):
40
- ctx.batch_size = tensor.shape[0]
41
-
42
- gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
43
-
44
- torch.distributed.all_gather(gathered_tensor, tensor)
45
- gathered_tensor = torch.cat(gathered_tensor, 0)
46
-
47
- return gathered_tensor
48
-
49
- @staticmethod
50
- def backward(ctx, grad_output):
51
- grad_input = grad_output.clone()
52
- torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
53
-
54
- idx_from = torch.distributed.get_rank() * ctx.batch_size
55
- idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
56
- return grad_input[idx_from:idx_to]
57
-
58
-
59
- def get_timestamp():
60
- return datetime.now().strftime("%y%m%d-%H%M%S")
61
-
62
-
63
- def get_commit_hash():
64
- message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
65
- return message.strip().decode("utf-8")
66
-
67
-
68
- class DDP(DistributedDataParallel):
69
- """
70
- Override the forward call in lightning so it goes to training and validation step respectively
71
- """
72
-
73
- def forward(self, *inputs, **kwargs): # pragma: no cover
74
- if version.parse(torch.__version__[:6]) < version.parse("1.11"):
75
- self._sync_params()
76
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
77
- assert len(self.device_ids) == 1
78
- if self.module.training:
79
- output = self.module.training_step(*inputs[0], **kwargs[0])
80
- elif self.module.testing:
81
- output = self.module.test_step(*inputs[0], **kwargs[0])
82
- else:
83
- output = self.module.validation_step(*inputs[0], **kwargs[0])
84
- if torch.is_grad_enabled():
85
- # We'll return the output object verbatim since it is a freeform
86
- # object. We need to find any tensors in this object, though,
87
- # because we need to figure out which parameters were used during
88
- # this forward pass, to ensure we short circuit reduction for any
89
- # unused parameters. Only if `find_unused_parameters` is set.
90
- if self.find_unused_parameters:
91
- self.reducer.prepare_for_backward(list(_find_tensors(output)))
92
- else:
93
- self.reducer.prepare_for_backward([])
94
- else:
95
- from torch.nn.parallel.distributed import (
96
- logging,
97
- Join,
98
- _DDPSink,
99
- _tree_flatten_with_rref,
100
- _tree_unflatten_with_rref,
101
- )
102
-
103
- with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
104
- if torch.is_grad_enabled() and self.require_backward_grad_sync:
105
- self.logger.set_runtime_stats_and_log()
106
- self.num_iterations += 1
107
- self.reducer.prepare_for_forward()
108
-
109
- # Notify the join context that this process has not joined, if
110
- # needed
111
- work = Join.notify_join_context(self)
112
- if work:
113
- self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
114
-
115
- # Calling _rebuild_buckets before forward compuation,
116
- # It may allocate new buckets before deallocating old buckets
117
- # inside _rebuild_buckets. To save peak memory usage,
118
- # call _rebuild_buckets before the peak memory usage increases
119
- # during forward computation.
120
- # This should be called only once during whole training period.
121
- if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
122
- logging.info("Reducer buckets have been rebuilt in this iteration.")
123
- self._has_rebuilt_buckets = True
124
-
125
- # sync params according to location (before/after forward) user
126
- # specified as part of hook, if hook was specified.
127
- buffer_hook_registered = hasattr(self, "buffer_hook")
128
- if self._check_sync_bufs_pre_fwd():
129
- self._sync_buffers()
130
-
131
- if self._join_config.enable:
132
- # Notify joined ranks whether they should sync in backwards pass or not.
133
- self._check_global_requires_backward_grad_sync(is_joined_rank=False)
134
-
135
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
136
- if self.module.training:
137
- output = self.module.training_step(*inputs[0], **kwargs[0])
138
- elif self.module.testing:
139
- output = self.module.test_step(*inputs[0], **kwargs[0])
140
- else:
141
- output = self.module.validation_step(*inputs[0], **kwargs[0])
142
-
143
- # sync params according to location (before/after forward) user
144
- # specified as part of hook, if hook was specified.
145
- if self._check_sync_bufs_post_fwd():
146
- self._sync_buffers()
147
-
148
- if torch.is_grad_enabled() and self.require_backward_grad_sync:
149
- self.require_forward_param_sync = True
150
- # We'll return the output object verbatim since it is a freeform
151
- # object. We need to find any tensors in this object, though,
152
- # because we need to figure out which parameters were used during
153
- # this forward pass, to ensure we short circuit reduction for any
154
- # unused parameters. Only if `find_unused_parameters` is set.
155
- if self.find_unused_parameters and not self.static_graph:
156
- # Do not need to populate this for static graph.
157
- self.reducer.prepare_for_backward(list(_find_tensors(output)))
158
- else:
159
- self.reducer.prepare_for_backward([])
160
- else:
161
- self.require_forward_param_sync = False
162
-
163
- # TODO: DDPSink is currently enabled for unused parameter detection and
164
- # static graph training for first iteration.
165
- if (self.find_unused_parameters and not self.static_graph) or (
166
- self.static_graph and self.num_iterations == 1
167
- ):
168
- state_dict = {
169
- "static_graph": self.static_graph,
170
- "num_iterations": self.num_iterations,
171
- }
172
-
173
- output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
174
- output_placeholders = [None for _ in range(len(output_tensor_list))]
175
- # Do not touch tensors that have no grad_fn, which can cause issues
176
- # such as https://github.com/pytorch/pytorch/issues/60733
177
- for i, output in enumerate(output_tensor_list):
178
- if torch.is_tensor(output) and output.grad_fn is None:
179
- output_placeholders[i] = output
180
-
181
- # When find_unused_parameters=True, makes tensors which require grad
182
- # run through the DDPSink backward pass. When not all outputs are
183
- # used in loss, this makes those corresponding tensors receive
184
- # undefined gradient which the reducer then handles to ensure
185
- # param.grad field is not touched and we don't error out.
186
- passthrough_tensor_list = _DDPSink.apply(
187
- self.reducer,
188
- state_dict,
189
- *output_tensor_list,
190
- )
191
- for i in range(len(output_placeholders)):
192
- if output_placeholders[i] is None:
193
- output_placeholders[i] = passthrough_tensor_list[i]
194
-
195
- # Reconstruct output data structure.
196
- output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
197
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/quantization/distrib.py DELETED
@@ -1,123 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Torch distributed utilities."""
8
-
9
- import typing as tp
10
-
11
- import torch
12
-
13
-
14
- def rank():
15
- if torch.distributed.is_initialized():
16
- return torch.distributed.get_rank()
17
- else:
18
- return 0
19
-
20
-
21
- def world_size():
22
- if torch.distributed.is_initialized():
23
- return torch.distributed.get_world_size()
24
- else:
25
- return 1
26
-
27
-
28
- def is_distributed():
29
- return world_size() > 1
30
-
31
-
32
- def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33
- if is_distributed():
34
- return torch.distributed.all_reduce(tensor, op)
35
-
36
-
37
- def _is_complex_or_float(tensor):
38
- return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39
-
40
-
41
- def _check_number_of_params(params: tp.List[torch.Tensor]):
42
- # utility function to check that the number of params in all workers is the same,
43
- # and thus avoid a deadlock with distributed all reduce.
44
- if not is_distributed() or not params:
45
- return
46
- # print('params[0].device ', params[0].device)
47
- tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
48
- all_reduce(tensor)
49
- if tensor.item() != len(params) * world_size():
50
- # If not all the workers have the same number, for at least one of them,
51
- # this inequality will be verified.
52
- raise RuntimeError(
53
- f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
54
- )
55
-
56
-
57
- def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
58
- """Broadcast the tensors from the given parameters to all workers.
59
- This can be used to ensure that all workers have the same model to start with.
60
- """
61
- if not is_distributed():
62
- return
63
- tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
64
- _check_number_of_params(tensors)
65
- handles = []
66
- for tensor in tensors:
67
- handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
68
- handles.append(handle)
69
- for handle in handles:
70
- handle.wait()
71
-
72
-
73
- def sync_buffer(buffers, average=True):
74
- """
75
- Sync grad for buffers. If average is False, broadcast instead of averaging.
76
- """
77
- if not is_distributed():
78
- return
79
- handles = []
80
- for buffer in buffers:
81
- if torch.is_floating_point(buffer.data):
82
- if average:
83
- handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
84
- else:
85
- handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
86
- handles.append((buffer, handle))
87
- for buffer, handle in handles:
88
- handle.wait()
89
- if average:
90
- buffer.data /= world_size
91
-
92
-
93
- def sync_grad(params):
94
- """
95
- Simpler alternative to DistributedDataParallel, that doesn't rely
96
- on any black magic. For simple models it can also be as fast.
97
- Just call this on your model parameters after the call to backward!
98
- """
99
- if not is_distributed():
100
- return
101
- handles = []
102
- for p in params:
103
- if p.grad is not None:
104
- handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
105
- handles.append((p, handle))
106
- for p, handle in handles:
107
- handle.wait()
108
- p.grad.data /= world_size()
109
-
110
-
111
- def average_metrics(metrics: tp.Dict[str, float], count=1.0):
112
- """Average a dictionary of metrics across all workers, using the optional
113
- `count` as unormalized weight.
114
- """
115
- if not is_distributed():
116
- return metrics
117
- keys, values = zip(*metrics.items())
118
- device = "cuda" if torch.cuda.is_available() else "cpu"
119
- tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
120
- tensor *= count
121
- all_reduce(tensor)
122
- averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
123
- return dict(zip(keys, averaged))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/quantization/vq.py DELETED
@@ -1,116 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Residual vector quantizer implementation."""
8
-
9
- from dataclasses import dataclass, field
10
- import math
11
- import typing as tp
12
-
13
- import torch
14
- from torch import nn
15
-
16
- # from .core_vq import ResidualVectorQuantization
17
- from .core_vq_lsx_version import ResidualVectorQuantization
18
-
19
-
20
- @dataclass
21
- class QuantizedResult:
22
- quantized: torch.Tensor
23
- codes: torch.Tensor
24
- bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
25
- penalty: tp.Optional[torch.Tensor] = None
26
- metrics: dict = field(default_factory=dict)
27
-
28
-
29
- class ResidualVectorQuantizer(nn.Module):
30
- """Residual Vector Quantizer.
31
- Args:
32
- dimension (int): Dimension of the codebooks.
33
- n_q (int): Number of residual vector quantizers used.
34
- bins (int): Codebook size.
35
- decay (float): Decay for exponential moving average over the codebooks.
36
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
37
- kmeans_iters (int): Number of iterations used for kmeans initialization.
38
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
39
- that have an exponential moving average cluster size less than the specified threshold with
40
- randomly selected vector from the current batch.
41
- """
42
-
43
- def __init__(
44
- self,
45
- dimension: int = 256,
46
- codebook_dim: int = None,
47
- n_q: int = 8,
48
- bins: int = 1024,
49
- decay: float = 0.99,
50
- kmeans_init: bool = True,
51
- kmeans_iters: int = 50,
52
- threshold_ema_dead_code: int = 2,
53
- ):
54
- super().__init__()
55
- self.n_q = n_q
56
- self.dimension = dimension
57
- self.codebook_dim = codebook_dim
58
- self.bins = bins
59
- self.decay = decay
60
- self.kmeans_init = kmeans_init
61
- self.kmeans_iters = kmeans_iters
62
- self.threshold_ema_dead_code = threshold_ema_dead_code
63
- self.vq = ResidualVectorQuantization(
64
- dim=self.dimension,
65
- codebook_dim=self.codebook_dim,
66
- codebook_size=self.bins,
67
- num_quantizers=self.n_q,
68
- decay=self.decay,
69
- kmeans_init=self.kmeans_init,
70
- kmeans_iters=self.kmeans_iters,
71
- threshold_ema_dead_code=self.threshold_ema_dead_code,
72
- )
73
-
74
- def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
75
- """Residual vector quantization on the given input tensor.
76
- Args:
77
- x (torch.Tensor): Input tensor.
78
- sample_rate (int): Sample rate of the input tensor.
79
- bandwidth (float): Target bandwidth.
80
- Returns:
81
- QuantizedResult:
82
- The quantized (or approximately quantized) representation with
83
- the associated bandwidth and any penalty term for the loss.
84
- """
85
- bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
86
- n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
87
- quantized, codes, commit_loss = self.vq(x, n_q=n_q)
88
- bw = torch.tensor(n_q * bw_per_q).to(x)
89
- return quantized, codes, bw, torch.mean(commit_loss)
90
- # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
91
-
92
- def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
93
- """Return n_q based on specified target bandwidth."""
94
- bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
95
- n_q = self.n_q
96
- if bandwidth and bandwidth > 0.0:
97
- n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
98
- return n_q
99
-
100
- def get_bandwidth_per_quantizer(self, sample_rate: int):
101
- """Return bandwidth per quantizer for a given input sample rate."""
102
- return math.log2(self.bins) * sample_rate / 1000
103
-
104
- def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
105
- """Encode a given input tensor with the specified sample rate at the given bandwidth.
106
- The RVQ encode method sets the appropriate number of quantizer to use
107
- and returns indices for each quantizer.
108
- """
109
- n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
110
- codes = self.vq.encode(x, n_q=n_q)
111
- return codes
112
-
113
- def decode(self, codes: torch.Tensor) -> torch.Tensor:
114
- """Decode the given codes to the quantized representation."""
115
- quantized = self.vq.decode(codes)
116
- return quantized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/audio_processing/semantic_module.py DELETED
@@ -1,310 +0,0 @@
1
- # Based on code from: https://github.com/zhenye234/xcodec
2
- # Licensed under MIT License
3
- # Modifications by BosonAI
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
-
9
- class Conv1d1x1(nn.Conv1d):
10
- """1x1 Conv1d."""
11
-
12
- def __init__(self, in_channels, out_channels, bias=True):
13
- super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
14
-
15
-
16
- class Conv1d(nn.Module):
17
- def __init__(
18
- self,
19
- in_channels: int,
20
- out_channels: int,
21
- kernel_size: int,
22
- stride: int = 1,
23
- padding: int = -1,
24
- dilation: int = 1,
25
- groups: int = 1,
26
- bias: bool = True,
27
- ):
28
- super().__init__()
29
- self.in_channels = in_channels
30
- self.out_channels = out_channels
31
- self.kernel_size = kernel_size
32
- if padding < 0:
33
- padding = (kernel_size - 1) // 2 * dilation
34
- self.dilation = dilation
35
- self.conv = nn.Conv1d(
36
- in_channels=in_channels,
37
- out_channels=out_channels,
38
- kernel_size=kernel_size,
39
- stride=stride,
40
- padding=padding,
41
- dilation=dilation,
42
- groups=groups,
43
- bias=bias,
44
- )
45
-
46
- def forward(self, x):
47
- """
48
- Args:
49
- x (Tensor): Float tensor variable with the shape (B, C, T).
50
- Returns:
51
- Tensor: Float tensor variable with the shape (B, C, T).
52
- """
53
- x = self.conv(x)
54
- return x
55
-
56
-
57
- class ResidualUnit(nn.Module):
58
- def __init__(
59
- self,
60
- in_channels: int,
61
- out_channels: int,
62
- kernel_size=3,
63
- dilation=1,
64
- bias=False,
65
- nonlinear_activation="ELU",
66
- nonlinear_activation_params={},
67
- ):
68
- super().__init__()
69
- self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
70
- self.conv1 = Conv1d(
71
- in_channels=in_channels,
72
- out_channels=out_channels,
73
- kernel_size=kernel_size,
74
- stride=1,
75
- dilation=dilation,
76
- bias=bias,
77
- )
78
- self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
79
-
80
- def forward(self, x):
81
- y = self.conv1(self.activation(x))
82
- y = self.conv2(self.activation(y))
83
- return x + y
84
-
85
-
86
- class ConvTranspose1d(nn.Module):
87
- def __init__(
88
- self,
89
- in_channels: int,
90
- out_channels: int,
91
- kernel_size: int,
92
- stride: int,
93
- padding=-1,
94
- output_padding=-1,
95
- groups=1,
96
- bias=True,
97
- ):
98
- super().__init__()
99
- if padding < 0:
100
- padding = (stride + 1) // 2
101
- if output_padding < 0:
102
- output_padding = 1 if stride % 2 else 0
103
- self.deconv = nn.ConvTranspose1d(
104
- in_channels=in_channels,
105
- out_channels=out_channels,
106
- kernel_size=kernel_size,
107
- stride=stride,
108
- padding=padding,
109
- output_padding=output_padding,
110
- groups=groups,
111
- bias=bias,
112
- )
113
-
114
- def forward(self, x):
115
- """
116
- Args:
117
- x (Tensor): Float tensor variable with the shape (B, C, T).
118
- Returns:
119
- Tensor: Float tensor variable with the shape (B, C', T').
120
- """
121
- x = self.deconv(x)
122
- return x
123
-
124
-
125
- class EncoderBlock(nn.Module):
126
- def __init__(
127
- self,
128
- in_channels: int,
129
- out_channels: int,
130
- stride: int,
131
- dilations=(1, 1),
132
- unit_kernel_size=3,
133
- bias=True,
134
- ):
135
- super().__init__()
136
- self.res_units = torch.nn.ModuleList()
137
- for dilation in dilations:
138
- self.res_units += [
139
- ResidualUnit(
140
- in_channels,
141
- in_channels,
142
- kernel_size=unit_kernel_size,
143
- dilation=dilation,
144
- )
145
- ]
146
- self.num_res = len(self.res_units)
147
-
148
- self.conv = Conv1d(
149
- in_channels=in_channels,
150
- out_channels=out_channels,
151
- kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
152
- stride=stride,
153
- bias=bias,
154
- )
155
-
156
- def forward(self, x):
157
- for idx in range(self.num_res):
158
- x = self.res_units[idx](x)
159
- x = self.conv(x)
160
- return x
161
-
162
-
163
- class Encoder(nn.Module):
164
- def __init__(
165
- self,
166
- input_channels: int,
167
- encode_channels: int,
168
- channel_ratios=(1, 1),
169
- strides=(1, 1),
170
- kernel_size=3,
171
- bias=True,
172
- block_dilations=(1, 1),
173
- unit_kernel_size=3,
174
- ):
175
- super().__init__()
176
- assert len(channel_ratios) == len(strides)
177
-
178
- self.conv = Conv1d(
179
- in_channels=input_channels,
180
- out_channels=encode_channels,
181
- kernel_size=kernel_size,
182
- stride=1,
183
- bias=False,
184
- )
185
- self.conv_blocks = torch.nn.ModuleList()
186
- in_channels = encode_channels
187
- for idx, stride in enumerate(strides):
188
- out_channels = int(encode_channels * channel_ratios[idx]) # could be float
189
- self.conv_blocks += [
190
- EncoderBlock(
191
- in_channels,
192
- out_channels,
193
- stride,
194
- dilations=block_dilations,
195
- unit_kernel_size=unit_kernel_size,
196
- bias=bias,
197
- )
198
- ]
199
- in_channels = out_channels
200
- self.num_blocks = len(self.conv_blocks)
201
- self.out_channels = out_channels
202
-
203
- def forward(self, x):
204
- x = self.conv(x)
205
- for i in range(self.num_blocks):
206
- x = self.conv_blocks[i](x)
207
- return x
208
-
209
-
210
- class DecoderBlock(nn.Module):
211
- """Decoder block (no up-sampling)"""
212
-
213
- def __init__(
214
- self,
215
- in_channels: int,
216
- out_channels: int,
217
- stride: int,
218
- dilations=(1, 1),
219
- unit_kernel_size=3,
220
- bias=True,
221
- ):
222
- super().__init__()
223
-
224
- if stride == 1:
225
- self.conv = Conv1d(
226
- in_channels=in_channels,
227
- out_channels=out_channels,
228
- kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
229
- stride=stride,
230
- bias=bias,
231
- )
232
- else:
233
- self.conv = ConvTranspose1d(
234
- in_channels=in_channels,
235
- out_channels=out_channels,
236
- kernel_size=(2 * stride),
237
- stride=stride,
238
- bias=bias,
239
- )
240
-
241
- self.res_units = torch.nn.ModuleList()
242
- for idx, dilation in enumerate(dilations):
243
- self.res_units += [
244
- ResidualUnit(
245
- out_channels,
246
- out_channels,
247
- kernel_size=unit_kernel_size,
248
- dilation=dilation,
249
- )
250
- ]
251
- self.num_res = len(self.res_units)
252
-
253
- def forward(self, x):
254
- x = self.conv(x)
255
- for idx in range(self.num_res):
256
- x = self.res_units[idx](x)
257
- return x
258
-
259
-
260
- class Decoder(nn.Module):
261
- def __init__(
262
- self,
263
- code_dim: int,
264
- output_channels: int,
265
- decode_channels: int,
266
- channel_ratios=(1, 1),
267
- strides=(1, 1),
268
- kernel_size=3,
269
- bias=True,
270
- block_dilations=(1, 1),
271
- unit_kernel_size=3,
272
- ):
273
- super().__init__()
274
- assert len(channel_ratios) == len(strides)
275
-
276
- self.conv1 = Conv1d(
277
- in_channels=code_dim,
278
- out_channels=int(decode_channels * channel_ratios[0]),
279
- kernel_size=kernel_size,
280
- stride=1,
281
- bias=False,
282
- )
283
-
284
- self.conv_blocks = torch.nn.ModuleList()
285
- for idx, stride in enumerate(strides):
286
- in_channels = int(decode_channels * channel_ratios[idx])
287
- if idx < (len(channel_ratios) - 1):
288
- out_channels = int(decode_channels * channel_ratios[idx + 1])
289
- else:
290
- out_channels = decode_channels
291
- self.conv_blocks += [
292
- DecoderBlock(
293
- in_channels,
294
- out_channels,
295
- stride,
296
- dilations=block_dilations,
297
- unit_kernel_size=unit_kernel_size,
298
- bias=bias,
299
- )
300
- ]
301
- self.num_blocks = len(self.conv_blocks)
302
-
303
- self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
304
-
305
- def forward(self, z):
306
- x = self.conv1(z)
307
- for i in range(self.num_blocks):
308
- x = self.conv_blocks[i](x)
309
- x = self.conv2(x)
310
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/constants.py DELETED
@@ -1,3 +0,0 @@
1
- AUDIO_IN_TOKEN = "<|AUDIO|>"
2
- AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
3
- EOS_TOKEN = "<|end_of_text|>"
 
 
 
 
higgs_audio/data_collator/__init__.py DELETED
File without changes
higgs_audio/data_collator/higgs_audio_collator.py DELETED
@@ -1,583 +0,0 @@
1
- import librosa
2
- import torch
3
- import torch.nn.functional as F
4
- import math
5
- import numpy as np
6
- from typing import List, Tuple, Dict
7
-
8
- from dataclasses import dataclass
9
- from typing import List, Optional
10
- from transformers.models.whisper.processing_whisper import WhisperProcessor
11
-
12
- from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
13
- from ..model.utils import build_delay_pattern_mask
14
-
15
-
16
- def _ceil_to_nearest(n, round_to):
17
- return (n + round_to - 1) // round_to * round_to
18
-
19
-
20
- @dataclass
21
- class HiggsAudioBatchInput:
22
- input_ids: torch.LongTensor # shape (bsz, seq_len).
23
- attention_mask: torch.Tensor # shape (bsz, seq_len).
24
- audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
25
- audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
26
- audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
27
- audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
28
- # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
29
- # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
30
- # For example,
31
- # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
32
- # This is a batch of 3 samples, then we will have the group location as:
33
- # audio_out_ids_start_group_loc = [0, 0, 1, 2]
34
- audio_out_ids_start_group_loc: Optional[
35
- torch.LongTensor
36
- ] # shape (num_audio_out,), specify which a sample's group location in the batch
37
- audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
38
- audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
39
- label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
40
- label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
41
- reward: Optional[float] = None
42
-
43
-
44
- class HiggsAudioSampleCollator:
45
- """Sample collator for Higgs-Audio model.
46
-
47
- Args:
48
- whisper_processor (WhisperProcessor): The whisper processor.
49
- audio_in_token_id (int): The token id for audio-in.
50
- audio_out_token_id (int): The token id for audio-out.
51
- pad_token_id (int): The token id for padding.
52
- audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
53
- audio_stream_eos_id (int): The token id for audio-stream end of sentence.
54
- round_to (int): The round-to value.
55
- pad_left (bool): Whether to pad left.
56
- return_audio_in_tokens (bool): Whether to return audio-in tokens.
57
- use_delay_pattern (bool): Whether to use delay pattern.
58
- disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
59
- chunk_size_seconds (int): The chunk size in seconds.
60
- add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
61
- mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
62
-
63
- """
64
-
65
- def __init__(
66
- self,
67
- whisper_processor: WhisperProcessor,
68
- audio_in_token_id,
69
- audio_out_token_id,
70
- pad_token_id,
71
- audio_stream_bos_id,
72
- audio_stream_eos_id,
73
- round_to=8,
74
- pad_left=False,
75
- encode_whisper_embed=True,
76
- return_audio_in_tokens=True,
77
- audio_num_codebooks=None,
78
- use_delay_pattern=False,
79
- disable_audio_codes_transform=False,
80
- chunk_size_seconds=30, # Maximum duration for each chunk
81
- add_new_bos_eos_for_long_chunk=True,
82
- mask_audio_out_token_label=True,
83
- ):
84
- self.whisper_processor = whisper_processor
85
- self.round_to = round_to
86
- self.pad_left = pad_left
87
- self.audio_in_token_id = audio_in_token_id
88
- self.audio_out_token_id = audio_out_token_id
89
- self.audio_stream_bos_id = audio_stream_bos_id
90
- self.audio_stream_eos_id = audio_stream_eos_id
91
- self.pad_token_id = pad_token_id
92
- self.encode_whisper_embed = encode_whisper_embed
93
- self.return_audio_in_tokens = return_audio_in_tokens
94
- self.audio_num_codebooks = audio_num_codebooks
95
- self.use_delay_pattern = use_delay_pattern
96
- if encode_whisper_embed:
97
- self.chunk_size_seconds = chunk_size_seconds
98
- self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
99
- else:
100
- self.chunk_size_seconds = None
101
- self.chunk_size_samples = None
102
- self.disable_audio_codes_transform = disable_audio_codes_transform
103
- self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
104
- self.mask_audio_out_token_label = mask_audio_out_token_label
105
-
106
- def _process_and_duplicate_audio_tokens(
107
- self,
108
- input_ids: torch.Tensor,
109
- audio_idx: int,
110
- wv: torch.Tensor,
111
- sr: int,
112
- labels: Optional[torch.Tensor] = None,
113
- ) -> Tuple[torch.Tensor, torch.Tensor, int]:
114
- """Process long audio and duplicate corresponding audio tokens.
115
-
116
- Args:
117
- input_ids: Input token ids
118
- audio_idx: Index of the audio token in the sequence
119
- wv: Audio waveform
120
- sr: Sample rate
121
- labels: Optional label ids to be duplicated alongside input ids
122
-
123
- Returns:
124
- Tuple of:
125
- - New input ids with duplicated audio tokens
126
- - New label ids (if labels were provided) or None
127
- - Number of chunks created
128
- """
129
- # Calculate number of chunks needed
130
- total_samples = len(wv)
131
- num_chunks = math.ceil(total_samples / self.chunk_size_samples)
132
-
133
- if num_chunks <= 1:
134
- return input_ids, labels, 1
135
-
136
- # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
137
- audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
138
- # Duplicate sequence for each chunk
139
- duplicated_sequence = audio_token_seq.repeat(num_chunks)
140
-
141
- # Create new input_ids with duplicated tokens
142
- new_input_ids = torch.cat(
143
- [
144
- input_ids[: audio_idx - 1],
145
- duplicated_sequence,
146
- input_ids[audio_idx + 2 :],
147
- ]
148
- )
149
-
150
- # If labels are provided, duplicate them as well
151
- new_labels = None
152
- if labels is not None:
153
- label_seq = labels[audio_idx - 1 : audio_idx + 2]
154
- duplicated_labels = label_seq.repeat(num_chunks)
155
- new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
156
-
157
- return new_input_ids, new_labels, num_chunks
158
-
159
- def __call__(self, batch: List[ChatMLDatasetSample]):
160
- """Collate the input data with support for long audio processing."""
161
-
162
- label_ids = None
163
- label_audio_ids = None
164
- if all([ele.label_ids is None for ele in batch]):
165
- return_labels = False
166
- else:
167
- return_labels = True
168
-
169
- if self.encode_whisper_embed:
170
- # Process each sample in the batch to handle long audio
171
- # TODO(?) The implementation here can be optimized.
172
- processed_batch = []
173
- for i in range(len(batch)):
174
- sample = batch[i]
175
- audio_in_mask = sample.input_ids == self.audio_in_token_id
176
- audio_in_indices = torch.where(audio_in_mask)[0]
177
- audio_out_mask = sample.input_ids == self.audio_out_token_id
178
-
179
- # Process each audio token and duplicate if needed
180
- modified_input_ids = sample.input_ids
181
- modified_labels = sample.label_ids if return_labels else None
182
- modified_waveforms_concat = []
183
- modified_waveforms_start = []
184
- modified_sample_rate = []
185
- offset = 0 # Track position changes from duplicating tokens
186
- curr_wv_offset = 0
187
-
188
- # Process input audio tokens
189
- for idx, audio_idx in enumerate(audio_in_indices):
190
- # Get the audio for this token
191
- wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
192
- if sr != self.whisper_processor.feature_extractor.sampling_rate:
193
- resampled_wv = librosa.resample(
194
- wv.cpu().numpy(),
195
- orig_sr=sr,
196
- target_sr=self.whisper_processor.feature_extractor.sampling_rate,
197
- )
198
- else:
199
- resampled_wv = wv.cpu().numpy()
200
- wv = torch.tensor(resampled_wv, device=wv.device)
201
- sr = self.whisper_processor.feature_extractor.sampling_rate
202
-
203
- # Process and duplicate tokens if necessary
204
- token_pos = audio_idx + offset
205
- modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
206
- modified_input_ids, token_pos, wv, sr, modified_labels
207
- )
208
-
209
- # Update audio data
210
- for chunk_idx in range(num_chunks):
211
- chunk_start = chunk_idx * self.chunk_size_samples
212
- chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
213
- chunk_wv = wv[chunk_start:chunk_end]
214
- modified_waveforms_concat.append(chunk_wv)
215
- modified_waveforms_start.append(curr_wv_offset)
216
- curr_wv_offset += len(chunk_wv)
217
- modified_sample_rate.append(sr)
218
-
219
- # Update offset for next iteration
220
- offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
221
-
222
- # Create new sample with modified tokens and audio data
223
- processed_sample = ChatMLDatasetSample(
224
- input_ids=modified_input_ids,
225
- label_ids=modified_labels if return_labels else sample.label_ids,
226
- audio_ids_concat=sample.audio_ids_concat,
227
- audio_ids_start=sample.audio_ids_start,
228
- audio_waveforms_concat=torch.cat(modified_waveforms_concat)
229
- if modified_waveforms_concat
230
- else sample.audio_waveforms_concat,
231
- audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
232
- if modified_waveforms_start
233
- else sample.audio_waveforms_start,
234
- audio_sample_rate=torch.tensor(modified_sample_rate)
235
- if modified_sample_rate
236
- else sample.audio_sample_rate,
237
- audio_speaker_indices=torch.tensor([]),
238
- # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
239
- audio_label_ids_concat=sample.audio_label_ids_concat,
240
- )
241
- # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
242
- # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
243
- processed_batch.append(processed_sample)
244
- else:
245
- processed_batch = batch
246
-
247
- # Get the max sequence length based on processed batch
248
- max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
249
-
250
- # Get the ids for audio-in and audio-out for each batch
251
- audio_in_wv_l = []
252
- audio_in_ids_l = []
253
- audio_out_ids_l = []
254
- audio_out_ids_group_loc_l = []
255
- audio_in_label_ids_l = None
256
- audio_out_label_ids_l = None
257
- reward_l = []
258
-
259
- if return_labels:
260
- audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
261
-
262
- # Process the audio inputs and outputs
263
- for i in range(len(processed_batch)):
264
- audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
265
- audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
266
- audio_ids = torch.ones_like(processed_batch[i].input_ids)
267
- audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
268
- audio_in_ids = audio_ids[audio_in_mask]
269
- audio_out_ids = audio_ids[audio_out_mask]
270
-
271
- if return_labels:
272
- audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
273
- if self.mask_audio_out_token_label:
274
- processed_batch[i].label_ids[audio_out_mask] = -100
275
-
276
- # Process audio inputs
277
- if self.return_audio_in_tokens:
278
- audio_in_ids_l.extend(
279
- [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
280
- )
281
- if processed_batch[i].audio_label_ids_concat is not None:
282
- if audio_in_label_ids_l is None:
283
- audio_in_label_ids_l = []
284
- audio_in_label_ids_l.extend(
285
- [
286
- processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
287
- for idx in audio_in_ids
288
- ]
289
- )
290
-
291
- audio_out_ids_l.extend(
292
- [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
293
- )
294
- audio_out_ids_group_loc_l.append(i)
295
- if processed_batch[i].reward is not None:
296
- reward_l.append(processed_batch[i].reward)
297
-
298
- if processed_batch[i].audio_label_ids_concat is not None:
299
- if audio_out_label_ids_l is None:
300
- audio_out_label_ids_l = []
301
- audio_out_label_ids_l.extend(
302
- [
303
- processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
304
- for idx in audio_out_ids
305
- ]
306
- )
307
-
308
- if self.encode_whisper_embed:
309
- for idx in audio_in_ids:
310
- wv, sr = processed_batch[i].get_wv(idx)
311
- resampled_wv = wv.cpu().numpy()
312
- # Split long audio into chunks
313
- total_samples = len(resampled_wv)
314
- for chunk_start in range(0, total_samples, self.chunk_size_samples):
315
- chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
316
- chunk = resampled_wv[chunk_start:chunk_end]
317
- audio_in_wv_l.append(chunk)
318
- # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
319
- # f"Assertion failed: Mismatch in number of audios. " \
320
- # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
321
-
322
- if return_labels:
323
- audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
324
-
325
- # Process all audio features
326
- if len(audio_in_wv_l) > 0:
327
- feature_ret = self.whisper_processor.feature_extractor(
328
- audio_in_wv_l,
329
- sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
330
- return_attention_mask=True,
331
- padding="max_length",
332
- )
333
- audio_features = torch.from_numpy(feature_ret["input_features"])
334
- audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
335
- else:
336
- if self.encode_whisper_embed:
337
- audio_features = torch.zeros(
338
- (
339
- 0,
340
- self.whisper_processor.feature_extractor.feature_size,
341
- self.whisper_processor.feature_extractor.nb_max_frames,
342
- ),
343
- dtype=torch.float32,
344
- )
345
- audio_feature_attention_mask = torch.zeros(
346
- (0, self.whisper_processor.feature_extractor.nb_max_frames),
347
- dtype=torch.int32,
348
- )
349
- else:
350
- audio_features = None
351
- audio_feature_attention_mask = None
352
-
353
- # Process audio input tokens
354
- if len(audio_in_ids_l) > 0:
355
- # Append audio-stream-bos and eos tokens
356
- new_audio_in_ids_l = []
357
- for ele in audio_in_ids_l:
358
- if self.disable_audio_codes_transform:
359
- # Do not add audio-stream-bos or eos tokens.
360
- # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
361
- audio_codes = ele
362
- else:
363
- audio_codes = torch.cat(
364
- [
365
- torch.full(
366
- (ele.shape[0], 1),
367
- self.audio_stream_bos_id,
368
- dtype=torch.long,
369
- ),
370
- ele,
371
- torch.full(
372
- (ele.shape[0], 1),
373
- self.audio_stream_eos_id,
374
- dtype=torch.long,
375
- ),
376
- ],
377
- dim=1,
378
- )
379
- if self.use_delay_pattern:
380
- audio_codes = build_delay_pattern_mask(
381
- audio_codes.unsqueeze(0),
382
- bos_token_id=self.audio_stream_bos_id,
383
- pad_token_id=self.audio_stream_eos_id,
384
- )[0].squeeze(0)
385
- new_audio_in_ids_l.append(audio_codes)
386
- audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
387
- audio_in_ids_start = torch.cumsum(
388
- torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
389
- dim=0,
390
- )
391
- else:
392
- audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
393
- audio_in_ids_start = torch.zeros(0, dtype=torch.long)
394
-
395
- # Process audio output tokens
396
- audio_out_ids_start_group_loc = None
397
- if len(audio_out_ids_l) > 0:
398
- new_audio_out_ids_l = []
399
- label_audio_ids_l = []
400
- for idx, ele in enumerate(audio_out_ids_l):
401
- if self.disable_audio_codes_transform:
402
- # Do not add audio-stream-bos or eos tokens.
403
- # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
404
- audio_codes = ele
405
- if return_labels:
406
- label_audio_ids = audio_out_label_ids_l[idx]
407
- else:
408
- audio_codes = torch.cat(
409
- [
410
- torch.full(
411
- (ele.shape[0], 1),
412
- self.audio_stream_bos_id,
413
- dtype=torch.long,
414
- ),
415
- ele,
416
- torch.full(
417
- (ele.shape[0], 1),
418
- self.audio_stream_eos_id,
419
- dtype=torch.long,
420
- ),
421
- ],
422
- dim=1,
423
- )
424
- if return_labels:
425
- label_audio_ids = torch.cat(
426
- [
427
- torch.full((ele.shape[0], 1), -100, dtype=torch.long),
428
- ele,
429
- torch.full(
430
- (ele.shape[0], 1),
431
- self.audio_stream_eos_id,
432
- dtype=torch.long,
433
- ),
434
- ],
435
- dim=1,
436
- )
437
- if self.use_delay_pattern:
438
- audio_codes = build_delay_pattern_mask(
439
- audio_codes.unsqueeze(0),
440
- bos_token_id=self.audio_stream_bos_id,
441
- pad_token_id=self.audio_stream_eos_id,
442
- )[0].squeeze(0)
443
- if return_labels:
444
- label_audio_ids = build_delay_pattern_mask(
445
- label_audio_ids.unsqueeze(0),
446
- bos_token_id=-100,
447
- pad_token_id=-100,
448
- )[0].squeeze(0)
449
- new_audio_out_ids_l.append(audio_codes)
450
-
451
- if return_labels:
452
- if audio_out_no_train_flag[idx]:
453
- label_audio_ids[:] = -100
454
- label_audio_ids_l.append(label_audio_ids)
455
-
456
- audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
457
- if return_labels:
458
- label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
459
- audio_out_ids_start = torch.cumsum(
460
- torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
461
- dim=0,
462
- )
463
- audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
464
- else:
465
- audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
466
- audio_out_ids_start = torch.zeros(0, dtype=torch.long)
467
- if return_labels:
468
- label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
469
-
470
- reward = torch.tensor(reward_l, dtype=torch.float32)
471
-
472
- # Handle padding for input ids and attention mask
473
- if self.pad_left:
474
- input_ids = torch.stack(
475
- [
476
- F.pad(
477
- ele.input_ids,
478
- (max_seq_length - len(ele.input_ids), 0),
479
- value=self.pad_token_id,
480
- )
481
- for ele in processed_batch
482
- ]
483
- )
484
- if return_labels:
485
- label_ids = torch.stack(
486
- [
487
- F.pad(
488
- ele.label_ids,
489
- (max_seq_length - len(ele.label_ids), 0),
490
- value=-100,
491
- )
492
- for ele in processed_batch
493
- ]
494
- )
495
- attention_mask = torch.stack(
496
- [
497
- F.pad(
498
- torch.ones_like(ele.input_ids),
499
- (max_seq_length - len(ele.input_ids), 0),
500
- value=0,
501
- )
502
- for ele in processed_batch
503
- ]
504
- )
505
- else:
506
- input_ids = torch.stack(
507
- [
508
- F.pad(
509
- ele.input_ids,
510
- (0, max_seq_length - len(ele.input_ids)),
511
- value=self.pad_token_id,
512
- )
513
- for ele in processed_batch
514
- ]
515
- )
516
- if return_labels:
517
- label_ids = torch.stack(
518
- [
519
- F.pad(
520
- ele.label_ids,
521
- (0, max_seq_length - len(ele.label_ids)),
522
- value=-100,
523
- )
524
- for ele in processed_batch
525
- ]
526
- )
527
- attention_mask = torch.stack(
528
- [
529
- F.pad(
530
- torch.ones_like(ele.input_ids),
531
- (0, max_seq_length - len(ele.input_ids)),
532
- value=0,
533
- )
534
- for ele in processed_batch
535
- ]
536
- )
537
-
538
- if not self.return_audio_in_tokens:
539
- audio_in_ids = None
540
- audio_in_ids_start = None
541
-
542
- # Apply audio_num_codebooks limit if specified
543
- if self.audio_num_codebooks is not None:
544
- if audio_in_ids is not None:
545
- audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
546
- if audio_out_ids is not None:
547
- audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
548
- if label_audio_ids is not None:
549
- label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
550
-
551
- return HiggsAudioBatchInput(
552
- input_ids=input_ids,
553
- attention_mask=attention_mask,
554
- audio_features=audio_features,
555
- audio_feature_attention_mask=audio_feature_attention_mask,
556
- audio_out_ids=audio_out_ids,
557
- audio_out_ids_start=audio_out_ids_start,
558
- audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
559
- audio_in_ids=audio_in_ids,
560
- audio_in_ids_start=audio_in_ids_start,
561
- label_ids=label_ids,
562
- label_audio_ids=label_audio_ids,
563
- reward=reward,
564
- )
565
-
566
-
567
- class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
568
- def __init__(self, *args, **kwargs):
569
- super().__init__(*args, **kwargs)
570
-
571
- def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
572
- # flatten ranked chatml samples
573
- chosen = []
574
- rejected = []
575
-
576
- for sample in batch:
577
- chosen.append(sample.max_score_sample())
578
- rejected.append(sample.min_score_sample())
579
-
580
- merged = chosen
581
- merged.extend(rejected)
582
-
583
- return super().__call__(batch=merged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/data_types.py DELETED
@@ -1,38 +0,0 @@
1
- """Basic data types for multimodal ChatML format."""
2
-
3
- from dataclasses import dataclass
4
- from typing import Dict, List, Optional, Union
5
-
6
-
7
- @dataclass
8
- class AudioContent:
9
- audio_url: str
10
- # Base64 encoded audio bytes
11
- raw_audio: Optional[str] = None
12
- offset: Optional[float] = None
13
- duration: Optional[float] = None
14
- row_id: Optional[int] = None
15
- type: str = "audio"
16
-
17
-
18
- @dataclass
19
- class TextContent:
20
- text: str
21
- type: str = "text"
22
-
23
-
24
- @dataclass
25
- class Message:
26
- role: str
27
- content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
28
- recipient: Optional[str] = None
29
-
30
-
31
- @dataclass
32
- class ChatMLSample:
33
- """Dataclass to hold multimodal ChatML data."""
34
-
35
- messages: List[Message]
36
- start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
37
- misc: Optional[Dict] = None
38
- speaker: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/dataset/__init__.py DELETED
File without changes
higgs_audio/dataset/chatml_dataset.py DELETED
@@ -1,554 +0,0 @@
1
- import dacite
2
- import pandas as pd
3
- import torch
4
- import json
5
-
6
- import numpy as np
7
- import multiprocessing as mp
8
-
9
- from dataclasses import dataclass, fields
10
- from abc import ABC, abstractmethod
11
- from typing import Union, List, Dict, Optional
12
-
13
- from ..data_types import ChatMLSample, TextContent, AudioContent
14
- from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
15
-
16
- from loguru import logger
17
-
18
- # Whisper processor, 30 sec -> 3000 features
19
- # Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
20
- WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
21
-
22
-
23
- @dataclass
24
- class ChatMLDatasetSample:
25
- input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
26
- label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
27
- audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
28
- # Here `audio_seq_len` is the length of the concatenated audio tokens.`
29
- audio_ids_start: (
30
- torch.LongTensor
31
- ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
32
- audio_waveforms_concat: (
33
- torch.Tensor
34
- ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
35
- audio_waveforms_start: (
36
- torch.LongTensor
37
- ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
38
- audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
39
- audio_speaker_indices: (
40
- torch.LongTensor
41
- ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
42
- audio_label_ids_concat: Optional[torch.LongTensor] = (
43
- None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
44
- )
45
- # Here `audio_seq_len` is the length of the concatenated audio tokens.`
46
- reward: Optional[float] = None
47
-
48
- def num_audios(self):
49
- return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
50
-
51
- def get_audio_codes(self, idx):
52
- code_start = self.audio_ids_start[idx]
53
- if idx < len(self.audio_ids_start) - 1:
54
- code_end = self.audio_ids_start[idx + 1]
55
- else:
56
- code_end = self.audio_ids_concat.shape[-1]
57
-
58
- return self.audio_ids_concat[:, code_start:code_end]
59
-
60
- def get_audio_codes_labels(self, idx):
61
- if self.audio_label_ids_concat is None:
62
- return None
63
- code_start = self.audio_ids_start[idx]
64
- if idx < len(self.audio_ids_start) - 1:
65
- code_end = self.audio_ids_start[idx + 1]
66
- else:
67
- code_end = self.audio_ids_concat.shape[-1]
68
-
69
- return self.audio_label_ids_concat[:, code_start:code_end]
70
-
71
- def get_wv(self, idx):
72
- wv_start = self.audio_waveforms_start[idx]
73
- sr = self.audio_sample_rate[idx]
74
- if idx < len(self.audio_waveforms_start) - 1:
75
- wv_end = self.audio_waveforms_start[idx + 1]
76
- else:
77
- wv_end = self.audio_waveforms_concat.shape[-1]
78
- return self.audio_waveforms_concat[wv_start:wv_end], sr
79
-
80
- def cal_num_tokens(
81
- self,
82
- encode_whisper_embed: bool = True,
83
- encode_audio_in_tokens: bool = False,
84
- encode_audio_out_tokens: bool = True,
85
- audio_in_token_id: int = 128015,
86
- audio_out_token_id: int = 128016,
87
- ) -> int:
88
- # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
89
- # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
90
- num_tokens = len(self.input_ids) - len(self.audio_ids_start)
91
-
92
- if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
93
- audio_lengths = torch.diff(self.audio_waveforms_start)
94
- if len(audio_lengths):
95
- # Sum before calling .item()
96
- num_tokens += (
97
- (
98
- np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
99
- ).sum()
100
- ).item()
101
- # add the last audio's token estimation
102
- num_tokens += (
103
- np.ceil(
104
- WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
105
- * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
106
- / self.audio_sample_rate[-1]
107
- )
108
- ).item()
109
-
110
- if self.audio_ids_concat.size(1) > 0:
111
- audio_io_ids = self.input_ids[
112
- (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
113
- ]
114
- audio_io_id_lengths = torch.concat(
115
- [
116
- torch.diff(self.audio_ids_start),
117
- torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
118
- ]
119
- )
120
- if encode_audio_in_tokens:
121
- num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
122
-
123
- if encode_audio_out_tokens:
124
- num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
125
-
126
- return int(num_tokens)
127
-
128
- @classmethod
129
- def merge(
130
- cls,
131
- samples: List["ChatMLDatasetSample"],
132
- eos_token_id: int,
133
- ignore_index: int,
134
- padding_size: Optional[int] = None,
135
- ) -> "ChatMLDatasetSample":
136
- """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
137
-
138
- Args:
139
- samples (List[ChatMLDatasetSample]): List of samples to merge.
140
- eos_token_id (int): Tokens to be inserted into input_ids between samples.
141
- ignore_index (int): Default label for padding.
142
- padding_size (Optional[int]): If provided, pad the sequence to with this length.
143
-
144
- Returns:
145
- ChatMLDatasetSample: Merged and potentially padded sample.
146
- """
147
- if not samples:
148
- logger.fatal("The samples list is empty and cannot be merged.")
149
- raise ValueError("The samples list is empty and cannot be merged.")
150
-
151
- # Initialize empty lists for concatenation
152
- input_ids_list = []
153
- label_ids_list = []
154
- audio_ids_concat_list = []
155
- audio_ids_start_list = []
156
- audio_waveforms_concat_list = []
157
- audio_waveforms_start_list = []
158
- audio_sample_rate_list = []
159
- audio_speaker_indices_list = []
160
-
161
- # Track offsets
162
- audio_ids_offset = 0
163
- audio_waveforms_offset = 0
164
-
165
- for sample in samples:
166
- # Add input_ids and label_ids with padding
167
- if input_ids_list:
168
- input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
169
- label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
170
- input_ids_list.append(sample.input_ids)
171
- label_ids_list.append(sample.label_ids)
172
-
173
- # Add audio_ids_concat and handle empty audio ids
174
- if sample.audio_ids_concat.size(1) > 0:
175
- audio_ids_concat_list.append(sample.audio_ids_concat)
176
-
177
- # Offset and add audio_ids_start
178
- audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
179
- audio_ids_offset += sample.audio_ids_concat.size(
180
- 1
181
- ) # (num_codebooks, seq_len): Update offset by audio_seq_len
182
-
183
- # Add audio_waveforms_concat
184
- if sample.audio_waveforms_concat.size(0) > 0:
185
- # Check dimensions of the audio waveform to ensure consistency
186
- if (
187
- audio_waveforms_concat_list
188
- and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
189
- ):
190
- logger.warning(
191
- f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
192
- )
193
- continue
194
-
195
- audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
196
- audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
197
- audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
198
-
199
- # Add audio_sample_rate and audio_speaker_indices
200
- audio_sample_rate_list.append(sample.audio_sample_rate)
201
-
202
- audio_speaker_indices_list.append(sample.audio_speaker_indices)
203
-
204
- # Concatenate all tensors
205
- input_ids = torch.cat(input_ids_list, dim=0)
206
- label_ids = torch.cat(label_ids_list, dim=0)
207
-
208
- # Apply padding if padding_size is specified
209
- if padding_size is not None and padding_size > 0:
210
- input_ids = torch.cat(
211
- [
212
- input_ids,
213
- torch.full((padding_size,), eos_token_id, dtype=torch.long),
214
- ],
215
- dim=0,
216
- )
217
- label_ids = torch.cat(
218
- [
219
- label_ids,
220
- torch.full((padding_size,), ignore_index, dtype=torch.long),
221
- ],
222
- dim=0,
223
- )
224
-
225
- # Safely concatenate audio tensors with proper error handling
226
- try:
227
- audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
228
- audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
229
-
230
- # Check for dimensional consistency in audio waveforms
231
- if audio_waveforms_concat_list:
232
- dims = [t.dim() for t in audio_waveforms_concat_list]
233
- if not all(d == dims[0] for d in dims):
234
- # If dimensions don't match, log warning and filter out the problematic tensors
235
- logger.warning(
236
- f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
237
- )
238
- expected_dim = max(set(dims), key=dims.count) # Most common dimension
239
- audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
240
-
241
- # Recalculate audio_waveforms_start with the filtered list
242
- if audio_waveforms_concat_list:
243
- audio_waveforms_offset = 0
244
- audio_waveforms_start_list = []
245
- for waveform in audio_waveforms_concat_list:
246
- audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
247
- audio_waveforms_offset += waveform.size(0)
248
-
249
- audio_waveforms_concat = (
250
- torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
251
- )
252
- audio_waveforms_start = (
253
- torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
254
- )
255
- audio_sample_rate = (
256
- torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
257
- )
258
- audio_speaker_indices = (
259
- torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
260
- )
261
-
262
- except RuntimeError as e:
263
- logger.error(f"Error during tensor concatenation: {str(e)}")
264
- logger.warning("Falling back to empty audio tensors")
265
- # Fall back to empty tensors
266
- audio_ids_concat = torch.tensor([[]])
267
- audio_ids_start = torch.tensor([])
268
- audio_waveforms_concat = torch.tensor([])
269
- audio_waveforms_start = torch.tensor([])
270
- audio_sample_rate = torch.tensor([])
271
- audio_speaker_indices = torch.tensor([])
272
-
273
- # Create the merged sample
274
- merged_sample = cls(
275
- input_ids=input_ids,
276
- label_ids=label_ids,
277
- audio_ids_concat=audio_ids_concat,
278
- audio_ids_start=audio_ids_start,
279
- audio_waveforms_concat=audio_waveforms_concat,
280
- audio_waveforms_start=audio_waveforms_start,
281
- audio_sample_rate=audio_sample_rate,
282
- audio_speaker_indices=audio_speaker_indices,
283
- )
284
-
285
- return merged_sample
286
-
287
-
288
- @dataclass
289
- class RankedChatMLDatasetSampleTuple:
290
- samples: List[ChatMLDatasetSample]
291
- scores: List[float]
292
-
293
- def max_score_sample(self) -> ChatMLDatasetSample:
294
- idx = self.scores.index(max(self.scores))
295
- self.samples[idx].reward = self.scores[idx]
296
- return self.samples[idx]
297
-
298
- def min_score_sample(self) -> ChatMLDatasetSample:
299
- idx = self.scores.index(min(self.scores))
300
- self.samples[idx].reward = self.scores[idx]
301
- return self.samples[idx]
302
-
303
-
304
- @dataclass
305
- class ChatMLDatasetStorageSample:
306
- input_tokens: torch.LongTensor
307
- label_tokens: torch.LongTensor
308
- audio_bytes_cache_dir_index: int
309
- audio_codes_cache_dir_index: int
310
- audio_bytes_indices: torch.LongTensor
311
- audio_codes_indices: torch.LongTensor
312
- speaker_indices: torch.LongTensor
313
- file_index: int
314
- original_sample_index: int
315
-
316
-
317
- # TODO(sxjscience): We need to revist the logic about parsing speaker ids.
318
- # Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
319
- def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
320
- """Preprocess the ChatML sample to get the tokens for the text part.
321
-
322
- Args:
323
- sample (ChatMLSample): The ChatML sample to preprocess.
324
- tokenizer: The tokenizer to use for encoding the text.
325
-
326
- """
327
-
328
- try:
329
- if not isinstance(sample, ChatMLSample):
330
- # Handle all fields that could be NaN
331
- if "speaker" in sample and pd.isna(sample["speaker"]):
332
- sample["speaker"] = None
333
- if "start_index" in sample and pd.isna(sample["start_index"]):
334
- sample["start_index"] = None
335
- if "content" in sample and pd.isna(sample["content"]):
336
- sample["content"] = ""
337
-
338
- # Convert any other potential NaN values in nested structures
339
- def convert_nan_to_none(obj):
340
- import numpy as np
341
-
342
- if isinstance(obj, (pd.Series, np.ndarray)):
343
- return obj.tolist()
344
- elif pd.api.types.is_scalar(obj) and pd.isna(obj):
345
- return None
346
- elif isinstance(obj, dict):
347
- return {k: convert_nan_to_none(v) for k, v in obj.items()}
348
- elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
349
- return [convert_nan_to_none(item) for item in obj]
350
- return obj
351
-
352
- # Clean the sample data
353
- clean_sample = convert_nan_to_none(sample)
354
-
355
- val_keys = []
356
- for field in fields(ChatMLSample):
357
- if field.name in clean_sample:
358
- val_keys.append(field.name)
359
- clean_sample = {k: clean_sample[k] for k in val_keys}
360
-
361
- try:
362
- sample = dacite.from_dict(
363
- data_class=ChatMLSample,
364
- data=clean_sample,
365
- config=dacite.Config(strict=True, check_types=True),
366
- )
367
- except Exception as e:
368
- print(f"Failed to convert to ChatMLSample: {e}")
369
- print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
370
- return None, None, None, None
371
-
372
- input_tokens = []
373
- label_tokens = []
374
- audio_contents = []
375
- speaker_id = None
376
- if sample.speaker is not None:
377
- speaker_id = sample.speaker
378
- elif sample.misc is not None:
379
- if "speaker" in sample.misc:
380
- speaker_id = sample.misc["speaker"]
381
-
382
- total_m = len(sample.messages)
383
- for turn_id, message in enumerate(sample.messages):
384
- role = message.role
385
- recipient = message.recipient
386
- content = message.content
387
- content_l = []
388
-
389
- if isinstance(content, str):
390
- content_l.append(TextContent(text=content))
391
- elif isinstance(content, TextContent):
392
- content_l.append(content)
393
- elif isinstance(content, AudioContent):
394
- content_l.append(content)
395
- elif isinstance(content, list):
396
- for ele in content:
397
- if isinstance(ele, str):
398
- content_l.append(TextContent(text=ele))
399
- else:
400
- content_l.append(ele)
401
- if turn_id == 0:
402
- prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
403
- else:
404
- prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
405
- eot_postfix = "<|eot_id|>"
406
- eom_postfix = "<|eom_id|>"
407
-
408
- prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
409
- input_tokens.extend(prefix_tokens)
410
- label_tokens.extend([-100 for _ in prefix_tokens])
411
-
412
- if recipient:
413
- assert role == "assistant", "Recipient is only available for assistant role."
414
- recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
415
- input_tokens.extend(recipient_tokens)
416
- label_tokens.extend(recipient_tokens)
417
-
418
- for content in content_l:
419
- if content.type == "text":
420
- text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
421
- input_tokens.extend(text_tokens)
422
- if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
423
- label_tokens.extend(text_tokens)
424
- else:
425
- label_tokens.extend([-100 for _ in text_tokens])
426
-
427
- elif content.type == "audio":
428
- # Generate the text-part of the audio tokens
429
- audio_contents.append(content)
430
- if role == "user" or role == "system":
431
- # Add the text tokens
432
- text_tokens = tokenizer.encode(
433
- f"<|audio_bos|><|AUDIO|><|audio_eos|>",
434
- add_special_tokens=False,
435
- )
436
- input_tokens.extend(text_tokens)
437
- label_tokens.extend([-100 for _ in text_tokens])
438
- elif role == "assistant":
439
- # Add the text tokens for audio-out part.
440
- text_tokens = tokenizer.encode(
441
- f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
442
- add_special_tokens=False,
443
- )
444
- input_tokens.extend(text_tokens)
445
- if sample.start_index is None or turn_id >= sample.start_index:
446
- label_tokens.extend(text_tokens)
447
- else:
448
- label_tokens.extend([-100 for _ in text_tokens])
449
- next_id = turn_id + 1
450
- if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
451
- postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
452
- input_tokens.extend(postfix_tokens)
453
- else:
454
- postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
455
- input_tokens.extend(postfix_tokens)
456
- if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
457
- label_tokens.extend(postfix_tokens)
458
- else:
459
- label_tokens.extend([-100 for _ in postfix_tokens])
460
-
461
- return input_tokens, label_tokens, audio_contents, speaker_id
462
-
463
- except Exception as e:
464
- print(f"Error in prepare_chatml_sample: {str(e)}")
465
- print(f"Sample data: {json.dumps(sample, indent=2)}")
466
- return None, None, None, None
467
-
468
-
469
- def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
470
- """Extract the generation prompt and reference answer from the input tokens.
471
-
472
- For example:
473
-
474
- Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
475
- What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
476
- <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
477
-
478
- -->
479
-
480
- Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
481
- What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
482
- <|start_header_id|>assistant<|end_header_id|>\n\n',
483
- Reference = 'At first they went by quick, too quick to even get.'
484
-
485
- Args:
486
- input_tokens: The input tokens.
487
- audio_contents: The audio contents.
488
- tokenizer: The tokenizer to use for decoding the text.
489
-
490
- Returns:
491
- prompt_tokens: The tokens for the prompt.
492
- reference_answer: The reference answer.
493
- num_audios_in_reference: The number of audios in the reference answer.
494
-
495
- """
496
- input_text = tokenizer.decode(input_tokens)
497
- generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
498
- postfix = "<|eot_id|>"
499
- assert generation_prefix in input_text
500
- generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
501
- generation_prompt = input_text[:generation_prompt_end_loc]
502
- reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
503
- num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
504
- return (
505
- tokenizer.encode(generation_prompt, add_special_tokens=False),
506
- reference_answer,
507
- num_audios_in_reference,
508
- )
509
-
510
-
511
- def prepare_chatml_dataframe_single_process(df, tokenizer):
512
- """Prepare the ChatML DataFrame."""
513
- ret = []
514
- for _, row in df.iterrows():
515
- input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
516
- ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
517
- return ret
518
-
519
-
520
- def prepare_chatml_dataframe(df, tokenizer, num_process=16):
521
- if num_process is None:
522
- return prepare_chatml_dataframe_single_process(df, tokenizer)
523
- else:
524
- num_process = max(min(len(df) // 1000, num_process), 1)
525
- workloads = np.array_split(df, num_process)
526
- with mp.Pool(num_process) as pool:
527
- ret = pool.starmap(
528
- prepare_chatml_dataframe_single_process,
529
- [(workload, tokenizer) for workload in workloads],
530
- )
531
- return sum(ret, [])
532
-
533
-
534
- class DatasetInterface(ABC):
535
- @abstractmethod
536
- def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
537
- """Retrieve a dataset sample by index."""
538
- raise NotImplementedError
539
-
540
-
541
- class IterableDatasetInterface(ABC):
542
- @abstractmethod
543
- def __iter__(
544
- self,
545
- ) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
546
- """Retrieve a sample by iterating through the dataset."""
547
- raise NotImplementedError
548
-
549
-
550
- @dataclass
551
- class DatasetInfo:
552
- dataset_type: str
553
- group_type: Optional[str] = None
554
- mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/model/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- from transformers import AutoConfig, AutoModel
2
-
3
- from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
4
- from .modeling_higgs_audio import HiggsAudioModel
5
-
6
-
7
- AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
8
- AutoConfig.register("higgs_audio", HiggsAudioConfig)
9
- AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
 
 
 
 
 
 
 
 
 
 
higgs_audio/model/audio_head.py DELETED
@@ -1,139 +0,0 @@
1
- """Projector that maps hidden states from the LLM component to multimodal logits."""
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from dataclasses import dataclass
7
- from typing import Optional, Tuple
8
-
9
- from .common import HiggsAudioPreTrainedModel
10
- from .configuration_higgs_audio import HiggsAudioConfig
11
-
12
-
13
- @dataclass
14
- class HiggsAudioDecoderLayerOutput:
15
- logits: torch.FloatTensor
16
- audio_logits: torch.FloatTensor
17
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
18
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
19
-
20
-
21
- class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
22
- """Projection layers that map hidden states from the LLM component to audio / text logits.
23
-
24
- We support two type of audio head:
25
- - Basic Audio Head:
26
- Directly map the hidden states to audio logits for all the codebooks.
27
- """
28
-
29
- def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
30
- super().__init__(config)
31
- self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
32
- self.audio_lm_head = nn.Linear(
33
- config.text_config.hidden_size,
34
- config.audio_num_codebooks * (config.audio_codebook_size + 2),
35
- bias=False,
36
- )
37
-
38
- # Initialize weights and apply final processing
39
- self.post_init()
40
-
41
- def forward(
42
- self,
43
- hidden_states,
44
- audio_out_mask,
45
- label_audio_ids=None,
46
- attention_mask=None,
47
- position_ids=None,
48
- past_key_values=None,
49
- use_cache=None,
50
- output_attentions=None,
51
- output_hidden_states=None,
52
- output_audio_hidden_states=False,
53
- cache_position=None,
54
- ):
55
- """
56
- Args:
57
- hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
58
- Hidden states from the LLM component
59
- audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
60
- Mask for identifying the audio out tokens.
61
- label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
62
- Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
63
- attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
64
- Mask to avoid performing attention on padding token indices
65
- position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
66
- Position ids for the input tokens
67
-
68
- Returns:
69
- logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
70
- Logits for text tokens
71
- audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
72
- Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
73
- """
74
- logits = self.text_lm_head(hidden_states)
75
-
76
- all_hidden_states = () if output_hidden_states else None
77
- all_self_attns = () if output_attentions else None
78
- next_decoder_cache = None
79
-
80
- # TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
81
- if self.config.audio_decoder_proj_num_layers > 0:
82
- # create position embeddings to be shared across the decoder layers
83
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
84
- for decoder_layer in self.transformer_layers:
85
- if output_hidden_states:
86
- all_hidden_states += (hidden_states,)
87
-
88
- if self.gradient_checkpointing and self.training:
89
- layer_outputs = self._gradient_checkpointing_func(
90
- decoder_layer.__call__,
91
- hidden_states,
92
- attention_mask,
93
- position_ids,
94
- past_key_values,
95
- output_attentions,
96
- use_cache,
97
- cache_position,
98
- position_embeddings,
99
- )
100
- else:
101
- layer_outputs = decoder_layer(
102
- hidden_states,
103
- attention_mask=attention_mask,
104
- position_ids=position_ids,
105
- past_key_value=past_key_values,
106
- output_attentions=output_attentions,
107
- use_cache=use_cache,
108
- cache_position=cache_position,
109
- position_embeddings=position_embeddings,
110
- )
111
- hidden_states = layer_outputs[0]
112
- hidden_states = self.norm(hidden_states)
113
-
114
- if output_hidden_states:
115
- all_hidden_states += (hidden_states,)
116
-
117
- if output_attentions:
118
- all_self_attns += (layer_outputs[1],)
119
-
120
- if use_cache:
121
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
122
-
123
- next_cache = next_decoder_cache if use_cache else None
124
-
125
- audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
126
-
127
- if output_audio_hidden_states:
128
- audio_hidden_states = hidden_states[audio_out_mask]
129
- else:
130
- audio_hidden_states = None
131
-
132
- return (
133
- logits,
134
- audio_logits,
135
- all_self_attns,
136
- all_hidden_states,
137
- audio_hidden_states,
138
- next_cache,
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/model/common.py DELETED
@@ -1,27 +0,0 @@
1
- from torch import nn
2
-
3
- from transformers.modeling_utils import PreTrainedModel
4
-
5
- from .configuration_higgs_audio import HiggsAudioConfig
6
-
7
-
8
- class HiggsAudioPreTrainedModel(PreTrainedModel):
9
- config_class = HiggsAudioConfig
10
- base_model_prefix = "model"
11
- supports_gradient_checkpointing = True
12
- _no_split_modules = []
13
- _skip_keys_device_placement = "past_key_values"
14
- _supports_flash_attn_2 = True
15
- _supports_sdpa = True
16
-
17
- def _init_weights(self, module):
18
- std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
19
-
20
- if isinstance(module, (nn.Linear, nn.Conv1d)):
21
- module.weight.data.normal_(mean=0.0, std=std)
22
- if module.bias is not None:
23
- module.bias.data.zero_()
24
- elif isinstance(module, nn.Embedding):
25
- module.weight.data.normal_(mean=0.0, std=std)
26
- if module.padding_idx is not None:
27
- module.weight.data[module.padding_idx].zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/model/configuration_higgs_audio.py DELETED
@@ -1,235 +0,0 @@
1
- from transformers.configuration_utils import PretrainedConfig
2
- from transformers.models.auto import CONFIG_MAPPING
3
-
4
-
5
- class HiggsAudioEncoderConfig(PretrainedConfig):
6
- """Configuration of the Audio encoder in Higgs-Audio."""
7
-
8
- model_type = "higgs_audio_encoder"
9
-
10
- def __init__(
11
- self,
12
- num_mel_bins=128,
13
- encoder_layers=32,
14
- encoder_attention_heads=20,
15
- encoder_ffn_dim=5120,
16
- encoder_layerdrop=0.0,
17
- d_model=1280,
18
- dropout=0.0,
19
- attention_dropout=0.0,
20
- activation_function="gelu",
21
- activation_dropout=0.0,
22
- scale_embedding=False,
23
- init_std=0.02,
24
- max_source_positions=1500,
25
- pad_token_id=128001,
26
- **kwargs,
27
- ):
28
- super().__init__(**kwargs)
29
-
30
- self.num_mel_bins = num_mel_bins
31
- self.d_model = d_model
32
- self.encoder_layers = encoder_layers
33
- self.encoder_attention_heads = encoder_attention_heads
34
- self.encoder_ffn_dim = encoder_ffn_dim
35
- self.dropout = dropout
36
- self.attention_dropout = attention_dropout
37
- self.activation_function = activation_function
38
- self.activation_dropout = activation_dropout
39
- self.encoder_layerdrop = encoder_layerdrop
40
- self.num_hidden_layers = encoder_layers
41
- self.init_std = init_std
42
- self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
43
- self.max_source_positions = max_source_positions
44
- self.pad_token_id = pad_token_id
45
-
46
-
47
- class HiggsAudioConfig(PretrainedConfig):
48
- r"""
49
- This is the configuration class for the HiggsAudioModel.
50
-
51
- Args:
52
- text_config (`Union[AutoConfig, dict]`):
53
- The config object or dictionary of the text backbone.
54
- audio_encoder_config (`Union[AutoConfig, dict]`):
55
- The config object or dictionary of the whisper encoder.
56
- The audio encoder will be bidirectional and will be only available for audio understanding.
57
- audio_tokenizer_config
58
- The config object or dictionary of the audio tokenizer.
59
- audio_adapter_type
60
- The type of audio adapter to use. We support two types of adapter:
61
- - stack:
62
- We stack additional Transformer layers after the main LLM backbone for audio generation.
63
- - dual_ffn:
64
- For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture
65
- that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens.
66
- - dual_ffn_fast_forward:
67
- We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers,
68
- the audio hidden states will be directly fast-forward to the next layer.
69
- This reduces the computational cost for audio generation.
70
- audio_embed_avg (`bool`, *optional*, defaults to False):
71
- Whether to average the audio embeddings before sending them to the text attention layer.
72
- audio_ffn_hidden_size
73
- The hidden size of the audio feedforward network in dual-path FFN
74
- audio_ffn_intermediate_size
75
- The intermediate size of the audio feedforward network in dual-path FFN
76
- audio_dual_ffn_layers
77
- The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN).
78
- audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0):
79
- The number of attention heads in the audio decoder projection layer.
80
- use_delay_pattern (`bool`, *optional*, defaults to False):
81
- Whether to use delay pattern in the audio decoder.
82
- skip_audio_tower (`bool`, *optional*, defaults to False):
83
- Whether to skip the audio tower in the audio encoder.
84
- use_audio_out_embed_projector (`bool`, *optional*, defaults to False):
85
- Whether to use an embedding projector to map audio out embeddings.
86
- use_audio_out_self_attention (`bool`, *optional*, defaults to False):
87
- Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer.
88
- audio_num_codebooks (`int`, *optional*, defaults to 12):
89
- The number of codebooks in RVQGAN.
90
- audio_codebook_size (`int`, *optional*, defaults to 1024):
91
- The size of each codebook in RVQGAN.
92
- audio_stream_bos_id
93
- The id of the bos in the audio stream
94
- audio_stream_eos_id
95
- The id of the eos in the audio stream
96
- audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"):
97
- The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011,
98
- which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer.
99
- audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"):
100
- The special `<|audio_eos|>` token. We use 128012 as the default value,
101
- which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer.
102
- audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"):
103
- The special `<|audio_out_bos|>` token. We use 128013 as the default value,
104
- which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer.
105
- audio_token (`str`, *optional*, defaults to "<|AUDIO|>"):
106
- The special `<|AUDIO|>` token. We use 128015 as the default value,
107
- which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer.
108
- This token indicates that the location should be filled in with whisper features.
109
- audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"):
110
- The special `<|AUDIO_OUT|>` token. We use 128016 as the default value,
111
- which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer.
112
- This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer.
113
- """
114
-
115
- model_type = "higgs_audio"
116
- is_composition = True
117
-
118
- def __init__(
119
- self,
120
- text_config=None,
121
- audio_encoder_config=None,
122
- audio_tokenizer_config=None,
123
- audio_adapter_type="stack",
124
- audio_embed_avg=False,
125
- audio_ffn_hidden_size=4096,
126
- audio_ffn_intermediate_size=14336,
127
- audio_dual_ffn_layers=None,
128
- audio_decoder_proj_num_layers=0,
129
- encode_whisper_embed=True,
130
- encode_audio_in_tokens=False,
131
- use_delay_pattern=False,
132
- skip_audio_tower=False,
133
- use_audio_out_embed_projector=False,
134
- use_audio_out_self_attention=False,
135
- use_rq_transformer=False,
136
- rq_transformer_hidden_size=None,
137
- rq_transformer_intermediate_size=None,
138
- rq_transformer_num_attention_heads=None,
139
- rq_transformer_num_key_value_heads=None,
140
- rq_transformer_num_hidden_layers=3,
141
- audio_num_codebooks=12,
142
- audio_codebook_size=1024,
143
- audio_stream_bos_id=1024,
144
- audio_stream_eos_id=1025,
145
- audio_bos_token="<|audio_bos|>",
146
- audio_eos_token="<|audio_eos|>",
147
- audio_out_bos_token="<|audio_out_bos|>",
148
- audio_in_token="<|AUDIO|>",
149
- audio_out_token="<|AUDIO_OUT|>",
150
- audio_in_token_idx=128015,
151
- audio_out_token_idx=128016,
152
- pad_token_id=128001,
153
- audio_out_bos_token_id=128013,
154
- audio_eos_token_id=128012,
155
- **kwargs,
156
- ):
157
- if isinstance(audio_encoder_config, dict):
158
- audio_encoder_config["model_type"] = (
159
- audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder"
160
- )
161
- audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config)
162
- elif audio_encoder_config is None:
163
- audio_encoder_config = HiggsAudioEncoderConfig()
164
-
165
- if isinstance(text_config, dict):
166
- text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
167
- text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
168
- elif text_config is None:
169
- text_config = CONFIG_MAPPING["llama"]()
170
-
171
- assert audio_adapter_type in [
172
- "stack",
173
- "dual_ffn",
174
- "dual_ffn_fast_forward",
175
- ], f"Invalid audio adapter type: {audio_adapter_type}"
176
- if audio_adapter_type.startswith("dual_ffn"):
177
- assert audio_dual_ffn_layers is not None, (
178
- "audio_dual_ffn_layers must be specified when using dual_ffn adapter."
179
- )
180
- self.text_config = text_config
181
- self.audio_encoder_config = audio_encoder_config
182
- self.audio_tokenizer_config = audio_tokenizer_config
183
- self.audio_adapter_type = audio_adapter_type
184
- self.audio_embed_avg = audio_embed_avg
185
- self.audio_ffn_hidden_size = audio_ffn_hidden_size
186
- self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
187
- self.audio_dual_ffn_layers = audio_dual_ffn_layers
188
- self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
189
- self.encode_whisper_embed = encode_whisper_embed
190
- self.encode_audio_in_tokens = encode_audio_in_tokens
191
- self.use_delay_pattern = use_delay_pattern
192
- self.skip_audio_tower = skip_audio_tower
193
- self.use_audio_out_embed_projector = use_audio_out_embed_projector
194
- self.use_audio_out_self_attention = use_audio_out_self_attention
195
-
196
- self.use_rq_transformer = use_rq_transformer
197
-
198
- if self.use_rq_transformer:
199
- assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
200
- self.rq_transformer_hidden_size = rq_transformer_hidden_size
201
- self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
202
- self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
203
- self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
204
- self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
205
-
206
- if use_rq_transformer:
207
- # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
208
- if self.rq_transformer_hidden_size is None:
209
- self.rq_transformer_hidden_size = text_config.hidden_size
210
- assert self.rq_transformer_hidden_size % 128 == 0
211
- if self.rq_transformer_intermediate_size is None:
212
- self.rq_transformer_intermediate_size = text_config.intermediate_size
213
- if self.rq_transformer_num_attention_heads is None:
214
- self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
215
- if self.rq_transformer_num_key_value_heads is None:
216
- self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
217
- assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
218
- assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
219
-
220
- self.audio_num_codebooks = audio_num_codebooks
221
- self.audio_codebook_size = audio_codebook_size
222
- self.audio_bos_token = audio_bos_token
223
- self.audio_eos_token = audio_eos_token
224
- self.audio_out_bos_token = audio_out_bos_token
225
- self.audio_in_token = audio_in_token
226
- self.audio_out_token = audio_out_token
227
- self.audio_in_token_idx = audio_in_token_idx
228
- self.audio_out_token_idx = audio_out_token_idx
229
- self.audio_stream_bos_id = audio_stream_bos_id
230
- self.audio_stream_eos_id = audio_stream_eos_id
231
- self.audio_out_bos_token_id = audio_out_bos_token_id
232
- self.audio_eos_token_id = audio_eos_token_id
233
-
234
- super().__init__(**kwargs)
235
- self.pad_token_id = pad_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/model/cuda_graph_runner.py DELETED
@@ -1,129 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from typing import Optional, List, Dict, Tuple, Union
4
- import gc
5
-
6
- from transformers.cache_utils import Cache
7
-
8
-
9
- _NUM_WARMUP_ITERS = 2
10
-
11
-
12
- class CUDAGraphRunner(nn.Module):
13
- def __init__(self, model):
14
- super().__init__()
15
- self.model = model
16
-
17
- self.input_buffers: Dict[str, torch.Tensor] = {}
18
- self.output_buffers: Dict[str, torch.Tensor] = {}
19
-
20
- self._graph: Optional[torch.cuda.CUDAGraph] = None
21
-
22
- @property
23
- def graph(self):
24
- assert self._graph is not None
25
- return self._graph
26
-
27
- def capture(
28
- self,
29
- hidden_states: torch.Tensor,
30
- causal_mask: torch.Tensor,
31
- position_ids: torch.Tensor,
32
- audio_discrete_codes_mask: torch.Tensor,
33
- cache_position: torch.Tensor,
34
- past_key_values: Union[Cache, List[torch.FloatTensor]],
35
- use_cache: bool,
36
- audio_attention_mask: torch.Tensor,
37
- fast_forward_attention_mask: torch.Tensor,
38
- output_attentions: bool,
39
- output_hidden_states: bool,
40
- is_decoding_audio_token: Optional[bool] = None,
41
- is_using_cuda_graph: Optional[bool] = False,
42
- stream: torch.cuda.Stream = None,
43
- memory_pool: Optional[Tuple[int, int]] = None,
44
- ):
45
- assert self._graph is None
46
- # Run warmup iterations
47
- for _ in range(_NUM_WARMUP_ITERS):
48
- self.model(
49
- hidden_states=hidden_states,
50
- causal_mask=causal_mask,
51
- position_ids=position_ids,
52
- audio_discrete_codes_mask=audio_discrete_codes_mask,
53
- cache_position=cache_position,
54
- past_key_values=past_key_values,
55
- use_cache=use_cache,
56
- audio_attention_mask=audio_attention_mask,
57
- fast_forward_attention_mask=fast_forward_attention_mask,
58
- output_attentions=output_attentions,
59
- output_hidden_states=output_hidden_states,
60
- is_decoding_audio_token=is_decoding_audio_token,
61
- is_using_cuda_graph=is_using_cuda_graph,
62
- )
63
-
64
- torch.cuda.synchronize()
65
-
66
- # Capture the graph
67
- self._graph = torch.cuda.CUDAGraph()
68
- with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
69
- out_hidden_states, all_hidden_states, all_self_attns = self.model(
70
- hidden_states=hidden_states,
71
- causal_mask=causal_mask,
72
- position_ids=position_ids,
73
- audio_discrete_codes_mask=audio_discrete_codes_mask,
74
- cache_position=cache_position,
75
- past_key_values=past_key_values,
76
- use_cache=use_cache,
77
- audio_attention_mask=audio_attention_mask,
78
- fast_forward_attention_mask=fast_forward_attention_mask,
79
- output_attentions=output_attentions,
80
- output_hidden_states=output_hidden_states,
81
- is_decoding_audio_token=is_decoding_audio_token,
82
- is_using_cuda_graph=is_using_cuda_graph,
83
- )
84
- # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
85
- # del outputs
86
- gc.collect()
87
- torch.cuda.synchronize()
88
-
89
- # Save input and output buffers
90
- self.input_buffers = {
91
- "hidden_states": hidden_states,
92
- "causal_mask": causal_mask,
93
- "position_ids": position_ids,
94
- "audio_discrete_codes_mask": audio_discrete_codes_mask,
95
- "cache_position": cache_position,
96
- "past_key_values": past_key_values,
97
- "audio_attention_mask": audio_attention_mask,
98
- "fast_forward_attention_mask": fast_forward_attention_mask,
99
- }
100
- self.output_buffers = {
101
- "hidden_states": out_hidden_states,
102
- "all_hidden_states": all_hidden_states,
103
- "all_self_attns": all_self_attns,
104
- }
105
-
106
- def forward(
107
- self,
108
- hidden_states: torch.Tensor,
109
- causal_mask: torch.Tensor,
110
- position_ids: torch.Tensor,
111
- audio_discrete_codes_mask: torch.Tensor,
112
- cache_position: torch.Tensor,
113
- audio_attention_mask: torch.Tensor,
114
- fast_forward_attention_mask: torch.Tensor,
115
- **kwargs,
116
- ) -> torch.Tensor:
117
- # Copy input tensors to buffers
118
- self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
119
- self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
120
- self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
121
- self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
122
- self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
123
- self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
124
- self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
125
-
126
- # Run the captured graph
127
- self.graph.replay()
128
-
129
- return self.output_buffers["hidden_states"], None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/model/custom_modules.py DELETED
@@ -1,155 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class PartiallyFrozenEmbedding(nn.Module):
6
- """Split an existing `nn.Embedding` module that splits the embedding into:
7
-
8
- - A frozen embedding for indices [0..freeze_until_idx].
9
- - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
10
-
11
- This should work with both Zero-2 and Zero-3 seamlessly
12
- """
13
-
14
- def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
15
- """
16
- :param original_embedding: An instance of nn.Embedding (the original embedding layer).
17
- :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
18
- """
19
- super().__init__()
20
- self.freeze_until_idx = freeze_until_idx
21
- self.original_vocab_size = original_embedding.num_embeddings
22
- self.embedding_dim = original_embedding.embedding_dim
23
-
24
- # Split the original embedding into frozen and trainable parts
25
- self.embedding_frozen = nn.Embedding(
26
- freeze_until_idx,
27
- self.embedding_dim,
28
- dtype=original_embedding.weight.dtype,
29
- device=original_embedding.weight.device,
30
- )
31
- self.embedding_trainable = nn.Embedding(
32
- self.original_vocab_size - freeze_until_idx,
33
- self.embedding_dim,
34
- dtype=original_embedding.weight.dtype,
35
- device=original_embedding.weight.device,
36
- )
37
-
38
- # Copy weights from the original embedding into the frozen and trainable parts
39
- with torch.no_grad():
40
- self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
41
- self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
42
-
43
- # Freeze the frozen embedding
44
- self.embedding_frozen.weight.requires_grad = False
45
-
46
- def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
47
- """
48
- Forward pass for the split embedding wrapper.
49
- :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
50
- """
51
- # Masks to separate frozen and trainable indices
52
- # (bsz, seq_len)
53
- mask_frozen = input_ids < self.freeze_until_idx
54
- mask_trainable = ~mask_frozen
55
-
56
- # Output tensor for embedding results
57
- batch_size, seq_len = input_ids.shape
58
- embeddings = torch.zeros(
59
- batch_size,
60
- seq_len,
61
- self.embedding_dim,
62
- device=input_ids.device,
63
- dtype=self.embedding_frozen.weight.dtype,
64
- )
65
-
66
- # Handle frozen embedding
67
- if mask_frozen.any():
68
- frozen_ids = input_ids[mask_frozen]
69
- frozen_emb = self.embedding_frozen(frozen_ids)
70
- embeddings[mask_frozen] = frozen_emb
71
-
72
- # Handle trainable embedding
73
- if mask_trainable.any():
74
- # Adjust trainable IDs to the local index space of the trainable embedding
75
- trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
76
- trainable_emb = self.embedding_trainable(trainable_ids)
77
- embeddings[mask_trainable] = trainable_emb
78
-
79
- return embeddings
80
-
81
- def to_unsplit(self) -> nn.Embedding:
82
- unsplit_embedding = nn.Embedding(
83
- self.original_vocab_size,
84
- self.embedding_dim,
85
- dtype=self.embedding_frozen.weight.dtype,
86
- device=self.embedding_frozen.weight.device,
87
- )
88
-
89
- with torch.no_grad():
90
- unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
91
- unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
92
-
93
- return unsplit_embedding
94
-
95
-
96
- class PartiallyFrozenLinear(nn.Module):
97
- """A wrapper around nn.Linear to partially freeze part of the weight matrix."""
98
-
99
- def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
100
- """
101
- :param original_linear: The original nn.Linear layer.
102
- :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
103
- """
104
- super().__init__()
105
- assert original_linear.bias is None, "Currently only support linear module without bias"
106
-
107
- self.freeze_until_idx = freeze_until_idx
108
- self.input_dim = original_linear.in_features
109
- self.output_dim = original_linear.out_features
110
-
111
- # Create frozen and trainable linear layers
112
- self.linear_frozen = nn.Linear(
113
- self.input_dim,
114
- freeze_until_idx,
115
- bias=False,
116
- dtype=original_linear.weight.dtype,
117
- device=original_linear.weight.device,
118
- )
119
- self.linear_trainable = nn.Linear(
120
- self.input_dim,
121
- self.output_dim - freeze_until_idx,
122
- bias=False,
123
- dtype=original_linear.weight.dtype,
124
- device=original_linear.weight.device,
125
- )
126
-
127
- # Copy weights from the original linear layer
128
- with torch.no_grad():
129
- self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
130
- self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
131
-
132
- # Freeze the frozen linear layer
133
- self.linear_frozen.weight.requires_grad = False
134
-
135
- def forward(self, input_tensor):
136
- # input_tensor: (bsz, seq_len, hidden_state_dim)
137
- frozen_output = self.linear_frozen(input_tensor)
138
- trainable_output = self.linear_trainable(input_tensor)
139
- return torch.cat((frozen_output, trainable_output), dim=-1)
140
-
141
- def to_unsplit(self) -> nn.Linear:
142
- unsplit_linear = nn.Linear(
143
- self.input_dim,
144
- self.output_dim,
145
- bias=False,
146
- dtype=self.linear_frozen.weight.dtype,
147
- device=self.linear_frozen.weight.device,
148
- )
149
-
150
- # Copy weights from the frozen and trainable layers into the unsplit linear layer
151
- with torch.no_grad():
152
- unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
153
- unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
154
-
155
- return unsplit_linear
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/model/modeling_higgs_audio.py DELETED
The diff for this file is too large to render. See raw diff
 
higgs_audio/model/utils.py DELETED
@@ -1,778 +0,0 @@
1
- import contextlib
2
- from contextlib import contextmanager
3
- from functools import wraps
4
- import torch
5
- from transformers.integrations import is_deepspeed_available
6
-
7
- if is_deepspeed_available():
8
- from deepspeed.utils import groups as deepspeed_groups
9
- from deepspeed.sequence.layer import _SeqAllToAll
10
- else:
11
- deepspeed_groups = None
12
- _SeqAllToAll = None
13
-
14
-
15
- def _ceil_to_nearest(n, round_to):
16
- return (n + round_to - 1) // round_to * round_to
17
-
18
-
19
- def count_parameters(model, trainable_only=True):
20
- if trainable_only:
21
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
22
- else:
23
- return sum(p.numel() for p in model.parameters())
24
-
25
-
26
- # TODO(sxjscience) Consider to move the function to audio_processing/utils.py
27
- def build_delay_pattern_mask(
28
- input_ids: torch.LongTensor,
29
- bos_token_id: int,
30
- pad_token_id: int,
31
- ):
32
- """Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
33
-
34
- In the delay pattern, each codebook is offset by the previous codebook by
35
- one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
36
-
37
- Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
38
-
39
- - [ *, *, *, *, *, P, P, P]
40
- - [ B, *, *, *, *, *, P, P]
41
- - [ B, B, *, *, *, *, *, P]
42
- - [ B, B, B, *, *, *, *, *]
43
-
44
- where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
45
-
46
- Now let's consider the case where we have a sequence of audio tokens to condition on.
47
- The audio tokens were originally in the following non-delayed form:
48
-
49
- - [a, b]
50
- - [c, d]
51
- - [e, f]
52
- - [g, h]
53
-
54
- After conversion, we get the following delayed form:
55
- - [a, b, -1, -1, -1]
56
- - [B, c, d, -1, -1]
57
- - [B, B, e, f, -1]
58
- - [B, B, B, g, h]
59
-
60
- Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
61
- In that case, we should override the `-1` tokens in auto-regressive generation.
62
-
63
- Args:
64
- input_ids (:obj:`torch.LongTensor`):
65
- The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
66
- bos_token_id (:obj:`int`):
67
- The id of the special delay token
68
- pad_token_id (:obj:`int`):
69
- The id of the padding token. Should be the same as eos_token_id.
70
-
71
- Returns:
72
- input_ids (:obj:`torch.LongTensor`):
73
- The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
74
- input_ids_with_gen_mask (:obj:`torch.LongTensor`):
75
- The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
76
-
77
- """
78
- bsz, num_codebooks, seq_len = input_ids.shape
79
-
80
- new_seq_len = seq_len + num_codebooks - 1
81
- input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
82
- bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
83
- eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
84
- input_ids_with_gen_mask[bos_mask] = bos_token_id
85
- input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
86
- input_ids = input_ids_with_gen_mask.clone()
87
- input_ids[eos_mask] = pad_token_id
88
- input_ids_with_gen_mask[eos_mask] = -1
89
- return input_ids, input_ids_with_gen_mask
90
-
91
-
92
- def revert_delay_pattern(data):
93
- """Convert samples encoded with delay pattern back to the original form.
94
-
95
- Args:
96
- data (:obj:`torch.Tensor`):
97
- The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
98
-
99
- Returns:
100
- ret (:obj:`torch.Tensor`):
101
- Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
102
- """
103
- assert len(data.shape) == 2
104
- out_l = []
105
- num_codebooks = data.shape[0]
106
- for i in range(num_codebooks):
107
- out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
108
- return torch.cat(out_l, dim=0)
109
-
110
-
111
- def merge_input_ids_with_audio_features(
112
- audio_features_embed,
113
- audio_features_length,
114
- audio_in_embed,
115
- audio_in_ids_start,
116
- audio_out_embed,
117
- audio_out_ids_start,
118
- audio_in_token_idx,
119
- audio_out_token_idx,
120
- inputs_embeds,
121
- input_ids,
122
- attention_mask,
123
- label_ids,
124
- pad_token_id,
125
- ignore_index=-100,
126
- round_to=8,
127
- left_padding=True,
128
- ):
129
- """
130
- Merge input_ids with audio features into final embeddings.
131
-
132
- Args:
133
- audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
134
- Encoded vectors of all audios in the batch (obtained from the semantic encoder)
135
- audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
136
- The length of audio embeddings of each audio as stacked in `audio_features_embed`
137
- audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
138
- The embeddings of audio-in tokens
139
- audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
140
- The start index of the audio-in tokens for each audio
141
- audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
142
- The embeddings of audio-out tokens
143
- audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
144
- The start index of the audio-out tokens for each audio
145
- audio_in_token_idx
146
- The index of the audio-in token in the vocabulary
147
- audio_out_token_idx
148
- The index of the audio-out token in the vocabulary
149
- inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
150
- Token embeddings before merging with audio embeddings
151
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
152
- Input_ids of tokens, possibly filled with audio token
153
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
154
- Mask to avoid performing attention on padding token indices.
155
- label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
156
- labels need to be recalculated to support training (if provided)
157
- pad_token_id (`int`):
158
- The index of the pad token in the vocabulary
159
- ignore_index
160
- The index to ignore in the loss calculation
161
- round_to
162
- The number to round to for padding
163
- left_padding
164
- Whether to apply left padding
165
-
166
- Returns:
167
- final_embedding
168
- The final embeddings after merging audio embeddings with text embeddings.
169
- final_attention_mask
170
- The final attention mask after merging audio embeddings with text embeddings.
171
- final_labels
172
- The labels for the text stream
173
- position_ids
174
- Positional ids for the merged data
175
- final_input_ids
176
- The final input_ids after merging audio embeddings with text embeddings.
177
- final_audio_in_mask
178
- Mask for audio-in embeddings
179
- final_audio_in_discrete_codes_mask
180
- Mask for audio-in discrete tokens
181
- final_audio_out_mask
182
- Mask for audio-out embeddings
183
-
184
- Explanation:
185
- each audio has variable length embeddings, with length specified by
186
- - audio_features_length
187
- - audio_in_ids_start
188
- - audio_out_ids_start
189
-
190
- Task:
191
- - fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
192
- - fill each <|AUDIO_OUT|> with the audio-out embeddings
193
-
194
- Example:
195
- <|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
196
- <|AUDIO|>: Z (8 tokens)
197
-
198
- X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
199
- if right padding
200
- input_ids: [
201
- a b c d e f X g h i j k Y l m
202
- o p q r Z s t u v _ _ _ _ _ _
203
- ]
204
- input_ids should be: [
205
- a b c d e f X X X X X g h i j k Y Y Y l m
206
- o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
207
- ]
208
- labels should be: [
209
- a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
210
- o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
211
- ]
212
- elif left padding
213
- input_ids: [
214
- a b c d e f X g h i j k Y l m
215
- _ _ _ _ _ _ o p q r Z s t u v
216
- ]
217
- input_ids should be: [
218
- a b c d e f X X X X X g h i j k Y Y Y l m
219
- _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
220
- ]
221
- labels should be: [
222
- a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
223
- _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
224
- ]
225
-
226
- """
227
- if label_ids is None:
228
- skip_labels = True
229
- else:
230
- skip_labels = False
231
- if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
232
- audio_features_embed = None
233
- if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
234
- audio_in_embed = None
235
- if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
236
- audio_out_embed = None
237
-
238
- batch_size, sequence_length, embed_dim = inputs_embeds.shape
239
-
240
- target_device = inputs_embeds.device
241
- if left_padding is None:
242
- left_padding = torch.any(attention_mask[:, 0] == 0)
243
-
244
- audio_in_token_mask = input_ids == audio_in_token_idx
245
- audio_out_token_mask = input_ids == audio_out_token_idx
246
- text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
247
-
248
- # 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
249
- token_placeholder_num = torch.ones_like(input_ids)
250
-
251
- if audio_features_embed is not None:
252
- num_audios, max_audio_tokens, _ = audio_features_embed.shape
253
- audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
254
- audio_features_length.device
255
- ) < audio_features_length.unsqueeze(1)
256
- masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
257
- token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
258
-
259
- if audio_in_embed is not None:
260
- audio_in_codes_length = torch.concat(
261
- [
262
- audio_in_ids_start[1:] - audio_in_ids_start[:-1],
263
- torch.tensor(
264
- [audio_in_embed.shape[0] - audio_in_ids_start[-1]],
265
- device=audio_in_ids_start.device,
266
- dtype=torch.long,
267
- ),
268
- ],
269
- dim=0,
270
- )
271
- if audio_features_embed is not None:
272
- token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
273
- else:
274
- token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
275
-
276
- if audio_out_embed is not None:
277
- audio_out_codes_length = torch.concat(
278
- [
279
- audio_out_ids_start[1:] - audio_out_ids_start[:-1],
280
- torch.tensor(
281
- [audio_out_embed.shape[0] - audio_out_ids_start[-1]],
282
- device=audio_out_ids_start.device,
283
- dtype=torch.long,
284
- ),
285
- ],
286
- dim=0,
287
- )
288
- token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
289
-
290
- new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
291
- max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
292
- nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
293
-
294
- if left_padding:
295
- new_token_positions += nb_audio_pad[:, None] # offset for left padding
296
-
297
- # 2. Create the full embedding, already padded to the maximum position
298
- final_embedding = torch.zeros(
299
- (batch_size, max_token_num, embed_dim),
300
- dtype=inputs_embeds.dtype,
301
- device=inputs_embeds.device,
302
- )
303
- final_attention_mask = torch.zeros(
304
- (batch_size, max_token_num),
305
- dtype=attention_mask.dtype,
306
- device=inputs_embeds.device,
307
- )
308
- final_input_ids = torch.full(
309
- (batch_size, max_token_num),
310
- pad_token_id,
311
- dtype=input_ids.dtype,
312
- device=inputs_embeds.device,
313
- )
314
- if skip_labels:
315
- final_labels = None
316
- else:
317
- final_labels = torch.full(
318
- (batch_size, max_token_num),
319
- ignore_index,
320
- dtype=label_ids.dtype,
321
- device=inputs_embeds.device,
322
- )
323
-
324
- final_audio_in_mask = torch.full(
325
- (batch_size, max_token_num),
326
- False,
327
- dtype=torch.bool,
328
- device=inputs_embeds.device,
329
- )
330
- final_audio_in_discrete_codes_mask = torch.full(
331
- (batch_size, max_token_num),
332
- False,
333
- dtype=torch.bool,
334
- device=inputs_embeds.device,
335
- )
336
- final_audio_out_mask = torch.full(
337
- (batch_size, max_token_num),
338
- False,
339
- dtype=torch.bool,
340
- device=inputs_embeds.device,
341
- )
342
- # 3. Get the audio-in token positions and audio-out token positions
343
- batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
344
- audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
345
- audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
346
- audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
347
- audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
348
-
349
- if audio_in_embed is not None:
350
- # Fill in the audio-in embeddings
351
- seq_indices = (
352
- torch.arange(max_token_num, device=target_device)
353
- .unsqueeze(0)
354
- .expand(audio_in_ids_start.shape[0], max_token_num)
355
- )
356
- audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
357
- batch_indices, col_indices = torch.where(
358
- (seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
359
- & (seq_indices <= audio_features_token_ends.unsqueeze(1))
360
- )
361
- batch_indices = audio_in_batch_id[batch_indices]
362
- final_embedding[batch_indices, col_indices] = audio_in_embed
363
- final_input_ids[batch_indices, col_indices] = audio_in_token_idx
364
- if not skip_labels:
365
- final_labels[batch_indices, col_indices] = ignore_index
366
- final_audio_in_mask[batch_indices, col_indices] = True
367
- final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
368
- audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
369
-
370
- if audio_features_embed is not None:
371
- # Fill in the audio features
372
- seq_indices = (
373
- torch.arange(max_token_num, device=target_device)
374
- .unsqueeze(0)
375
- .expand(audio_features_embed.shape[0], max_token_num)
376
- )
377
- audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
378
- batch_indices, col_indices = torch.where(
379
- (seq_indices >= audio_features_token_starts.unsqueeze(1))
380
- & (seq_indices <= audio_features_token_ends.unsqueeze(1))
381
- )
382
- batch_indices = audio_in_batch_id[batch_indices]
383
- final_embedding[batch_indices, col_indices] = masked_audio_in_features
384
- final_input_ids[batch_indices, col_indices] = audio_in_token_idx
385
- if not skip_labels:
386
- final_labels[batch_indices, col_indices] = ignore_index
387
- final_audio_in_mask[batch_indices, col_indices] = True
388
-
389
- if audio_out_embed is not None:
390
- # Fill in the audio-out embeddings
391
- seq_indices = (
392
- torch.arange(max_token_num, device=target_device)
393
- .unsqueeze(0)
394
- .expand(audio_out_ids_start.shape[0], max_token_num)
395
- )
396
- audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
397
- batch_indices, col_indices = torch.where(
398
- (seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
399
- & (seq_indices <= audio_out_embed_ends.unsqueeze(1))
400
- )
401
- batch_indices = audio_out_batch_id[batch_indices]
402
- final_embedding[batch_indices, col_indices] = audio_out_embed
403
- final_input_ids[batch_indices, col_indices] = audio_out_token_idx
404
- if not skip_labels:
405
- final_labels[batch_indices, col_indices] = ignore_index
406
- final_audio_out_mask[batch_indices, col_indices] = True
407
-
408
- # Fill in the original text embeddings and labels
409
- batch_indices, non_audio_indices = torch.where(text_token_mask)
410
- text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
411
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
412
- if not skip_labels:
413
- final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
414
- final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
415
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
416
- final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
417
-
418
- # Trim the tensor if there are redundant padding tokens
419
- if left_padding:
420
- first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
421
- first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
422
- if first_non_zero_loc > 0:
423
- final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
424
- final_embedding = final_embedding[:, first_non_zero_loc:]
425
- if not skip_labels:
426
- final_labels = final_labels[:, first_non_zero_loc:]
427
- final_input_ids = final_input_ids[:, first_non_zero_loc:]
428
- final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
429
- final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
430
- final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
431
- else:
432
- # We have done right padding, so we need to trim the mask
433
- last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
434
- last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
435
- if last_non_zero_loc < max_token_num:
436
- final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
437
- final_embedding = final_embedding[:, :last_non_zero_loc]
438
- if not skip_labels:
439
- final_labels = final_labels[:, :last_non_zero_loc]
440
- final_input_ids = final_input_ids[:, :last_non_zero_loc]
441
- final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
442
- final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
443
- final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
444
-
445
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
446
- return (
447
- final_embedding,
448
- final_attention_mask,
449
- final_labels,
450
- position_ids,
451
- final_input_ids,
452
- final_audio_in_mask,
453
- final_audio_in_discrete_codes_mask,
454
- final_audio_out_mask,
455
- )
456
-
457
-
458
- def is_deepspeed_ulysses_enabled():
459
- if deepspeed_groups is None:
460
- return False
461
-
462
- """Check if sequence parallelism is enabled."""
463
- return deepspeed_groups._get_sequence_parallel_world_size() > 1
464
-
465
-
466
- def support_deepspeed_ulysses(module):
467
- """A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
468
- module._sp_size = None
469
- module._sp_rank = None
470
- module._sp_group = None
471
-
472
- @property
473
- def sp_size(self):
474
- if self._sp_size is None:
475
- self._sp_size = 1
476
- if is_deepspeed_ulysses_enabled():
477
- self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
478
- return self._sp_size
479
-
480
- @property
481
- def sp_rank(self):
482
- if self._sp_rank is None:
483
- self._sp_rank = 0
484
- if is_deepspeed_ulysses_enabled():
485
- self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
486
- return self._sp_rank
487
-
488
- @property
489
- def sp_group(self):
490
- if self._sp_group is None and is_deepspeed_ulysses_enabled():
491
- self._sp_group = deepspeed_groups._get_sequence_parallel_group()
492
- return self._sp_group
493
-
494
- module.sp_size = sp_size
495
- module.sp_rank = sp_rank
496
- module.sp_group = sp_group
497
-
498
- return module
499
-
500
-
501
- def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
502
- """Perform all-to-all before and after the attention function."""
503
-
504
- def attention_decorator(attn_func=None):
505
- def wrapped(*args, **kwargs):
506
- if is_deepspeed_ulysses_enabled():
507
- sp_group = deepspeed_groups._get_sequence_parallel_group()
508
- scatter_idx = head_dim # Scatter on num_heads dimension
509
- gather_idx = seq_dim # Gather on seq_len dimension
510
- batch_dim_idx = 0
511
- args = list(args)
512
- args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
513
- args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
514
- args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
515
- args = tuple(args)
516
-
517
- attn_output = attn_func(*args, **kwargs)
518
-
519
- if is_deepspeed_ulysses_enabled():
520
- scatter_idx = seq_dim # Scatter back on seq_len dimension
521
- gather_idx = head_dim # Gather on num_heads dimension
522
- batch_dim_idx = 0
523
- attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
524
-
525
- return attn_output
526
-
527
- return wrapped
528
-
529
- return attention_decorator
530
-
531
-
532
- def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
533
- """Slice the corresponding cos and sin chunks for rope."""
534
-
535
- def rope_decorator(rope_func=None):
536
- def wrapped(*args, **kwargs):
537
- if is_deepspeed_ulysses_enabled():
538
- sp_rank = deepspeed_groups._get_sequence_parallel_rank()
539
- args = list(args)
540
- seq_chunk_size = args[0].size(state_seq_dim)
541
- args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
542
- args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
543
- args = tuple(args)
544
-
545
- return rope_func(*args, **kwargs)
546
-
547
- return wrapped
548
-
549
- return rope_decorator
550
-
551
-
552
- def _gather_tensors(input_, group=None):
553
- """Gather tensors and concatenate them along a dimension."""
554
- input_ = input_.contiguous()
555
- world_size = torch.distributed.get_world_size(group)
556
- if world_size == 1:
557
- return input_
558
- tensor_shapes = [
559
- torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
560
- ]
561
- input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
562
- torch.distributed.all_gather(tensor_shapes, input_size, group=group)
563
- gathered_buffers = [
564
- torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
565
- ]
566
- torch.distributed.all_gather(gathered_buffers, input_, group=group)
567
- return gathered_buffers
568
-
569
-
570
- def _scatter_tensors(input_, group=None):
571
- """Scatter tensors."""
572
- world_size = torch.distributed.get_world_size(group)
573
- if world_size == 1:
574
- return input_
575
- rank = torch.distributed.get_rank(group)
576
- return input_[rank]
577
-
578
-
579
- class _GatherTensors(torch.autograd.Function):
580
- """All gather tensors among the ranks."""
581
-
582
- @staticmethod
583
- def symbolic(graph, input_, group):
584
- return _gather_tensors(input_, group)
585
-
586
- @staticmethod
587
- def forward(ctx, input_, group):
588
- ctx.group = group
589
- return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
590
-
591
- @staticmethod
592
- def backward(ctx, grad_output):
593
- return _scatter_tensors(grad_output, ctx.group), None
594
-
595
-
596
- def all_gather_tensors(input_, size=None, dim=0, group=None):
597
- if torch.distributed.get_world_size(group) == 1:
598
- # no sequence parallelism
599
- return input_
600
- gathered_tensors = _GatherTensors.apply(input_, group)
601
-
602
- if size:
603
- split_gathered_tensors = []
604
- for s, gathered_tensor in zip(size, gathered_tensors):
605
- split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
606
- split_gathered_tensors.append(split_gathered_tensor)
607
-
608
- gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
609
-
610
- return torch.cat(gathered_tensors, dim).contiguous()
611
-
612
-
613
- def get_sequence_data_parallel_world_size():
614
- return torch.distributed.get_world_size()
615
-
616
-
617
- def get_sequence_data_parallel_rank():
618
- return torch.distributed.get_rank()
619
-
620
-
621
- def get_sequence_data_parallel_group():
622
- return torch.distributed.group.WORLD
623
-
624
-
625
- if is_deepspeed_available():
626
- deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
627
- deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
628
- deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
629
-
630
-
631
- def _gather_tokens(input_, dim=0, group=None):
632
- """Gather tensors and concatenate them along a dimension"""
633
- input_ = input_.contiguous()
634
- world_size = torch.distributed.get_world_size(group)
635
- if world_size == 1:
636
- return input_
637
-
638
- gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
639
- torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
640
- if dim == 0:
641
- shape = list(input_.size())
642
- shape[0] = shape[0] * world_size
643
- output = gather_buffer.view(shape)
644
- else:
645
- tensor_list = [
646
- gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
647
- ]
648
- # Note: torch.cat already creates a contiguous tensor.
649
- output = torch.cat(tensor_list, dim=dim).contiguous()
650
-
651
- return output
652
-
653
-
654
- def _drop_tokens(input_, dim=0, group=None):
655
- """Divide a tensor among the sequence parallel ranks"""
656
- world_size = torch.distributed.get_world_size(group)
657
- if world_size == 1:
658
- return input_
659
- this_rank = torch.distributed.get_rank(group)
660
- assert input_.shape[dim] % world_size == 0, (
661
- f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
662
- )
663
- chunk_size = input_.shape[dim] // world_size
664
-
665
- return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
666
-
667
-
668
- class _DropTokens(torch.autograd.Function):
669
- "Divide tokens equally among the sequence parallel ranks"
670
-
671
- @staticmethod
672
- def symbolic(graph, input_, dim, group, grad_scale):
673
- return _drop_tokens(input_, dim, group)
674
-
675
- @staticmethod
676
- def forward(ctx, input_, dim, group, grad_scale):
677
- ctx.dim = dim
678
- ctx.group = group
679
- ctx.grad_scale = grad_scale
680
- return _drop_tokens(input_, dim, group)
681
-
682
- @staticmethod
683
- def backward(ctx, grad_output):
684
- grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
685
- if ctx.grad_scale != 1:
686
- grad_input /= ctx.grad_scale
687
- return grad_input, None, None, None
688
-
689
-
690
- class _GatherTokens(torch.autograd.Function):
691
- "Gather tokens among the sequence parallel ranks"
692
-
693
- @staticmethod
694
- def symbolic(graph, input_, dim, group, grad_scale):
695
- return _gather_tokens(input_, dim, group)
696
-
697
- @staticmethod
698
- def forward(ctx, input_, dim, group, grad_scale):
699
- ctx.dim = dim
700
- ctx.group = group
701
- ctx.grad_scale = grad_scale
702
- return _gather_tokens(input_, dim, group)
703
-
704
- @staticmethod
705
- def backward(ctx, grad_output):
706
- grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
707
- if ctx.grad_scale != 1:
708
- grad_input *= ctx.grad_scale
709
- return grad_input, None, None, None
710
-
711
-
712
- def drop_tokens(input_, dim=0, group=None, grad_scale=1):
713
- if torch.distributed.get_world_size(group) == 1:
714
- # no sequence parallelism
715
- return input_
716
- return _DropTokens.apply(input_, dim, group, grad_scale)
717
-
718
-
719
- def gather_tokens(input_, dim=0, group=None, grad_scale=1):
720
- if torch.distributed.get_world_size(group) == 1:
721
- # no sequence parallelism
722
- return input_
723
- return _GatherTokens.apply(input_, dim, group, grad_scale)
724
-
725
-
726
- def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
727
- """
728
- Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
729
-
730
- Args:
731
- sp_size (`int`):
732
- Sequence parallel size.
733
- sp_rank (`int`):
734
- Sequence parallel rank for the current process.
735
- dim (`int`):
736
- The dimension to slice
737
- """
738
- if sp_size == 1:
739
- return args[0] if len(args) == 1 else args
740
-
741
- seq_length = args[0].size(dim)
742
- for arg in args[1:]:
743
- assert arg.size(dim) == seq_length, (
744
- f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
745
- )
746
- assert seq_length % sp_size == 0, (
747
- f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
748
- )
749
-
750
- sub_seq_length = seq_length // sp_size
751
- sub_seq_start = sp_rank * sub_seq_length
752
-
753
- output = []
754
- for ind in args:
755
- ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
756
- output.append(ind)
757
-
758
- return tuple(output) if len(output) > 1 else output[0]
759
-
760
-
761
- @contextmanager
762
- def disable_deepspeed_ulysses():
763
- """Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
764
- if is_deepspeed_ulysses_enabled():
765
- _old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
766
-
767
- def _get_sequence_parallel_world_size():
768
- return 1
769
-
770
- deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
771
- try:
772
- yield
773
- finally:
774
- deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
775
- else:
776
- context = contextlib.nullcontext
777
- with context():
778
- yield
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/serve/serve_engine.py DELETED
@@ -1,474 +0,0 @@
1
- import asyncio
2
- import base64
3
- import torch
4
- import numpy as np
5
- from io import BytesIO
6
- from dataclasses import dataclass, field
7
- from typing import List, Optional, Union
8
- from copy import deepcopy
9
- from transformers import AutoTokenizer, AutoProcessor
10
- from transformers.cache_utils import StaticCache
11
- from transformers.generation.streamers import BaseStreamer
12
- from transformers.generation.stopping_criteria import StoppingCriteria
13
- from dataclasses import asdict
14
- from loguru import logger
15
- import threading
16
- import librosa
17
-
18
-
19
- from ..dataset.chatml_dataset import (
20
- ChatMLSample,
21
- ChatMLDatasetSample,
22
- prepare_chatml_sample,
23
- )
24
- from ..model import HiggsAudioModel
25
- from ..model.utils import revert_delay_pattern
26
- from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
27
- from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
28
-
29
-
30
- def normalize_chinese_punctuation(text):
31
- """
32
- Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
33
- """
34
- # Mapping of Chinese punctuation to English punctuation
35
- chinese_to_english_punct = {
36
- ",": ",", # comma
37
- "。": ".", # period
38
- ":": ":", # colon
39
- ";": ";", # semicolon
40
- "?": "?", # question mark
41
- "!": "!", # exclamation mark
42
- "(": "(", # left parenthesis
43
- ")": ")", # right parenthesis
44
- "【": "[", # left square bracket
45
- "】": "]", # right square bracket
46
- "《": "<", # left angle quote
47
- "》": ">", # right angle quote
48
- "“": '"', # left double quotation
49
- "”": '"', # right double quotation
50
- "‘": "'", # left single quotation
51
- "’": "'", # right single quotation
52
- "、": ",", # enumeration comma
53
- "—": "-", # em dash
54
- "…": "...", # ellipsis
55
- "·": ".", # middle dot
56
- "「": '"', # left corner bracket
57
- "」": '"', # right corner bracket
58
- "『": '"', # left double corner bracket
59
- "』": '"', # right double corner bracket
60
- }
61
-
62
- # Replace each Chinese punctuation with its English counterpart
63
- for zh_punct, en_punct in chinese_to_english_punct.items():
64
- text = text.replace(zh_punct, en_punct)
65
-
66
- return text
67
-
68
-
69
- @dataclass
70
- class HiggsAudioStreamerDelta:
71
- """Represents a chunk of generated content, either text or audio tokens."""
72
-
73
- text: Optional[str] = None
74
- text_tokens: Optional[torch.Tensor] = None
75
- audio_tokens: Optional[torch.Tensor] = None
76
- finish_reason: Optional[str] = None
77
-
78
-
79
- class AsyncHiggsAudioStreamer(BaseStreamer):
80
- """
81
- Async streamer that handles both text and audio token generation from Higgs-Audio model.
82
- Stores chunks in a queue to be consumed by downstream applications.
83
-
84
- Parameters:
85
- tokenizer (`AutoTokenizer`):
86
- The tokenizer used to decode text tokens.
87
- skip_prompt (`bool`, *optional*, defaults to `False`):
88
- Whether to skip the prompt tokens in generation.
89
- timeout (`float`, *optional*):
90
- The timeout for the queue. If `None`, the queue will block indefinitely.
91
- decode_kwargs (`dict`, *optional*):
92
- Additional keyword arguments to pass to the tokenizer's `decode` method.
93
-
94
- Examples:
95
- ```python
96
- >>> from transformers import AutoTokenizer
97
- >>> from threading import Thread
98
- >>> import asyncio
99
-
100
- >>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
101
- >>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
102
- >>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
103
-
104
- >>> async def main():
105
- ... streamer = AsyncHiggsAudioStreamer(tokenizer)
106
- ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
107
- ... thread = Thread(target=model.generate, kwargs=generation_kwargs)
108
- ... thread.start()
109
- ...
110
- ... async for delta in streamer:
111
- ... if delta.text is not None:
112
- ... print("Text:", delta.text)
113
- ... if delta.audio_tokens is not None:
114
- ... print("Audio tokens shape:", delta.audio_tokens.shape)
115
- >>> asyncio.run(main())
116
- ```
117
- """
118
-
119
- def __init__(
120
- self,
121
- tokenizer: "AutoTokenizer",
122
- skip_prompt: bool = False,
123
- timeout: Optional[float] = None,
124
- audio_num_codebooks: int = 1,
125
- **decode_kwargs,
126
- ):
127
- self.tokenizer = tokenizer
128
- self.skip_prompt = skip_prompt
129
- self.timeout = timeout
130
- self.decode_kwargs = decode_kwargs
131
- self.audio_num_codebooks = audio_num_codebooks
132
-
133
- # Queue to store generated chunks
134
- self.queue = asyncio.Queue()
135
- self.stop_signal = None
136
-
137
- # Get running event loop
138
- self.loop = asyncio.get_running_loop()
139
- self.has_asyncio_timeout = hasattr(asyncio, "timeout")
140
-
141
- # State tracking
142
- self.next_tokens_are_prompt = True
143
-
144
- def put(self, value: torch.Tensor):
145
- """
146
- Receives tokens and processes them as either text or audio tokens.
147
- For text tokens, decodes and caches them until complete words are formed.
148
- For audio tokens, directly queues them.
149
- """
150
- if value.shape[0] > 1 and not self.next_tokens_are_prompt:
151
- # This is likely audio tokens (shape: [audio_num_codebooks])
152
- assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
153
- delta = HiggsAudioStreamerDelta(audio_tokens=value)
154
- self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
155
- return
156
-
157
- # Skip prompt tokens if configured
158
- if self.skip_prompt and self.next_tokens_are_prompt:
159
- self.next_tokens_are_prompt = False
160
- return
161
-
162
- # Process as text tokens
163
- if len(value.shape) > 1:
164
- value = value[0]
165
-
166
- text = self.tokenizer.decode(value, **self.decode_kwargs)
167
- delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
168
- self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
169
-
170
- def end(self):
171
- """Flushes any remaining text tokens and signals the end of generation."""
172
- self.next_tokens_are_prompt = True
173
- self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
174
-
175
- def __aiter__(self):
176
- return self
177
-
178
- async def __anext__(self):
179
- try:
180
- if self.has_asyncio_timeout:
181
- async with asyncio.timeout(self.timeout):
182
- value = await self.queue.get()
183
- else:
184
- value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
185
- except asyncio.TimeoutError:
186
- raise TimeoutError()
187
- else:
188
- if value == self.stop_signal:
189
- raise StopAsyncIteration()
190
- else:
191
- return value
192
-
193
-
194
- class AsyncStoppingCriteria(StoppingCriteria):
195
- """
196
- Stopping criteria that checks for stop signal from a threading event.
197
-
198
- Args:
199
- stop_signal (threading.Event): Event that will receive stop signals
200
- """
201
-
202
- def __init__(self, stop_signal: threading.Event):
203
- self.stop_signal = stop_signal
204
-
205
- def __call__(self, input_ids, scores, **kwargs) -> bool:
206
- if self.stop_signal.is_set():
207
- logger.info(f"Stop signal received. Can be caused by client disconnection.")
208
- return True
209
- return False
210
-
211
-
212
- @dataclass
213
- class HiggsAudioResponse:
214
- audio: Optional[np.ndarray] = None
215
- generated_audio_tokens: Optional[np.ndarray] = None
216
- sampling_rate: Optional[int] = None
217
- generated_text: str = ""
218
- generated_text_tokens: np.ndarray = field(default_factory=np.ndarray)
219
- usage: Optional[dict] = None
220
-
221
-
222
- class HiggsAudioServeEngine:
223
- def __init__(
224
- self,
225
- model_name_or_path: str,
226
- audio_tokenizer_name_or_path: str,
227
- tokenizer_name_or_path: Optional[str] = None,
228
- device: str = "cuda",
229
- torch_dtype: Union[torch.dtype, str] = "auto",
230
- kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
231
- ):
232
- """
233
- Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
234
- The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
235
-
236
- Args:
237
- model_name_or_path (str):
238
- The name or path of the model to load.
239
- audio_tokenizer_name_or_path (str):
240
- The name or path of the audio tokenizer to load.
241
- tokenizer_name_or_path (str):
242
- The name or path of the tokenizer to load.
243
- device (str):
244
- The device to use for the model.
245
- kv_cache_lengths (List[int]):
246
- The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
247
- torch_dtype (Union[torch.dtype, str]):
248
- The dtype to use for the model.
249
- """
250
- self.device = device
251
- self.model_name_or_path = model_name_or_path
252
- self.torch_dtype = torch_dtype
253
-
254
- # Initialize model and tokenizer
255
- self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
256
- logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
257
-
258
- if tokenizer_name_or_path is None:
259
- tokenizer_name_or_path = model_name_or_path
260
- logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
261
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
262
-
263
- logger.info(f"Initializing Higgs Audio Tokenizer")
264
- self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
265
-
266
- self.audio_num_codebooks = self.model.config.audio_num_codebooks
267
- self.audio_codebook_size = self.model.config.audio_codebook_size
268
- self.audio_tokenizer_tps = self.audio_tokenizer.tps
269
- self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
270
- self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
271
- # Set the audio special tokens
272
- self.model.set_audio_special_tokens(self.tokenizer)
273
-
274
- # Prepare KV caches for different lengths
275
- cache_config = deepcopy(self.model.config.text_config)
276
- cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
277
- if self.model.config.audio_dual_ffn_layers:
278
- cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
279
- # A list of KV caches for different lengths
280
- self.kv_caches = {
281
- length: StaticCache(
282
- config=cache_config,
283
- max_batch_size=1,
284
- max_cache_len=length,
285
- device=self.model.device,
286
- dtype=self.model.dtype,
287
- )
288
- for length in sorted(kv_cache_lengths)
289
- }
290
-
291
- if self.model.config.encode_whisper_embed:
292
- logger.info(f"Loading whisper processor")
293
- whisper_processor = AutoProcessor.from_pretrained(
294
- "openai/whisper-large-v3-turbo",
295
- trust_remote=True,
296
- device=self.device,
297
- )
298
- else:
299
- whisper_processor = None
300
-
301
- # Reuse collator to prepare inference samples
302
- self.collator = HiggsAudioSampleCollator(
303
- whisper_processor=whisper_processor,
304
- encode_whisper_embed=self.model.config.encode_whisper_embed,
305
- audio_in_token_id=self.model.config.audio_in_token_idx,
306
- audio_out_token_id=self.model.config.audio_out_token_idx,
307
- audio_stream_bos_id=self.model.config.audio_stream_bos_id,
308
- audio_stream_eos_id=self.model.config.audio_stream_eos_id,
309
- pad_token_id=self.model.config.pad_token_id,
310
- return_audio_in_tokens=False,
311
- use_delay_pattern=self.model.config.use_delay_pattern,
312
- audio_num_codebooks=self.model.config.audio_num_codebooks,
313
- round_to=1,
314
- )
315
-
316
- # Lock to prevent multiple generations from happening at the same time
317
- self.generate_lock = threading.Lock()
318
-
319
- # Capture CUDA graphs for each KV cache length
320
- # if device == "cuda":
321
- # logger.info(f"Capturing CUDA graphs for each KV cache length")
322
- # self.model.capture_model(self.kv_caches.values())
323
-
324
- def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
325
- input_tokens, _, audio_contents, _ = prepare_chatml_sample(
326
- chat_ml_sample,
327
- self.tokenizer,
328
- )
329
-
330
- postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
331
- if force_audio_gen:
332
- postfix += "<|audio_out_bos|>"
333
- postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
334
- input_tokens.extend(postfix)
335
-
336
- # Configure the audio inputs
337
- audio_ids_l = []
338
- for audio_content in audio_contents:
339
- if audio_content.audio_url not in ["placeholder", ""]:
340
- raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
341
- elif audio_content.raw_audio is not None:
342
- raw_audio, _ = librosa.load(
343
- BytesIO(base64.b64decode(audio_content.raw_audio)),
344
- sr=self.audio_tokenizer.sampling_rate,
345
- )
346
- else:
347
- raw_audio = None
348
-
349
- if raw_audio is not None:
350
- audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
351
- audio_ids_l.append(audio_ids.squeeze(0).cpu())
352
-
353
- if len(audio_ids_l) > 0:
354
- audio_ids_start = torch.tensor(
355
- np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
356
- dtype=torch.long,
357
- device=self.device,
358
- )[0:-1]
359
- audio_ids_concat = torch.cat(audio_ids_l, dim=1)
360
- else:
361
- audio_ids_start = None
362
- audio_ids_concat = None
363
-
364
- sample = ChatMLDatasetSample(
365
- input_ids=torch.LongTensor(input_tokens),
366
- label_ids=None,
367
- audio_ids_concat=audio_ids_concat,
368
- audio_ids_start=audio_ids_start,
369
- audio_waveforms_concat=None,
370
- audio_waveforms_start=None,
371
- audio_sample_rate=None,
372
- audio_speaker_indices=None,
373
- )
374
- data = self.collator([sample])
375
- inputs = asdict(data)
376
- for k, v in inputs.items():
377
- if isinstance(v, torch.Tensor):
378
- inputs[k] = v.to(self.model.device)
379
-
380
- return inputs
381
-
382
- def _prepare_kv_caches(self):
383
- for kv_cache in self.kv_caches.values():
384
- kv_cache.reset()
385
-
386
- def generate(
387
- self,
388
- chat_ml_sample: ChatMLSample,
389
- max_new_tokens: int,
390
- temperature: float = 0.7,
391
- top_k: Optional[int] = None,
392
- top_p: float = 0.95,
393
- stop_strings: Optional[List[str]] = None,
394
- force_audio_gen: bool = False,
395
- ras_win_len: Optional[int] = None,
396
- ras_win_max_num_repeat: int = 2,
397
- ):
398
- """
399
- Generate audio from a chatml sample.
400
- Args:
401
- chat_ml_sample: A chatml sample.
402
- max_new_tokens: The maximum number of new tokens to generate.
403
- temperature: The temperature to use for the generation.
404
- top_p: The top p to use for the generation.
405
- Returns:
406
- A dictionary with the following keys:
407
- audio: The generated audio.
408
- sampling_rate: The sampling rate of the generated audio.
409
- """
410
- # Default stop strings
411
- if stop_strings is None:
412
- stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
413
-
414
- with torch.no_grad(), self.generate_lock:
415
- inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
416
- prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
417
-
418
- self._prepare_kv_caches()
419
-
420
- outputs = self.model.generate(
421
- **inputs,
422
- max_new_tokens=max_new_tokens,
423
- use_cache=True,
424
- stop_strings=stop_strings,
425
- tokenizer=self.tokenizer,
426
- do_sample=False if temperature == 0.0 else True,
427
- temperature=temperature,
428
- top_k=top_k,
429
- top_p=top_p,
430
- past_key_values_buckets=self.kv_caches,
431
- ras_win_len=ras_win_len,
432
- ras_win_max_num_repeat=ras_win_max_num_repeat,
433
- )
434
-
435
- if len(outputs[1]) > 0:
436
- wv_list = []
437
- for output_audio in outputs[1]:
438
- vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
439
- wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
440
- wv_list.append(wv_numpy)
441
- wv_numpy = np.concatenate(wv_list)
442
- else:
443
- wv_numpy = None
444
-
445
- # We only support one request at a time now
446
- generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
447
- generated_text = self.tokenizer.decode(generated_text_tokens)
448
- generated_audio_tokens = outputs[1][0].cpu().numpy()
449
- return HiggsAudioResponse(
450
- audio=wv_numpy,
451
- generated_audio_tokens=generated_audio_tokens,
452
- sampling_rate=self.audio_tokenizer.sampling_rate,
453
- generated_text=generated_text,
454
- generated_text_tokens=generated_text_tokens,
455
- usage={
456
- "prompt_tokens": prompt_token_ids.shape[0],
457
- "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
458
- "total_tokens": (
459
- prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
460
- ),
461
- "cached_tokens": 0,
462
- },
463
- )
464
-
465
- def text_normalize(self, text: str) -> str:
466
- """
467
- Normalize the text.
468
- """
469
- # Perform some basic normalization
470
- text = normalize_chinese_punctuation(text)
471
- # Handle parentheses
472
- text = text.replace("(", " ")
473
- text = text.replace(")", " ")
474
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio/serve/utils.py DELETED
@@ -1,254 +0,0 @@
1
- import uuid
2
- import base64
3
- import re
4
- import regex
5
- from typing import AsyncGenerator, Union
6
- import io
7
- from pydub import AudioSegment
8
- import torch
9
- import numpy as np
10
- from functools import lru_cache
11
-
12
- from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer
13
-
14
-
15
- def random_uuid() -> str:
16
- return str(uuid.uuid4().hex)
17
-
18
-
19
- async def async_generator_wrap(first_element, gen: AsyncGenerator):
20
- """Wrap an async generator with the first element."""
21
- yield first_element
22
- async for item in gen:
23
- yield item
24
-
25
-
26
- @lru_cache(maxsize=50)
27
- def encode_base64_content_from_file(file_path: str) -> str:
28
- """Encode a content from a local file to base64 format."""
29
- # Read the MP3 file as binary and encode it directly to Base64
30
- with open(file_path, "rb") as audio_file:
31
- audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
32
- return audio_base64
33
-
34
-
35
- def pcm16_to_target_format(
36
- np_audio: np.ndarray,
37
- sample_rate: int,
38
- bit_depth: int,
39
- channels: int,
40
- format: str,
41
- target_rate: int,
42
- ):
43
- wav_audio = AudioSegment(
44
- np_audio.tobytes(),
45
- frame_rate=sample_rate,
46
- sample_width=bit_depth // 8,
47
- channels=channels,
48
- )
49
- if target_rate is not None and target_rate != sample_rate:
50
- wav_audio = wav_audio.set_frame_rate(target_rate)
51
-
52
- # Convert WAV to MP3
53
- target_io = io.BytesIO()
54
- wav_audio.export(target_io, format=format)
55
- target_io.seek(0)
56
-
57
- return target_io
58
-
59
-
60
- chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
61
-
62
-
63
- def contains_chinese(text: str):
64
- return bool(chinese_char_pattern.search(text))
65
-
66
-
67
- # remove blank between chinese character
68
- def replace_blank(text: str):
69
- out_str = []
70
- for i, c in enumerate(text):
71
- if c == " ":
72
- if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
73
- out_str.append(c)
74
- else:
75
- out_str.append(c)
76
- return "".join(out_str)
77
-
78
-
79
- def replace_corner_mark(text: str):
80
- text = text.replace("²", "平方")
81
- text = text.replace("³", "立方")
82
- return text
83
-
84
-
85
- # remove meaningless symbol
86
- def remove_bracket(text: str):
87
- text = text.replace("(", "").replace(")", "")
88
- text = text.replace("【", "").replace("】", "")
89
- text = text.replace("`", "").replace("`", "")
90
- text = text.replace("——", " ")
91
- return text
92
-
93
-
94
- # split paragrah logic:
95
- # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
96
- # 2. cal sentence len according to lang
97
- # 3. split sentence according to puncatation
98
- def split_paragraph(
99
- text: str,
100
- tokenize,
101
- lang="zh",
102
- token_max_n=80,
103
- token_min_n=60,
104
- merge_len=20,
105
- comma_split=False,
106
- ):
107
- def calc_utt_length(_text: str):
108
- if lang == "zh":
109
- return len(_text)
110
- else:
111
- return len(tokenize(_text))
112
-
113
- def should_merge(_text: str):
114
- if lang == "zh":
115
- return len(_text) < merge_len
116
- else:
117
- return len(tokenize(_text)) < merge_len
118
-
119
- if lang == "zh":
120
- pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
121
- else:
122
- pounc = [".", "?", "!", ";", ":"]
123
- if comma_split:
124
- pounc.extend([",", ","])
125
-
126
- if text[-1] not in pounc:
127
- if lang == "zh":
128
- text += "。"
129
- else:
130
- text += "."
131
-
132
- st = 0
133
- utts = []
134
- for i, c in enumerate(text):
135
- if c in pounc:
136
- if len(text[st:i]) > 0:
137
- utts.append(text[st:i] + c)
138
- if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
139
- tmp = utts.pop(-1)
140
- utts.append(tmp + text[i + 1])
141
- st = i + 2
142
- else:
143
- st = i + 1
144
-
145
- final_utts = []
146
- cur_utt = ""
147
- for utt in utts:
148
- if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
149
- final_utts.append(cur_utt)
150
- cur_utt = ""
151
- cur_utt = cur_utt + utt
152
- if len(cur_utt) > 0:
153
- if should_merge(cur_utt) and len(final_utts) != 0:
154
- final_utts[-1] = final_utts[-1] + cur_utt
155
- else:
156
- final_utts.append(cur_utt)
157
-
158
- return final_utts
159
-
160
-
161
- def is_only_punctuation(text: str):
162
- # Regular expression: Match strings that consist only of punctuation marks or are empty.
163
- punctuation_pattern = r"^[\p{P}\p{S}]*$"
164
- return bool(regex.fullmatch(punctuation_pattern, text))
165
-
166
-
167
- # spell Arabic numerals
168
- def spell_out_number(text: str, inflect_parser):
169
- new_text = []
170
- st = None
171
- for i, c in enumerate(text):
172
- if not c.isdigit():
173
- if st is not None:
174
- num_str = inflect_parser.number_to_words(text[st:i])
175
- new_text.append(num_str)
176
- st = None
177
- new_text.append(c)
178
- else:
179
- if st is None:
180
- st = i
181
- if st is not None and st < len(text):
182
- num_str = inflect_parser.number_to_words(text[st:])
183
- new_text.append(num_str)
184
- return "".join(new_text)
185
-
186
-
187
- def remove_emoji(text: str):
188
- # Pattern to match emojis and their modifiers
189
- # - Standard emoji range
190
- # - Zero-width joiners (U+200D)
191
- # - Variation selectors (U+FE0F, U+FE0E)
192
- # - Skin tone modifiers (U+1F3FB to U+1F3FF)
193
- emoji_pattern = re.compile(
194
- r"["
195
- r"\U00010000-\U0010FFFF" # Standard emoji range
196
- r"\u200D" # Zero-width joiner
197
- r"\uFE0F\uFE0E" # Variation selectors
198
- r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers
199
- r"]+",
200
- flags=re.UNICODE,
201
- )
202
- return emoji_pattern.sub(r"", text)
203
-
204
-
205
- def remove_repeated_punctuations(text, punctuations):
206
- if len(punctuations) == 0:
207
- return text
208
- pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations
209
- return re.sub(rf"({pattern})\1+", r"\1", text)
210
-
211
-
212
- def full_to_half_width(text: str) -> str:
213
- """Convert full-width punctuation to half-width in a given string."""
214
- full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
215
- half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
216
- trans_table = str.maketrans(full_width, half_width)
217
- return text.translate(trans_table)
218
-
219
-
220
- def split_interleaved_delayed_audios(
221
- audio_data: Union[list[list[int]], torch.Tensor],
222
- audio_tokenizer: HiggsAudioTokenizer,
223
- audio_stream_eos_id: int,
224
- ) -> list[tuple[list[list[int]], torch.Tensor]]:
225
- separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks
226
-
227
- # Convert separator to numpy array if audio_data is numpy array
228
- if isinstance(audio_data, torch.Tensor):
229
- audio_data = audio_data.transpose(1, 0)
230
- separator = torch.tensor(separator)
231
- # Find the indices where the rows equal the separator
232
- split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0]
233
- start = 0
234
- groups = []
235
- for idx in split_indices:
236
- groups.append(audio_data[start:idx].transpose(1, 0))
237
- start = idx + 1
238
- if start < len(audio_data):
239
- groups.append(audio_data[start:].transpose(1, 0))
240
- else:
241
- groups = []
242
- current = []
243
- for row in audio_data:
244
- current.append(row)
245
-
246
- if row == separator:
247
- groups.append(current)
248
- current = []
249
-
250
- # Don't forget the last group if there's no trailing separator
251
- if current:
252
- groups.append(current)
253
-
254
- return groups
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
higgs_audio_utils.py DELETED
@@ -1,280 +0,0 @@
1
- from typing import Optional
2
-
3
- # Import HiggsAudio components
4
- from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
5
- from higgs_audio.data_types import ChatMLSample, AudioContent, Message
6
-
7
- import base64
8
- from functools import lru_cache
9
- from loguru import logger
10
- import os
11
- import json
12
- import uuid
13
- import time
14
- import numpy as np
15
- import re
16
-
17
- def process_text_output(text_output: str):
18
- # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
19
- text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
20
- return text_output
21
-
22
-
23
- def check_return_audio(audio_wv: np.ndarray):
24
- # check if the audio returned is all silent
25
- if np.all(audio_wv == 0):
26
- logger.warning("Audio is silent, returning None")
27
-
28
-
29
- def load_voice_presets():
30
- """Load the voice presets from the voice_examples directory."""
31
- try:
32
- with open(
33
- os.path.join(os.path.dirname(__file__), "examples", "audios", "config.json"),
34
- "r",
35
- ) as f:
36
- voice_dict = json.load(f)
37
- voice_presets = {k: v for k, v in voice_dict.items()}
38
- voice_presets["EMPTY"] = "No reference voice"
39
- logger.info(f"Loaded voice presets: {list(voice_presets.keys())}")
40
- return voice_presets
41
- except FileNotFoundError:
42
- logger.warning("Voice examples config file not found. Using empty voice presets.")
43
- return {"EMPTY": "No reference voice"}
44
- except Exception as e:
45
- logger.error(f"Error loading voice presets: {e}")
46
- return {"EMPTY": "No reference voice"}
47
-
48
-
49
- SAMPLE_RATE = 24000
50
- DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
51
- VOICE_PRESETS = load_voice_presets()
52
-
53
-
54
- def initialize_engine(model_path, audio_tokenizer_path) -> bool:
55
- engine = HiggsAudioServeEngine(
56
- model_name_or_path=model_path,
57
- audio_tokenizer_name_or_path=audio_tokenizer_path,
58
- device="cuda",
59
- )
60
- return engine
61
-
62
- def get_voice_preset(voice_preset):
63
- """Get the voice path and text for a given voice preset."""
64
-
65
- preset_dir = os.path.join(os.path.dirname(__file__), "examples", "audios")
66
- voice_path = os.path.join(preset_dir, VOICE_PRESETS[voice_preset]["audio_file"])
67
-
68
- if not os.path.exists(voice_path):
69
- logger.warning(f"Voice preset file not found: {voice_path}")
70
- return None, "Voice preset not found"
71
-
72
- text = VOICE_PRESETS[voice_preset]["transcript"]
73
- return voice_path, text
74
-
75
-
76
- def normalize_chinese_punctuation(text):
77
- """
78
- Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
79
- """
80
- # Mapping of Chinese punctuation to English punctuation
81
- chinese_to_english_punct = {
82
- ",": ", ", # comma
83
- "。": ".", # period
84
- ":": ":", # colon
85
- ";": ";", # semicolon
86
- "?": "?", # question mark
87
- "!": "!", # exclamation mark
88
- "(": "(", # left parenthesis
89
- ")": ")", # right parenthesis
90
- "【": "[", # left square bracket
91
- "】": "]", # right square bracket
92
- "《": "<", # left angle quote
93
- "》": ">", # right angle quote
94
- "“": '"', # left double quotation
95
- "”": '"', # right double quotation
96
- "‘": "'", # left single quotation
97
- "’": "'", # right single quotation
98
- "、": ",", # enumeration comma
99
- "—": "-", # em dash
100
- "…": "...", # ellipsis
101
- "·": ".", # middle dot
102
- "「": '"', # left corner bracket
103
- "」": '"', # right corner bracket
104
- "『": '"', # left double corner bracket
105
- "』": '"', # right double corner bracket
106
- }
107
-
108
- # Replace each Chinese punctuation with its English counterpart
109
- for zh_punct, en_punct in chinese_to_english_punct.items():
110
- text = text.replace(zh_punct, en_punct)
111
-
112
- return text
113
-
114
-
115
- def normalize_text(transcript: str):
116
- transcript = normalize_chinese_punctuation(transcript)
117
- # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future)
118
- transcript = transcript.replace("(", " ")
119
- transcript = transcript.replace(")", " ")
120
- transcript = transcript.replace("°F", " degrees Fahrenheit")
121
- transcript = transcript.replace("°C", " degrees Celsius")
122
-
123
- for tag, replacement in [
124
- ("[laugh]", "<SE>[Laughter]</SE>"),
125
- ("[humming start]", "<SE>[Humming]</SE>"),
126
- ("[humming end]", "<SE_e>[Humming]</SE_e>"),
127
- ("[music start]", "<SE_s>[Music]</SE_s>"),
128
- ("[music end]", "<SE_e>[Music]</SE_e>"),
129
- ("[music]", "<SE>[Music]</SE>"),
130
- ("[sing start]", "<SE_s>[Singing]</SE_s>"),
131
- ("[sing end]", "<SE_e>[Singing]</SE_e>"),
132
- ("[applause]", "<SE>[Applause]</SE>"),
133
- ("[cheering]", "<SE>[Cheering]</SE>"),
134
- ("[cough]", "<SE>[Cough]</SE>"),
135
- ]:
136
- transcript = transcript.replace(tag, replacement)
137
-
138
- lines = transcript.split("\n")
139
- transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
140
- transcript = transcript.strip()
141
-
142
- if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
143
- transcript += "."
144
-
145
- return transcript
146
-
147
- @lru_cache(maxsize=20)
148
- def encode_audio_file(file_path):
149
- """Encode an audio file to base64."""
150
- with open(file_path, "rb") as audio_file:
151
- return base64.b64encode(audio_file.read()).decode("utf-8")
152
-
153
-
154
- def prepare_chatml_sample(
155
- voice_preset: str,
156
- text: str,
157
- reference_audio: Optional[str] = None,
158
- reference_text: Optional[str] = None,
159
- system_prompt: str = "",
160
- ):
161
- """Prepare a ChatMLSample for the HiggsAudioServeEngine."""
162
- messages = []
163
-
164
- # Add system message if provided
165
- if len(system_prompt) > 0:
166
- messages.append(Message(role="system", content=system_prompt))
167
-
168
- # Add reference audio if provided
169
- audio_base64 = None
170
- ref_text = ""
171
-
172
- if reference_audio:
173
- # Custom reference audio
174
- audio_base64 = encode_audio_file(reference_audio)
175
- ref_text = reference_text or ""
176
- elif voice_preset != "EMPTY":
177
- # Voice preset
178
- voice_path, ref_text = get_voice_preset(voice_preset)
179
- if voice_path is None:
180
- logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
181
- else:
182
- audio_base64 = encode_audio_file(voice_path)
183
-
184
- # Only add reference audio if we have it
185
- if audio_base64 is not None:
186
- # Add user message with reference text
187
- messages.append(Message(role="user", content=ref_text))
188
-
189
- # Add assistant message with audio content
190
- audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
191
- messages.append(Message(role="assistant", content=[audio_content]))
192
-
193
- # Add the main user message
194
- text = normalize_text(text)
195
- messages.append(Message(role="user", content=text))
196
-
197
- return ChatMLSample(messages=messages)
198
-
199
-
200
-
201
- def text_to_speech(
202
- engine,
203
- text,
204
- system_prompt="",
205
- voice_preset="EMPTY",
206
- reference_audio=None,
207
- reference_text=None,
208
- max_completion_tokens=1024,
209
- temperature=1.0,
210
- top_p=0.95,
211
- top_k=50,
212
- stop_strings=None,
213
- ras_win_len=7,
214
- ras_win_max_num_repeat=2,
215
- ):
216
- """
217
- Convert text to speech using HiggsAudioServeEngine.
218
-
219
- Args:
220
- text: The text to convert to speech
221
- voice_preset: The voice preset to use (or "EMPTY" for no preset)
222
- reference_audio: Optional path to reference audio file
223
- reference_text: Optional transcript of the reference audio
224
- max_completion_tokens: Maximum number of tokens to generate
225
- temperature: Sampling temperature for generation
226
- top_p: Top-p sampling parameter
227
- top_k: Top-k sampling parameter
228
- system_prompt: System prompt to guide the model
229
- stop_strings: Dataframe containing stop strings
230
- ras_win_len: Window length for repetition avoidance sampling
231
- ras_win_max_num_repeat: Maximum number of repetitions allowed in the window
232
-
233
- Returns:
234
- Tuple of (generated_text, (sample_rate, audio_data)) where audio_data is int16 numpy array
235
- """
236
-
237
- try:
238
- # Prepare ChatML sample
239
- chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
240
-
241
- # Convert stop strings format
242
- if stop_strings is None:
243
- stop_list = DEFAULT_STOP_STRINGS
244
- else:
245
- stop_list = [s for s in stop_strings["stops"] if s.strip()]
246
-
247
- request_id = f"tts-playground-{str(uuid.uuid4())}"
248
-
249
- start_time = time.time()
250
-
251
- # Generate using the engine
252
- response = engine.generate(
253
- chat_ml_sample=chatml_sample,
254
- max_new_tokens=max_completion_tokens,
255
- temperature=temperature,
256
- top_k=top_k if top_k > 0 else None,
257
- top_p=top_p,
258
- stop_strings=stop_list,
259
- ras_win_len=ras_win_len if ras_win_len > 0 else None,
260
- ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
261
- )
262
-
263
- generation_time = time.time() - start_time
264
-
265
- # Process the response
266
- text_output = process_text_output(response.generated_text)
267
-
268
- if response.audio is not None:
269
- # Convert to int16 for Gradio
270
- audio_data = (response.audio * 32767).astype(np.int16)
271
- check_return_audio(audio_data)
272
- return text_output, (response.sampling_rate, audio_data)
273
- else:
274
- logger.warning("No audio generated")
275
- return text_output, None
276
-
277
- except Exception as e:
278
- error_msg = f"Error generating speech: {e}"
279
- logger.error(error_msg)
280
- return f"❌ {error_msg}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -16,16 +16,5 @@ ninja
16
  gradio_extendedimage @ https://github.com/OutofAi/gradio-extendedimage/releases/download/0.0.2/gradio_extendedimage-0.0.2-py3-none-any.whl
17
  gradio_extendedaudio @ https://github.com/OutofAi/gradio-extendedaudio/releases/download/0.0.5/gradio_extendedaudio-0.0.5-py3-none-any.whl
18
 
19
- dacite
20
- boto3==1.35.36
21
- s3fs
22
- json_repair
23
- pandas
24
- pydantic
25
- vector_quantize_pytorch
26
- loguru
27
- pydub
28
- ruff==0.12.2
29
- click
30
-
31
- descript-audio-codec
 
16
  gradio_extendedimage @ https://github.com/OutofAi/gradio-extendedimage/releases/download/0.0.2/gradio_extendedimage-0.0.2-py3-none-any.whl
17
  gradio_extendedaudio @ https://github.com/OutofAi/gradio-extendedaudio/releases/download/0.0.5/gradio_extendedaudio-0.0.5-py3-none-any.whl
18
 
19
+ flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
20
+ onnxruntime
 
 
 
 
 
 
 
 
 
 
 
supertonic.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ from contextlib import contextmanager
5
+ from typing import Optional
6
+ from unicodedata import normalize
7
+
8
+ import numpy as np
9
+ import onnxruntime as ort
10
+ import soundfile as sf
11
+ from huggingface_hub import snapshot_download
12
+
13
+
14
+ class UnicodeProcessor:
15
+ def __init__(self, unicode_indexer_path: str):
16
+ with open(unicode_indexer_path, "r") as f:
17
+ self.indexer = json.load(f)
18
+
19
+ def _preprocess_text(self, text: str) -> str:
20
+ # TODO: add more preprocessing
21
+ text = normalize("NFKD", text)
22
+ return text
23
+
24
+ def _get_text_mask(self, text_ids_lengths: np.ndarray) -> np.ndarray:
25
+ text_mask = length_to_mask(text_ids_lengths)
26
+ return text_mask
27
+
28
+ def _text_to_unicode_values(self, text: str) -> np.ndarray:
29
+ unicode_values = np.array(
30
+ [ord(char) for char in text], dtype=np.uint16
31
+ ) # 2 bytes
32
+ return unicode_values
33
+
34
+ def __call__(self, text_list: list[str]) -> tuple[np.ndarray, np.ndarray]:
35
+ text_list = [self._preprocess_text(t) for t in text_list]
36
+ text_ids_lengths = np.array([len(text) for text in text_list], dtype=np.int64)
37
+ text_ids = np.zeros((len(text_list), text_ids_lengths.max()), dtype=np.int64)
38
+ for i, text in enumerate(text_list):
39
+ unicode_vals = self._text_to_unicode_values(text)
40
+ text_ids[i, : len(unicode_vals)] = np.array(
41
+ [self.indexer[val] for val in unicode_vals], dtype=np.int64
42
+ )
43
+ text_mask = self._get_text_mask(text_ids_lengths)
44
+ return text_ids, text_mask
45
+
46
+
47
+ class Style:
48
+ def __init__(self, style_ttl_onnx: np.ndarray, style_dp_onnx: np.ndarray):
49
+ self.ttl = style_ttl_onnx
50
+ self.dp = style_dp_onnx
51
+
52
+
53
+ class TextToSpeech:
54
+ def __init__(
55
+ self,
56
+ cfgs: dict,
57
+ text_processor: UnicodeProcessor,
58
+ dp_ort: ort.InferenceSession,
59
+ text_enc_ort: ort.InferenceSession,
60
+ vector_est_ort: ort.InferenceSession,
61
+ vocoder_ort: ort.InferenceSession,
62
+ ):
63
+ self.cfgs = cfgs
64
+ self.text_processor = text_processor
65
+ self.dp_ort = dp_ort
66
+ self.text_enc_ort = text_enc_ort
67
+ self.vector_est_ort = vector_est_ort
68
+ self.vocoder_ort = vocoder_ort
69
+ self.sample_rate = cfgs["ae"]["sample_rate"]
70
+ self.base_chunk_size = cfgs["ae"]["base_chunk_size"]
71
+ self.chunk_compress_factor = cfgs["ttl"]["chunk_compress_factor"]
72
+ self.ldim = cfgs["ttl"]["latent_dim"]
73
+
74
+ def sample_noisy_latent(
75
+ self, duration: np.ndarray
76
+ ) -> tuple[np.ndarray, np.ndarray]:
77
+ bsz = len(duration)
78
+ wav_len_max = duration.max() * self.sample_rate
79
+ wav_lengths = (duration * self.sample_rate).astype(np.int64)
80
+ chunk_size = self.base_chunk_size * self.chunk_compress_factor
81
+ latent_len = ((wav_len_max + chunk_size - 1) / chunk_size).astype(np.int32)
82
+ latent_dim = self.ldim * self.chunk_compress_factor
83
+ noisy_latent = np.random.randn(bsz, latent_dim, latent_len).astype(np.float32)
84
+ latent_mask = get_latent_mask(
85
+ wav_lengths, self.base_chunk_size, self.chunk_compress_factor
86
+ )
87
+ noisy_latent = noisy_latent * latent_mask
88
+ return noisy_latent, latent_mask
89
+
90
+ def _infer(
91
+ self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05
92
+ ) -> tuple[np.ndarray, np.ndarray]:
93
+ assert (
94
+ len(text_list) == style.ttl.shape[0]
95
+ ), "Number of texts must match number of style vectors"
96
+ bsz = len(text_list)
97
+ text_ids, text_mask = self.text_processor(text_list)
98
+ dur_onnx, *_ = self.dp_ort.run(
99
+ None, {"text_ids": text_ids, "style_dp": style.dp, "text_mask": text_mask}
100
+ )
101
+ dur_onnx = dur_onnx / speed
102
+ text_emb_onnx, *_ = self.text_enc_ort.run(
103
+ None,
104
+ {"text_ids": text_ids, "style_ttl": style.ttl, "text_mask": text_mask},
105
+ ) # dur_onnx: [bsz]
106
+ xt, latent_mask = self.sample_noisy_latent(dur_onnx)
107
+ total_step_np = np.array([total_step] * bsz, dtype=np.float32)
108
+ for step in range(total_step):
109
+ current_step = np.array([step] * bsz, dtype=np.float32)
110
+ xt, *_ = self.vector_est_ort.run(
111
+ None,
112
+ {
113
+ "noisy_latent": xt,
114
+ "text_emb": text_emb_onnx,
115
+ "style_ttl": style.ttl,
116
+ "text_mask": text_mask,
117
+ "latent_mask": latent_mask,
118
+ "current_step": current_step,
119
+ "total_step": total_step_np,
120
+ },
121
+ )
122
+ wav, *_ = self.vocoder_ort.run(None, {"latent": xt})
123
+ return wav, dur_onnx
124
+
125
+ def __call__(
126
+ self,
127
+ text: str,
128
+ style: Style,
129
+ total_step: int,
130
+ speed: float = 1.05,
131
+ silence_duration: float = 0.3,
132
+ ) -> tuple[np.ndarray, np.ndarray]:
133
+ assert (
134
+ style.ttl.shape[0] == 1
135
+ ), "Single speaker text to speech only supports single style"
136
+ text_list = chunk_text(text)
137
+ wav_cat = None
138
+ dur_cat = None
139
+ for text in text_list:
140
+ wav, dur_onnx = self._infer([text], style, total_step, speed)
141
+ if wav_cat is None:
142
+ wav_cat = wav
143
+ dur_cat = dur_onnx
144
+ else:
145
+ silence = np.zeros(
146
+ (1, int(silence_duration * self.sample_rate)), dtype=np.float32
147
+ )
148
+ wav_cat = np.concatenate([wav_cat, silence, wav], axis=1)
149
+ dur_cat += dur_onnx + silence_duration
150
+ return wav_cat, dur_cat
151
+
152
+ def batch(
153
+ self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05
154
+ ) -> tuple[np.ndarray, np.ndarray]:
155
+ return self._infer(text_list, style, total_step, speed)
156
+
157
+
158
+ def length_to_mask(lengths: np.ndarray, max_len: Optional[int] = None) -> np.ndarray:
159
+ """
160
+ Convert lengths to binary mask.
161
+
162
+ Args:
163
+ lengths: (B,)
164
+ max_len: int
165
+
166
+ Returns:
167
+ mask: (B, 1, max_len)
168
+ """
169
+ max_len = max_len or lengths.max()
170
+ ids = np.arange(0, max_len)
171
+ mask = (ids < np.expand_dims(lengths, axis=1)).astype(np.float32)
172
+ return mask.reshape(-1, 1, max_len)
173
+
174
+
175
+ def get_latent_mask(
176
+ wav_lengths: np.ndarray, base_chunk_size: int, chunk_compress_factor: int
177
+ ) -> np.ndarray:
178
+ latent_size = base_chunk_size * chunk_compress_factor
179
+ latent_lengths = (wav_lengths + latent_size - 1) // latent_size
180
+ latent_mask = length_to_mask(latent_lengths)
181
+ return latent_mask
182
+
183
+
184
+ def load_onnx(
185
+ onnx_path: str, opts: ort.SessionOptions, providers: list[str]
186
+ ) -> ort.InferenceSession:
187
+ return ort.InferenceSession(onnx_path, sess_options=opts, providers=providers)
188
+
189
+
190
+ def load_onnx_all(
191
+ onnx_dir: str, opts: ort.SessionOptions, providers: list[str]
192
+ ) -> tuple[
193
+ ort.InferenceSession,
194
+ ort.InferenceSession,
195
+ ort.InferenceSession,
196
+ ort.InferenceSession,
197
+ ]:
198
+ dp_onnx_path = os.path.join(onnx_dir, "duration_predictor.onnx")
199
+ text_enc_onnx_path = os.path.join(onnx_dir, "text_encoder.onnx")
200
+ vector_est_onnx_path = os.path.join(onnx_dir, "vector_estimator.onnx")
201
+ vocoder_onnx_path = os.path.join(onnx_dir, "vocoder.onnx")
202
+
203
+ dp_ort = load_onnx(dp_onnx_path, opts, providers)
204
+ text_enc_ort = load_onnx(text_enc_onnx_path, opts, providers)
205
+ vector_est_ort = load_onnx(vector_est_onnx_path, opts, providers)
206
+ vocoder_ort = load_onnx(vocoder_onnx_path, opts, providers)
207
+ return dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
208
+
209
+
210
+ def load_cfgs(onnx_dir: str) -> dict:
211
+ cfg_path = os.path.join(onnx_dir, "tts.json")
212
+ with open(cfg_path, "r") as f:
213
+ cfgs = json.load(f)
214
+ return cfgs
215
+
216
+
217
+ def load_text_processor(onnx_dir: str) -> UnicodeProcessor:
218
+ unicode_indexer_path = os.path.join(onnx_dir, "unicode_indexer.json")
219
+ text_processor = UnicodeProcessor(unicode_indexer_path)
220
+ return text_processor
221
+
222
+
223
+ def load_text_to_speech(onnx_dir: str, use_gpu: bool = False) -> TextToSpeech:
224
+ opts = ort.SessionOptions()
225
+ if use_gpu:
226
+ raise NotImplementedError("GPU mode is not fully tested")
227
+ else:
228
+ providers = ["CPUExecutionProvider"]
229
+ print("Using CPU for inference")
230
+ cfgs = load_cfgs(onnx_dir)
231
+ dp_ort, text_enc_ort, vector_est_ort, vocoder_ort = load_onnx_all(
232
+ onnx_dir, opts, providers
233
+ )
234
+ text_processor = load_text_processor(onnx_dir)
235
+ return TextToSpeech(
236
+ cfgs, text_processor, dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
237
+ )
238
+
239
+
240
+ def load_voice_style(voice_style_paths: list[str], verbose: bool = False) -> Style:
241
+ bsz = len(voice_style_paths)
242
+
243
+ # Read first file to get dimensions
244
+ with open(voice_style_paths[0], "r") as f:
245
+ first_style = json.load(f)
246
+ ttl_dims = first_style["style_ttl"]["dims"]
247
+ dp_dims = first_style["style_dp"]["dims"]
248
+
249
+ # Pre-allocate arrays with full batch size
250
+ ttl_style = np.zeros([bsz, ttl_dims[1], ttl_dims[2]], dtype=np.float32)
251
+ dp_style = np.zeros([bsz, dp_dims[1], dp_dims[2]], dtype=np.float32)
252
+
253
+ # Fill in the data
254
+ for i, voice_style_path in enumerate(voice_style_paths):
255
+ with open(voice_style_path, "r") as f:
256
+ voice_style = json.load(f)
257
+
258
+ ttl_data = np.array(
259
+ voice_style["style_ttl"]["data"], dtype=np.float32
260
+ ).flatten()
261
+ ttl_style[i] = ttl_data.reshape(ttl_dims[1], ttl_dims[2])
262
+
263
+ dp_data = np.array(voice_style["style_dp"]["data"], dtype=np.float32).flatten()
264
+ dp_style[i] = dp_data.reshape(dp_dims[1], dp_dims[2])
265
+
266
+ if verbose:
267
+ print(f"Loaded {bsz} voice styles")
268
+ return Style(ttl_style, dp_style)
269
+
270
+
271
+ @contextmanager
272
+ def timer(name: str):
273
+ start = time.time()
274
+ print(f"{name}...")
275
+ yield
276
+ print(f" -> {name} completed in {time.time() - start:.2f} sec")
277
+
278
+
279
+ def sanitize_filename(text: str, max_len: int) -> str:
280
+ """Sanitize filename by replacing non-alphanumeric characters with underscores"""
281
+ import re
282
+
283
+ prefix = text[:max_len]
284
+ return re.sub(r"[^a-zA-Z0-9]", "_", prefix)
285
+
286
+
287
+ def chunk_text(text: str, max_len: int = 300) -> list[str]:
288
+ """
289
+ Split text into chunks by paragraphs and sentences.
290
+
291
+ Args:
292
+ text: Input text to chunk
293
+ max_len: Maximum length of each chunk (default: 300)
294
+
295
+ Returns:
296
+ List of text chunks
297
+ """
298
+ import re
299
+
300
+ # Split by paragraph (two or more newlines)
301
+ paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text.strip()) if p.strip()]
302
+
303
+ chunks = []
304
+
305
+ for paragraph in paragraphs:
306
+ paragraph = paragraph.strip()
307
+ if not paragraph:
308
+ continue
309
+
310
+ # Split by sentence boundaries (period, question mark, exclamation mark followed by space)
311
+ # But exclude common abbreviations like Mr., Mrs., Dr., etc. and single capital letters like F.
312
+ pattern = r"(?<!Mr\.)(?<!Mrs\.)(?<!Ms\.)(?<!Dr\.)(?<!Prof\.)(?<!Sr\.)(?<!Jr\.)(?<!Ph\.D\.)(?<!etc\.)(?<!e\.g\.)(?<!i\.e\.)(?<!vs\.)(?<!Inc\.)(?<!Ltd\.)(?<!Co\.)(?<!Corp\.)(?<!St\.)(?<!Ave\.)(?<!Blvd\.)(?<!\b[A-Z]\.)(?<=[.!?])\s+"
313
+ sentences = re.split(pattern, paragraph)
314
+
315
+ current_chunk = ""
316
+
317
+ for sentence in sentences:
318
+ if len(current_chunk) + len(sentence) + 1 <= max_len:
319
+ current_chunk += (" " if current_chunk else "") + sentence
320
+ else:
321
+ if current_chunk:
322
+ chunks.append(current_chunk.strip())
323
+ current_chunk = sentence
324
+
325
+ if current_chunk:
326
+ chunks.append(current_chunk.strip())
327
+
328
+ return chunks
329
+
330
+ model_dir = snapshot_download("Supertone/supertonic")
331
+ onnx_dir = f"{model_dir}/onnx"
332
+ text_to_speech = load_text_to_speech(onnx_dir, False)
333
+
334
+ def generate_speech(text_list, save_dir, voice_style="M1", total_step=5, speed=1.05, n_test=1, batch=None):
335
+
336
+ saved_files_list = []
337
+
338
+ voice_style_paths = [f"{model_dir}/voice_styles/{voice_style}.json"] * len(text_list)
339
+
340
+ assert len(voice_style_paths) == len(
341
+ text_list
342
+ ), f"Number of voice styles ({len(voice_style_paths)}) must match number of texts ({len(text_list)})"
343
+ bsz = len(voice_style_paths)
344
+
345
+ style = load_voice_style(voice_style_paths, verbose=True)
346
+
347
+ for n in range(n_test):
348
+ print(f"\n[{n+1}/{n_test}] Starting synthesis...")
349
+ with timer("Generating speech from text"):
350
+ if batch:
351
+ wav, duration = text_to_speech.batch(text_list, style, total_step, speed)
352
+ else:
353
+ wav, duration = text_to_speech(text_list[0], style, total_step, speed)
354
+ if not os.path.exists(save_dir):
355
+ os.makedirs(save_dir)
356
+ for b in range(bsz):
357
+ fname = f"{sanitize_filename(text_list[b], 20)}_{n+1}.wav"
358
+ w = wav[b, : int(text_to_speech.sample_rate * duration[b].item())] # [T_trim]
359
+ sf.write(os.path.join(save_dir, fname), w, text_to_speech.sample_rate)
360
+ saved_files_list.append(f"{save_dir}/{fname}")
361
+ # print(f"Saved: {save_dir}/{fname}")
362
+ print("\n=== Synthesis completed successfully! ===")
363
+
364
+ return saved_files_list