sleeper371 commited on
Commit
37a9836
·
1 Parent(s): 6e4576a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +179 -0
  2. LICENSE +21 -0
  3. README.md +92 -14
  4. app.py +191 -0
  5. config.py +12 -0
  6. core/__init__.py +0 -0
  7. core/bark/__init__.py +5 -0
  8. core/bark/constants.py +18 -0
  9. core/bark/custom_context.py +79 -0
  10. core/bark/encodec.py +63 -0
  11. core/bark/generate_audio.py +117 -0
  12. core/bark/generate_audio_semantic_dataset.py +122 -0
  13. core/bark/generate_coarse.py +385 -0
  14. core/bark/generate_fine.py +210 -0
  15. core/bark/generate_semantic.py +361 -0
  16. core/bark/voice_clone.py +104 -0
  17. core/data_model/__init__.py +1 -0
  18. core/data_model/bark.py +337 -0
  19. core/memory/__init__.py +5 -0
  20. core/memory/common.py +187 -0
  21. core/memory/model_manager.py +289 -0
  22. core/memory/models.py +169 -0
  23. core/model/__init__.py +1 -0
  24. core/model/bark.py +425 -0
  25. core/model/hubert.py +237 -0
  26. core/trainer/__init__.py +1 -0
  27. core/trainer/custom_hubert_trainer.py +555 -0
  28. core/utils/__init__.py +7 -0
  29. core/utils/audio.py +104 -0
  30. core/utils/huggingface.py +169 -0
  31. core/utils/read_write_files.py +46 -0
  32. core/utils/text.py +13 -0
  33. event_handlers.py +436 -0
  34. generate_audio_semantic_dataset.py +155 -0
  35. prompts/de_speaker_0.npz +0 -0
  36. prompts/de_speaker_1.npz +0 -0
  37. prompts/de_speaker_2.npz +0 -0
  38. prompts/de_speaker_3.npz +0 -0
  39. prompts/de_speaker_4.npz +0 -0
  40. prompts/de_speaker_5.npz +0 -0
  41. prompts/de_speaker_6.npz +0 -0
  42. prompts/de_speaker_7.npz +0 -0
  43. prompts/de_speaker_8.npz +0 -0
  44. prompts/de_speaker_9.npz +0 -0
  45. prompts/en_speaker_0.npz +0 -0
  46. prompts/en_speaker_1.npz +0 -0
  47. prompts/en_speaker_2.npz +0 -0
  48. prompts/en_speaker_3.npz +0 -0
  49. prompts/en_speaker_4.npz +0 -0
  50. prompts/en_speaker_5.npz +0 -0
.gitignore ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+ bark_prompts/
176
+ generated_audio/
177
+
178
+ models/
179
+ .DS_Store
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Hao Huynh Nhat
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,92 @@
1
- ---
2
- title: Bark With Batch Inference
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: BARK model from SUNO with batch inference
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generate Audio from text and clone voice with BARK
2
+
3
+ You can generate audio from text with natural sounding voice and clone any voice (not perfect).
4
+ ![Screenshot Placeholder](./assets/images/screenshot.png)
5
+
6
+ Code worked on Python 3.12. May also work on other versions.
7
+
8
+ Example generated audio in the /assets/audio folder
9
+
10
+ ## Features
11
+
12
+ - **Text-to-Audio Generation:** Generate speech from text using the BARK model (supports 'small' and 'large' variants).
13
+ - **Parameter Control:** Adjust semantic, coarse, and fine temperature settings for generation diversity. Set a generation seed for reproducibility.
14
+ - **Device Selection:** Run inference on available devices (CPU, CUDA, MPS).
15
+ - **Standard Voice Prompts:** Utilize built-in BARK voice prompts (`.npz` files) located in the `bark_prompts` directory.
16
+ - **Custom Voice Prompt Creation (Voice Cloning):**
17
+ - Upload your own audio file (.wav, .mp3).
18
+ - Generate a BARK-compatible semantic prompt (`.npz` file) using a custom-trained HuBERT model.
19
+ - The generated prompt appears in the "Select Voice Prompt" dropdown for immediate use.
20
+ - **Audio Management:** View, play, and delete generated audio files directly within the interface.
21
+ - **Training Scripts:** Includes scripts to generate the necessary dataset (`generate_audio_semantic_dataset.py`) and train the custom HuBERT model (`train_hubert.py`).
22
+
23
+ ## Custom Voice Cloning Model
24
+
25
+ The core of the custom voice prompt generation relies on a fine-tuned HuBERT model.
26
+
27
+ - **Model:** `sleeper371/hubert-for-bark-semantic` on Hugging Face ([Link](https://huggingface.co/sleeper371/hubert-for-bark-semantic))
28
+ - **Architecture:** This model uses a HuBERT base feature extractor followed by a Transformer decoder head.
29
+ - **Training:** It was trained on over 4700 sentence pairs, mapping audio waveforms to the semantic tokens generated by BARK's semantic model. The training used a cross-entropy loss objective.
30
+ - **Dataset:** The training dataset is available at `sleeper371/bark-wave-semantic` on Hugging Face ([Link](https://huggingface.co/datasets/sleeper371/bark-wave-semantic)).
31
+ - **Comparison:** This approach is inspired by projects like [gitmylo/bark-data-gen](https://github.com/gitmylo/bark-data-gen), but differs in the head architecture (he used an LSTM head while I used a transformers decoder head)
32
+
33
+ ## Setup and Installation
34
+
35
+ Follow these steps to set up the environment and run the application.
36
+
37
+ 1. **Clone the Repository:**
38
+
39
+ 2. **Create a Virtual Environment:**
40
+ It's highly recommended to use a virtual environment to manage dependencies.
41
+
42
+ ```bash
43
+ # For Linux/macOS
44
+ python3 -m venv venv
45
+ source venv/bin/activate
46
+
47
+ # For Windows
48
+ python -m venv venv
49
+ .\venv\Scripts\activate
50
+ ```
51
+
52
+ 3. **Install Requirements:**
53
+ Make sure you have a `requirements.txt` file in the repository root containing all necessary packages (e.g., `gradio`, `torch`, `transformers`, `soundfile`, etc.).
54
+ ```bash
55
+ pip install -r requirements.txt
56
+ ```
57
+
58
+ ## Running the Application
59
+
60
+ Once the setup is complete, run the Gradio application:
61
+
62
+ ```bash
63
+ python app.py
64
+ ```
65
+
66
+ This will launch the Gradio interface, typically accessible at http://127.0.0.1:7860 in your web browser. The console output will provide the exact URL.
67
+
68
+ ## Training Your Own Custom HuBERT Model
69
+
70
+ If you want to train your own HuBERT model for voice cloning:
71
+
72
+ 1. Generate Dataset:
73
+
74
+ - Use the generate_audio_semantic_dataset.py script.
75
+
76
+ 2. Train the Model:
77
+
78
+ - Use the train_hubert.py script.
79
+
80
+ - This script takes the generated dataset (audio paths and semantic token paths) to fine-tune a HuBERT model with a Transformer decoder head.
81
+
82
+ - Configure training parameters (batch size, learning rate, epochs, output directory) within the script or via command-line arguments (if implemented).
83
+
84
+ ## License
85
+
86
+ MIT
87
+
88
+ ## Acknowledgements
89
+
90
+ - Suno AI, they trained the models
91
+
92
+ - gitmylo, inspired me to use HuBERT to predict semantic tokens from audio
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from config import *
3
+ from event_handlers import *
4
+
5
+
6
+ # --- Gradio UI Definition ---
7
+ # theme = gr.themes.Default(primary_hue=gr.themes.colors.blue).set()
8
+ theme = gr.themes.Ocean(primary_hue=gr.themes.colors.blue).set()
9
+
10
+ with gr.Blocks(
11
+ theme=theme,
12
+ title="grAudio",
13
+ css=".gradio-container { max-width: 95% !important; }",
14
+ ) as app:
15
+
16
+ # --- Global State ---
17
+ initial_audio_list = load_existing_audio()
18
+ audio_list_state = gr.State(value=initial_audio_list)
19
+ newly_generated_state = gr.State([])
20
+ # State to store the index of the selected row in the DataFrame
21
+ selected_index_state = gr.State(-1) # -1 means nothing selected
22
+
23
+ # --- UI Layout ---
24
+ gr.Markdown("# Generate Audio from text")
25
+ with gr.Row(equal_height=False):
26
+ # --- Column 1: Configuration (Left) ---
27
+ with gr.Column(scale=2, min_width=350):
28
+ gr.Markdown("### Generation Configuration")
29
+ with gr.Accordion("Batch size & Temperatures", open=True):
30
+ batch_size_number = gr.Number(
31
+ value=1,
32
+ label="Seed",
33
+ minimum=0,
34
+ step=1,
35
+ scale=1,
36
+ )
37
+ semantic_temp_slider = gr.Slider(
38
+ 0.1, 1.0, value=0.7, step=0.1, label="Semantic Temp"
39
+ )
40
+ coarse_temp_slider = gr.Slider(
41
+ 0.1, 1.0, value=0.7, step=0.1, label="Coarse Temp"
42
+ )
43
+ fine_temp_slider = gr.Slider(
44
+ 0.1, 1.0, value=0.7, step=0.1, label="Fine Temp"
45
+ )
46
+ with gr.Accordion("Model, Devices", open=True):
47
+ model_type_dropdown = gr.Dropdown(
48
+ choices=["small", "large"], value="small", label="Model Type"
49
+ )
50
+
51
+ available_devices, best_device = get_available_torch_devices()
52
+ device_dropdown = gr.Dropdown(
53
+ choices=available_devices, value=best_device, label="Device"
54
+ )
55
+ with gr.Accordion("Voice Prompt", open=True):
56
+ prompt_dropdown = gr.Dropdown(
57
+ choices=get_available_prompts(),
58
+ label="Select Voice Prompt",
59
+ info="Optional",
60
+ multiselect=False,
61
+ allow_custom_value=False,
62
+ )
63
+ refresh_prompts_btn = gr.Button(
64
+ "Refresh Prompts", variant="secondary", size="sm"
65
+ )
66
+ with gr.Accordion("Create New Voice Prompt", open=False):
67
+ prompt_audio_upload = gr.File(
68
+ value=None,
69
+ file_count="single",
70
+ label="Upload Audio (.wav, .mp3)",
71
+ file_types=["audio"],
72
+ type="filepath",
73
+ )
74
+ create_prompt_btn = gr.Button("Create Prompt", variant="secondary")
75
+
76
+ # --- Column 2: Text Input & Generate Button (Middle) ---
77
+ with gr.Column(scale=4, min_width=600):
78
+ gr.Markdown("### Text Input")
79
+ text_input_block = gr.Textbox(
80
+ lines=30,
81
+ placeholder="If your text includes multiple long sentences, select a voice prompt to have consistent speech.\nDo not use long sentence, split them out to multiple sentences with each less than 15 seconds",
82
+ label="Text Prompts",
83
+ )
84
+ generate_btn = gr.Button("Generate", variant="primary")
85
+ # --- Column 3: Generated Audio Display (Right) - SIMPLIFIED ---
86
+ with gr.Column(scale=2, min_width=250):
87
+ gr.Markdown("### Generated Audio")
88
+ # DataFrame to display the list
89
+ audio_dataframe = gr.DataFrame(
90
+ headers=["File", "Prompt", "Duration (s)"],
91
+ datatype=["str", "str", "str"],
92
+ interactive=True, # Allow row selection
93
+ row_count=(10, "dynamic"), # Show ~10 rows, scroll if more
94
+ col_count=(3, "fixed"),
95
+ # value=format_audio_list_for_dataframe(initial_audio_list) # Set initial value via app.load
96
+ )
97
+ # Single audio player for the selected item
98
+ selected_audio_player = gr.Audio(
99
+ label="Selected Audio",
100
+ type="filepath",
101
+ interactive=False, # Only for playback
102
+ )
103
+ # Single delete button
104
+ delete_selected_btn = gr.Button("Delete Selected Audio", variant="stop")
105
+
106
+ # --- Event Handling ---
107
+
108
+ # 1. Refresh Prompts Button
109
+ refresh_prompts_btn.click(
110
+ fn=update_available_prompts, inputs=None, outputs=[prompt_dropdown]
111
+ )
112
+
113
+ # 2. Create Prompt Button
114
+ create_prompt_btn.click(
115
+ fn=create_audio_prompt,
116
+ inputs=[prompt_audio_upload, device_dropdown],
117
+ outputs=[prompt_dropdown],
118
+ )
119
+
120
+ # 3. Generate Button -> Calls backend -> Outputs to temporary state
121
+ generate_btn.click(
122
+ fn=generate_batch_audio,
123
+ inputs=[
124
+ text_input_block,
125
+ semantic_temp_slider,
126
+ coarse_temp_slider,
127
+ fine_temp_slider,
128
+ batch_size_number,
129
+ model_type_dropdown,
130
+ device_dropdown,
131
+ prompt_dropdown,
132
+ ],
133
+ outputs=[newly_generated_state],
134
+ )
135
+
136
+ # 4. Temporary State Change -> Updates the main audio list state
137
+ newly_generated_state.change(
138
+ fn=update_audio_list,
139
+ inputs=[newly_generated_state, audio_list_state],
140
+ outputs=[audio_list_state],
141
+ show_progress="hidden",
142
+ )
143
+
144
+ # 5. Main Audio List State Change -> Updates the DataFrame display
145
+ # Also clears selection when the list updates.
146
+ audio_list_state.change(
147
+ fn=format_audio_list_for_dataframe,
148
+ inputs=[audio_list_state],
149
+ outputs=[audio_dataframe],
150
+ show_progress="hidden",
151
+ ).then( # Chain: after updating dataframe, clear selection player and index
152
+ fn=lambda: (None, -1), # Function returning values to clear outputs
153
+ inputs=None,
154
+ outputs=[selected_audio_player, selected_index_state],
155
+ show_progress="hidden",
156
+ queue=False,
157
+ )
158
+
159
+ # 6. DataFrame Row Selection -> Updates the selected index and audio player
160
+ audio_dataframe.select(
161
+ fn=handle_row_selection,
162
+ inputs=[audio_list_state], # Pass the full list state to find the filepath
163
+ outputs=[
164
+ selected_audio_player,
165
+ selected_index_state,
166
+ ],
167
+ show_progress="hidden",
168
+ )
169
+
170
+ # 7. Delete Selected Button Click -> Calls delete handler
171
+ delete_selected_btn.click(
172
+ fn=handle_delete_selected,
173
+ inputs=[selected_index_state, audio_list_state], # Pass index and list
174
+ outputs=[
175
+ audio_list_state, # Update the main list state
176
+ selected_index_state, # Clear the selected index
177
+ selected_audio_player, # Clear the audio player
178
+ ],
179
+ show_progress="hidden",
180
+ )
181
+
182
+ # 8. Initial Load: Populate the DataFrame
183
+ app.load(
184
+ fn=format_audio_list_for_dataframe,
185
+ inputs=[audio_list_state], # Use the initial state value
186
+ outputs=[audio_dataframe], # Render initial data into the DataFrame
187
+ )
188
+
189
+
190
+ if __name__ == "__main__":
191
+ app.launch(debug=True, share=False)
config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # --- Configuration ---
4
+ PROMPT_DIR = "./prompts"
5
+ GENERATED_AUDIO_DIR = "./generated_audio"
6
+ os.makedirs(PROMPT_DIR, exist_ok=True)
7
+ os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
8
+
9
+ # Constants for audio generation
10
+ DEFAULT_AUDIO_SAMPLE_RATE = 24000
11
+ DEFAULT_DURATION = 3
12
+ DEFAULT_FREQ = 440
core/__init__.py ADDED
File without changes
core/bark/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from core.bark.generate_audio import *
2
+
3
+ from core.bark.encodec import *
4
+
5
+ from core.bark.voice_clone import *
core/bark/constants.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original BARK semantic vocab size
2
+ SEMANTIC_VOCAB_SIZE = 10_000
3
+ # HuBERT model output vocab size
4
+ HUBERT_OUTPUT_VOCAB_SIZE = 10_003
5
+ CODEBOOK_SIZE = 1024
6
+ N_COARSE_CODEBOOKS = 2
7
+ COARSE_RATE_HZ = 75
8
+ COARSE_SEMANTIC_PAD_TOKEN = 12_048
9
+ COARSE_INFER_TOKEN = 12_050
10
+
11
+ # for the BERT model to get semantic tokens from raw texts
12
+ TEXT_ENCODING_OFFSET = 10_048
13
+ SEMANTIC_PAD_TOKEN = 10_000
14
+ TEXT_PAD_TOKEN = 129_595
15
+ SEMANTIC_INFER_TOKEN = 129_599
16
+ SEMANTIC_RATE_HZ = 49.9
17
+
18
+ N_FINE_CODEBOOKS = 8
core/bark/custom_context.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import torch
3
+ import funcy
4
+
5
+
6
+ """
7
+ Custom context managers for PyTorch inference operations.
8
+
9
+ This module provides context managers for controlling:
10
+ - CUDA benchmarking settings
11
+ - Inference mode and gradient calculation
12
+ - Automatic mixed precision (AMP) casting
13
+
14
+ The main context manager `inference_mode()` combines all these settings
15
+ for optimal inference performance.
16
+ """
17
+
18
+
19
+ class InferenceContext:
20
+ """
21
+ Context manager for controlling CUDA benchmarking settings.
22
+
23
+ Args:
24
+ benchmark (bool): Whether to enable cudnn benchmarking. Defaults to False
25
+ since input lengths may vary in inference scenarios.
26
+
27
+ This context manager saves and restores the original cudnn.benchmark setting
28
+ when entering/exiting the context.
29
+ """
30
+
31
+ def __init__(self, benchmark=False):
32
+ # we can't expect inputs to be the same length, so disable benchmarking by default
33
+ self._chosen_cudnn_benchmark = benchmark
34
+ self._cudnn_benchmark = None
35
+
36
+ def __enter__(self):
37
+ self._cudnn_benchmark = torch.backends.cudnn.benchmark
38
+ torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
39
+
40
+ def __exit__(self, exc_type, exc_value, exc_traceback):
41
+ torch.backends.cudnn.benchmark = self._cudnn_benchmark
42
+
43
+
44
+ if (
45
+ torch.cuda.is_available()
46
+ and hasattr(torch.cuda, "amp")
47
+ and hasattr(torch.cuda.amp, "autocast")
48
+ and hasattr(torch.cuda, "is_bf16_supported")
49
+ and torch.cuda.is_bf16_supported()
50
+ ):
51
+ autocast = funcy.partial(
52
+ torch.amp.autocast, dtype=torch.bfloat16, device_type="cuda"
53
+ )
54
+ """Context manager for automatic mixed precision (AMP) using bfloat16 where supported."""
55
+ else:
56
+
57
+ @contextlib.contextmanager
58
+ def autocast():
59
+ """No-op autocast context manager when bfloat16 is not supported."""
60
+ yield
61
+
62
+
63
+ @contextlib.contextmanager
64
+ def inference_mode():
65
+ """
66
+ Combined context manager for optimal inference performance.
67
+
68
+ Combines:
69
+ - CUDA benchmarking control
70
+ - PyTorch inference mode
71
+ - Disabled gradient calculation
72
+ - Automatic mixed precision casting (where supported)
73
+
74
+ Usage:
75
+ with inference_mode():
76
+ # inference operations here
77
+ """
78
+ with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
79
+ yield
core/bark/encodec.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from encodec import EncodecModel
5
+ from encodec.utils import convert_audio
6
+ from core.memory import model_manager, ModelEnum, env
7
+ from core.bark.custom_context import inference_mode
8
+
9
+
10
+ def encodec_decode_fine_tokens_to_audio(fine_tokens: torch.Tensor) -> np.ndarray:
11
+ """
12
+ expecting fine_tokens shape [codebook_size, timestep], concretely [8, 75*duration_in_sec]
13
+ Decode the given fine_tokens using the Encodec's decoder
14
+ Returns the audio sample array as an np.ndarray
15
+ Returns
16
+ np.ndarray of shape (B, C, T), C = 1 for mono audio
17
+ """
18
+ model_info = ModelEnum.ENCODEC24k.value
19
+
20
+ model_wrapper = model_manager.get_model(model_info)
21
+ model: EncodecModel = model_wrapper.model
22
+
23
+ device = next(model.parameters()).device
24
+
25
+ input_tensor = fine_tokens.transpose(0, 1).to(device)
26
+
27
+ emb = model.quantizer.decode(input_tensor)
28
+
29
+ output: torch.Tensor = model.decoder(emb)
30
+ audio_arr = output.detach().cpu().numpy()
31
+
32
+ del input_tensor, emb, output
33
+
34
+ return audio_arr
35
+
36
+
37
+ def encodec_encode_audio(
38
+ audio_sample: torch.Tensor, audio_sample_rate: int
39
+ ) -> torch.Tensor:
40
+ """
41
+ Encode the given audio sample using the encodec model
42
+ audio_sample expected shape: (channels, sample)
43
+
44
+ Returns codes as a tensor shape [n_q, T]
45
+ where n_q typically is 8 and T is the compressed time step dimension (75 per second for 24khz model)
46
+ """
47
+ model_wrapper = model_manager.get_model(ModelEnum.ENCODEC24k.value)
48
+ model: EncodecModel = model_wrapper.model
49
+
50
+ device = next(model.parameters()).device
51
+
52
+ wav = convert_audio(
53
+ audio_sample, audio_sample_rate, model.sample_rate, model.channels
54
+ )
55
+ wav = wav.unsqueeze(0).float().to(device)
56
+
57
+ # Extract discrete codes from EnCodec
58
+ with inference_mode():
59
+ encoded_frames = model.encode(wav)
60
+
61
+ codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
62
+
63
+ return codes[0, :, :]
core/bark/generate_audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ from typing_extensions import Union, List
4
+ import numpy as np
5
+ import torch
6
+ from dataclasses import asdict
7
+
8
+ from core.bark.generate_semantic import generate_semantic_tokens_from_text
9
+ from core.bark.generate_coarse import generate_coarse_tokens_from_semantic
10
+ from core.bark.generate_fine import generate_fine_tokens_from_coarse
11
+
12
+
13
+ from core.data_model.bark import BarkPrompt, BarkGenerationConfig
14
+ from core.bark.encodec import encodec_decode_fine_tokens_to_audio
15
+ from core.bark.constants import SEMANTIC_PAD_TOKEN, SEMANTIC_RATE_HZ
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format="%(asctime)s - %(levelname)s - %(message)s",
20
+ handlers=[logging.StreamHandler(sys.stdout)],
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def generate_audio(
26
+ texts: List[str],
27
+ prompt: Union[BarkPrompt, None] = None,
28
+ generation_config: BarkGenerationConfig = None,
29
+ silent: bool = False,
30
+ ) -> List[np.ndarray]:
31
+ """
32
+ Generate audio from text with an optional audio prompt
33
+ Args:
34
+ text (str): Input text to generate audio. Must be non-empty.
35
+ num_gen (int): number of audio to generate per text
36
+ prompt (Union[str, None]): optional path to a prompt file of type .npz that will be used as the audio prompt
37
+ generation_config: configurations to generate audio
38
+
39
+ """
40
+ if prompt is not None:
41
+ semantic_prompt = prompt.semantic_prompt if prompt is not None else None
42
+ # if len(semantic_prompt.shape) == 2:
43
+ # semantic_prompt = semantic_prompt[0, :]
44
+ assert (
45
+ len(semantic_prompt.shape) == 1
46
+ ), "expecting semantic prompt as a 1D array"
47
+ else:
48
+ semantic_prompt = None
49
+
50
+ if generation_config is None:
51
+ logger.info("using BARK default generation config")
52
+ generation_config = BarkGenerationConfig()
53
+
54
+ semantic_tokens = generate_semantic_tokens_from_text(
55
+ texts,
56
+ semantic_prompt,
57
+ **asdict(generation_config),
58
+ silent=silent,
59
+ )
60
+
61
+ # because we generate audio in batch, all audios in one batch have the same length
62
+ # of the longest audio. We need to remove the random section of the shorter audio
63
+ # after it has ended
64
+
65
+ # coarse token generation
66
+ coarse_tokens = generate_coarse_tokens_from_semantic(
67
+ semantic_tokens, prompt, **asdict(generation_config), silent=silent
68
+ )
69
+
70
+ # fine token generation
71
+ fine_tokens = generate_fine_tokens_from_coarse(
72
+ coarse_tokens=coarse_tokens,
73
+ history_prompt=prompt,
74
+ temperature=generation_config.generate_fine_temperature,
75
+ use_small_model=generation_config.use_small_model,
76
+ silent=silent,
77
+ )
78
+
79
+ # decoding the codes
80
+ audio_wave = encodec_decode_fine_tokens_to_audio(fine_tokens)
81
+ assert (
82
+ len(audio_wave.shape) == 3
83
+ ), f"expecting audio tensor of shape (B, C, T), received {audio_wave.shape}"
84
+
85
+ audio_wave = audio_wave.squeeze(1) # squeeze the channel dimension
86
+ res = remove_padded_segment_from_audio(audio_wave, semantic_tokens.cpu().numpy())
87
+ return res
88
+
89
+
90
+ def remove_padded_segment_from_audio(
91
+ audio_wave: np.ndarray, semantic_tokens: np.ndarray, audio_sample_rate: int = 24000
92
+ ) -> List[np.ndarray]:
93
+ # Because the semantic token tensor's time step dimension is of the longest audio in the sample
94
+ # all the remaining audio have shorter length would have random sound after its end
95
+ # we will change the values of coarse_token tensor of shorter audios at positions after it end
96
+ # to avoid random sound in the generated results
97
+ # SEMANTIC_PAD_TOKEN is also the end of sentence token
98
+ # this function assume audio_wave has shape (batch, T)
99
+ assert (
100
+ len(audio_wave.shape) == 2
101
+ ), f"expecting ndarray of shape (B, T), received {audio_wave.shape}"
102
+ mask = semantic_tokens == SEMANTIC_PAD_TOKEN
103
+ semantic_eos_indices = np.argmax(mask.astype(np.int32), axis=1) # Shape [batch]
104
+ wave_eos_indices: np.ndarray = semantic_eos_indices * (
105
+ audio_sample_rate / SEMANTIC_RATE_HZ
106
+ )
107
+ wave_eos_indices = wave_eos_indices.astype(np.int32)
108
+ res = []
109
+ for wave_index in range(audio_wave.shape[0]):
110
+ if wave_eos_indices[wave_index] == 0:
111
+ # zero means this audio is the longest one in the batch and there is no need to cut the padded segment
112
+ res.append(audio_wave[wave_index])
113
+ continue
114
+ start_padding_index = wave_eos_indices[wave_index]
115
+ res.append(audio_wave[wave_index, :start_padding_index])
116
+
117
+ return res
core/bark/generate_audio_semantic_dataset.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from dataclasses import asdict
5
+ from core.bark.generate_semantic import generate_semantic_tokens_from_text
6
+ from core.bark.generate_coarse import generate_coarse_tokens_from_semantic
7
+ from core.bark.generate_fine import generate_fine_tokens_from_coarse
8
+ from core.bark.encodec import encodec_decode_fine_tokens_to_audio
9
+ from core.bark.generate_audio import remove_padded_segment_from_audio
10
+ from core.data_model import WavSemantic, WavSemanticDataset, BarkGenerationConfig
11
+ from core.bark.constants import SEMANTIC_PAD_TOKEN
12
+
13
+
14
+ def generate_wav_semantic_dataset(
15
+ text_file_path: str,
16
+ generation_config: BarkGenerationConfig,
17
+ batch_size: int = 16,
18
+ silent: bool = False,
19
+ save_path: str = "./dataset",
20
+ save_data_as_raw_audio: bool = True,
21
+ ) -> None:
22
+ """
23
+ Generate a dataset of (wav, semantic_tokens) for training a model to predict semantic tokens from audio
24
+
25
+ Args
26
+ text_file_path: path to the text file that will be used to generate audio data
27
+ generation_config: the config used to generate data
28
+ batch_size: batch size when generate data
29
+ bark_model_type: either `large` or `small`, the coarse and fine model variant that will be used to generate audio
30
+ max_token_per_example: a criteria to limit the length of an example from text. The text will be tokenized using a BERT tokenizer,
31
+ and the tokenized text will be truncated to not exceed this length
32
+ save_path: path to save the generated dataset
33
+ save_data_as_raw_audio: if True, waves will be saved as raw audio, otherwise it will be saved as compressed .npz file
34
+ """
35
+ texts = read_text_file(text_file_path)
36
+ assert len(texts) > 0, "empty text data"
37
+
38
+ mini_batches = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
39
+ progress_bar = tqdm(
40
+ total=len(mini_batches), disable=silent, desc="Generating wav-semantic dataset"
41
+ )
42
+ for batch in mini_batches:
43
+ semantic_tokens = generate_semantic_tokens_from_text(
44
+ texts=batch, semantic_prompt=None, silent=True, **asdict(generation_config)
45
+ )
46
+
47
+ coarse = generate_coarse_tokens_from_semantic(
48
+ semantic_tokens=semantic_tokens,
49
+ history_prompt=None,
50
+ silent=True,
51
+ **asdict(generation_config)
52
+ )
53
+
54
+ fine = generate_fine_tokens_from_coarse(
55
+ coarse_tokens=coarse,
56
+ history_prompt=None,
57
+ temperature=generation_config.generate_fine_temperature,
58
+ use_small_model=generation_config.use_small_model,
59
+ silent=True,
60
+ )
61
+
62
+ # generate audio waves from the fine tokens
63
+ waves = encodec_decode_fine_tokens_to_audio(fine)
64
+ # remove the channel dimension
65
+ waves = waves.squeeze(1)
66
+
67
+ waves = remove_padded_segment_from_audio(waves, semantic_tokens.cpu().numpy())
68
+
69
+ save_semantic_wave_data(
70
+ batch,
71
+ waves,
72
+ semantic_tokens.detach().cpu().numpy(),
73
+ 24000,
74
+ generation_config,
75
+ save_path,
76
+ save_data_as_raw_audio,
77
+ )
78
+
79
+ progress_bar.update(1)
80
+ del semantic_tokens, coarse, fine, waves
81
+
82
+
83
+ def save_semantic_wave_data(
84
+ texts: List[str],
85
+ waves: List[np.ndarray],
86
+ semantic_tokens: np.ndarray,
87
+ sample_rate: int,
88
+ generation_config: BarkGenerationConfig,
89
+ save_path: str,
90
+ save_raw_audio: bool,
91
+ ) -> None:
92
+ """
93
+ Save the given data as a WaveSemantic dataset
94
+ """
95
+ examples = []
96
+ assert (
97
+ len(texts) == len(waves) == semantic_tokens.shape[0]
98
+ ), "unexpected array length"
99
+
100
+ model_type = "small" if generation_config.use_small_model else "large"
101
+
102
+ # remove the padding tokens at the end of the semantic sequences
103
+ mask = semantic_tokens == SEMANTIC_PAD_TOKEN
104
+ semantic_padding_indices = np.argmax(mask.astype(np.int32), axis=1)
105
+
106
+ for i, (text, padding_index) in enumerate(zip(texts, semantic_padding_indices)):
107
+ if padding_index == 0:
108
+ padding_index = len(semantic_tokens[i])
109
+ example = WavSemantic(text, waves[i], semantic_tokens[i, :padding_index])
110
+ examples.append(example)
111
+
112
+ dataset = WavSemanticDataset(sample_rate, generation_config, model_type, examples)
113
+
114
+ dataset.save(save_path, save_raw_audio)
115
+
116
+
117
+ def read_text_file(path: str) -> List[str]:
118
+ with open(path, "r") as file:
119
+ lines = file.readlines()
120
+ # Remove newline characters
121
+ lines = [line.strip() for line in lines]
122
+ return lines
core/bark/generate_coarse.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from tqdm import tqdm
5
+ from typing_extensions import Optional, Union, Tuple
6
+
7
+ from core.bark.constants import *
8
+ from core.model.bark import GPT
9
+ from core.data_model.bark import BarkPrompt
10
+ from core.bark.custom_context import inference_mode
11
+
12
+ from core.memory import model_manager, ModelEnum, env
13
+
14
+ # number of coarse tokens per one semantic token for one second
15
+ num_coarse_per_semantic = (COARSE_RATE_HZ / SEMANTIC_RATE_HZ) * N_COARSE_CODEBOOKS
16
+
17
+
18
+ def generate_coarse_tokens_from_semantic(
19
+ semantic_tokens: torch.Tensor,
20
+ history_prompt: Union[BarkPrompt, None] = None,
21
+ generate_coarse_temperature: Union[float, None] = 0.6,
22
+ coarse_top_k: Union[int, None] = None,
23
+ coarse_top_p: Union[float, None] = None,
24
+ silent: bool = False,
25
+ max_coarse_history: int = 630,
26
+ sliding_window_length: int = 60,
27
+ use_kv_caching: bool = True,
28
+ use_small_model: bool = False,
29
+ **kwargs,
30
+ ) -> torch.Tensor:
31
+ # Validate inputs
32
+ _validate_semantic_tokens(semantic_tokens)
33
+ _validate_history_prompt(history_prompt)
34
+
35
+ assert (
36
+ 60 <= max_coarse_history <= 630
37
+ ), "max_coarse_history must be between 60 and 630"
38
+ assert (
39
+ max_coarse_history + sliding_window_length <= 1024 - 256
40
+ ), "Context exceeds model limit"
41
+
42
+ # align the number of semantic history token with the given max_coarse_history
43
+ max_semantic_history = int(max_coarse_history / num_coarse_per_semantic)
44
+
45
+ # align the length of the provided semantic and coarse history
46
+ semantic_history, coarse_history = _process_history_prompt(
47
+ history_prompt, max_semantic_history, num_coarse_per_semantic
48
+ )
49
+
50
+ # Load coarse model
51
+ coarse_model_info = (
52
+ ModelEnum.BARK_COARSE_SMALL.value
53
+ if use_small_model
54
+ else ModelEnum.BARK_COARSE.value
55
+ )
56
+ model_wrapper = model_manager.get_model(coarse_model_info)
57
+ model: GPT = model_wrapper.model
58
+ assert isinstance(model, GPT), "unexpected model type"
59
+
60
+ # total_steps is the number of coarse tokens the model need to predict
61
+ total_steps = int(
62
+ np.floor(semantic_tokens.size(1) * num_coarse_per_semantic / N_COARSE_CODEBOOKS)
63
+ * N_COARSE_CODEBOOKS
64
+ )
65
+ assert (
66
+ total_steps > 0 and total_steps % N_COARSE_CODEBOOKS == 0
67
+ ), "Invalid step count"
68
+
69
+ batch, T = semantic_tokens.size()
70
+ # expand the semantic history at the batch dimension to match with the semantic_tokens tensor's batch size
71
+ # for the concatenation
72
+ semantic_history = semantic_history[None].expand((batch, -1))
73
+ full_semantic = torch.hstack([semantic_history, semantic_tokens]).to(torch.int32)
74
+ base_semantic_index = semantic_history.size(1)
75
+
76
+ # Generate coarse tokens
77
+ with inference_mode():
78
+ generated_coarse = _generate_coarse_with_sliding_window(
79
+ model,
80
+ full_semantic,
81
+ coarse_history,
82
+ total_steps,
83
+ base_semantic_index,
84
+ max_semantic_history,
85
+ num_coarse_per_semantic,
86
+ generate_coarse_temperature,
87
+ coarse_top_k,
88
+ coarse_top_p,
89
+ silent,
90
+ max_coarse_history,
91
+ sliding_window_length,
92
+ use_kv_caching,
93
+ )
94
+
95
+ # remove the history prompt from the generated tokens
96
+ generated_coarse = generated_coarse[:, coarse_history.size(0) :]
97
+ assert generated_coarse.size(1) == total_steps, "Generated length mismatch"
98
+
99
+ # Reshape and adjust coarse codes
100
+ B, L = generated_coarse.shape
101
+ # Broadcasting subtracts from all elements
102
+ coarse_output = (
103
+ generated_coarse.reshape(B, -1, N_COARSE_CODEBOOKS).transpose(1, 2)
104
+ - SEMANTIC_VOCAB_SIZE
105
+ )
106
+
107
+ for codebook_idx in range(1, N_COARSE_CODEBOOKS):
108
+ coarse_output[:, codebook_idx, :] -= codebook_idx * CODEBOOK_SIZE
109
+
110
+ return coarse_output
111
+
112
+
113
+ def _validate_semantic_tokens(semantic_tokens: torch.Tensor) -> None:
114
+ """
115
+ Validate the input semantic tokens tensor.
116
+
117
+ Args:
118
+ semantic_tokens: Tensor of semantic tokens (1D).
119
+
120
+ Raises:
121
+ AssertionError: If the tensor does not meet expected criteria.
122
+ """
123
+ assert isinstance(
124
+ semantic_tokens, torch.Tensor
125
+ ), "Semantic tokens must be a torch.Tensor"
126
+ assert semantic_tokens.dim() == 2, "Semantic tokens must be 2D"
127
+ assert semantic_tokens.size(1) > 0, "Semantic tokens tensor cannot be empty"
128
+ assert semantic_tokens.min() >= 0, "Semantic tokens must be non-negative"
129
+ assert (
130
+ semantic_tokens.max() <= SEMANTIC_VOCAB_SIZE
131
+ ), "Semantic tokens exceed vocab size"
132
+
133
+
134
+ def _validate_history_prompt(history_prompt: Union[BarkPrompt, None]) -> None:
135
+ """
136
+ Validate the history prompt if provided.
137
+
138
+ Args:
139
+ history_prompt: BarkPrompt object or None.
140
+
141
+ Raises:
142
+ AssertionError: If the prompt does not meet expected criteria.
143
+ """
144
+ if history_prompt is None:
145
+ return
146
+
147
+ assert isinstance(
148
+ history_prompt, BarkPrompt
149
+ ), "History prompt must be a BarkPrompt object"
150
+ semantic = history_prompt.semantic_prompt
151
+ coarse = history_prompt.coarse_prompt
152
+
153
+ assert (
154
+ isinstance(semantic, torch.Tensor) and semantic.dim() == 1
155
+ ), "Semantic prompt must be 1D tensor"
156
+ assert (
157
+ semantic.size(0) > 0
158
+ and semantic.min() >= 0
159
+ and semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
160
+ )
161
+ assert (
162
+ isinstance(coarse, torch.Tensor) and coarse.dim() == 2
163
+ ), "Coarse prompt must be 2D tensor"
164
+ assert (
165
+ coarse.shape[0] == N_COARSE_CODEBOOKS
166
+ ), "Coarse prompt must have correct number of codebooks"
167
+ assert coarse.min() >= 0 and coarse.max() <= CODEBOOK_SIZE - 1
168
+
169
+
170
+ def _process_history_prompt(
171
+ history_prompt: Union[BarkPrompt, None],
172
+ max_semantic_history: int,
173
+ coarse_to_semantic_ratio: float,
174
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ """
176
+ Process the history prompt into semantic and coarse history tensors.
177
+ Trim on the left (keep the right most tokens)
178
+ Args:
179
+ history_prompt: BarkPrompt object or None.
180
+ max_semantic_history: Maximum number of semantic history tokens.
181
+ coarse_to_semantic_ratio: Ratio of coarse to semantic token rates.
182
+
183
+ Returns:
184
+ Tuple[semantic_history, coarse_history]: Processed history tensors.
185
+ """
186
+ if history_prompt is None:
187
+ return torch.tensor(
188
+ [], dtype=torch.int32, device=torch.device(env.DEVICE)
189
+ ), torch.tensor([], dtype=torch.int32, device=torch.device(env.DEVICE))
190
+
191
+ semantic_history = history_prompt.semantic_prompt
192
+ coarse_history = history_prompt.coarse_prompt
193
+
194
+ # Add offset then "ravel("F")" flatten
195
+ coarse_history = _add_codebook_offset(coarse_history, CODEBOOK_SIZE)
196
+ coarse_history_flat = coarse_history.T.flatten() + SEMANTIC_VOCAB_SIZE
197
+
198
+ # Trim histories to fit max length
199
+ n_semantic_hist = min(
200
+ max_semantic_history,
201
+ semantic_history.size(0) - semantic_history.size(0) % 2, # Ensure even length
202
+ int(coarse_history_flat.size(0) // coarse_to_semantic_ratio),
203
+ )
204
+ n_coarse_hist = int(round(n_semantic_hist * coarse_to_semantic_ratio))
205
+
206
+ semantic_history = semantic_history[-n_semantic_hist:].to(torch.int32)
207
+ coarse_history_flat = coarse_history_flat[-n_coarse_hist:].to(torch.int32)
208
+ coarse_history_flat = coarse_history_flat[:-2] # Original time alignment hack
209
+
210
+ return semantic_history, coarse_history_flat
211
+
212
+
213
+ def _add_codebook_offset(x: torch.Tensor, offset: int) -> torch.Tensor:
214
+ """
215
+ x shape (n_codebook, T)
216
+ n_codebook start from 0 to n, from the second codebook row on we add offset * row_num
217
+ """
218
+ for n in range(1, x.shape[0]):
219
+ x[n, :] += offset * n
220
+ return x
221
+
222
+
223
+ def _sample_coarse_token(
224
+ logits: torch.Tensor,
225
+ temperature: Union[float, None],
226
+ top_k: Optional[int],
227
+ top_p: Optional[float],
228
+ logit_start_idx: int,
229
+ ) -> torch.Tensor:
230
+ """
231
+ Sample a coarse token from model logits with filtering.
232
+
233
+ Args:
234
+ logits: Model output logits (shape [batch, seq, vocab]).
235
+ temperature: Sampling temperature for randomness.
236
+ top_k: Number of top logits to consider, if specified.
237
+ top_p: Nucleus sampling threshold, if specified.
238
+ logit_start_idx: Starting index for coarse token logits.
239
+
240
+ Returns:
241
+ torch.Tensor: Sampled token with offset applied (shape [1]).
242
+ """
243
+ relevant_logits = logits[:, 0, logit_start_idx : logit_start_idx + CODEBOOK_SIZE]
244
+
245
+ if temperature is None:
246
+ probs = F.softmax(relevant_logits, dim=-1)
247
+ next_token = torch.argmax(probs, dim=-1, keepdim=True).to(torch.int32)
248
+ else:
249
+ if top_p is not None: # this branch is untested
250
+ # Optimize with NumPy for top-p filtering,
251
+ original_device = relevant_logits.device
252
+ logits_np = relevant_logits.detach().cpu().numpy().astype(np.float32)
253
+ sorted_indices = np.argsort(logits_np)[::-1]
254
+ sorted_logits = logits_np[sorted_indices]
255
+ cumulative_probs = np.cumsum(
256
+ F.softmax(torch.from_numpy(sorted_logits), dim=-1).numpy()
257
+ )
258
+ indices_to_remove = cumulative_probs > top_p
259
+ indices_to_remove[1:] = indices_to_remove[:-1].copy()
260
+ indices_to_remove[0] = False
261
+ logits_np[sorted_indices[indices_to_remove]] = -np.inf
262
+ relevant_logits = torch.from_numpy(logits_np).to(original_device)
263
+
264
+ if top_k is not None:
265
+ top_values, _ = torch.topk(
266
+ relevant_logits, min(top_k, relevant_logits.size(-1))
267
+ )
268
+ relevant_logits[relevant_logits < top_values[:, [-1]]] = -float("Inf")
269
+
270
+ probs = F.softmax(relevant_logits / temperature, dim=-1)
271
+ next_token = torch.multinomial(probs, num_samples=1).to(torch.int32)
272
+ return next_token + logit_start_idx
273
+
274
+
275
+ def _generate_coarse_with_sliding_window(
276
+ model: GPT,
277
+ full_semantic: torch.Tensor,
278
+ coarse_history: torch.Tensor,
279
+ total_steps: int,
280
+ base_semantic_index: int,
281
+ max_semantic_history: int,
282
+ coarse_per_semantic: float,
283
+ temperature: float,
284
+ top_k: Optional[int],
285
+ top_p: Optional[float],
286
+ silent: bool,
287
+ max_coarse_history: int,
288
+ sliding_window_length: int,
289
+ use_kv_caching: bool,
290
+ ) -> torch.Tensor:
291
+ """
292
+ Generate coarse tokens using a sliding window approach.
293
+
294
+ Args:
295
+ model: GPT model for coarse token generation.
296
+ full_semantic: 2D tensor of Concatenated semantic history and input tokens.
297
+ coarse_history: 1D tensor, Initial coarse history tokens.
298
+ total_steps: Total number of coarse tokens to generate.
299
+ base_semantic_index: Start index of input semantic tokens.
300
+ max_semantic_history: Maximum semantic history length.
301
+ coarse_per_semantic: Coarse-to-semantic token ratio.
302
+ temperature: Sampling temperature.
303
+ top_k: Top-k filtering parameter.
304
+ top_p: Top-p filtering parameter.
305
+ silent: Suppresses progress bar if True.
306
+ max_coarse_history: Maximum coarse history length.
307
+ sliding_window_length: Tokens per window.
308
+ use_kv_caching: Enables KV caching.
309
+
310
+ Returns:
311
+ torch.Tensor: Generated coarse tokens (1D).
312
+ """
313
+ device = next(model.parameters()).device
314
+ semantic_tensor = full_semantic.to(device) # Add batch dimension
315
+ coarse_tensor = (
316
+ coarse_history[None].expand((semantic_tensor.shape[0], -1)).to(device)
317
+ )
318
+
319
+ window_count = int(np.ceil(total_steps / sliding_window_length))
320
+ progress_bar = tqdm(
321
+ total=window_count, disable=silent, desc="Generating coarse tokens"
322
+ )
323
+ step_counter = 0 # equivalent to the number of coarse tokens generated so far
324
+
325
+ for _ in range(window_count):
326
+ current_semantic_idx = base_semantic_index + int(
327
+ round(step_counter / coarse_per_semantic)
328
+ )
329
+
330
+ window_start = max(0, current_semantic_idx - max_semantic_history)
331
+ semantic_window = semantic_tensor[:, window_start : window_start + 256]
332
+ semantic_window = F.pad(
333
+ semantic_window,
334
+ (0, 256 - semantic_window.shape[-1]),
335
+ "constant",
336
+ COARSE_SEMANTIC_PAD_TOKEN,
337
+ )
338
+
339
+ input_tensor = torch.hstack(
340
+ [
341
+ semantic_window,
342
+ torch.tensor([COARSE_INFER_TOKEN], device=device)[None].expand(
343
+ (semantic_window.shape[0], -1)
344
+ ),
345
+ coarse_tensor[:, -max_coarse_history:],
346
+ ]
347
+ )
348
+
349
+ kv_cache = None
350
+ for _ in range(sliding_window_length):
351
+ if step_counter >= total_steps:
352
+ break
353
+
354
+ is_first_codebook = step_counter % N_COARSE_CODEBOOKS == 0
355
+ logit_start_idx = (
356
+ SEMANTIC_VOCAB_SIZE + (1 - int(is_first_codebook)) * CODEBOOK_SIZE
357
+ )
358
+
359
+ model_input = (
360
+ input_tensor[:, [-1]]
361
+ if use_kv_caching and kv_cache is not None
362
+ else input_tensor
363
+ )
364
+ logits, kv_cache = model(
365
+ model_input, use_cache=use_kv_caching, past_kv=kv_cache
366
+ )
367
+ next_token = _sample_coarse_token(
368
+ logits,
369
+ temperature,
370
+ top_k,
371
+ top_p,
372
+ logit_start_idx,
373
+ )
374
+
375
+ coarse_tensor = torch.cat((coarse_tensor, next_token), dim=1)
376
+ input_tensor = torch.cat((input_tensor, next_token), dim=1)
377
+
378
+ step_counter += 1
379
+ del logits, next_token
380
+
381
+ del input_tensor
382
+ progress_bar.update(1)
383
+
384
+ progress_bar.close()
385
+ return coarse_tensor
core/bark/generate_fine.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ import torch.nn.functional as F
6
+
7
+ from core.data_model.bark import BarkPrompt
8
+ from core.bark.custom_context import inference_mode
9
+ from core.model import FineGPT
10
+ from core.memory import ModelEnum, model_manager
11
+ from core.bark.constants import *
12
+
13
+
14
+ def generate_fine_tokens_from_coarse(
15
+ coarse_tokens: torch.Tensor,
16
+ history_prompt: Union[BarkPrompt, None] = None,
17
+ temperature: float = 0.5,
18
+ use_small_model: bool = True,
19
+ silent: bool = False,
20
+ ) -> torch.Tensor:
21
+ """
22
+ Generate fine-grained audio codes from coarse audio codes using the BARK fine model.
23
+
24
+ This function takes coarse tokens (representing a partial set of audio codebooks) and
25
+ autoregressively predicts the remaining fine tokens, optionally conditioning on a history
26
+ prompt. The process involves sliding a context window over the sequence, predicting 512
27
+ timesteps at a time based on a 1024-timestep input.
28
+
29
+ Prompt tokens are trim on the left (keep the right most tokens)
30
+
31
+ Args:
32
+ coarse_tokens (torch.Tensor): Coarse audio codes with shape (batch, n_coarse, sequence_length),
33
+ where n_coarse <= N_FINE_CODEBOOKS - 1 and values are in [0, CODEBOOK_SIZE - 1].
34
+ history_prompt (BarkPrompt, optional): Historical fine tokens for conditioning, or None.
35
+ temperature (float): Sampling temperature for fine token prediction; if None, uses argmax.
36
+ silent (bool): If True, suppresses progress bar output.
37
+
38
+ Returns:
39
+ torch.Tensor: Fine audio codes with shape (N_FINE_CODEBOOKS, sequence_length),
40
+ matching the input sequence_length.
41
+
42
+ Raises:
43
+ AssertionError: If input validation fails for coarse_tokens or history_prompt.
44
+ """
45
+ # Validate inputs
46
+ _validate_coarse_tokens(coarse_tokens=coarse_tokens)
47
+ history_fine_tokens = _validate_and_load_history(history_prompt=history_prompt)
48
+ batch, n_coarse, sequence_length = coarse_tokens.shape
49
+
50
+ # Load the fine model
51
+ model_info = (
52
+ ModelEnum.BARK_FINE_SMALL.value
53
+ if use_small_model
54
+ else ModelEnum.BARK_FINE.value
55
+ )
56
+ model_wrapper = model_manager.get_model(model_info)
57
+ model: FineGPT = model_wrapper.model
58
+ assert isinstance(model, FineGPT), "Expected FineGPT model type"
59
+ device = next(model.parameters()).device
60
+ coarse_tokens = coarse_tokens.to(device)
61
+ # stack coarse tokens with padding for remaining codebooks across the codebook dimension
62
+ # e.g original coarse_token shape (B, 2, T), after vstack shape: (B, 8, T) where codebook size = 8
63
+ pad_tensor = torch.full(
64
+ (batch, N_FINE_CODEBOOKS - n_coarse, sequence_length),
65
+ CODEBOOK_SIZE,
66
+ dtype=torch.int32,
67
+ device=device,
68
+ )
69
+
70
+ input_tensor = torch.cat((coarse_tokens, pad_tensor), dim=1)
71
+
72
+ # Prepend history if provided. Maximum history time step is 512
73
+ # this is a horizontal prepend on the left of the previous padded input tensor
74
+ # output tensor: (8, history_timestep + coarse_timestep), history_timestep <= 512
75
+ n_history = 0
76
+ if history_fine_tokens is not None:
77
+ history_fine_tokens = history_fine_tokens.expand((batch, N_FINE_CODEBOOKS, -1))
78
+ history_limit = min(history_fine_tokens.shape[-1], 512)
79
+ history_slice = history_fine_tokens[:, :, -history_limit:].to(
80
+ device, dtype=torch.int32
81
+ )
82
+ input_tensor = torch.cat((history_slice, input_tensor), dim=-1)
83
+ n_history = history_limit # number of time step dimension in the prompt
84
+
85
+ # right Pad if total_length (history_timestep + coarse_timestep) is less than model context (1024)
86
+ total_length = input_tensor.shape[-1]
87
+ padding_needed = max(0, 1024 - total_length)
88
+ if padding_needed > 0:
89
+ padding = torch.full(
90
+ (batch, N_FINE_CODEBOOKS, padding_needed),
91
+ CODEBOOK_SIZE,
92
+ dtype=torch.int32,
93
+ device=device,
94
+ )
95
+ input_tensor = torch.cat((input_tensor, padding), dim=2)
96
+ total_length = input_tensor.shape[-1]
97
+
98
+ # Calculate number of prediction loops
99
+ context_window = 1024 # Model's input context size
100
+ prediction_step = 512 # Number of new timesteps predicted per loop
101
+ remaining_length = max(0, sequence_length - (context_window - n_history))
102
+ extra_loops = (remaining_length + prediction_step - 1) // prediction_step
103
+ n_loops = 1 + extra_loops # Total loops: initial + extra
104
+
105
+ # Process sequence in sliding windows
106
+ input_tensor = input_tensor.transpose(
107
+ -2, -1
108
+ ) # Shape: (total_length, N_FINE_CODEBOOKS)
109
+ with inference_mode():
110
+ for loop_idx in tqdm(
111
+ range(n_loops), disable=silent, desc="Generating fine tokens"
112
+ ):
113
+ # Define window boundaries
114
+ # the last loop, by using window_start = (total_length - context_window),
115
+ # the input will be: input_tensor[:, -1024:, :], the last context_window timestep of the input
116
+ window_start = min(
117
+ loop_idx * prediction_step, total_length - context_window
118
+ )
119
+
120
+ fill_start = min(
121
+ n_history + loop_idx * prediction_step, total_length - prediction_step
122
+ )
123
+ fill_offset = fill_start - window_start
124
+ window_end = window_start + context_window
125
+
126
+ # Extract input window
127
+ # Shape: (1, 1024, N_FINE_CODEBOOKS)
128
+ input_window = input_tensor[:, window_start:window_end, :]
129
+
130
+ # Predict fine codebooks autoregressively
131
+ for codebook_idx in range(n_coarse, N_FINE_CODEBOOKS):
132
+ # Shape: (1, 1024, vocab_size)
133
+ logits = model(codebook_idx, input_window)
134
+ if temperature is None:
135
+ preds = torch.argmax(
136
+ logits[:, fill_offset:, :CODEBOOK_SIZE], dim=-1
137
+ )
138
+ else:
139
+ scaled_logits = logits[:, :, :CODEBOOK_SIZE] / temperature
140
+ probs = F.softmax(scaled_logits, dim=-1)
141
+ probs = probs[:, fill_offset:, :]
142
+ # Reshape to [2 * N, 1024] for multinomial
143
+ B, N, C = probs.shape # B=2, N=512-fill_offset, C=1024
144
+ probs_2d = probs.reshape(-1, C) # Shape: [2 * N, 1024]
145
+
146
+ # Perform multinomial sampling
147
+ # Shape: [2 * N, 1]
148
+ preds = torch.multinomial(probs_2d, num_samples=1)
149
+
150
+ # Reshape back to [2, N] after squeezing
151
+ preds = preds.squeeze(-1).reshape(B, N)
152
+
153
+ input_window[:, fill_offset:, codebook_idx] = preds.to(torch.int32)
154
+
155
+ # Update main tensor with predictions
156
+ fill_length = min(prediction_step, total_length - fill_start)
157
+ input_tensor[:, fill_start : fill_start + fill_length, codebook_idx] = (
158
+ input_window[
159
+ :, fill_offset : fill_offset + fill_length, codebook_idx
160
+ ]
161
+ )
162
+
163
+ # Extract final result, removing history and padding
164
+ # Shape: (N_FINE_CODEBOOKS, sequence_length)
165
+ fine_tokens = input_tensor.transpose(-1, -2)[
166
+ :, :, n_history : n_history + sequence_length
167
+ ]
168
+
169
+ # Verify output shape matches input sequence length
170
+ assert fine_tokens.shape[-1] == sequence_length, "Output length mismatch"
171
+
172
+ return fine_tokens
173
+
174
+
175
+ def _validate_coarse_tokens(coarse_tokens: torch.Tensor) -> None:
176
+ """Validate coarse token tensor properties."""
177
+ assert isinstance(
178
+ coarse_tokens, torch.Tensor
179
+ ), "coarse_tokens must be a torch.Tensor"
180
+ assert len(coarse_tokens.shape) == 3, "coarse_tokens must be 3D"
181
+ assert (
182
+ 1 <= coarse_tokens.shape[1] <= N_FINE_CODEBOOKS - 1
183
+ ), "Invalid number of coarse codebooks"
184
+ assert coarse_tokens.shape[-1] > 0, "Sequence length must be positive"
185
+ assert (
186
+ coarse_tokens.min() >= 0 and coarse_tokens.max() <= CODEBOOK_SIZE
187
+ ), "Token values out of range"
188
+
189
+
190
+ def _validate_and_load_history(
191
+ history_prompt: Union[BarkPrompt, None],
192
+ ) -> Union[torch.Tensor, None]:
193
+ """Validate and load history prompt if provided."""
194
+ if history_prompt is None:
195
+ return None
196
+
197
+ history_fine_tokens = history_prompt.fine_prompt
198
+ assert isinstance(
199
+ history_fine_tokens, torch.Tensor
200
+ ), "history_prompt.fine_prompt must be a torch.Tensor"
201
+ assert len(history_fine_tokens.shape) == 2, "History must be 2D"
202
+ assert (
203
+ history_fine_tokens.shape[0] == N_FINE_CODEBOOKS
204
+ ), "History must have all fine codebooks"
205
+ assert history_fine_tokens.shape[1] > 0, "History must not empty"
206
+ assert (
207
+ history_fine_tokens.min() >= 0
208
+ and history_fine_tokens.max() <= CODEBOOK_SIZE - 1
209
+ ), "History values out of range"
210
+ return history_fine_tokens
core/bark/generate_semantic.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Union
2
+ import re
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import BertTokenizer
9
+
10
+ from core.memory import model_manager, ModelEnum, env
11
+ from core.bark.custom_context import inference_mode
12
+ from core.bark.constants import *
13
+ from core.model import GPT
14
+
15
+ SEMANTIC_EOS_TOKEN = 10_000
16
+
17
+
18
+ def generate_semantic_tokens_from_text(
19
+ texts: List[str],
20
+ semantic_prompt: Union[torch.Tensor, None] = None,
21
+ temperature: Union[float, None] = 0.7,
22
+ semantic_top_k: Union[int, None] = None,
23
+ semantic_top_p: Union[int, None] = None,
24
+ min_eos_p: float = 0.2,
25
+ max_gen_duration_second: Union[float, None] = None,
26
+ allow_early_stop: bool = True,
27
+ use_kv_caching: bool = True,
28
+ use_small_model: bool = True,
29
+ silent: Union[bool, None] = False,
30
+ max_token_ids_per_sentence: int = 256,
31
+ **kwargs,
32
+ ) -> torch.Tensor:
33
+ # trim white spaces and replace redundant white space characters
34
+ texts = _preprocess_texts(texts)
35
+ assert all([len(text) > 0 for text in texts]), f"invalid input text {texts}"
36
+
37
+ if semantic_prompt is None:
38
+ semantic_prompt = torch.tensor([])
39
+ else:
40
+ assert isinstance(
41
+ semantic_prompt, torch.Tensor
42
+ ), f"expecting semantic_prompt of type torch.Tensor, received {type(semantic_prompt)}"
43
+ assert semantic_prompt.dim() == 1, "expect 1D tensor as semantic_prompt"
44
+
45
+ # load the GPT-style model that generate semantic token from text
46
+ # and the BERT tokenizer to memory
47
+ text_model_info = (
48
+ ModelEnum.BARK_TEXT_SMALL.value
49
+ if use_small_model
50
+ else ModelEnum.BARK_TEXT.value
51
+ )
52
+
53
+ text_model = model_manager.get_model(text_model_info)
54
+ assert text_model.model is not None, "text model is None"
55
+ assert text_model.preprocessor is not None, "tokenizer for the text model is None"
56
+
57
+ assert isinstance(
58
+ text_model.model, GPT
59
+ ), f"expecting model of type GPT, got {type(text_model.model)}"
60
+
61
+ assert isinstance(
62
+ text_model.preprocessor, BertTokenizer
63
+ ), f"expecting preprocessor of type BertTokenizer, got {type(text_model.preprocessor)}"
64
+
65
+ model: GPT = text_model.model
66
+ tokenizer: BertTokenizer = text_model.preprocessor
67
+ device = next(model.parameters()).device
68
+
69
+ # tokenize the given text using the BERT tokenizer
70
+ token_ids = [tokenizer.encode(text, add_special_tokens=False) for text in texts]
71
+
72
+ # for each token_ids of each sentence, append an encoding offset token
73
+ token_ids = [np.array(sentence) + TEXT_ENCODING_OFFSET for sentence in token_ids]
74
+
75
+ # encoded_text's length must has length 256 as from the original implementation
76
+ # pad to the right if the token_ids of the sentence is shorter, trim on the right if it is longer than 256 tokens
77
+ token_ids = [
78
+ trim_or_pad_array(sentence, TEXT_PAD_TOKEN, max_token_ids_per_sentence)
79
+ for sentence in token_ids
80
+ ]
81
+
82
+ token_ids_tensor = torch.vstack(token_ids).to(dtype=torch.int32, device=device)
83
+
84
+ # when the token_ids list has one element (batch size = 1), the above cat operation created a 1D tensor
85
+ # we need to check and make it 2D
86
+ if len(token_ids_tensor.shape) == 1:
87
+ token_ids_tensor = token_ids_tensor.unsqueeze(0)
88
+ # semantic prompt also need to be an array of 256 discrete tokens
89
+ semantic_prompt = trim_or_pad_array(semantic_prompt, SEMANTIC_PAD_TOKEN, 256)
90
+
91
+ # need to replicate the semantic_prompt array to match the shape of the token_ids for concatenation
92
+ semantic_prompt = (
93
+ semantic_prompt.unsqueeze(0).expand((token_ids_tensor.shape[0], -1)).to(device)
94
+ )
95
+
96
+ # final input is the concatenation of the token_ids and the semantic tokens array
97
+ input_tensor = torch.cat(
98
+ [
99
+ token_ids_tensor, # shape (batch_size, T)
100
+ semantic_prompt,
101
+ torch.tensor([SEMANTIC_INFER_TOKEN], device=device)
102
+ .unsqueeze(0)
103
+ .expand((token_ids_tensor.shape[0], -1)),
104
+ ],
105
+ dim=1,
106
+ ).to(torch.int64)
107
+
108
+ # 256 token_ids, 256 prompt tokens, 1 semantic_infer token as the last column
109
+ assert (
110
+ input_tensor.shape[1] == 256 + 256 + 1
111
+ ), f"expecting tensor shape [batch, 513], received {input_tensor.shape}"
112
+
113
+ with inference_mode():
114
+ output: torch.Tensor = _generate_semantic(
115
+ model=model,
116
+ x=input_tensor,
117
+ temperature=temperature,
118
+ top_k=semantic_top_k,
119
+ top_p=semantic_top_p,
120
+ min_eos_p=min_eos_p,
121
+ max_gen_duration_s=max_gen_duration_second,
122
+ allow_early_stop=allow_early_stop,
123
+ use_kv_caching=use_kv_caching,
124
+ silent=silent,
125
+ )
126
+
127
+ validate_semantic_token_output(output)
128
+ return output
129
+
130
+
131
+ def _generate_semantic(
132
+ model: GPT,
133
+ x: torch.Tensor,
134
+ temperature: float = 0.7,
135
+ top_k: Optional[int] = None,
136
+ top_p: Optional[float] = None,
137
+ min_eos_p: float = 0.2,
138
+ max_gen_duration_s: Optional[float] = None,
139
+ allow_early_stop: bool = True,
140
+ use_kv_caching: bool = False,
141
+ silent: bool = False,
142
+ ) -> torch.Tensor:
143
+ # Maximum number of tokens to generate
144
+ max_steps = 2048
145
+
146
+ # Initialize progress bar for user feedback (custom due to unpredictable stopping)
147
+ progress_bar = tqdm(
148
+ total=max_steps, disable=silent, desc="Generating semantic tokens"
149
+ )
150
+ last_progress = 0
151
+
152
+ # Key-value cache for attention optimization
153
+ kv_cache = None
154
+
155
+ # Autoregressive generation loop
156
+ for step in range(max_steps):
157
+ # Determine input based on KV caching
158
+ if use_kv_caching and kv_cache is not None:
159
+ # Use only the last token with cached attention states
160
+ x_input = x[:, [-1]] # Shape [1, 1]
161
+ else:
162
+ # Use full sequence (recomputes attention each time)
163
+ x_input = x # Shape [1, seq_len]
164
+
165
+ # Forward pass through the model
166
+ logits, kv_cache = model(
167
+ x_input,
168
+ merge_context=True, # Merges text and semantic history context
169
+ past_kv=kv_cache, # Previous attention states
170
+ use_cache=use_kv_caching, # Enables caching if requested
171
+ )
172
+
173
+ # Sample the next token and check for early stopping
174
+ next_token, should_stop = _sample_next_token(
175
+ logits=logits,
176
+ temperature=temperature,
177
+ top_k=top_k,
178
+ top_p=top_p,
179
+ semantic_eos_token=SEMANTIC_EOS_TOKEN,
180
+ allow_early_stop=allow_early_stop,
181
+ min_eos_p=min_eos_p,
182
+ )
183
+
184
+ # Check stopping conditions
185
+ # only stop if all generations in the batch reached the stopping condition
186
+ if torch.all(should_stop):
187
+ progress_bar.update(step - last_progress + 1)
188
+ break
189
+
190
+ if step == max_steps - 1:
191
+ progress_bar.update()
192
+ break
193
+
194
+ # Append the new token to the sequence
195
+ x = torch.cat((x, next_token), dim=1)
196
+
197
+ # Update duration and progress
198
+ # total_duration_s += duration_per_step
199
+ if step > last_progress:
200
+ progress_bar.update(step - last_progress)
201
+ last_progress = step
202
+
203
+ # Clean up tensors to manage memory
204
+ del logits, next_token
205
+
206
+ # Finalize progress bar
207
+ progress_bar.total = step + 1
208
+ progress_bar.close()
209
+
210
+ # Extract generated tokens (skip initial 513 context tokens)
211
+ output = x[:, 256 + 256 + 1 :].detach()
212
+
213
+ return output
214
+
215
+
216
+ def _sample_next_token(
217
+ logits: torch.Tensor, # what is the shape of logits?
218
+ temperature: float,
219
+ top_k: Optional[int],
220
+ top_p: Optional[float],
221
+ semantic_eos_token: int,
222
+ allow_early_stop: bool,
223
+ min_eos_p: Optional[float],
224
+ ) -> Tuple[torch.Tensor, torch.BoolTensor]:
225
+ """
226
+ Sample the next token from logits with optional top-k, top-p filtering and early stopping.
227
+
228
+ Args:
229
+ logits: Tensor of shape [batch, seq, vocab_size] containing model predictions.
230
+ temperature: Controls randomness of sampling (lower = more deterministic).
231
+ top_k: If set, keeps only the top-k logits.
232
+ top_p: If set, applies nucleus (top-p) filtering.
233
+ vocab_size: Size of the semantic vocabulary (e.g., SEMANTIC_VOCAB_SIZE).
234
+ allow_early_stop: Whether to check for EOS token or probability threshold.
235
+ min_eos_p: Minimum probability for EOS to trigger early stop.
236
+ eos_token: Token ID representing end-of-sequence.
237
+
238
+ Returns:
239
+ Tuple[next_token, should_stop]:
240
+ - next_token: Sampled token (shape [1]).
241
+ - should_stop: Whether to stop generation (EOS detected).
242
+ """
243
+ # Extract logits for the last position in the sequence
244
+ relevant_logits = logits[:, -1, :semantic_eos_token]
245
+
246
+ # Append EOS logit if early stopping is allowed
247
+ if allow_early_stop:
248
+ eos_logit = logits[:, -1, [semantic_eos_token]]
249
+ relevant_logits = torch.hstack((relevant_logits, eos_logit))
250
+
251
+ # select the token with the highest probability
252
+ if temperature is None:
253
+ # next_token shape (B, 1)
254
+ probs = F.softmax(relevant_logits, dim=-1)
255
+ next_token = torch.argmax(probs, dim=-1, keepdim=True)
256
+ # when the model predict a 206 token_id, it continue to predict that same token_id with argmax
257
+ # we will intentionally avoid that token_id here
258
+ if torch.any(next_token == 206):
259
+ next_token = anything_but(probs, 206)
260
+
261
+ # do some maneuvers to introduce diversity in the sampling of the next token
262
+ else:
263
+ # Apply top-p (nucleus) filtering for diversity
264
+ if top_p is not None: # this if branch is untested
265
+ # Convert to NumPy for faster sorting (optimization from original)
266
+ original_device = relevant_logits.device
267
+ logits_np = relevant_logits.detach().cpu().type(torch.float32).numpy()
268
+ sorted_indices = np.argsort(logits_np)[::-1] # Descending order
269
+ sorted_logits = logits_np[sorted_indices]
270
+ cumulative_probs = np.cumsum(
271
+ F.softmax(torch.from_numpy(sorted_logits), dim=-1).numpy()
272
+ )
273
+ indices_to_remove = cumulative_probs > top_p
274
+ # Shift to keep at least one
275
+ indices_to_remove[1:] = indices_to_remove[:-1].copy()
276
+ indices_to_remove[0] = False # Ensure top token stays
277
+ logits_np[sorted_indices[indices_to_remove]] = -np.inf
278
+ relevant_logits = torch.from_numpy(logits_np).to(original_device)
279
+
280
+ # Apply top-k filtering for diversity
281
+ if top_k is not None:
282
+ top_values, _ = torch.topk(
283
+ relevant_logits, min(top_k, relevant_logits.size(-1))
284
+ )
285
+ # compare the whole logit tensor to its k_th largest value, batch wise
286
+ relevant_logits[relevant_logits < top_values[:, [-1]]] = -float("Inf")
287
+
288
+ # Compute probabilities with temperature scaling
289
+ probs = F.softmax(relevant_logits / temperature, dim=-1)
290
+
291
+ # Sample the next token
292
+ next_token = torch.multinomial(probs, num_samples=1).to(torch.int32)
293
+
294
+ # Check for early stopping conditions for each sequence in the batch
295
+ if allow_early_stop:
296
+ # EOS token is vocab_size when appended
297
+ is_eos_token = (next_token == semantic_eos_token).flatten()
298
+ eos_prob_high = min_eos_p is not None and probs[:, -1] >= min_eos_p
299
+ should_stop = torch.logical_or(is_eos_token, eos_prob_high)
300
+
301
+ # when batch dimension is 1, next_token is a 1D array, need to make it 2D
302
+ if len(next_token.shape) == 1:
303
+ next_token = next_token.unsqueeze(0)
304
+ return next_token, should_stop
305
+
306
+
307
+ # select the second largest probability token if the argmax is the avoided token
308
+ # otherwise select the argmax token
309
+ def anything_but(probs: torch.Tensor, avoid_id: int) -> torch.Tensor:
310
+ # probs shape (B, C)
311
+ # return tensor shape (B, 1)
312
+ values, indices = torch.topk(probs, 2, dim=-1)
313
+ selected = []
314
+ # loop over the batch dimension
315
+ for b in range(probs.shape[0]):
316
+ if indices[b, 0] == avoid_id:
317
+ selected.append(indices[b, 1])
318
+ continue
319
+ selected.append(indices[b, 0])
320
+ return torch.tensor(selected, dtype=torch.int32, device=probs.device).unsqueeze(1)
321
+
322
+
323
+ def validate_semantic_token_output(output: torch.Tensor) -> None:
324
+ assert torch.all(
325
+ (0 <= output) & (output <= SEMANTIC_VOCAB_SIZE)
326
+ ), "unexpected output tokens"
327
+
328
+
329
+ # preprocess the texts for the generate_text_semantic model
330
+ def _preprocess_texts(texts: List[str]) -> List[str]:
331
+ return [re.sub(r"\s+", " ", text).strip() for text in texts]
332
+
333
+
334
+ def trim_or_pad_array(
335
+ array: Union[np.ndarray, torch.Tensor], pad_token: int, max_length: int = 256
336
+ ) -> torch.Tensor:
337
+ """
338
+ Trim on the left (keep the right most tokens), pad on the right
339
+ """
340
+ # Convert np.ndarray to torch.Tensor if necessary
341
+ if isinstance(array, np.ndarray):
342
+ tensor = torch.from_numpy(array).to(device=torch.device(env.DEVICE))
343
+ else: # Already a torch.Tensor
344
+ tensor = array
345
+
346
+ # Get the current length
347
+ current_length = tensor.shape[0]
348
+
349
+ if current_length > max_length:
350
+ # Trim from the end (last max_length elements)
351
+ return tensor[-max_length:]
352
+
353
+ elif current_length < max_length:
354
+ # Left pad 0, right pad to max_length
355
+ padding = (0, max_length - current_length)
356
+ return torch.nn.functional.pad(
357
+ tensor, padding, mode="constant", value=pad_token
358
+ )
359
+
360
+ # If length equals max_length, just return as is
361
+ return tensor
core/bark/voice_clone.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from typing import Optional
4
+
5
+ from core.utils import read_audio_file
6
+ from core.bark import encodec_encode_audio
7
+
8
+ from core.model.hubert import HuBERTForBarkSemantic
9
+ from core.memory import model_manager, ModelEnum
10
+ from core.bark.custom_context import InferenceContext
11
+ from core.data_model import *
12
+
13
+
14
+ HUBERT_SAMPLE_RATE = 16000
15
+
16
+
17
+ def generate_semantic_tokens_from_hubert(
18
+ waves: torch.Tensor,
19
+ audio_sample_rate: int,
20
+ temperature: float,
21
+ eos_p: float,
22
+ max_length: int,
23
+ device: Optional[torch.device],
24
+ inference_dtype: torch.dtype = torch.float32,
25
+ ) -> torch.Tensor:
26
+ """
27
+ Generate semantic tokens from audio using the HuBERT model.
28
+
29
+ Args:
30
+ audio: 2D tensor of raw audio samples (shape: [B, T], where T is the number of samples)
31
+ sample_rate: Sample rate of the input audio (default: 24000, matching EnCodec in BARK)
32
+ hubert_model_name: Name of the HuBERT model from Hugging Face (default: facebook/hubert-large-ls960-ft)
33
+ device: Torch device to run the model on (defaults to CUDA if available, else CPU)
34
+ max_length: Maximum length of semantic tokens to return (optional, for truncation)
35
+
36
+ Returns:
37
+ torch.Tensor: 1D tensor of semantic tokens (e.g., shape [N], where N is the sequence length)
38
+
39
+ Raises:
40
+ RuntimeError: If HuBERT model loading or processing fails
41
+ """
42
+ assert (
43
+ len(waves.shape) == 2
44
+ ), f"expecting a tensor of shape [B, T], got {waves.shape}"
45
+ waves = waves.to(device)
46
+
47
+ # # HuBERT expects audio at 16kHz, resample if necessary
48
+ if audio_sample_rate != HUBERT_SAMPLE_RATE:
49
+ resampler = torchaudio.transforms.Resample(
50
+ orig_freq=audio_sample_rate, new_freq=HUBERT_SAMPLE_RATE
51
+ ).to(device)
52
+ waves = resampler(waves)
53
+
54
+ model = model_manager.get_model(ModelEnum.HuBERTBaseForBarkSemantic.value).model
55
+
56
+ assert isinstance(
57
+ model, HuBERTForBarkSemantic
58
+ ), f"expecting HuBERTForBarkSemantic model type, received {type(model)}"
59
+
60
+ waves = waves.to(dtype=inference_dtype)
61
+ model = model.to(dtype=inference_dtype)
62
+
63
+ with InferenceContext():
64
+ predictions: torch.Tensor = model.generate(
65
+ wav_input=waves, temperature=temperature, eos_p=eos_p, max_length=max_length
66
+ )
67
+
68
+ return predictions
69
+
70
+
71
+ def create_bark_prompt(
72
+ audio_file: AudioFile, temperature: float, eos_p: float, device: torch.device
73
+ ) -> BarkPrompt:
74
+ """
75
+ Turn raw audio into valid BARK prompt. When given a raw audio file, use this function
76
+ to generate a valid BARK prompt
77
+ """
78
+ # Read the audio
79
+ raw_audio = read_audio_file(
80
+ path=audio_file.audio_file_path,
81
+ target_sample_rate=HUBERT_SAMPLE_RATE,
82
+ channels=1,
83
+ max_duration=15,
84
+ )
85
+
86
+ audio_tensor = torch.tensor(raw_audio.astype(np.float32), device=device)
87
+ # Generate semantic tokens from audio using HuBERT
88
+ semantic_tokens: torch.Tensor = generate_semantic_tokens_from_hubert(
89
+ waves=audio_tensor.unsqueeze(0),
90
+ audio_sample_rate=16000,
91
+ temperature=temperature,
92
+ eos_p=eos_p,
93
+ max_length=600,
94
+ device=device,
95
+ )
96
+
97
+ # Generate codebook tokens using EnCodec
98
+ codes = encodec_encode_audio(
99
+ audio_sample=torch.from_numpy(raw_audio[None]),
100
+ audio_sample_rate=HUBERT_SAMPLE_RATE,
101
+ )
102
+
103
+ # Assuming codes has shape [num_codebooks, T], typically 8 codebooks for 24kHz
104
+ return BarkPrompt(semantic_tokens, codes[:2, :], codes[:, :])
core/data_model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from core.data_model.bark import *
core/data_model/bark.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from pathlib import Path
4
+ import torch
5
+
6
+ from dataclasses import dataclass, asdict, fields
7
+ import numpy as np
8
+ from enum import Enum
9
+ from pydantic import BaseModel, Field
10
+ from typing import Optional, Union, List, Literal
11
+ from datetime import datetime
12
+ from core.utils import save_audio_file, read_audio_file
13
+
14
+
15
+ @dataclass
16
+ class BarkGenerationConfig:
17
+ semantic_top_k: Union[int, None] = 1000 # a tenth of the semantic vocab size
18
+ coarse_top_k: Union[int, None] = 100 # a tenth of the coarse codebook size
19
+ semantic_top_p: Union[int, None] = None
20
+ coarse_top_p: Union[int, None] = None
21
+ min_eos_p: float = 0.5
22
+ max_gen_duration_second: Union[float, None] = None
23
+ allow_early_stop: bool = True
24
+ use_kv_caching: bool = True
25
+ max_coarse_history: int = 630
26
+ sliding_window_length: int = 60
27
+ max_token_per_example: int = 256
28
+ # set to None to use argmax sampling
29
+ temperature: float = 0.6
30
+ generate_coarse_temperature: float = 0.6
31
+ # set this to None if you want to use argmax to generate fine token
32
+ generate_fine_temperature: float = 0.6
33
+ use_small_model: bool = True
34
+
35
+ def __init__(self, **kwargs):
36
+ # Get field names from dataclass
37
+ valid_fields = {f.name for f in fields(self)}
38
+ # Set only known fields
39
+ for key, value in kwargs.items():
40
+ if key in valid_fields:
41
+ setattr(self, key, value)
42
+
43
+ def to_dict(self) -> dict:
44
+ return asdict(self)
45
+
46
+ @classmethod
47
+ def from_dict(cls, data: dict) -> "BarkGenerationConfig":
48
+ return cls(**data)
49
+
50
+
51
+ @dataclass
52
+ class BarkPrompt:
53
+ """
54
+ semantic_prompt shape: (T)
55
+ coarse_prompt shape: (2, T)
56
+ fine_prompt shape: (8, T)
57
+ those T are different depends on the rate of token type per second
58
+ """
59
+
60
+ semantic_prompt: torch.Tensor
61
+ coarse_prompt: torch.Tensor
62
+ fine_prompt: torch.Tensor
63
+
64
+ def save_prompt(self, file_path: str) -> bool:
65
+ """
66
+ Save all 3 prompts to disk as JSON. Return True if success, False if error
67
+ """
68
+ # Ensure the directory exists
69
+ directory = os.path.dirname(file_path)
70
+ if directory: # If there's a directory component
71
+ os.makedirs(directory, exist_ok=True)
72
+
73
+ data = {
74
+ "semantic_prompt": self.semantic_prompt.detach().cpu().tolist(),
75
+ "coarse_prompt": self.coarse_prompt.detach().cpu().tolist(),
76
+ "fine_prompt": self.fine_prompt.detach().cpu().tolist(),
77
+ }
78
+
79
+ if not file_path.endswith(".json"):
80
+ file_path += ".json"
81
+
82
+ try:
83
+ with open(file_path, "w", encoding="utf-8") as f:
84
+ json.dump(data, f)
85
+ return True
86
+ except Exception:
87
+ return False
88
+
89
+ @classmethod
90
+ def load_prompt(cls, file_path: str, device: torch.device) -> "BarkPrompt":
91
+ """
92
+ Load a prompt from disk. File to load can be either a .json or .npz file
93
+ """
94
+ try:
95
+ if file_path.endswith(".json"):
96
+ with open(file_path, "r", encoding="utf-8") as f:
97
+ prompt = json.load(f)
98
+
99
+ assert (
100
+ "semantic_prompt" in prompt
101
+ and "coarse_prompt" in prompt
102
+ and "fine_prompt" in prompt
103
+ ), f"invalid prompt data {prompt}"
104
+
105
+ semantic_prompt = torch.tensor(prompt["semantic_prompt"])
106
+ coarse_prompt = torch.tensor(prompt["coarse_prompt"])
107
+ fine_prompt = torch.tensor(prompt["fine_prompt"])
108
+
109
+ elif file_path.endswith(".npz"):
110
+ with np.load(file_path) as data:
111
+ assert (
112
+ "semantic_prompt" in data
113
+ and "coarse_prompt" in data
114
+ and "fine_prompt" in data
115
+ ), f"invalid prompt data in NPZ file"
116
+
117
+ semantic_prompt = torch.from_numpy(data["semantic_prompt"])
118
+ coarse_prompt = torch.from_numpy(data["coarse_prompt"])
119
+ fine_prompt = torch.from_numpy(data["fine_prompt"])
120
+
121
+ else:
122
+ raise ValueError("Unsupported file format. Use .json or .npz")
123
+
124
+ # Convert to device and dtype after loading
125
+ semantic_prompt = semantic_prompt.to(device=device, dtype=torch.int32)
126
+ coarse_prompt = coarse_prompt.to(device=device, dtype=torch.int32)
127
+ fine_prompt = fine_prompt.to(device=device, dtype=torch.int32)
128
+
129
+ # Shape checks remain the same
130
+ if len(semantic_prompt.shape) == 2:
131
+ semantic_prompt = semantic_prompt[0, :]
132
+ assert (
133
+ len(semantic_prompt.shape) == 1
134
+ ), "expecting semantic_prompt as a 1D array"
135
+
136
+ assert (
137
+ coarse_prompt.shape[0] == 2
138
+ ), "expecting coarse_prompt has 2 code book dimension"
139
+
140
+ assert (
141
+ fine_prompt.shape[0] == 8
142
+ ), "expecting fine_prompt has 8 code book dimension"
143
+
144
+ return cls(semantic_prompt, coarse_prompt, fine_prompt)
145
+
146
+ except Exception as e:
147
+ raise ValueError(f"Failed to load file: {str(e)}")
148
+
149
+
150
+ class AudioFile(BaseModel):
151
+ """Model for validating raw audio prompt inputs."""
152
+
153
+ audio_file_path: str = Field(..., description="Path to the audio file")
154
+ max_duration: int = Field(
155
+ ..., ge=1, description="Maximum duration of the audio in seconds"
156
+ )
157
+
158
+ def get_default_prompt_name(self) -> str:
159
+ audio_file_name = Path(self.audio_file_path).name
160
+ return f"{audio_file_name}_{datetime.now().strftime('%Y_%m_%d_%H_%M')}"
161
+
162
+
163
+ class TextToAudioInput(BaseModel):
164
+ """Model for validating inputs to the text-to-audio generation function."""
165
+
166
+ texts: List[str] = Field(
167
+ ..., min_items=1, description="List of text strings to convert to audio"
168
+ )
169
+ audio_prompt: Optional[Union[AudioFile, str]] = Field(
170
+ None, description="Optional audio prompt (raw or file path)"
171
+ )
172
+ sample_rate: int = Field(
173
+ default=24000, ge=1, description="Sample rate for generated audio"
174
+ )
175
+ device: Optional[str] = Field(
176
+ None, description="Device to use for generation (e.g., 'cuda', 'cpu')"
177
+ )
178
+ save_path: str = Field(
179
+ default="./artifact", description="Directory to save generated audio files"
180
+ )
181
+
182
+
183
+ class TextToAudioModel(Enum):
184
+ BARK = "BARK"
185
+
186
+
187
+ @dataclass
188
+ class WavSemantic:
189
+ """
190
+ An example of a pair (wav, semantic) for training a model to predict semantic from audio
191
+ """
192
+
193
+ text: str
194
+ wav: np.ndarray
195
+ semantic: np.ndarray
196
+
197
+
198
+ @dataclass
199
+ class WavSemanticDataset:
200
+ sample_rate: int
201
+ semantic_generation_config: BarkGenerationConfig
202
+ bark_model_type: Literal["small", "large"]
203
+ data: List[WavSemantic]
204
+
205
+ def save(self, save_path: str, save_raw_audio: bool) -> None:
206
+ """
207
+ Save this WavSemanticDataset instance to disk at the specified path with compression.
208
+
209
+ Args:
210
+ save_path: Directory path where the dataset will be saved (default: './data').
211
+ """
212
+ # Ensure the save directory exists
213
+ save_dir = Path(save_path)
214
+ save_dir.mkdir(parents=True, exist_ok=True)
215
+
216
+ # this allows continuous saving of data, e.g save every new batch of data generated
217
+ if not os.path.exists(save_dir / "metadata.json"):
218
+ # Prepare metadata dictionary using instance attributes
219
+ metadata = {
220
+ "sample_rate": self.sample_rate,
221
+ "semantic_generation_config": self.semantic_generation_config.to_dict(),
222
+ "bark_model_type": self.bark_model_type,
223
+ }
224
+
225
+ # Save metadata as JSON
226
+ with open(save_dir / "metadata.json", "w") as f:
227
+ json.dump(metadata, f, indent=2)
228
+
229
+ next_index = self._get_latest_saved_file_index(save_path) + 1
230
+ # Save each WavSemantic sample
231
+ for i, sample in enumerate(self.data):
232
+ sample_dir = save_dir / f"sample_{i+next_index}"
233
+ sample_dir.mkdir(exist_ok=True)
234
+
235
+ # Save text
236
+ with open(sample_dir / "text.txt", "w") as f:
237
+ f.write(sample.text)
238
+
239
+ # Save wav and semantic in a single compressed .npz file
240
+ if save_raw_audio:
241
+ save_audio_file(
242
+ sample.wav, self.sample_rate, str(sample_dir / "audio.wav")
243
+ )
244
+ with open(sample_dir / "semantic.json", "w") as f:
245
+ json.dump(sample.semantic.tolist(), f)
246
+ else:
247
+ np.savez_compressed(
248
+ sample_dir / "data.npz", wav=sample.wav, semantic=sample.semantic
249
+ )
250
+
251
+ @staticmethod
252
+ def _get_latest_saved_file_index(dataset_path: str) -> int:
253
+ file_names = os.listdir(dataset_path)
254
+ file_names.remove("metadata.json")
255
+ if len(file_names) == 0:
256
+ return -1
257
+
258
+ indices = [
259
+ int(file_name.split("_")[-1].split(".")[0]) for file_name in file_names
260
+ ]
261
+
262
+ return max(indices)
263
+
264
+ @classmethod
265
+ def load(cls, load_path: str, num_samples: int = 5000) -> "WavSemanticDataset":
266
+ """
267
+ Load a WavSemanticDataset from disk at the specified path.
268
+
269
+ Args:
270
+ load_path: Directory path where the dataset is saved.
271
+ num_samples: maximum number of samples to load from the folder
272
+ Returns:
273
+ A new WavSemanticDataset instance loaded from disk.
274
+ """
275
+ load_dir = Path(load_path)
276
+ if not load_dir.exists():
277
+ raise FileNotFoundError(f"Directory {load_path} does not exist")
278
+
279
+ filenames = os.listdir(load_dir)
280
+ if len(filenames) == 1:
281
+ # when there is a folder inside the load_path folder, step into it
282
+ load_dir = load_dir / filenames[0]
283
+ filenames = os.listdir(load_dir)
284
+
285
+ # Load metadata
286
+ with open(load_dir / "metadata.json", "r") as f:
287
+ metadata = json.load(f)
288
+
289
+ # Reconstruct semantic_generation_config
290
+ config = BarkGenerationConfig.from_dict(metadata["semantic_generation_config"])
291
+
292
+ # Load each WavSemantic sample
293
+ data = []
294
+ for i, filename in enumerate(filenames):
295
+ if not "sample" in filename:
296
+ continue
297
+ sample_dir = load_dir / filename
298
+
299
+ # Load text
300
+ with open(sample_dir / "text.txt", "r") as f:
301
+ text = f.read()
302
+
303
+ # Load compressed wav and semantic from .npz file
304
+ if os.path.isfile(sample_dir / "data.npz"):
305
+ with np.load(sample_dir / "data.npz") as npz_data:
306
+ wav = npz_data["wav"]
307
+ semantic = npz_data["semantic"]
308
+ # assuming audio wave file was stored separately from the semantic file
309
+ else:
310
+ # assuming "audio.wav" and "semantic.npz" exist in the folder
311
+ wav = read_audio_file(
312
+ sample_dir / "audio.wav", metadata["sample_rate"], 1, False, None
313
+ )
314
+ if os.path.isfile(sample_dir / "semantic.npz"):
315
+ with np.load(sample_dir / "semantic.npz") as npz_data:
316
+ semantic = npz_data["semantic"]
317
+ elif os.path.isfile(sample_dir / "semantic.json"):
318
+ with open(sample_dir / "semantic.json") as f:
319
+ semantic = np.array(json.load(f))
320
+
321
+ data.append(WavSemantic(text=text, wav=wav, semantic=semantic))
322
+ if i > num_samples:
323
+ break
324
+
325
+ # Reconstruct and return the dataset
326
+ return cls(
327
+ sample_rate=metadata["sample_rate"],
328
+ semantic_generation_config=config,
329
+ bark_model_type=metadata["bark_model_type"],
330
+ data=data,
331
+ )
332
+
333
+ def __getitem__(self, idx: int) -> WavSemantic:
334
+ return self.data[idx]
335
+
336
+ def __len__(self) -> int:
337
+ return len(self.data)
core/memory/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from core.memory.model_manager import *
2
+
3
+ from core.memory.model_manager import *
4
+
5
+ from core.memory.common import *
core/memory/common.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from dotenv import load_dotenv
4
+ from enum import Enum
5
+ from typing import ClassVar, Dict, Any
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # Configure logging with a default level (will be updated by EnvVars)
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class LogLevel(Enum):
15
+ """Enumeration of valid logging levels."""
16
+
17
+ DEBUG = "DEBUG"
18
+ INFO = "INFO"
19
+ WARNING = "WARNING"
20
+ ERROR = "ERROR"
21
+ CRITICAL = "CRITICAL"
22
+
23
+
24
+ def grab_best_device(use_gpu: bool, enable_mps: bool) -> str:
25
+ """
26
+ Determine the best available device for PyTorch operations.
27
+
28
+ Args:
29
+ use_gpu (bool): Whether to prioritize GPU/MPS over CPU.
30
+ enable_mps (bool): Whether to allow MPS (Metal Performance Shaders) on Apple Silicon.
31
+
32
+ Returns:
33
+ str: Device identifier ("cuda", "mps", or "cpu").
34
+ """
35
+ if use_gpu and torch.cuda.is_available():
36
+ device = "cuda"
37
+ logger.debug("Selected CUDA device (GPU available)")
38
+ elif use_gpu and enable_mps and torch.backends.mps.is_available():
39
+ device = "mps"
40
+ logger.debug("Selected MPS device (Apple Silicon GPU available)")
41
+ else:
42
+ device = "cpu"
43
+ logger.debug("Selected CPU device (no GPU/MPS available or disabled)")
44
+ return device
45
+
46
+
47
+ class EnvVars:
48
+ """
49
+ Class to manage and expose environment variables with type safety and runtime configurability.
50
+
51
+ Loads variables from a .env file or system environment, applies defaults if not found, and allows updates
52
+ at runtime. Variables are stored as instance attributes rather than polluting the global namespace.
53
+ """
54
+
55
+ # Default values for environment variables
56
+ _DEFAULTS: ClassVar[Dict[str, Any]] = {
57
+ "GLOBAL_ENABLE_MPS": True, # Enable PyTorch's Metal Performance Shaders on Apple Silicon
58
+ "AUDIO_SAMPLE_RATE": 24000, # Default sample rate for audio processing (in Hz)
59
+ "SUNO_USE_SMALL_MODELS": True, # Use smaller Bark models if True
60
+ "CACHE_DIR": "./models",
61
+ "LOG_LEVEL": LogLevel.INFO, # Default logging level
62
+ "USE_GPU": True, # Whether to prioritize GPU/MPS over CPU
63
+ }
64
+
65
+ def __init__(self) -> None:
66
+ """Initialize the EnvVars instance and load variables."""
67
+ self._vars: Dict[str, Any] = {}
68
+ self._load_env_vars()
69
+ self._update_attributes()
70
+
71
+ def _load_env_vars(self) -> None:
72
+ """Load environment variables from .env file or system, falling back to defaults."""
73
+ load_dotenv() # Load .env file into os.environ
74
+ for var_name, default_value in self._DEFAULTS.items():
75
+ value = os.getenv(var_name)
76
+ if value is None:
77
+ logger.info(
78
+ f"{var_name} not found in environment, using default: {default_value}"
79
+ )
80
+ self._vars[var_name] = default_value
81
+ else:
82
+ # Convert value to the appropriate type based on default
83
+ if isinstance(default_value, bool):
84
+ self._vars[var_name] = value.lower() in ("true", "1", "t")
85
+ elif isinstance(default_value, int):
86
+ self._vars[var_name] = int(value)
87
+ elif isinstance(default_value, float):
88
+ self._vars[var_name] = float(value)
89
+ elif isinstance(default_value, LogLevel):
90
+ self._vars[var_name] = LogLevel(value.upper())
91
+ else:
92
+ self._vars[var_name] = value
93
+ logger.info(
94
+ f"{var_name} loaded from environment: {self._vars[var_name]}"
95
+ )
96
+
97
+ def _update_attributes(self) -> None:
98
+ """Update instance attributes and apply settings (e.g., logging level, device)."""
99
+ # Set instance attributes
100
+ self.GLOBAL_ENABLE_MPS: bool = self._vars["GLOBAL_ENABLE_MPS"]
101
+ self.AUDIO_SAMPLE_RATE: int = self._vars["AUDIO_SAMPLE_RATE"]
102
+ self.SUNO_USE_SMALL_MODELS: bool = self._vars["SUNO_USE_SMALL_MODELS"]
103
+ self.CACHE_DIR: str = self._vars["CACHE_DIR"]
104
+ self.LOG_LEVEL: LogLevel = self._vars["LOG_LEVEL"]
105
+ self.USE_GPU: bool = self._vars["USE_GPU"]
106
+ self.DEVICE: str = grab_best_device(self.USE_GPU, self.GLOBAL_ENABLE_MPS)
107
+ logging.getLogger().setLevel(self.LOG_LEVEL.value)
108
+
109
+ def update(self, var_name: str, value: Any) -> None:
110
+ """
111
+ Update an environment variable at runtime and reapply settings.
112
+
113
+ Args:
114
+ var_name (str): Name of the variable to update (must be in _DEFAULTS).
115
+ value (Any): New value for the variable.
116
+
117
+ Raises:
118
+ KeyError: If var_name is not a recognized environment variable.
119
+ """
120
+ if var_name not in self._DEFAULTS:
121
+ raise KeyError(f"Unknown environment variable: {var_name}")
122
+
123
+ # Convert value to the appropriate type based on default
124
+ default_type = type(self._DEFAULTS[var_name])
125
+ if default_type is bool:
126
+ self._vars[var_name] = bool(
127
+ value.lower() in ("true", "1", "t") if isinstance(value, str) else value
128
+ )
129
+ elif default_type is int:
130
+ self._vars[var_name] = int(value)
131
+ elif default_type is float:
132
+ self._vars[var_name] = float(value)
133
+ elif default_type is LogLevel:
134
+ self._vars[var_name] = LogLevel(
135
+ value.upper() if isinstance(value, str) else value
136
+ )
137
+ else:
138
+ self._vars[var_name] = value
139
+
140
+ logger.info(f"Updated {var_name} to {self._vars[var_name]}")
141
+ self._update_attributes()
142
+
143
+
144
+ # Create global instance to access environment variables
145
+ env = EnvVars()
146
+
147
+
148
+ def get_cached_or_download_model_from_hf(
149
+ repo_id: str, file_name: str, cache_dir: str = env.CACHE_DIR
150
+ ) -> str:
151
+ """
152
+ Download a model from Hugging Face Hub if not already cached.
153
+
154
+ Args:
155
+ repo_id (str): The repository ID on Hugging Face Hub (e.g., 'suno/bark').
156
+ file_name (str): The name of the model file to download (e.g., 'text.pt').
157
+ cache_dir (str): Directory to store cached models (defaults to env.CACHE_DIR).
158
+
159
+ Returns:
160
+ str: The full path to the downloaded or cached model file.
161
+
162
+ Raises:
163
+ OSError: If the cache directory cannot be created.
164
+ RuntimeError: If the download from Hugging Face fails.
165
+ """
166
+ # Ensure cache directory exists
167
+ try:
168
+ os.makedirs(cache_dir, exist_ok=True)
169
+ except OSError as e:
170
+ logger.error(f"Failed to create cache directory {cache_dir}: {str(e)}")
171
+ raise
172
+
173
+ # Check if file is already cached
174
+ cached_path = os.path.join(cache_dir, file_name)
175
+ if os.path.exists(cached_path):
176
+ logger.debug(f"Model found in cache: {cached_path}")
177
+ return cached_path
178
+
179
+ # Download from Hugging Face if not cached
180
+ logger.info(f"Downloading model {repo_id}/{file_name} to {cache_dir}")
181
+ try:
182
+ hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=cache_dir)
183
+ logger.debug(f"Model downloaded successfully to {cached_path}")
184
+ return cached_path
185
+ except Exception as e:
186
+ logger.error(f"Failed to download model {repo_id}/{file_name}: {str(e)}")
187
+ raise RuntimeError(f"Failed to download model {repo_id}/{file_name}: {str(e)}")
core/memory/model_manager.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psutil
2
+ import logging
3
+ from typing import Dict, Optional, Callable, Any, Literal
4
+ from collections import OrderedDict
5
+ from threading import Lock
6
+ import torch
7
+ from transformers import BertTokenizer
8
+ from encodec import EncodecModel
9
+
10
+ from core.memory.common import get_cached_or_download_model_from_hf, env
11
+ from core.model.bark import GPTConfig, FineGPTConfig, GPT, FineGPT
12
+ from core.memory.models import *
13
+
14
+ # Configure logging for this module
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def clear_cuda_cache() -> None:
19
+ """
20
+ Clear the CUDA memory cache if GPU is available.
21
+
22
+ Raises:
23
+ RuntimeError: If CUDA operations fail unexpectedly.
24
+ """
25
+ if torch.cuda.is_available():
26
+ try:
27
+ torch.cuda.empty_cache()
28
+ torch.cuda.synchronize()
29
+ logger.debug("CUDA cache cleared successfully")
30
+ except RuntimeError as e:
31
+ logger.error(f"Failed to clear CUDA cache: {str(e)}")
32
+ raise RuntimeError(f"CUDA cache clear failed: {str(e)}")
33
+
34
+
35
+ class ModelManager:
36
+ """
37
+ Manager class for loading, caching, and unloading PyTorch models with memory management.
38
+
39
+ Prioritizes GPU memory when available, with an optional `offload_to_cpu` flag to use CPU RAM instead.
40
+ Uses an LRU (Least Recently Used) cache to keep only the most recently used models in memory.
41
+ Automatically unloads models when memory usage (GPU or CPU, depending on config) exceeds a threshold
42
+ or the maximum number of cached models is reached.
43
+ """
44
+
45
+ def __init__(self, max_models: int = 10, offload_to_cpu: bool = False):
46
+ """
47
+ Initialize the model manager.
48
+
49
+ Args:
50
+ max_models (int): Maximum number of models to keep in memory before unloading (default: 5)
51
+ offload_to_cpu (bool): If True, use CPU RAM instead of GPU memory (default: False)
52
+ """
53
+ self._models: OrderedDict = OrderedDict() # LRU cache for loaded models
54
+ self._lock = Lock() # Thread lock for safe concurrent access
55
+ self._max_models = max_models # Max number of models to cache
56
+ # Whether to offload models to CPU instead of GPU
57
+ self._offload_to_cpu = offload_to_cpu
58
+ self._device = torch.device(env.DEVICE) # Device to load models onto
59
+ logger.info(f"Model manager initialized with device: {self._device}")
60
+
61
+ def _check_memory(self) -> bool:
62
+ """
63
+ Check if current memory usage is below the threshold, focusing on GPU unless offloaded to CPU.
64
+
65
+ Returns:
66
+ bool: True if memory usage is safe, False if it exceeds the threshold
67
+ """
68
+ if self._offload_to_cpu or not torch.cuda.is_available():
69
+ # Check CPU memory usage
70
+ mem = psutil.virtual_memory() # System memory stats
71
+ total_mem_used = mem.used / 1e9 # CPU memory used in GB
72
+ total_mem_available = mem.total / 1e9 # Total CPU memory in GB
73
+ else:
74
+ # Check GPU memory usage
75
+ total_mem_used = (
76
+ torch.cuda.memory_allocated() / 1e9
77
+ ) # GPU memory used in GB
78
+ total_mem_available = (
79
+ torch.cuda.get_device_properties(0).total_memory / 1e9
80
+ ) # Total GPU memory in GB
81
+
82
+ usage_ratio = total_mem_used / total_mem_available
83
+ logger.debug(
84
+ f"Memory usage on {self._device}: {usage_ratio:.2%} (threshold: {MEMORY_THRESHOLD})"
85
+ )
86
+ return usage_ratio < MEMORY_THRESHOLD
87
+
88
+ def _unload_lru_model(self):
89
+ """Unload the least recently used model to free memory."""
90
+ with self._lock:
91
+ if self._models:
92
+ # Remove oldest entry
93
+ model_info, model_instance = self._models.popitem(last=False)
94
+ logger.info(
95
+ f"Unloading model {model_info} from {self._device} to free memory"
96
+ )
97
+ # Move model to CPU before deletion to ensure GPU memory is freed
98
+ if not self._offload_to_cpu and torch.cuda.is_available():
99
+ model_instance.model = model_instance.model.cpu()
100
+ del model_instance # Explicitly delete reference
101
+ logger.debug(f"Memory freed from {self._device}")
102
+
103
+ def get_model(self, model_info: ModelInfo) -> Model:
104
+ """
105
+ Retrieve or load a model, managing memory constraints on the chosen device (GPU or CPU).
106
+
107
+ Args:
108
+ model_info (ModelInfo): Metadata for the model to load
109
+
110
+ Returns:
111
+ Model: The loaded model instance with config and preprocessor
112
+
113
+ Raises:
114
+ ValueError: If model_info is invalid
115
+ """
116
+ assert isinstance(
117
+ model_info, ModelInfo
118
+ ), f"invalid model_info type {type(model_info)}"
119
+ with self._lock:
120
+ # If model is already loaded, move it to the end (most recently used) and return it
121
+ if model_info in self._models:
122
+ self._models.move_to_end(model_info)
123
+ return self._models[model_info]
124
+
125
+ # Ensure memory is available by unloading models if necessary
126
+ while not self._check_memory() or len(self._models) >= self._max_models:
127
+ self._unload_lru_model()
128
+
129
+ if model_info.load_model is not None:
130
+ model = model_info.load_model(model_info, torch.device(env.DEVICE))
131
+ elif model_info.checkpoint_name is not None:
132
+ model = load_transformers_model(model_info, self._device)
133
+ elif model_info.repo_id is not None and model_info.file_name is not None:
134
+ model_file_path = get_cached_or_download_model_from_hf(
135
+ repo_id=model_info.repo_id, file_name=model_info.file_name
136
+ )
137
+ model = load_model_from_file(model_info, model_file_path, self._device)
138
+ else:
139
+ raise ValueError(
140
+ "Invalid model info: must provide checkpoint_name or repo_id/file_name"
141
+ )
142
+
143
+ # Cache the loaded model
144
+ self._models[model_info] = model
145
+ clear_cuda_cache()
146
+ logger.info(f"Loaded and cached model {model_info} on {self._device}")
147
+ return model
148
+
149
+ def unload_model(self, model_info: ModelInfo):
150
+ """
151
+ Manually unload a specific model from memory.
152
+
153
+ Args:
154
+ model_info (ModelInfo): Metadata of the model to unload
155
+ """
156
+ with self._lock:
157
+ if model_info in self._models:
158
+ model_instance = self._models[model_info]
159
+ # Move model to CPU before deletion if on GPU
160
+ if not self._offload_to_cpu and torch.cuda.is_available():
161
+ model_instance.model = model_instance.model.cpu()
162
+ del self._models[model_info]
163
+ logger.info(f"Manually unloaded model {model_info} from {self._device}")
164
+
165
+
166
+ def load_model_from_file(
167
+ model_info: ModelInfo, model_file_path: str, device: torch.device
168
+ ) -> Model:
169
+ """
170
+ Load a model from a file (e.g., custom weights from Hugging Face).
171
+
172
+ Args:
173
+ model_info (ModelInfo): Metadata for the model
174
+ model_file_path (str): Path to the model weights file
175
+ device (torch.device): Device to load the model onto (CPU or GPU)
176
+
177
+ Returns:
178
+ Model: Loaded model instance
179
+ """
180
+ if model_info.repo_id == "suno/bark":
181
+ return load_bark_model(model_info, model_file_path, device)
182
+ if model_info.model_type == "custom_hubert_tokenizer":
183
+ return load_custom_hubert_tokenizer(model_info, model_file_path, device)
184
+ raise ValueError(f"Unknown how to load model {model_info}")
185
+
186
+
187
+ # temporary turnoff this hubert
188
+ def load_custom_hubert_tokenizer(
189
+ model_info: ModelInfo, model_file_path: str, device: torch.device
190
+ ) -> Model:
191
+ # Automatically uses the right layers
192
+ # tokenizer = HuBERTForBarkSemantic.load_from_checkpoint(
193
+ # model_file_path, torch.device(env.DEVICE)
194
+ # ).to(device)
195
+
196
+ # return Model(model=tokenizer)
197
+ return Model(model=None)
198
+
199
+
200
+ def load_transformers_model(model_info: ModelInfo, device: torch.device) -> Model:
201
+ """
202
+ Load a model using Hugging Face's transformers library.
203
+
204
+ Args:
205
+ model_info (ModelInfo): Metadata for the model
206
+ device (torch.device): Device to load the model onto (CPU or GPU)
207
+
208
+ Returns:
209
+ Model: Loaded model instance
210
+ """
211
+ if model_info.checkpoint_name == "facebook/encodec_24khz":
212
+ model = EncodecModel.encodec_model_24khz()
213
+ model.encode()
214
+ model = model.to(device)
215
+ return Model(model)
216
+ raise NotImplementedError("Only Encodec 24k supported for now")
217
+
218
+
219
+ def load_bark_model(
220
+ model_info: ModelInfo, model_file_path: str, device: torch.device
221
+ ) -> Model:
222
+ """
223
+ Load a Bark model from a file.
224
+
225
+ Args:
226
+ model_info (ModelInfo): Metadata for the Bark model
227
+ model_file_path (str): Path to the model weights file
228
+ device (torch.device): Device to load the model onto (CPU or GPU)
229
+
230
+ Returns:
231
+ Model: Loaded Bark model instance with config and optional tokenizer
232
+ """
233
+ # Load checkpoint directly to the specified device
234
+ # weights_only = False only for trusted source
235
+ checkpoint = torch.load(model_file_path, map_location=device, weights_only=False)
236
+ ConfigClass, ModelClass = (
237
+ (GPTConfig, GPT)
238
+ if model_info.model_type in ["text", "coarse"]
239
+ else (FineGPTConfig, FineGPT)
240
+ )
241
+
242
+ model_args = preprocess_model_args(checkpoint["model_args"])
243
+
244
+ conf = ConfigClass(**model_args)
245
+ model = ModelClass(conf)
246
+ state_dict = _update_bark_state_dict(model, checkpoint["model"])
247
+ model.load_state_dict(state_dict, strict=False)
248
+
249
+ model = model.to(device) # Ensure model is on the correct device
250
+ model.eval()
251
+ logger.info(f"Loaded Bark model: {model_info} on {device}")
252
+
253
+ # Add tokenizer for text models (tokenizer stays on CPU as it doesn't require GPU)
254
+ preprocessor = (
255
+ BertTokenizer.from_pretrained("bert-base-multilingual-cased")
256
+ if model_info.model_type == "text"
257
+ else None
258
+ )
259
+ return Model(model, conf, preprocessor)
260
+
261
+
262
+ def preprocess_model_args(model_args: dict) -> dict:
263
+ if "input_vocab_size" not in model_args:
264
+ model_args["input_vocab_size"] = model_args["vocab_size"]
265
+ model_args["output_vocab_size"] = model_args["vocab_size"]
266
+ del model_args["vocab_size"]
267
+ return model_args
268
+
269
+
270
+ def _update_bark_state_dict(model: GPT, state_dict: Dict[str, Any]) -> Dict[str, Any]:
271
+ """
272
+ Update the state dictionary by removing unwanted prefixes (specific to Bark models).
273
+
274
+ Args:
275
+ model (GPT): The model instance to align the state dict with
276
+ state_dict (Dict[str, Any]): The loaded state dictionary
277
+
278
+ Returns:
279
+ Dict[str, Any]: Updated state dictionary
280
+ """
281
+ unwanted_prefix = "_orig_mod."
282
+ for key in list(state_dict.keys()):
283
+ if key.startswith(unwanted_prefix):
284
+ state_dict[key[len(unwanted_prefix) :]] = state_dict.pop(key)
285
+ return state_dict
286
+
287
+
288
+ # Instantiate the global model manager with default GPU priority
289
+ model_manager = ModelManager(offload_to_cpu=False if env.USE_GPU else True)
core/memory/models.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ from dataclasses import asdict
5
+ from typing_extensions import Optional, Callable
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from transformers import BertTokenizer
9
+ from encodec import EncodecModel
10
+
11
+ import torch
12
+ from core.model.bark import GPT
13
+ from core.memory.common import env
14
+ from core.utils import download_file_from_hf
15
+ from core.model.hubert import HuBERTForBarkSemantic, HubertForBarkSemanticConfig
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format="%(asctime)s - %(levelname)s - %(message)s",
20
+ handlers=[logging.StreamHandler(sys.stdout)],
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Memory threshold (in percentage) to trigger unloading of models when memory usage gets too high
25
+ # 90% of available memory; applies to GPU unless offloaded to CPU
26
+ MEMORY_THRESHOLD = 0.9
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class ModelInfo:
31
+ """Data structure to hold metadata about a model."""
32
+
33
+ # Hugging Face repository ID (e.g., "suno/bark")
34
+ repo_id: Optional[str] = None
35
+ # Filename of the model weights (e.g., "text.pt")
36
+ file_name: Optional[str] = None
37
+ # Pretrained checkpoint name (e.g., "facebook/encodec_24khz")
38
+ checkpoint_name: Optional[str] = None
39
+ # Configuration class for the model
40
+ config_class: Optional[type] = None
41
+ # Model class to instantiate
42
+ model_class: Optional[type] = None
43
+ # Preprocessor class (e.g., tokenizer)
44
+ preprocessor_class: Optional[type] = None
45
+ # Type of model (e.g., "text", "coarse", "encodec")
46
+ model_type: Optional[str] = None
47
+ # define the function that load the model
48
+ load_model: Optional[Callable] = None
49
+
50
+
51
+ @dataclass
52
+ class Model:
53
+ """Container for a loaded model, its configuration, and preprocessor."""
54
+
55
+ model: Callable # The PyTorch model instance
56
+ config: Optional[Callable] = None # Model configuration object
57
+ # Preprocessor (e.g., tokenizer for text models)
58
+ preprocessor: Optional[Callable] = None
59
+
60
+
61
+ def _load_encodec_model(model_info: ModelInfo, device: torch.device) -> Model:
62
+ model = EncodecModel.encodec_model_24khz()
63
+ model.set_target_bandwidth(6.0)
64
+ model.eval()
65
+ model.to(device)
66
+ return Model(model)
67
+
68
+
69
+ def _load_hubert_base_for_bark_semantic(
70
+ model_info: ModelInfo, device: torch.device
71
+ ) -> "Model":
72
+ os.makedirs(env.CACHE_DIR, exist_ok=True)
73
+ local_file_path = os.path.join(env.CACHE_DIR, model_info.file_name)
74
+ if not os.path.isfile(local_file_path):
75
+ logger.info(
76
+ f"Downloading {model_info.file_name} model from {model_info.repo_id}"
77
+ )
78
+ download_file_from_hf(
79
+ model_info.repo_id, "model", model_info.file_name, env.CACHE_DIR
80
+ )
81
+
82
+ checkpoint = torch.load(local_file_path, map_location=device)
83
+
84
+ assert isinstance(
85
+ checkpoint, dict
86
+ ), "expecting a dictionary, got {type(checkpoint)}"
87
+
88
+ state_dict = checkpoint.get("model_state_dict", None)
89
+ assert (
90
+ state_dict is not None
91
+ ), f"model_state_dict not in checkpoint, {checkpoint.keys()}"
92
+
93
+ model_config = checkpoint.get("config", None)
94
+ assert model_config is not None, "not found model config in checkpoint"
95
+
96
+ config = HubertForBarkSemanticConfig(**model_config)
97
+ model = HuBERTForBarkSemantic(
98
+ config=config, load_hubert_pretrained_weights=False, device=device
99
+ )
100
+ model.load_state_dict(state_dict=state_dict, strict=True)
101
+
102
+ return Model(model=model, config=config, preprocessor=None)
103
+
104
+
105
+ # TODO: refactor this class, each ModelInfo should have its own _load_model function for consistency
106
+ # and avoid complicated if-else paths
107
+ class ModelEnum(Enum):
108
+ """
109
+ Enumeration of supported models with their metadata.
110
+ Each entry maps to a ModelInfo object defining how to load the model.
111
+ """
112
+
113
+ BARK_TEXT_SMALL = ModelInfo(
114
+ repo_id="suno/bark",
115
+ file_name="text.pt",
116
+ model_type="text",
117
+ model_class=GPT,
118
+ preprocessor_class=BertTokenizer,
119
+ )
120
+ BARK_COARSE_SMALL = ModelInfo(
121
+ repo_id="suno/bark", file_name="coarse.pt", model_type="coarse"
122
+ )
123
+ BARK_FINE_SMALL = ModelInfo(
124
+ repo_id="suno/bark", file_name="fine.pt", model_type="fine"
125
+ )
126
+
127
+ BARK_TEXT = ModelInfo(repo_id="suno/bark", file_name="text_2.pt", model_type="text")
128
+ BARK_COARSE = ModelInfo(
129
+ repo_id="suno/bark", file_name="coarse_2.pt", model_type="coarse"
130
+ )
131
+ BARK_FINE = ModelInfo(repo_id="suno/bark", file_name="fine_2.pt", model_type="fine")
132
+
133
+ CustomHuBERTTokenizer = ModelInfo(
134
+ repo_id="GitMylo/bark-voice-cloning",
135
+ file_name="quantifier_hubert_base_ls960_14.pth",
136
+ model_type="custom_hubert_tokenizer",
137
+ )
138
+
139
+ ENCODEC24k = ModelInfo(
140
+ checkpoint_name="facebook/encodec_24khz",
141
+ model_type="encodec",
142
+ load_model=_load_encodec_model,
143
+ )
144
+
145
+ HuBERTBaseForBarkSemantic = ModelInfo(
146
+ checkpoint_name="facebook/hubert-base-ls960",
147
+ repo_id="sleeper371/hubert-for-bark-semantic",
148
+ file_name="hubert_epoch_30_2025_04_06_03_23_eval_loss_0.5520355800787607_acc_0.8344086021505376.pt",
149
+ load_model=_load_hubert_base_for_bark_semantic,
150
+ )
151
+
152
+ @classmethod
153
+ def get_model_info(cls, model_name: str) -> ModelInfo:
154
+ """
155
+ Retrieve ModelInfo for a given model name.
156
+
157
+ Args:
158
+ model_name (str): Name of the model (e.g., "BARK_TEXT_SMALL")
159
+
160
+ Returns:
161
+ ModelInfo: Metadata for the requested model
162
+
163
+ Raises:
164
+ ValueError: If the model name is not recognized
165
+ """
166
+ try:
167
+ return cls[model_name].value
168
+ except KeyError:
169
+ raise ValueError(f"Unknown model name: {model_name}")
core/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from core.model.bark import *
core/model/bark.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ codes adapted from https://github.com/suno-ai/bark
3
+ """
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+
13
+ @dataclass
14
+ class GPTConfig:
15
+ block_size: int = 1024
16
+ input_vocab_size: int = 10_048
17
+ output_vocab_size: int = 10_048
18
+ n_layer: int = 12
19
+ n_head: int = 12
20
+ n_embd: int = 768
21
+ dropout: float = 0.0
22
+ bias: bool = (
23
+ True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class FineGPTConfig(GPTConfig):
29
+ n_codes_total: int = 8
30
+ n_codes_given: int = 1
31
+
32
+
33
+ class LayerNorm(nn.Module):
34
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
35
+
36
+ def __init__(self, ndim: int, bias: bool) -> None:
37
+ super().__init__()
38
+ self.weight = nn.Parameter(torch.ones(ndim))
39
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
40
+
41
+ def forward(self, input):
42
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
43
+
44
+
45
+ class MLP(nn.Module):
46
+
47
+ def __init__(self, config: GPTConfig):
48
+ super().__init__()
49
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
50
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
51
+ self.dropout = nn.Dropout(config.dropout)
52
+ self.gelu = nn.GELU()
53
+
54
+ def forward(self, x) -> torch.Tensor:
55
+ x = self.c_fc(x)
56
+ x = self.gelu(x)
57
+ x = self.c_proj(x)
58
+ x = self.dropout(x)
59
+ return x
60
+
61
+
62
+ class CausalSelfAttention(nn.Module):
63
+ def __init__(self, config: GPTConfig) -> None:
64
+ super().__init__()
65
+ assert config.n_embd % config.n_head == 0
66
+
67
+ # key, query, value projections for all heads, but in a batch
68
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
69
+ # output projection
70
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
71
+ # regularization
72
+ self.attn_dropout = nn.Dropout(config.dropout)
73
+ self.resid_dropout = nn.Dropout(config.dropout)
74
+ self.n_head = config.n_head
75
+ self.n_embd = config.n_embd
76
+ self.dropout = config.dropout
77
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
78
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
79
+ if not self.flash:
80
+ # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
81
+ # causal mask to ensure that attention is only applied to the left in the input sequence
82
+ self.register_buffer(
83
+ "bias",
84
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
85
+ 1, 1, config.block_size, config.block_size
86
+ ),
87
+ )
88
+
89
+ def forward(
90
+ self, x: torch.Tensor, past_kv: torch.Tensor = None, use_cache: bool = False
91
+ ):
92
+ B, T, C = (
93
+ x.size()
94
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
95
+
96
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
97
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
98
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(
99
+ 1, 2
100
+ ) # (B, nh, T, hs)
101
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(
102
+ 1, 2
103
+ ) # (B, nh, T, hs)
104
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(
105
+ 1, 2
106
+ ) # (B, nh, T, hs)
107
+
108
+ if past_kv is not None:
109
+ past_key = past_kv[0]
110
+ past_value = past_kv[1]
111
+ k = torch.cat((past_key, k), dim=-2)
112
+ v = torch.cat((past_value, v), dim=-2)
113
+
114
+ FULL_T = k.shape[-2]
115
+
116
+ if use_cache is True:
117
+ present = (k, v)
118
+ else:
119
+ present = None
120
+
121
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
122
+ if self.flash:
123
+ # efficient attention using Flash Attention CUDA kernels
124
+ if past_kv is not None:
125
+ # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
126
+ # the query for the last token. scaled_dot_product_attention interprets this as the first token in the
127
+ # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
128
+ # to work around this we set is_causal=False.
129
+ is_causal = False
130
+ else:
131
+ is_causal = True
132
+
133
+ y = torch.nn.functional.scaled_dot_product_attention(
134
+ q, k, v, dropout_p=self.dropout, is_causal=is_causal
135
+ )
136
+ else:
137
+ # manual implementation of attention
138
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
139
+ att = att.masked_fill(
140
+ self.bias[:, :, FULL_T - T : FULL_T, :FULL_T] == 0, float("-inf")
141
+ )
142
+ att = F.softmax(att, dim=-1)
143
+ att = self.attn_dropout(att)
144
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
145
+ y = (
146
+ y.transpose(1, 2).contiguous().view(B, T, C)
147
+ ) # re-assemble all head outputs side by side
148
+
149
+ # output projection
150
+ y = self.resid_dropout(self.c_proj(y))
151
+ return (y, present)
152
+
153
+
154
+ class Block(nn.Module):
155
+
156
+ def __init__(self, config: GPTConfig, layer_idx: int) -> None:
157
+ super().__init__()
158
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
159
+ self.attn = CausalSelfAttention(config)
160
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
161
+ self.mlp = MLP(config)
162
+ self.layer_idx = layer_idx
163
+
164
+ def forward(
165
+ self, x: torch.Tensor, past_kv: torch.Tensor = None, use_cache: bool = False
166
+ ):
167
+ attn_output, prev_kvs = self.attn(
168
+ self.ln_1(x), past_kv=past_kv, use_cache=use_cache
169
+ )
170
+ x = x + attn_output
171
+ x = x + self.mlp(self.ln_2(x))
172
+ return (x, prev_kvs)
173
+
174
+
175
+ class GPT(nn.Module):
176
+ def __init__(self, config: GPTConfig):
177
+ super().__init__()
178
+ assert config.input_vocab_size is not None
179
+ assert config.output_vocab_size is not None
180
+ assert config.block_size is not None
181
+ self.config = config
182
+
183
+ self.transformer = nn.ModuleDict(
184
+ dict(
185
+ wte=nn.Embedding(config.input_vocab_size, config.n_embd),
186
+ wpe=nn.Embedding(config.block_size, config.n_embd),
187
+ drop=nn.Dropout(config.dropout),
188
+ h=nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
189
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
190
+ )
191
+ )
192
+ self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
193
+ # Note: lm_head lacks bias, implying parameter sharing with wte for efficiency
194
+
195
+ def get_num_params(self, non_embedding: bool = True) -> int:
196
+ """
197
+ Return the number of parameters in the model.
198
+ For non-embedding count (default), the position embeddings get subtracted.
199
+ The token embeddings would too, except due to the parameter sharing these
200
+ params are actually used as weights in the final layer, so we include them.
201
+ """
202
+ n_params = sum(p.numel() for p in self.parameters())
203
+ if non_embedding:
204
+ n_params -= self.transformer.wte.weight.numel()
205
+ n_params -= self.transformer.wpe.weight.numel()
206
+ return n_params
207
+
208
+ def forward(
209
+ self,
210
+ idx: torch.Tensor,
211
+ merge_context: bool = False,
212
+ past_kv: torch.Tensor = None,
213
+ position_ids: torch.Tensor = None,
214
+ use_cache: bool = False,
215
+ ):
216
+ device = idx.device
217
+ b, t = idx.size()
218
+ if past_kv is not None:
219
+ # When past_kv is provided, this is optimized for autoregressive generation
220
+ assert (
221
+ t == 1
222
+ ), "should only pass in the last token of the sequence when using kv_cache"
223
+ # Shape: (b, 1, n_embd), single token case
224
+ tok_emb = self.transformer.wte(idx)
225
+ else:
226
+ if merge_context:
227
+ # Custom feature: assumes first 256 tokens are one context, next 256 another, rest is sequence
228
+ assert idx.shape[1] >= 256 + 256 + 1
229
+ t = idx.shape[1] - 256 # Adjusts t for merged context length
230
+ else:
231
+ assert (
232
+ t <= self.config.block_size
233
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
234
+
235
+ if merge_context:
236
+ # Merges two contexts by adding their embeddings, not a standard GPT behavior
237
+ tok_emb = torch.cat(
238
+ [
239
+ self.transformer.wte(idx[:, :256])
240
+ + self.transformer.wte(idx[:, 256 : 256 + 256]),
241
+ self.transformer.wte(idx[:, 256 + 256 :]),
242
+ ],
243
+ dim=1,
244
+ )
245
+ else:
246
+ tok_emb = self.transformer.wte(idx)
247
+
248
+ if past_kv is None:
249
+ past_length = 0
250
+ # Empty cache for each layer
251
+ past_kv = tuple([None] * len(self.transformer.h))
252
+ else:
253
+ # Infers prior sequence length from cache
254
+ past_length = past_kv[0][0].size(-2)
255
+
256
+ if position_ids is None:
257
+ position_ids = torch.arange(
258
+ past_length, t + past_length, dtype=torch.long, device=device
259
+ )
260
+ position_ids = position_ids.unsqueeze(0)
261
+ assert position_ids.shape == (1, t)
262
+
263
+ pos_emb = self.transformer.wpe(position_ids)
264
+
265
+ x = self.transformer.drop(tok_emb + pos_emb)
266
+
267
+ # Prepares cache for key-value pairs if enabled
268
+ new_kv = () if use_cache else None
269
+
270
+ for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
271
+ x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
272
+ if use_cache:
273
+ new_kv = new_kv + (kv,) # Accumulates new key-value pairs for caching
274
+
275
+ x = self.transformer.ln_f(x)
276
+
277
+ # Optimization: only computes logits for the last token, efficient for generation
278
+ logits = self.lm_head(x[:, [-1], :]) # Preserves time dim with [-1]
279
+
280
+ return (
281
+ logits,
282
+ new_kv,
283
+ ) # Returns tuple: logits for next token, cache if requested
284
+
285
+
286
+ class NonCausalSelfAttention(nn.Module):
287
+ def __init__(self, config):
288
+ super().__init__()
289
+ assert config.n_embd % config.n_head == 0
290
+ # key, query, value projections for all heads, but in a batch
291
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
292
+ # output projection
293
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
294
+ # regularization
295
+ self.attn_dropout = nn.Dropout(config.dropout)
296
+ self.resid_dropout = nn.Dropout(config.dropout)
297
+ self.n_head = config.n_head
298
+ self.n_embd = config.n_embd
299
+ self.dropout = config.dropout
300
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
301
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
302
+
303
+ def forward(self, x):
304
+ B, T, C = (
305
+ x.size()
306
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
307
+
308
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
309
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
310
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(
311
+ 1, 2
312
+ ) # (B, nh, T, hs)
313
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(
314
+ 1, 2
315
+ ) # (B, nh, T, hs)
316
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(
317
+ 1, 2
318
+ ) # (B, nh, T, hs)
319
+
320
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
321
+ if self.flash:
322
+ # efficient attention using Flash Attention CUDA kernels
323
+ y = torch.nn.functional.scaled_dot_product_attention(
324
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
325
+ )
326
+ else:
327
+ # manual implementation of attention
328
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
329
+ att = F.softmax(att, dim=-1)
330
+ att = self.attn_dropout(att)
331
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
332
+ y = (
333
+ y.transpose(1, 2).contiguous().view(B, T, C)
334
+ ) # re-assemble all head outputs side by side
335
+
336
+ # output projection
337
+ y = self.resid_dropout(self.c_proj(y))
338
+ return y
339
+
340
+
341
+ class FineBlock(nn.Module):
342
+ def __init__(self, config):
343
+ super().__init__()
344
+ self.ln_1 = nn.LayerNorm(config.n_embd)
345
+ self.attn = NonCausalSelfAttention(config)
346
+ self.ln_2 = nn.LayerNorm(config.n_embd)
347
+ self.mlp = MLP(config)
348
+
349
+ def forward(self, x):
350
+ x = x + self.attn(self.ln_1(x))
351
+ x = x + self.mlp(self.ln_2(x))
352
+ return x
353
+
354
+
355
+ class FineGPT(GPT):
356
+ def __init__(self, config):
357
+ super().__init__(config)
358
+ del self.lm_head
359
+ self.config = config
360
+ self.n_codes_total = config.n_codes_total
361
+ self.transformer = nn.ModuleDict(
362
+ dict(
363
+ wtes=nn.ModuleList(
364
+ [
365
+ nn.Embedding(config.input_vocab_size, config.n_embd)
366
+ for _ in range(config.n_codes_total)
367
+ ]
368
+ ),
369
+ wpe=nn.Embedding(config.block_size, config.n_embd),
370
+ drop=nn.Dropout(config.dropout),
371
+ h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
372
+ ln_f=nn.LayerNorm(config.n_embd),
373
+ )
374
+ )
375
+ self.lm_heads = nn.ModuleList(
376
+ [
377
+ nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
378
+ for _ in range(config.n_codes_given, self.n_codes_total)
379
+ ]
380
+ )
381
+ for i in range(self.n_codes_total - config.n_codes_given):
382
+ self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
383
+
384
+ def forward(self, pred_idx, idx):
385
+ device = idx.device
386
+ b, t, codes = idx.size()
387
+ assert (
388
+ t <= self.config.block_size
389
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
390
+ assert pred_idx > 0, "cannot predict 0th codebook"
391
+ assert codes == self.n_codes_total, (b, t, codes)
392
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
393
+ 0
394
+ ) # shape (1, t)
395
+
396
+ # forward the GPT model itself
397
+ tok_embs = [
398
+ wte(idx[:, :, i]).unsqueeze(-1)
399
+ for i, wte in enumerate(self.transformer.wtes)
400
+ ] # token embeddings of shape (b, t, n_embd)
401
+ tok_emb = torch.cat(tok_embs, dim=-1)
402
+ pos_emb = self.transformer.wpe(
403
+ pos
404
+ ) # position embeddings of shape (1, t, n_embd)
405
+ x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
406
+ x = self.transformer.drop(x + pos_emb)
407
+ for block in self.transformer.h:
408
+ x = block(x)
409
+ x = self.transformer.ln_f(x)
410
+ logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
411
+ return logits
412
+
413
+ def get_num_params(self, non_embedding=True):
414
+ """
415
+ Return the number of parameters in the model.
416
+ For non-embedding count (default), the position embeddings get subtracted.
417
+ The token embeddings would too, except due to the parameter sharing these
418
+ params are actually used as weights in the final layer, so we include them.
419
+ """
420
+ n_params = sum(p.numel() for p in self.parameters())
421
+ if non_embedding:
422
+ for wte in self.transformer.wtes:
423
+ n_params -= wte.weight.numel()
424
+ n_params -= self.transformer.wpe.weight.numel()
425
+ return n_params
core/model/hubert.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple, Union, Literal
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.modeling_outputs import BaseModelOutput
8
+ from transformers import HubertModel, AutoConfig, AutoModel
9
+
10
+
11
+ @dataclass
12
+ class CustomHubertConfig:
13
+ """Configuration class for CustomHubert model."""
14
+
15
+ # e.g., "facebook/hubert-base-ls960" or "facebook/hubert-large-ll60k"
16
+ checkpoint_name: str
17
+ # Layer to extract features from (0-indexed, e.g., 9 for 10th layer)
18
+ feature_layer: int = 11
19
+ # Target audio sample rate in Hz
20
+ target_sample_rate: int = 16000
21
+ # Optional length multiple for audio trimming
22
+ seq_len_multiple_of: Optional[int] = None
23
+
24
+
25
+ @dataclass
26
+ class HubertForBarkSemanticConfig:
27
+ """Configuration for HuBERTForBarkSemantic."""
28
+
29
+ # # HuBERT model checkpoint for feature extractor layer
30
+ checkpoint_name: Literal["facebook/hubert-base-ls960", "hubert-large-ls960-ft"]
31
+ vocab_size: int
32
+ # Layer to extract features from
33
+ feature_layer: int = 11
34
+ # last three tokens for SOS, EOS and PAD tokens
35
+ # maximum target sequence length
36
+ max_target_length: int = 2000
37
+ num_decoder_layer: int = 12
38
+ sos_token_id: int = 10000
39
+ eos_token_id: int = 10001
40
+
41
+
42
+ class HubertFeatureExtractor(nn.Module):
43
+ """
44
+ A custom HuBERT model that loads a pretrained model from transformers and extracts
45
+ features from a specified layer. Processes raw audio waveforms and returns hidden states.
46
+
47
+ Args:
48
+ config (CustomHubertConfig): Configuration specifying checkpoint, layer, and audio settings.
49
+ device (torch.device, optional): Device to run the model on (e.g., "cuda" or "cpu").
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ config: CustomHubertConfig,
55
+ load_pretrained_weights: bool,
56
+ device: Optional[torch.device] = None,
57
+ ):
58
+ super().__init__()
59
+ self.config = config
60
+ self.target_sample_rate = config.target_sample_rate
61
+
62
+ # Load pretrained HuBERT model from transformers
63
+ self.hubert_config = AutoConfig.from_pretrained(config.checkpoint_name)
64
+ if load_pretrained_weights:
65
+ self.model = HubertModel.from_pretrained(config.checkpoint_name)
66
+ else:
67
+ # don't download the pretrained weights, init the model from the config
68
+ self.model = AutoModel.from_config(self.hubert_config)
69
+
70
+ # Validate feature_layer
71
+ # e.g., 12 for BASE, 24 for LARGE
72
+ num_layers = self.model.config.num_hidden_layers
73
+ if not (0 <= config.feature_layer < num_layers):
74
+ raise ValueError(
75
+ f"feature_layer must be between 0 and {num_layers - 1}, got {config.feature_layer}"
76
+ )
77
+ self.feature_layer = config.feature_layer
78
+
79
+ # Move to device if specified
80
+ if device is not None:
81
+ self.to(device)
82
+
83
+ @property
84
+ def hidden_size(self) -> int:
85
+ """Returns the hidden size of the HuBERT model (e.g., 768 for BASE, 1024 for LARGE)."""
86
+ return self.model.config.hidden_size
87
+
88
+ def forward(
89
+ self,
90
+ wav_input: torch.Tensor,
91
+ ) -> torch.Tensor:
92
+ """
93
+ Processes raw audio waveforms through HuBERT and extracts features from the specified layer.
94
+ Input audio sample rate expected 16k
95
+
96
+ Args:
97
+ wav_input (torch.Tensor): Raw audio waveforms, shape [batch_size, audio_length].
98
+ return_shape (Tuple[int, int], optional): If provided, reshapes output to [batch_size, seq_length, hidden_size].
99
+
100
+ Returns:
101
+ torch.Tensor: Features from the specified layer. Shape depends on return_shape:
102
+ - If None: [batch_size * seq_length, hidden_size] (flattened).
103
+ - If provided: [batch_size, seq_length, hidden_size].
104
+ """
105
+
106
+ # Forward pass through HuBERT
107
+ # output_hidden_states=True returns all layer outputs
108
+ outputs: BaseModelOutput = self.model(
109
+ input_values=wav_input, output_hidden_states=True, return_dict=True
110
+ )
111
+
112
+ # Extract features from the specified layer (0-indexed)
113
+ # hidden_states is a tuple of [batch_size, seq_length, hidden_size] for each layer
114
+ features = outputs.hidden_states[self.feature_layer] # e.g., [2, 500, 768]
115
+ features = features.contiguous()
116
+ return features
117
+
118
+
119
+ class HuBERTForBarkSemantic(nn.Module):
120
+ def __init__(
121
+ self,
122
+ config: HubertForBarkSemanticConfig,
123
+ load_hubert_pretrained_weights: bool = True,
124
+ device: Optional[torch.device] = None,
125
+ ):
126
+ super().__init__()
127
+ self.config = config
128
+
129
+ # HuBERT feature extractor
130
+ hubert_config = CustomHubertConfig(
131
+ checkpoint_name=config.checkpoint_name,
132
+ feature_layer=config.feature_layer,
133
+ )
134
+ self.hubert = HubertFeatureExtractor(
135
+ config=hubert_config,
136
+ load_pretrained_weights=load_hubert_pretrained_weights,
137
+ device=device,
138
+ )
139
+
140
+ # e.g., 768 for BASE
141
+ input_size = self.hubert.model.config.hidden_size
142
+
143
+ # Transformer Decoder
144
+ self.decoder_embedding = nn.Embedding(config.vocab_size, input_size)
145
+ self.pos_embedding = nn.Parameter(
146
+ torch.zeros(1, config.max_target_length, input_size)
147
+ )
148
+ self.decoder = nn.TransformerDecoder(
149
+ nn.TransformerDecoderLayer(
150
+ d_model=input_size,
151
+ nhead=8,
152
+ dim_feedforward=2048,
153
+ dropout=0.1,
154
+ batch_first=True,
155
+ ),
156
+ num_layers=config.num_decoder_layer, # Adjust as needed
157
+ )
158
+ self.fc = nn.Linear(input_size, config.vocab_size)
159
+
160
+ if device is not None:
161
+ self.to(device)
162
+
163
+ def save_state_dict(self, save_path: str):
164
+ torch.save(self.state_dict(), save_path)
165
+
166
+ def forward(self, wav_input: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
167
+ """
168
+ Forward pass: Extracts HuBERT features and predicts semantic token probabilities.
169
+
170
+ Args:
171
+ wav_input: [batch_size, audio_length] (e.g., [2, 160000])
172
+ tgt: the target sequence
173
+
174
+ Returns:
175
+ [batch_size, seq_length, vocab_size + 1] (e.g., [2, 500, VOCAB_SIZE])
176
+ """
177
+ memory: torch.Tensor = self.hubert(wav_input) # [B, T, 768]
178
+ B, T_tgt = tgt.shape
179
+ tgt_emb = self.decoder_embedding(tgt) + self.pos_embedding[:, :T_tgt, :]
180
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(T_tgt).to(tgt.device)
181
+
182
+ output: torch.Tensor = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
183
+ logits = self.fc(output)
184
+ return logits
185
+
186
+ @torch.no_grad
187
+ def generate(
188
+ self,
189
+ wav_input: torch.Tensor,
190
+ temperature: Optional[float] = 0.8,
191
+ eos_p: Optional[float] = 0.5,
192
+ max_length: int = 600,
193
+ ) -> torch.Tensor:
194
+ """
195
+ Inference: autoregressive generation.
196
+ assuming wav_input audio is at 16000 sample rate"""
197
+ self.eval()
198
+ memory = self.hubert(wav_input)
199
+ B = wav_input.shape[0]
200
+ tgt = torch.full(
201
+ size=(B, 1), fill_value=self.config.sos_token_id, device=wav_input.device
202
+ )
203
+
204
+ for _ in range(max_length):
205
+ tgt_emb = (
206
+ self.decoder_embedding(tgt) + self.pos_embedding[:, : tgt.shape[1], :]
207
+ )
208
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.shape[1]).to(
209
+ tgt.device
210
+ )
211
+
212
+ output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
213
+ # logits shape (B, T', vocab_size)
214
+ logits: torch.Tensor = self.fc(output[:, -1, :])
215
+
216
+ if temperature is not None and temperature > 0:
217
+ probs = torch.softmax(input=logits / temperature, dim=-1)
218
+ next_token = torch.multinomial(input=probs, num_samples=1)
219
+ else:
220
+ probs = torch.softmax(input=logits, dim=-1)
221
+ next_token = logits.argmax(dim=-1, keepdim=True)
222
+
223
+ # stop if the EOS token probabilities are higher than the provided eos_p
224
+ if eos_p is not None and eos_p > 0:
225
+ if torch.all(probs[:, self.config.eos_token_id] > eos_p):
226
+ break
227
+
228
+ # early stopping
229
+ if torch.all(next_token == self.config.eos_token_id):
230
+ break
231
+
232
+ tgt = torch.cat([tgt, next_token], dim=1)
233
+ if (next_token == self.config.eos_token_id).all():
234
+ break
235
+
236
+ # remove the [SOS] token from the generated semantic sequences
237
+ return tgt[:, 1:]
core/trainer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from core.trainer.custom_hubert_trainer import *
core/trainer/custom_hubert_trainer.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from datetime import datetime
4
+ import logging
5
+ import sys
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.optim import Adam
10
+ from torch.optim.lr_scheduler import LRScheduler, LinearLR
11
+ from torch.utils.data import Dataset, DataLoader, random_split
12
+ import torchaudio
13
+ from tqdm import tqdm
14
+
15
+ from typing import Literal, List, Optional, Tuple, Dict, Callable, Union, Any
16
+ from core.data_model import WavSemantic, WavSemanticDataset
17
+ from core.utils import read_audio_file, upload_file_to_hf
18
+
19
+ # cudnn error about non-contiguous input at the lstm layer, disable it fixed the issue
20
+ torch.backends.cudnn.enabled = False
21
+
22
+ # Set up logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format="%(asctime)s - %(levelname)s - %(message)s",
26
+ handlers=[logging.StreamHandler(sys.stdout)],
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ HUBERT_SAMPLE_RATE = 16000
32
+ # 10_000 and 10_001 are for SOS and EOS tokens
33
+ SEMANTIC_PADDING_TOKEN = 10002
34
+ SOS_TOKEN = 10_000
35
+ EOS_TOKEN = 10_001
36
+
37
+
38
+ class WavSemanticTorchDataset(Dataset):
39
+ """PyTorch Dataset for WavSemantic data with resampling and noise augmentation.
40
+ Padding is carried out in a collator function.
41
+
42
+ Args:
43
+ samples: List of WavSemantic objects (speech data).
44
+ orig_sample_rate: Original sample rate of the audio.
45
+ target_sample_rate: Desired sample rate (default: 16000 Hz).
46
+ device: Device to move tensors to (optional).
47
+ noises: List of noise waveforms as NumPy arrays (optional, for augmentation).
48
+ noises audio must already have sample_rate = target_sample rate, this class doesn't resample it
49
+ augment_prob: Probability of applying noise augmentation (default: 0.5).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ samples: List["WavSemantic"],
55
+ orig_sample_rate: int,
56
+ target_sample_rate: Optional[int] = 16000,
57
+ device: Optional[torch.device] = None,
58
+ noises: Optional[List[np.ndarray]] = None,
59
+ augment_prob: float = 0.5,
60
+ ):
61
+ self.samples = samples
62
+ self.orig_sample_rate = orig_sample_rate
63
+ self.target_sample_rate = target_sample_rate
64
+ self.device = device
65
+ self.noises = noises
66
+ self.augment_prob = augment_prob
67
+ self.resampler = torchaudio.transforms.Resample(
68
+ orig_freq=orig_sample_rate, new_freq=target_sample_rate
69
+ )
70
+
71
+ def __len__(self) -> int:
72
+ return len(self.samples)
73
+
74
+ def _normalize_waveform(self, wav: torch.Tensor) -> torch.Tensor:
75
+ """Normalize waveform to [-1, 1]."""
76
+ max_val = wav.abs().max()
77
+ if max_val > 0:
78
+ wav = wav / max_val
79
+ return wav
80
+
81
+ def _add_time_varying_noise(
82
+ self, speech: torch.Tensor, noise: torch.Tensor, snr_db: float
83
+ ) -> torch.Tensor:
84
+ """Add noise to a random segment of the speech with fade-in/fade-out."""
85
+ speech_len = speech.size(0)
86
+ noise_len = noise.size(0)
87
+
88
+ # Match noise length (loop or trim)
89
+ if noise_len < speech_len:
90
+ repeats = int(np.ceil(speech_len / noise_len))
91
+ noise = noise.repeat(repeats)[:speech_len]
92
+ else:
93
+ noise = noise[:speech_len]
94
+
95
+ # Random segment (50%-100% of speech length)
96
+ seg_len = int(speech_len * np.random.uniform(0.5, 1.0))
97
+ start = np.random.randint(0, speech_len - seg_len + 1)
98
+ end = start + seg_len
99
+
100
+ # Compute noise scaling based on SNR
101
+ speech_energy = torch.mean(speech[start:end] ** 2)
102
+ noise_energy = torch.mean(noise[start:end] ** 2)
103
+ snr_linear = 10 ** (snr_db / 10.0)
104
+ noise_scale = torch.sqrt(speech_energy / (noise_energy * snr_linear + 1e-10))
105
+
106
+ # Apply noise to segment with fade-in/fade-out
107
+ fade_len = min(1000, seg_len // 4) # Fade over 1000 samples or 1/4 segment
108
+ fade_in = torch.linspace(0, 1, fade_len)
109
+ fade_out = torch.linspace(1, 0, fade_len)
110
+ mask = torch.ones(seg_len)
111
+ if fade_len > 0:
112
+ mask[:fade_len] = fade_in
113
+ mask[-fade_len:] = fade_out
114
+
115
+ noisy_segment = speech[start:end] + (noise_scale * noise[start:end] * mask)
116
+ noisy_speech = speech.clone()
117
+ noisy_speech[start:end] = noisy_segment
118
+
119
+ return torch.clamp(noisy_speech, -1, 1)
120
+
121
+ def _augment_with_noise(self, wav: torch.Tensor) -> torch.Tensor:
122
+ """Augment waveform with random noise mixture."""
123
+ if not self.noises or len(self.noises) == 0:
124
+ return wav
125
+
126
+ # Decide how many noises to mix (1 or 2)
127
+ num_noises = np.random.randint(1, 3) # 1 or 2 noises
128
+ random_indices = np.random.randint(0, len(self.noises), size=num_noises)
129
+ selected_noises = [self.noises[i] for i in random_indices]
130
+ noisy_wav = wav.clone()
131
+ for noise_np in selected_noises:
132
+ noise = torch.from_numpy(noise_np).float()
133
+ noise = self._normalize_waveform(noise) # Normalize noise
134
+ snr_db = np.random.uniform(0, 20) # Random SNR between 0-20 dB
135
+ noisy_wav = self._add_time_varying_noise(noisy_wav, noise, snr_db)
136
+
137
+ # Volume normalization: re-normalize after mixing
138
+ noisy_wav = self._normalize_waveform(noisy_wav)
139
+ return noisy_wav
140
+
141
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
142
+ sample = self.samples[idx]
143
+
144
+ # Convert NumPy wav to torch tensor and resample
145
+ wav_tensor = torch.from_numpy(sample.wav).float()
146
+ if self.orig_sample_rate != self.target_sample_rate:
147
+ wav_tensor = self.resampler(wav_tensor)
148
+
149
+ # Normalize to [-1, 1]
150
+ wav_tensor = self._normalize_waveform(wav_tensor)
151
+
152
+ # Apply noise augmentation with probability
153
+ if self.noises and np.random.rand() < self.augment_prob:
154
+ wav_tensor = self._augment_with_noise(wav_tensor)
155
+
156
+ # Convert semantic to torch tensor (assuming integer tokens for CTC)
157
+ semantic_tensor = torch.from_numpy(sample.semantic).long()
158
+
159
+ # Move to device if specified
160
+ if self.device is not None:
161
+ wav_tensor = wav_tensor.to(self.device)
162
+ semantic_tensor = semantic_tensor.to(self.device)
163
+
164
+ return wav_tensor, semantic_tensor
165
+
166
+
167
+ def wav_semantic_collate_fn(
168
+ batch: List[Tuple[torch.Tensor, torch.Tensor]],
169
+ sos_token: int = SOS_TOKEN, # Adjust based on your vocab
170
+ eos_token: int = EOS_TOKEN, # Adjust based on your vocab
171
+ padding_token: int = SEMANTIC_PADDING_TOKEN, # Adjust based on your vocab
172
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
173
+ """
174
+ Collate function for wav and semantic token pairs, adding <SOS> and <EOS> to targets.
175
+
176
+ Args:
177
+ batch: List of (wav_tensor, semantic_tensor) tuples.
178
+ sos_token: Index of the <SOS> token.
179
+ eos_token: Index of the <EOS> token.
180
+ padding_token: Index of the padding token.
181
+
182
+ Returns:
183
+ Tuple of (padded_wavs, padded_targets, wav_lengths, target_lengths).
184
+ - padded_wavs: [B, max_wav_len]
185
+ - padded_targets: [B, max_target_len] with <SOS> and <EOS>
186
+ - wav_lengths: [B] (original wav lengths)
187
+ - target_lengths: [B] (original semantic lengths + 2 for <SOS> and <EOS>)
188
+ """
189
+ waves, semantics = zip(*batch)
190
+ # Add <SOS> and <EOS> to each semantic sequence
191
+ semantics_with_tokens = [
192
+ torch.cat(
193
+ [
194
+ torch.tensor([sos_token], dtype=torch.long, device=semantic.device),
195
+ semantic,
196
+ torch.tensor([eos_token], dtype=torch.long, device=semantic.device),
197
+ ]
198
+ )
199
+ for semantic in semantics
200
+ ]
201
+
202
+ # Compute lengths *after* adding <SOS> and <EOS>
203
+ wav_lengths = torch.tensor([wav.size(0) for wav in waves], dtype=torch.long)
204
+ target_lengths = torch.tensor(
205
+ [semantic.size(0) for semantic in semantics_with_tokens], dtype=torch.long
206
+ )
207
+
208
+ # Pad waves and targets to max length in batch
209
+ max_wav_len = max(wav_lengths).item()
210
+ max_target_len = max(target_lengths).item()
211
+
212
+ padded_wavs = torch.zeros(size=(len(waves), max_wav_len), device=waves[0].device)
213
+ padded_targets = torch.full(
214
+ size=(len(semantics), max_target_len),
215
+ fill_value=padding_token,
216
+ dtype=torch.long,
217
+ device=semantics[0].device,
218
+ )
219
+
220
+ for i, (wav, semantic) in enumerate(zip(waves, semantics_with_tokens)):
221
+ padded_wavs[i, : wav.size(0)] = wav
222
+ padded_targets[i, : semantic.size(0)] = semantic
223
+
224
+ return padded_wavs, padded_targets, wav_lengths, target_lengths
225
+
226
+
227
+ def load_train_val_dataloaders(
228
+ dataset: WavSemanticDataset,
229
+ train_ratio: float,
230
+ batch_size: int,
231
+ target_sample_rate: int = 16000,
232
+ noises: List[np.ndarray] = None,
233
+ augment_prob: float = 0.5,
234
+ device: Optional[torch.device] = None,
235
+ ) -> Tuple[DataLoader, DataLoader]:
236
+ """
237
+ Load train and validation DataLoaders from a WavSemanticDataset with dynamic batch padding.
238
+
239
+ Args:
240
+ dataset: The WavSemanticDataset instance to split and load.
241
+ train_ratio: Fraction of data to use for training (0 to 1).
242
+ batch_size: Number of samples per batch.
243
+ target_sample_rate: Target sample rate for resampling (default: 16000 Hz).
244
+ device: Optional device to move tensors to (default: None, stays on CPU).
245
+
246
+ Returns:
247
+ Tuple of (train_dataloader, val_dataloader).
248
+ """
249
+ # Split dataset into train and val
250
+ total_samples = len(dataset.data)
251
+ train_size = int(train_ratio * total_samples)
252
+ val_size = total_samples - train_size
253
+ train_data, val_data = random_split(dataset.data, [train_size, val_size])
254
+
255
+ # Create datasets without fixed max_sequence_length
256
+ train_dataset = WavSemanticTorchDataset(
257
+ samples=train_data,
258
+ orig_sample_rate=dataset.sample_rate,
259
+ target_sample_rate=target_sample_rate,
260
+ device=device,
261
+ noises=noises,
262
+ augment_prob=augment_prob,
263
+ )
264
+ val_dataset = WavSemanticTorchDataset(
265
+ samples=val_data,
266
+ orig_sample_rate=dataset.sample_rate,
267
+ target_sample_rate=target_sample_rate,
268
+ device=device,
269
+ noises=noises,
270
+ augment_prob=augment_prob,
271
+ )
272
+
273
+ # Create dataloaders with custom collate function
274
+ train_dataloader = DataLoader(
275
+ train_dataset,
276
+ batch_size=batch_size,
277
+ shuffle=True,
278
+ num_workers=0, # Increase if you have multiple cores
279
+ collate_fn=wav_semantic_collate_fn,
280
+ )
281
+ val_dataloader = DataLoader(
282
+ val_dataset,
283
+ batch_size=batch_size,
284
+ shuffle=False,
285
+ num_workers=0,
286
+ collate_fn=wav_semantic_collate_fn,
287
+ )
288
+
289
+ return train_dataloader, val_dataloader
290
+
291
+
292
+ def train_hubert_one_epoch(
293
+ model: nn.Module,
294
+ optimizer: torch.optim.Optimizer,
295
+ criterion: nn.CrossEntropyLoss,
296
+ train_dataloader: DataLoader,
297
+ grad_scaler: torch.cuda.amp.GradScaler,
298
+ device: torch.device,
299
+ progress_bar: Optional[tqdm] = None,
300
+ enable_autocast: bool = False,
301
+ ) -> Dict[str, float]:
302
+ """
303
+ Train the HuBERT model for one epoch using mixed-precision training with CrossEntropyLoss.
304
+
305
+ Args:
306
+ model: The HuBERT model with Transformer decoder.
307
+ optimizer: Optimizer for updating model parameters.
308
+ criterion: CrossEntropyLoss function.
309
+ train_dataloader: DataLoader for training data.
310
+ grad_scaler: Gradient scaler for mixed-precision training.
311
+ device: Device to train on (e.g., 'cuda', 'mps', 'cpu').
312
+ progress_bar: Optional tqdm progress bar.
313
+
314
+ Returns:
315
+ Dict with 'loss' metric.
316
+ """
317
+ model.train()
318
+ total_loss = 0.0
319
+ for batch in train_dataloader:
320
+ # DataLoader already moves data to device
321
+ waves, targets = batch[0], batch[1]
322
+ optimizer.zero_grad()
323
+ with torch.autocast(
324
+ device_type=device.type, dtype=torch.bfloat16, enabled=enable_autocast
325
+ ):
326
+
327
+ logits: torch.Tensor = model(waves, targets)
328
+
329
+ loss = criterion(logits[:, :-1, :].transpose(1, 2), targets[:, 1:])
330
+
331
+ total_loss += loss.detach().item()
332
+
333
+ # Mixed precision with scaler (remove scaler if autocast is disabled)
334
+ grad_scaler.scale(loss).backward()
335
+ grad_scaler.step(optimizer)
336
+ grad_scaler.update()
337
+
338
+ if progress_bar is not None:
339
+ progress_bar.update(1)
340
+
341
+ avg_loss = total_loss / len(train_dataloader)
342
+ return {"loss": avg_loss}
343
+
344
+
345
+ def eval_hubert(
346
+ model: nn.Module,
347
+ criterion: nn.CrossEntropyLoss,
348
+ val_dataloader: DataLoader,
349
+ device: torch.device,
350
+ sos_token: int = SOS_TOKEN,
351
+ eos_token: int = EOS_TOKEN,
352
+ padding_token: int = SEMANTIC_PADDING_TOKEN,
353
+ ) -> Dict[str, float]:
354
+ """
355
+ Evaluate the updated HuBERT model with Transformer decoder on the validation set.
356
+
357
+ Args:
358
+ model: The HuBERT model with Transformer decoder.
359
+ criterion: CrossEntropyLoss function.
360
+ val_dataloader: DataLoader for validation data (waves, targets).
361
+ device: Device to evaluate on.
362
+ sos_token: Index of the <SOS> token.
363
+ eos_token: Index of the <EOS> token.
364
+ padding_token: Index of the padding token.
365
+
366
+ Returns:
367
+ Dict with 'loss', 'accuracy', and 'num_tokens' metrics.
368
+ """
369
+ model.eval()
370
+ total_loss = 0.0
371
+ total_correct = 0
372
+ total_tokens = 0
373
+ num_batches = 0
374
+
375
+ for batch in val_dataloader:
376
+ # targets: [B, T'] with <SOS> and <EOS>
377
+ waves, targets = batch[0].to(device), batch[1].to(device)
378
+
379
+ with torch.no_grad(), torch.autocast(
380
+ device_type=device.type, dtype=torch.bfloat16
381
+ ):
382
+ # [B, T', semantic_vocab_size]
383
+ # transformers use batch_first=True
384
+ # targets is a tensor of [B, T'], all including [SOS] and [EOS] tokens
385
+ logits: torch.Tensor = model(waves, targets)
386
+
387
+ # remove the last token predictions from the logits
388
+ # remove the first token, which is SOS token from the targets
389
+ # transpose the logits tensor from (B, T, C) to (B, C, T)
390
+ loss = criterion(logits[:, :-1, :].transpose(1, 2), targets[:, 1:])
391
+
392
+ # Calculate accuracy (ignoring padding tokens)
393
+ preds = logits.argmax(dim=-1)[:, :-1]
394
+ target_shifted = targets[:, 1:]
395
+ mask = target_shifted != padding_token
396
+ total_correct += (preds[mask] == target_shifted[mask]).sum().item()
397
+ total_tokens += mask.sum().item()
398
+
399
+ total_loss += loss.item()
400
+ num_batches += 1
401
+
402
+ avg_loss = total_loss / num_batches
403
+ accuracy = total_correct / total_tokens if total_tokens > 0 else 0.0
404
+
405
+ return {"loss": avg_loss, "accuracy": accuracy, "num_tokens": total_tokens}
406
+
407
+
408
+ def _load_noise_dataset(data_path: str, target_sample_rate: int) -> List[np.ndarray]:
409
+ data = []
410
+ # Add more extensions as needed ".flac", ".ogg", ".aiff"
411
+ audio_extensions = (".wav", ".mp3")
412
+
413
+ # Walk through all directories and subdirectories
414
+ for root, dirs, files in os.walk(data_path):
415
+ for filename in files:
416
+ # Check if the file has an audio extension
417
+ if filename.lower().endswith(audio_extensions):
418
+ filepath = os.path.join(root, filename)
419
+ try:
420
+ audio = read_audio_file(
421
+ filepath,
422
+ target_sample_rate=target_sample_rate,
423
+ channels=1,
424
+ normalize=False,
425
+ )
426
+ data.append(audio)
427
+ except Exception as e:
428
+ print(f"Warning: Could not load {filepath}: {str(e)}")
429
+ continue
430
+
431
+ if len(data) == 0:
432
+ raise RuntimeError(f"No audio files found in {data_path} or its subdirectories")
433
+
434
+ return data
435
+
436
+
437
+ def train_hubert_quantizer(
438
+ model: nn.Module,
439
+ model_config: Dict[str, Any],
440
+ lr: float,
441
+ num_epoch: int,
442
+ train_ratio: float = 0.8,
443
+ batch_size: int = 64,
444
+ data_path: str = "./wav_semantic_dataset",
445
+ checkpoint_path: str = "./checkpoints",
446
+ save_checkpoint_every: int = 2,
447
+ enable_grad_scaler: bool = False,
448
+ augment_data_with_noise: bool = False,
449
+ augment_prob: float = 0.5,
450
+ noise_data_path: str = "./noise_dataset",
451
+ publish_hf: bool = False,
452
+ publish_to_repo: str = "",
453
+ num_samples: int = 5000,
454
+ device: torch.device = "cuda",
455
+ ) -> nn.Module:
456
+ """
457
+ Train a HuBERT model with mixed-precision training and save checkpoints.
458
+
459
+ Args:
460
+ model: The HuBERT model to train.
461
+ lr: Learning rate for the optimizer.
462
+ num_epoch: Number of epochs to train.
463
+ train_ratio: Fraction of data for training.
464
+ batch_size: Batch size for DataLoaders.
465
+ data_path: Path to the saved dataset.
466
+ checkpoint_path: Directory to save checkpoints.
467
+ save_checkpoint_every: Save checkpoint every N epochs.
468
+ augment_data_with_noise: whether to add random noise to training audio
469
+ augment_prob: probability of a sample will be augmented with noise
470
+ num_samples: maximum number of samples to load from the dataset
471
+ Returns:
472
+ The trained model.
473
+ """
474
+
475
+ # else "mps" if torch.backends.mps.is_available()
476
+ # mix precision training doesn't work with mps device at the grad_scaler.step(optimizer) step
477
+ # for testing just run on cpu
478
+ model.to(device)
479
+
480
+ # Load dataset and create dataloaders
481
+ dataset = WavSemanticDataset.load(data_path, num_samples=num_samples)
482
+ noises = None
483
+ if augment_data_with_noise:
484
+ logger.info(f"reading noise data from {noise_data_path}")
485
+ noises = _load_noise_dataset(noise_data_path, target_sample_rate=16000)
486
+
487
+ train_dataloader, val_dataloader = load_train_val_dataloaders(
488
+ dataset,
489
+ train_ratio=train_ratio,
490
+ batch_size=batch_size,
491
+ target_sample_rate=HUBERT_SAMPLE_RATE,
492
+ noises=noises,
493
+ augment_prob=augment_prob,
494
+ device=device,
495
+ )
496
+
497
+ optimizer = Adam(model.parameters(), lr=lr)
498
+ criterion = nn.CrossEntropyLoss(ignore_index=SEMANTIC_PADDING_TOKEN)
499
+ grad_scaler = torch.amp.GradScaler(device.type, enabled=enable_grad_scaler)
500
+ progress_bar = tqdm(total=num_epoch * len(train_dataloader), desc="Training HuBERT")
501
+ # scheduler = LinearLR(
502
+ # optimizer, start_factor=1, end_factor=0.5, total_iters=(num_epoch / 2)
503
+ # )
504
+ scheduler = None
505
+ Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
506
+
507
+ for epoch in range(num_epoch):
508
+ train_result = train_hubert_one_epoch(
509
+ model=model,
510
+ optimizer=optimizer,
511
+ criterion=criterion,
512
+ train_dataloader=train_dataloader,
513
+ grad_scaler=grad_scaler,
514
+ device=device,
515
+ progress_bar=progress_bar,
516
+ enable_autocast=enable_grad_scaler,
517
+ )
518
+ with torch.no_grad():
519
+ eval_result = eval_hubert(
520
+ model=model,
521
+ criterion=criterion,
522
+ val_dataloader=val_dataloader,
523
+ device=device,
524
+ )
525
+
526
+ if scheduler is not None:
527
+ scheduler.step()
528
+
529
+ logger.info(
530
+ f"Epoch {epoch + 1}/{num_epoch}, Train: {train_result}, Eval: {eval_result}"
531
+ )
532
+
533
+ if (epoch + 1) % save_checkpoint_every == 0:
534
+ checkpoint_file = os.path.join(
535
+ checkpoint_path,
536
+ f"hubert_epoch_{epoch + 1}_{datetime.now().strftime('%Y_%m_%d_%H_%M')}_eval_loss_{eval_result.get('loss', 0)}_acc_{eval_result.get('accuracy', 0)}.pt",
537
+ )
538
+ torch.save(
539
+ { # should have save the model configuration for later loading
540
+ "epoch": epoch + 1,
541
+ "model_state_dict": model.state_dict(),
542
+ # "optimizer_state_dict": optimizer.state_dict(),
543
+ "train_result": train_result,
544
+ "eval_result": eval_result,
545
+ "config": model_config,
546
+ },
547
+ checkpoint_file,
548
+ )
549
+ logger.info(f"Saved checkpoint to {checkpoint_file}")
550
+
551
+ if publish_hf:
552
+ upload_file_to_hf(checkpoint_file, publish_to_repo, "model")
553
+
554
+ progress_bar.close()
555
+ return model
core/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from core.utils.audio import *
2
+
3
+ from core.utils.text import *
4
+
5
+ from core.utils.read_write_files import *
6
+
7
+ from core.utils.huggingface import *
core/utils/audio.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpful functions to process audio
3
+ """
4
+
5
+ import numpy as np
6
+ import soundfile as sf
7
+
8
+ from typing_extensions import Annotated, Literal, Optional
9
+ import torchaudio
10
+ import torch
11
+
12
+ AudioChannel = Literal[1, 2]
13
+
14
+
15
+ def read_audio_file(
16
+ path: str,
17
+ target_sample_rate: int = 16000,
18
+ channels: int = 1,
19
+ normalize: bool = True,
20
+ max_duration: Optional[float] = None,
21
+ ) -> np.ndarray:
22
+ """Read and resample audio file
23
+ If target_sample_rate is different than the audio's sample rate, this function will resample it
24
+ If GPU is available, the resampling will be on GPU.
25
+
26
+ Args:
27
+ path: Path to the audio file (supports WAV, FLAC, OGG)
28
+ target_sample_rate: Target sample rate (default: 24000)
29
+ channels: Number of output channels (1 for mono, 2 for stereo)
30
+ normalize: Whether to normalize audio to [-1, 1]
31
+ max_duration: Maximum duration in seconds (truncates longer files)
32
+ device: Device to process on ("cuda" or "cpu", defaults to cuda if available)
33
+
34
+ Returns:
35
+ np.ndarray: Processed audio samples as a numpy array
36
+
37
+ Raises:
38
+ RuntimeError: If the file cannot be read or processing fails
39
+ """
40
+ try:
41
+ # Load audio file with torchaudio
42
+ waveform, original_sample_rate = torchaudio.load(path) # [channels, samples]
43
+
44
+ # Truncate to max_duration before resampling
45
+ if max_duration is not None:
46
+ max_samples = int(max_duration * original_sample_rate)
47
+ if waveform.size(1) > max_samples:
48
+ waveform = waveform[:, :max_samples]
49
+
50
+ # Downmix to desired channels
51
+ if waveform.size(0) > channels:
52
+ if channels == 1:
53
+ waveform = waveform.mean(dim=0, keepdim=True) # Mono: average channels
54
+ elif channels == 2:
55
+ waveform = waveform[:2, :] # Stereo: take first 2 channels
56
+
57
+ # Resample if needed
58
+ if original_sample_rate != target_sample_rate:
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ waveform = waveform.to(device)
61
+ resampler = torchaudio.transforms.Resample(
62
+ orig_freq=original_sample_rate,
63
+ new_freq=target_sample_rate,
64
+ resampling_method="sinc_interp_kaiser", # Fast and high-quality
65
+ ).to(device)
66
+ waveform = resampler(waveform)
67
+
68
+ # Normalize to [-1, 1] if requested
69
+ if normalize:
70
+ max_val = waveform.abs().max()
71
+ if max_val > 0:
72
+ waveform = waveform / max_val
73
+
74
+ # Move back to CPU and convert to numpy
75
+ data = waveform.cpu().numpy()
76
+
77
+ # Ensure correct shape (remove extra dim if mono)
78
+ if channels == 1 and data.shape[0] == 1:
79
+ data = data[0, :]
80
+
81
+ return data
82
+
83
+ except Exception as e:
84
+ raise RuntimeError(f"Failed to read audio file {path}: {str(e)}")
85
+
86
+
87
+ def save_audio_file(
88
+ audio_array: np.ndarray, sample_rate: int, file_path: str, format="WAV"
89
+ ):
90
+ """
91
+ Save an audio array to a file.
92
+
93
+ Parameters:
94
+ - audio_array: numpy array or list containing the audio samples
95
+ - sample_rate: int, the sample rate of the audio (e.g., 44100 Hz)
96
+ - file_path: str, path where the file will be saved (e.g., 'output.wav')
97
+ - format: str, audio file format (e.g., 'WAV', 'FLAC', 'OGG'), default is 'WAV'
98
+ """
99
+ try:
100
+ if not file_path.endswith(".wav"):
101
+ file_path += ".wav"
102
+ sf.write(file_path, audio_array, sample_rate, format=format)
103
+ except Exception as e:
104
+ print(f"Error saving audio file at {file_path}: {e}")
core/utils/huggingface.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from typing import Optional, Literal
4
+ import os
5
+ import shutil
6
+ from zipfile import ZipFile
7
+ from pathlib import Path
8
+ from huggingface_hub import hf_hub_download, upload_file
9
+
10
+ # Set up logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format="%(asctime)s - %(levelname)s - %(message)s",
14
+ handlers=[logging.StreamHandler(sys.stdout)],
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+ __all__ = ["download_dataset_from_hf", "upload_file_to_hf", "download_file_from_hf"]
19
+
20
+
21
+ def download_dataset_from_hf(
22
+ repo_id: str,
23
+ filename: str,
24
+ dest_path: str,
25
+ token: str = None,
26
+ local_dir: str = "./downloads",
27
+ remove_downloaded_file: bool = True,
28
+ ) -> None:
29
+ """
30
+ Download a file from Hugging Face repository and unzip it to destination path
31
+
32
+ Args:
33
+ repo_id (str): Hugging Face repository ID (username/repo_name)
34
+ filename (str): Name of the file to download from the repository
35
+ dest_path (str): Destination path where contents will be unzipped
36
+ token (str, optional): Hugging Face token, if None will prompt for login
37
+ """
38
+ # Ensure destination directory exists
39
+ os.makedirs(dest_path, exist_ok=True)
40
+ if token is None:
41
+ logger.info("reading HF_TOKEN variable from environment")
42
+ token = os.getenv("HF_TOKEN")
43
+
44
+ # Download the file
45
+ downloaded_file = hf_hub_download(
46
+ repo_id=repo_id,
47
+ filename=filename,
48
+ repo_type="dataset", # Specify dataset repository
49
+ local_dir=local_dir, # Temporary download location
50
+ token=token,
51
+ )
52
+ logger.info(f"Downloaded {filename} to {downloaded_file}")
53
+
54
+ # Check if it's a zip file
55
+ if filename.endswith(".zip"):
56
+ # Extract the zip file
57
+ with ZipFile(downloaded_file, "r") as zip_ref:
58
+ zip_ref.extractall(dest_path)
59
+ logger.info(f"Unzipped contents to {dest_path}")
60
+
61
+ # Clean up the downloaded zip file
62
+ if remove_downloaded_file:
63
+ os.remove(downloaded_file)
64
+ logger.info(f"Cleaned up temporary file: {downloaded_file}")
65
+ else:
66
+ # If not a zip, just move the file
67
+ final_path = os.path.join(dest_path, filename)
68
+ shutil.move(downloaded_file, final_path)
69
+ logger.info(f"Moved {filename} to {final_path}")
70
+
71
+
72
+ def download_file_from_hf(
73
+ repo_id: str,
74
+ repo_type: Literal["model", "dataset"],
75
+ filename: str,
76
+ dest_path: str,
77
+ token: str = None,
78
+ ) -> None:
79
+ """
80
+ Download a file from Hugging Face repository and unzip it to destination path
81
+
82
+ Args:
83
+ repo_id (str): Hugging Face repository ID (username/repo_name)
84
+ repo_type: model for model repo, dataset for dataset repo
85
+ filename (str): Name of the file to download from the repository
86
+ dest_path (str): Destination path where contents will be unzipped
87
+ token (str, optional): Hugging Face token, if None will prompt for login
88
+
89
+ """
90
+ # Ensure destination directory exists
91
+ os.makedirs(dest_path, exist_ok=True)
92
+ if token is None:
93
+ logger.info("reading HF_TOKEN variable from environment")
94
+ token = os.getenv("HF_TOKEN")
95
+
96
+ # Download the file
97
+ downloaded_file = hf_hub_download(
98
+ repo_id=repo_id,
99
+ filename=filename,
100
+ repo_type=repo_type,
101
+ local_dir="./downloads", # Temporary download location
102
+ token=token,
103
+ )
104
+ logger.info(f"Downloaded {filename} to {downloaded_file}")
105
+
106
+ # Check if it's a zip file
107
+ if filename.endswith(".zip"):
108
+ # Extract the zip file
109
+ with ZipFile(downloaded_file, "r") as zip_ref:
110
+ zip_ref.extractall(dest_path)
111
+ logger.info(f"Unzipped contents to {dest_path}")
112
+
113
+ # Clean up the downloaded zip file
114
+ os.remove(downloaded_file)
115
+ logger.info(f"Cleaned up temporary file: {downloaded_file}")
116
+ else:
117
+ # If not a zip, just move the file
118
+ final_path = os.path.join(dest_path, filename)
119
+ shutil.move(downloaded_file, final_path)
120
+ logger.info(f"Moved {filename} to {final_path}")
121
+
122
+
123
+ def upload_file_to_hf(
124
+ local_file_path: str,
125
+ repo_id: str,
126
+ repo_type: Literal["model", "dataset"],
127
+ token: Optional[str] = None,
128
+ path_in_repo: Optional[str] = None,
129
+ commit_message: str = "Upload file",
130
+ ) -> None:
131
+ """
132
+ Upload a file to Hugging Face hub.
133
+
134
+ Args:
135
+ local_file_path (str): Path to the local .pt checkpoint file
136
+ repo_id (str): Repository ID in format "username/repo_name"
137
+ repo_type (str, optional): Type of repository, either "model" or "dataset"
138
+ token (str): Hugging Face authentication token. Read from environment variable HF_TOKEN if don't provide
139
+ path_in_repo (str, optional): Destination path in the repository.
140
+ Defaults to the filename from local_checkpoint_path
141
+ commit_message (str, optional): Commit message for the upload
142
+
143
+ Raises:
144
+ FileNotFoundError: If the checkpoint file doesn't exist
145
+ ValueError: If the repository ID is invalid
146
+ """
147
+ # Validate file exists
148
+ if not os.path.isfile(local_file_path):
149
+ raise FileNotFoundError(f"File not found: {local_file_path}")
150
+
151
+ # Use filename as default path_in_repo if not specified
152
+ if path_in_repo is None:
153
+ path_in_repo = Path(local_file_path).name
154
+
155
+ if token is None:
156
+ logger.info("reading HF_TOKEN variable from environment")
157
+ token = os.getenv("HF_TOKEN")
158
+ if token is None:
159
+ raise RuntimeError("not found HF_TOKEN variable from environment")
160
+
161
+ upload_file(
162
+ path_or_fileobj=local_file_path,
163
+ path_in_repo=path_in_repo,
164
+ repo_id=repo_id,
165
+ repo_type=repo_type,
166
+ token=token,
167
+ commit_message=commit_message,
168
+ )
169
+ logger.info(f"Successfully uploaded {local_file_path} to {repo_id}/{path_in_repo}")
core/utils/read_write_files.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+
4
+
5
+ def zip_folder(folder_path: str, output_path: str) -> bool:
6
+ """
7
+ Zip a folder and its contents to a zip file.
8
+
9
+ Args:
10
+ folder_path (str): Path to the folder to be zipped
11
+ output_path (str): Path where the zip file will be created
12
+
13
+ Returns:
14
+ bool: True if successful, False otherwise
15
+ """
16
+ try:
17
+ # Ensure the folder exists
18
+ if not os.path.isdir(folder_path):
19
+ print(f"Error: {folder_path} is not a valid directory")
20
+ return False
21
+
22
+ # Get the absolute path of the folder
23
+ abs_folder_path = os.path.abspath(folder_path)
24
+
25
+ # Create a ZipFile object in write mode
26
+ with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
27
+ # Walk through the folder
28
+ for root, dirs, files in os.walk(abs_folder_path):
29
+ for file in files:
30
+ # Get the absolute path of the file
31
+ abs_file_path = os.path.join(root, file)
32
+
33
+ # Calculate relative path for the file inside the zip
34
+ rel_path = os.path.relpath(
35
+ abs_file_path, os.path.dirname(abs_folder_path)
36
+ )
37
+
38
+ # Add file to zip
39
+ zipf.write(abs_file_path, rel_path)
40
+
41
+ print(f"Successfully created zip file at {output_path}")
42
+ return True
43
+
44
+ except Exception as e:
45
+ print(f"Error creating zip file: {e}")
46
+ return False
core/utils/text.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def normalize_whitespace(text: str) -> str:
2
+ """
3
+ Normalize whitespace in text by:
4
+ 1. Removing leading and trailing whitespace
5
+ 2. Replacing any sequence of whitespace characters with a single space
6
+
7
+ Args:
8
+ text: Input string to normalize
9
+
10
+ Returns:
11
+ String with normalized whitespace
12
+ """
13
+ return ' '.join(text.split())
event_handlers.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Dict, Any
2
+ import traceback
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ import time
7
+ import os
8
+ import re
9
+ import wave
10
+ import contextlib
11
+ import logging
12
+ import pandas as pd
13
+ import gc
14
+
15
+ import nltk
16
+
17
+ nltk.download("punkt")
18
+ from nltk.tokenize import sent_tokenize
19
+
20
+ from core.data_model import AudioFile
21
+ from core.bark.voice_clone import create_bark_prompt
22
+ from core.bark.generate_audio import generate_audio
23
+ from core.data_model import BarkPrompt, BarkGenerationConfig
24
+ from core.utils.audio import save_audio_file
25
+ from config import *
26
+
27
+ # Set up logging
28
+ logging.basicConfig(
29
+ level=logging.INFO,
30
+ format="%(asctime)s - %(levelname)s - %(message)s",
31
+ )
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # return list of available devices and the best device to be used as default for all inference
36
+ def get_available_torch_devices() -> Tuple[List[str], str]:
37
+ devices = ["cpu"]
38
+ best_device = "cpu"
39
+ if torch.mps.is_available():
40
+ devices.append("mps")
41
+ best_device = "mps"
42
+ if torch.cuda.is_available():
43
+ devices.append("cuda")
44
+ best_device = "cuda"
45
+
46
+ return devices, best_device
47
+
48
+
49
+ # --- Helper Functions ---
50
+ # (Keep get_wav_duration, load_existing_audio, get_safe_filename,
51
+ # generate_sine_wave, save_audio, parse_text_prompts, get_available_prompts,
52
+ # create_audio_prompt as they are, they are mostly backend logic)
53
+ def get_wav_duration(filepath):
54
+ """Gets the duration of a WAV file in seconds."""
55
+ try:
56
+ with contextlib.closing(wave.open(filepath, "r")) as f:
57
+ frames = f.getnframes()
58
+ rate = f.getframerate()
59
+ if rate > 0:
60
+ duration = frames / float(rate)
61
+ return duration
62
+ else:
63
+ logger.info(f"Warning: Framerate is 0 for {filepath}")
64
+ return 0
65
+ except wave.Error as e:
66
+ logger.info(f"Warning: Could not read wave file header for {filepath}: {e}")
67
+ return 0
68
+ except Exception as e:
69
+ logger.info(f"Warning: Could not get duration for {filepath}: {e}")
70
+ return 0
71
+
72
+
73
+ def load_existing_audio() -> List[Dict[str, Any]]:
74
+ """Scans the audio directory and loads metadata for existing WAV files."""
75
+ logger.info("\n--- Loading Existing Audio Files ---")
76
+ existing_files_metadata = []
77
+ if not os.path.isdir(GENERATED_AUDIO_DIR):
78
+ logger.info(f"Directory not found: {GENERATED_AUDIO_DIR}")
79
+ return []
80
+
81
+ try:
82
+ for filename in os.listdir(GENERATED_AUDIO_DIR):
83
+ if filename.lower().endswith(".wav"):
84
+ filepath = os.path.join(GENERATED_AUDIO_DIR, filename)
85
+ if not os.path.isfile(filepath):
86
+ continue
87
+
88
+ match = re.match(r"^(.*)_(\d{13})\.wav$", filename)
89
+ text_guess = "Unknown (from filename)"
90
+ timestamp_ms = 0
91
+ if match:
92
+ text_guess = match.group(1).replace("_", " ")
93
+ try:
94
+ timestamp_ms = int(match.group(2))
95
+ except ValueError:
96
+ timestamp_ms = 0
97
+ else:
98
+ text_guess = os.path.splitext(filename)[0].replace("_", " ")
99
+
100
+ timestamp_sec = (
101
+ timestamp_ms / 1000.0
102
+ if timestamp_ms > 0
103
+ else os.path.getmtime(filepath)
104
+ )
105
+ duration = get_wav_duration(filepath)
106
+
107
+ metadata = {
108
+ "text": text_guess,
109
+ "path": filepath,
110
+ "duration": duration,
111
+ "timestamp": timestamp_sec,
112
+ }
113
+ existing_files_metadata.append(metadata)
114
+
115
+ except Exception as e:
116
+ logger.error(f"Error loading existing audio files: {e}")
117
+
118
+ existing_files_metadata.sort(key=lambda x: x.get("timestamp", 0))
119
+ logger.info(
120
+ f"--- Finished Loading {len(existing_files_metadata)} Existing Files ---"
121
+ )
122
+ return existing_files_metadata
123
+
124
+
125
+ def get_safe_filename(base_name: str, extension: str, directory: str) -> str:
126
+ """Creates a safe and unique filename in the target directory."""
127
+ safe_base = "".join(
128
+ c if c.isalnum() or c in ["_", "-"] else "_" for c in base_name[:50]
129
+ )
130
+ timestamp = int(time.time() * 1000)
131
+ filename = f"{safe_base}_{timestamp}.{extension}"
132
+ filepath = os.path.join(directory, filename)
133
+ counter = 1
134
+ while os.path.exists(filepath):
135
+ filename = f"{safe_base}_{timestamp}_{counter}.{extension}"
136
+ filepath = os.path.join(directory, filename)
137
+ counter += 1
138
+ return filepath
139
+
140
+
141
+ def update_audio_list(
142
+ newly_generated_metadata: List[Dict[str, Any]],
143
+ current_audio_list: List[Dict[str, Any]],
144
+ ) -> List[Dict[str, Any]]:
145
+ """Appends new metadata to the list and sorts it by timestamp."""
146
+ logger.info(f"\n--- Updating Audio List State ---")
147
+ if not isinstance(current_audio_list, list):
148
+ logger.info("Current audio list was not a list, initializing.")
149
+ current_audio_list = []
150
+ if not isinstance(newly_generated_metadata, list):
151
+ logger.info("Newly generated metadata is not a list, skipping update.")
152
+ return current_audio_list
153
+
154
+ logger.info(f"Current list size: {len(current_audio_list)}")
155
+ logger.info(f"Adding {len(newly_generated_metadata)} new items.")
156
+ updated_list = current_audio_list + newly_generated_metadata
157
+ updated_list.sort(key=lambda x: x.get("timestamp", 0))
158
+ logger.info(f"Updated list state size: {len(updated_list)}")
159
+ logger.info("--- Finished Updating Audio List State ---")
160
+ return updated_list
161
+
162
+
163
+ def format_audio_list_for_dataframe(audio_list: List[Dict[str, Any]]) -> pd.DataFrame:
164
+ """Converts the list of audio metadata dicts into a pandas DataFrame for display."""
165
+ logger.info("\n--- Formatting List for DataFrame ---")
166
+ if not audio_list:
167
+ logger.info("Audio list is empty, returning empty DataFrame.")
168
+ # Return empty DataFrame with correct columns
169
+ return pd.DataFrame(columns=["File", "Prompt", "Duration (s)"])
170
+
171
+ display_data = []
172
+ for item in audio_list:
173
+ filepath = item.get("path", "N/A")
174
+ filename = os.path.basename(filepath) if filepath != "N/A" else "N/A"
175
+ # Truncate long text prompts for display in the table
176
+ text_prompt = item.get("text", "N/A")
177
+ display_text = (
178
+ (text_prompt[:75] + "...") if len(text_prompt) > 75 else text_prompt
179
+ )
180
+ duration = item.get("duration", 0)
181
+ display_data.append(
182
+ {
183
+ "File": filename,
184
+ "Prompt": display_text,
185
+ "Duration (s)": f"{duration:.2f}" if duration else "N/A",
186
+ # Store the full path implicitly by list order, not shown in df
187
+ }
188
+ )
189
+
190
+ df = pd.DataFrame(display_data)
191
+ logger.info(f"Created DataFrame with {len(df)} rows.")
192
+ logger.info("--- Finished Formatting List for DataFrame ---")
193
+ return df
194
+
195
+
196
+ def handle_row_selection(
197
+ audio_list: List[Dict[str, Any]], evt: gr.SelectData
198
+ ) -> Tuple[Optional[str], int]:
199
+ """
200
+ Handles the selection event from the DataFrame.
201
+ Updates the audio player with the selected file's path.
202
+ Returns the filepath and the selected index.
203
+ """
204
+ logger.info("\n--- Handling Row Selection ---")
205
+ selected_index = evt.index[0] if evt.index else None # Get row index
206
+ logger.info(f"DataFrame row selected. Event data: {evt}")
207
+
208
+ if selected_index is not None and 0 <= selected_index < len(audio_list):
209
+ selected_item = audio_list[selected_index]
210
+ filepath = selected_item.get("path")
211
+ logger.info(f"Selected item at index {selected_index}: {selected_item}")
212
+ if filepath and os.path.exists(filepath):
213
+ logger.info(f"Updating audio player with: {filepath}")
214
+ logger.info("--- Finished Handling Row Selection (Success) ---")
215
+ return filepath, selected_index
216
+ else:
217
+ logger.info(f"File not found for selected item: {filepath}")
218
+ gr.Warning(
219
+ f"File not found for selected row: {os.path.basename(filepath or 'N/A')}"
220
+ )
221
+ logger.info("--- Finished Handling Row Selection (File Not Found) ---")
222
+ return None, selected_index # Keep index, but clear player
223
+ else:
224
+ logger.info("Invalid selection index or empty list.")
225
+ logger.info("--- Finished Handling Row Selection (Invalid Index) ---")
226
+ return None, -1 # Clear player and indicate no valid selection
227
+
228
+
229
+ def handle_delete_selected(
230
+ selected_index: int, current_audio_list: List[Dict[str, Any]]
231
+ ) -> Tuple[List[Dict[str, Any]], int, Optional[str]]:
232
+ """
233
+ Deletes the audio file corresponding to the selected index.
234
+ Updates the main audio list state.
235
+ Clears the selection index and audio player.
236
+ """
237
+ logger.info("\n--- Handling Delete Selected ---")
238
+ logger.info(f"Attempting deletion for selected index: {selected_index}")
239
+
240
+ if (
241
+ selected_index is None
242
+ or selected_index < 0
243
+ or selected_index >= len(current_audio_list)
244
+ ):
245
+ gr.Warning("No valid audio selected for deletion.")
246
+ logger.info("No valid index provided.")
247
+ # Return current list, clear index, clear player
248
+ return current_audio_list, -1, None
249
+
250
+ item_to_delete = current_audio_list[selected_index]
251
+ filepath_to_delete = item_to_delete.get("path")
252
+ logger.info(f"Item to delete: {item_to_delete}")
253
+
254
+ # Create the new list excluding the item
255
+ # Corrected slicing logic: include elements before and after the index
256
+ new_audio_list = (
257
+ current_audio_list[:selected_index] + current_audio_list[selected_index + 1 :]
258
+ )
259
+ logger.info(f"New list size after filtering: {len(new_audio_list)}")
260
+
261
+ # Try to delete the file from disk
262
+ deletion_successful_on_disk = False
263
+ try:
264
+ if filepath_to_delete and os.path.exists(filepath_to_delete):
265
+ os.remove(filepath_to_delete)
266
+ logger.info(f"Successfully deleted file: {filepath_to_delete}")
267
+ gr.Info(f"Deleted {os.path.basename(filepath_to_delete)}")
268
+ deletion_successful_on_disk = True
269
+ elif filepath_to_delete:
270
+ logger.info(f"File not found for deletion: {filepath_to_delete}")
271
+ gr.Warning("Audio entry removed from list, but file was not found on disk.")
272
+ deletion_successful_on_disk = True # Consider list update successful
273
+ else:
274
+ logger.info("Invalid filepath in selected item.")
275
+ gr.Warning("Could not delete: Invalid file path associated with selection.")
276
+ # Revert list change if filepath was invalid from the start? Or keep it removed?
277
+ # Let's keep it removed from the list for consistency.
278
+ deletion_successful_on_disk = True # Treat as success for list update
279
+
280
+ except OSError as e:
281
+ logger.info(f"Error deleting file {filepath_to_delete}: {e}")
282
+ traceback.logger.info_exc()
283
+ gr.Error(f"Error deleting file: {e}")
284
+ # If file deletion fails, we still return the updated list (item removed).
285
+ # If you want to revert the list change on OS error, return `current_audio_list` here.
286
+
287
+ logger.info("--- Finished Deleting Selected Item ---")
288
+ # Return the updated list, clear the selected index, clear the audio player
289
+ return new_audio_list, -1, None
290
+
291
+
292
+ def get_available_prompts() -> List[str]:
293
+ """Loads available prompt file names."""
294
+ try:
295
+ prompts = [
296
+ f
297
+ for f in os.listdir(PROMPT_DIR)
298
+ if os.path.isfile(os.path.join(PROMPT_DIR, f))
299
+ and f.lower().endswith((".npz", ".npy", ".json"))
300
+ ]
301
+
302
+ if len(prompts) == 0:
303
+ gr.Info("No prompts found.", duration=3)
304
+
305
+ return ["None"] + prompts
306
+ except Exception as e:
307
+ logger.info(f"Error loading prompts: {e}")
308
+ gr.Info(f"Error loading prompts {e}", duration=3, title="Error")
309
+ return ["None"]
310
+
311
+
312
+ def update_available_prompts() -> gr.update:
313
+ try:
314
+ prompts = [
315
+ f
316
+ for f in os.listdir(PROMPT_DIR)
317
+ if os.path.isfile(os.path.join(PROMPT_DIR, f))
318
+ and f.lower().endswith((".npz", ".npy", ".json"))
319
+ ]
320
+
321
+ if len(prompts) == 0:
322
+ gr.Info("No prompts found.", duration=3)
323
+
324
+ return gr.update(choices=["None"] + prompts)
325
+ except Exception as e:
326
+ logger.info(f"Error loading prompts: {e}")
327
+ gr.Info(f"Error loading prompts {e}", duration=3, title="Error")
328
+ return gr.update()
329
+
330
+
331
+ def generate_batch_audio(
332
+ text: str,
333
+ semantic_temp: float,
334
+ coarse_temp: float,
335
+ fine_temp: float,
336
+ manual_seed: int,
337
+ model_type: str,
338
+ inference_device: str,
339
+ selected_prompt_name: Optional[str],
340
+ ) -> Tuple[List[Dict[str, Any]], str]:
341
+ """
342
+ Generates audio (sine wave) for each line of text input.
343
+ Returns metadata for generated files.
344
+ """
345
+ gc.collect()
346
+
347
+ torch.manual_seed(manual_seed)
348
+ if not text:
349
+ gr.Warning("No valid text prompts provided.")
350
+ return []
351
+
352
+ generated_metadata = []
353
+
354
+ bark_prompt = None
355
+ if selected_prompt_name != "None":
356
+ gr.Info("Loading audio prompt...")
357
+ prompt_path = os.path.join(PROMPT_DIR, selected_prompt_name)
358
+ bark_prompt = BarkPrompt.load_prompt(
359
+ prompt_path, torch.device(inference_device)
360
+ )
361
+
362
+ generation_config = BarkGenerationConfig(
363
+ temperature=semantic_temp,
364
+ generate_coarse_temperature=coarse_temp,
365
+ generate_fine_temperature=fine_temp,
366
+ use_small_model=True if model_type == "small" else False,
367
+ )
368
+
369
+ # split the text into sentences
370
+ sentences = sent_tokenize(text)
371
+
372
+ gr.Info("Generating Audio....", duration=120)
373
+ waves = generate_audio(
374
+ texts=sentences,
375
+ prompt=bark_prompt,
376
+ generation_config=generation_config,
377
+ silent=True,
378
+ )
379
+ audio = np.concat(waves, axis=-1)
380
+
381
+ output_filepath = get_safe_filename(text, "wav", GENERATED_AUDIO_DIR)
382
+ save_audio_file(audio, DEFAULT_AUDIO_SAMPLE_RATE, output_filepath)
383
+ duration_sec = audio.shape[0] // DEFAULT_AUDIO_SAMPLE_RATE
384
+ metadata = {
385
+ "text": text,
386
+ "path": output_filepath,
387
+ "duration": duration_sec,
388
+ "timestamp": time.time(),
389
+ }
390
+ generated_metadata.append(metadata)
391
+ gr.Info("Done!", duration=5)
392
+ return generated_metadata
393
+
394
+
395
+ def create_audio_prompt(
396
+ uploaded_audio_file: Optional[str],
397
+ device: str,
398
+ progress: gr.Progress = gr.Progress(),
399
+ ) -> gr.update:
400
+ """Processes an uploaded audio file to create a voice prompt file (stub)."""
401
+ logger.info("\n--- Starting Prompt Creation ---")
402
+ if uploaded_audio_file is None or len(uploaded_audio_file) == 0:
403
+ gr.Warning("No audio file uploaded!")
404
+ return gr.update()
405
+
406
+ logger.info(f"Processing uploaded file: {uploaded_audio_file}")
407
+
408
+ try:
409
+ progress(0, desc="Starting prompt creation...")
410
+ new_prompt_filename = None
411
+ progress(0.2, desc="Extracting prompt features...")
412
+ audio_file = AudioFile(audio_file_path=uploaded_audio_file, max_duration=10)
413
+ prompt = create_bark_prompt(
414
+ audio_file=audio_file, temperature=1, eos_p=0.2, device=torch.device(device)
415
+ )
416
+
417
+ progress(0.8, desc="Saving prompt file...")
418
+ original_basename = os.path.splitext(os.path.basename(uploaded_audio_file))[0]
419
+ prompt_filepath = get_safe_filename(original_basename, "json", PROMPT_DIR)
420
+ new_prompt_filename = os.path.basename(prompt_filepath)
421
+
422
+ ok = prompt.save_prompt(prompt_filepath)
423
+ if ok:
424
+ progress(1.0, desc="Prompt creation complete.")
425
+
426
+ else:
427
+ progress(1.0, desc="Error when saving prompt")
428
+
429
+ new_choices = get_available_prompts()
430
+
431
+ return gr.update(choices=new_choices, value=new_prompt_filename)
432
+
433
+ except Exception as e:
434
+ logger.info(f"Error creating prompt: {e}")
435
+ gr.Error(f"Prompt creation failed: {e}")
436
+ return f"Error creating prompt: {e}", gr.update()
generate_audio_semantic_dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ from typing import Optional
5
+
6
+ from core.bark.generate_audio_semantic_dataset import (
7
+ generate_wav_semantic_dataset,
8
+ BarkGenerationConfig,
9
+ )
10
+ from core.utils import upload_file_to_hf, zip_folder
11
+
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format="%(asctime)s - %(levelname)s - %(message)s",
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def parse_dataset_args(args_list=None):
21
+ """Parse arguments specific to dataset creation."""
22
+ parser = argparse.ArgumentParser(description="Audio Semantic Dataset Creation")
23
+
24
+ parser.add_argument(
25
+ "--text-file",
26
+ type=str,
27
+ default="data/test_data.txt",
28
+ help="Path to text file for dataset generation",
29
+ )
30
+ parser.add_argument(
31
+ "--batch-size",
32
+ type=int,
33
+ default=2,
34
+ help="Batch size for processing (default: 1)",
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--output-dir",
39
+ type=str,
40
+ default="./dataset",
41
+ help="Output directory for generated files (default: ./dataset)",
42
+ )
43
+ parser.add_argument(
44
+ "--max-tokens",
45
+ type=int,
46
+ default=256,
47
+ help="Maximum tokens per example (default: 256)",
48
+ )
49
+ parser.add_argument(
50
+ "--use-small-model",
51
+ action="store_true",
52
+ help="Use small model for generation",
53
+ )
54
+ parser.add_argument(
55
+ "--save-raw-audio",
56
+ action="store_true",
57
+ help="Store generated audio as .wav instead of .npz",
58
+ )
59
+ parser.add_argument(
60
+ "--publish-hf",
61
+ action="store_true",
62
+ help="Publish dataset to HuggingFace Hub",
63
+ )
64
+ parser.add_argument(
65
+ "--repo-id",
66
+ type=str,
67
+ help="HuggingFace repo ID to publish to",
68
+ )
69
+ parser.add_argument(
70
+ "--path-in-repo",
71
+ type=str,
72
+ help="Path in HF repo",
73
+ default=None,
74
+ )
75
+ parser.add_argument(
76
+ "--silent", action="store_true", help="Suppress progress output"
77
+ )
78
+
79
+ return parser.parse_args(args_list)
80
+
81
+
82
+ def create_audio_semantic_dataset(
83
+ text_file: str,
84
+ output_dir: str = "./dataset",
85
+ batch_size: int = 1,
86
+ max_tokens: int = 256,
87
+ use_small_model: bool = False,
88
+ save_raw_audio: bool = False,
89
+ publish_hf: bool = False,
90
+ repo_id: Optional[str] = None,
91
+ path_in_repo: Optional[str] = None,
92
+ silent: bool = False,
93
+ ) -> None:
94
+ """Create audio semantic dataset from text file.
95
+
96
+ Can be called directly with parameters or via command line using parse_dataset_args().
97
+
98
+ Args:
99
+ text_file: Path to input text file
100
+ output_dir: Directory to save generated dataset
101
+ batch_size: Batch size for processing
102
+ max_tokens: Maximum tokens per example
103
+ use_small_model: Whether to use small model
104
+ save_raw_audio: Save as raw audio (.wav) instead of .npz
105
+ publish_hf: Whether to publish to HuggingFace Hub
106
+ repo_id: HF repo ID to publish to
107
+ path_in_repo: Path in HF repo
108
+ silent: Suppress progress output
109
+ """
110
+ os.makedirs(output_dir, exist_ok=True)
111
+
112
+ if not os.path.isfile(text_file):
113
+ raise FileNotFoundError(f"Text file not found: {text_file}")
114
+
115
+ logger.info(f"Starting dataset generation from {text_file}")
116
+ generation_config = BarkGenerationConfig(
117
+ temperature=None,
118
+ generate_coarse_temperature=None,
119
+ generate_fine_temperature=None,
120
+ use_small_model=use_small_model,
121
+ )
122
+
123
+ generate_wav_semantic_dataset(
124
+ text_file_path=text_file,
125
+ generation_config=generation_config,
126
+ batch_size=batch_size,
127
+ save_path=output_dir,
128
+ save_data_as_raw_audio=save_raw_audio,
129
+ silent=silent,
130
+ )
131
+ logger.info("Dataset generation completed")
132
+
133
+ if publish_hf and repo_id:
134
+ logger.info("Publishing dataset to huggingface hub")
135
+ zip_path = "./dataset.zip"
136
+ success = zip_folder(output_dir, zip_path)
137
+ if not success:
138
+ raise RuntimeError(f"Unable to zip folder {output_dir}")
139
+ upload_file_to_hf(zip_path, repo_id, "dataset", path_in_repo=path_in_repo)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ args = parse_dataset_args()
144
+ create_audio_semantic_dataset(
145
+ text_file=args.text_file,
146
+ output_dir=args.output_dir,
147
+ batch_size=args.batch_size,
148
+ max_tokens=args.max_tokens,
149
+ use_small_model=args.use_small_model,
150
+ save_raw_audio=args.save_raw_audio,
151
+ publish_hf=args.publish_hf,
152
+ repo_id=args.repo_id,
153
+ path_in_repo=args.path_in_repo,
154
+ silent=args.silent,
155
+ )
prompts/de_speaker_0.npz ADDED
Binary file (39.6 kB). View file
 
prompts/de_speaker_1.npz ADDED
Binary file (27.5 kB). View file
 
prompts/de_speaker_2.npz ADDED
Binary file (24.7 kB). View file
 
prompts/de_speaker_3.npz ADDED
Binary file (31.3 kB). View file
 
prompts/de_speaker_4.npz ADDED
Binary file (30.7 kB). View file
 
prompts/de_speaker_5.npz ADDED
Binary file (31.3 kB). View file
 
prompts/de_speaker_6.npz ADDED
Binary file (23.2 kB). View file
 
prompts/de_speaker_7.npz ADDED
Binary file (40.1 kB). View file
 
prompts/de_speaker_8.npz ADDED
Binary file (28.5 kB). View file
 
prompts/de_speaker_9.npz ADDED
Binary file (51.1 kB). View file
 
prompts/en_speaker_0.npz ADDED
Binary file (28.1 kB). View file
 
prompts/en_speaker_1.npz ADDED
Binary file (25.2 kB). View file
 
prompts/en_speaker_2.npz ADDED
Binary file (26.2 kB). View file
 
prompts/en_speaker_3.npz ADDED
Binary file (35 kB). View file
 
prompts/en_speaker_4.npz ADDED
Binary file (23.8 kB). View file
 
prompts/en_speaker_5.npz ADDED
Binary file (24.7 kB). View file